← Back to Leaderboard

The AI CUDA Engineer 👷

44_ConvTranspose2d_Multiply_GlobalAvgPool_GlobalAvgPool_Meanblock_size_experimentation_base

Level 2 • Task 44

Kernel Information

Related Kernels (Level 2, Task 44 • 44_ConvTranspose2d_Multiply_GlobalAvgPool_GlobalAvgPool_Mean)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 optimized_spatial_reduction_base 0.18 1.21 0.72
🥇 optimized_spatial_reduction_edit_1 0.18 1.21 0.72
🥉 minimal_sync_reduction_edit_1 0.19 1.19 0.71
4 shared_memory_tiled_reduction_base 0.19 1.17 0.70
5 fused_global_avg_base 0.20 1.13 0.67
6 block_size_experimentation_base 0.20 1.11 0.66
7 optimized_strided_avg_pooling_edit_1 0.20 1.10 0.66
7 aligned_ldg_optimized_kernel_base 0.20 1.10 0.66
9 combined_optimized_mean_kernel_base 0.20 1.10 0.65
9 vectorized_ldg_mean_kernel_base 0.20 1.10 0.65
9 optimized_mean_kernel_base 0.20 1.10 0.65
9 warp_uniform_mean_kernel_base_base 0.20 1.10 0.65
9 unrolled_vectorized_mean_kernel_base 0.20 1.10 0.65
9 atomic_final_reduction_base 0.20 1.10 0.65
15 optimized_sync_reduction_base 0.20 1.09 0.65
15 shared_mem_reduction_optimized_base 0.20 1.09 0.65
15 modular_shared_warp_mean_base_base 0.20 1.09 0.65
15 coalesced_vectorized_mean_kernel_base 0.20 1.09 0.65
15 reduced_sync_shared_memory_base 0.20 1.09 0.65
15 fused_atomic_reduction_base 0.20 1.09 0.65
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>

// Kernel to compute the mean of each (batch, channel) slice using vectorized loads,
// and manual loop unrolling via #pragma unroll to reduce loop overhead in both global memory
// accesses and shared memory reduction.
// Each block processes one (batch, channel) slice and atomically accumulates its slice mean
// into a global accumulator.

// Experimenting with different block sizes to find the optimal configuration.

// Template parameter for block size
template <unsigned int blockSize>
__global__ void block_size_experimentation_kernel(
    const float* __restrict__ input,
    float* __restrict__ global_accum,
    int H,
    int W,
    int C
) {
    extern __shared__ float shared[];  // Shared memory for reduction
    int num_elements = H * W;
    int batch = blockIdx.x / C;
    int channel = blockIdx.x % C;
    int input_offset = (batch * C + channel) * num_elements;
    float sum = 0.0f;

    // Use vectorized loads if the number of elements is divisible by 4 (ensuring 128-bit alignment)
    if ((num_elements & 3) == 0) {
        int num_vec = num_elements >> 2;  // equivalent to num_elements / 4
        const float4* in_vec = reinterpret_cast<const float4*>(input + input_offset);
        for (int i = threadIdx.x; i < num_vec; i += blockDim.x) {
            #pragma unroll
            {
                float4 v = __ldg(&in_vec[i]);
                sum += v.x + v.y + v.z + v.w;
            }
        }
    } else {
        for (int i = threadIdx.x; i < num_elements; i += blockDim.x) {
            #pragma unroll
            {
                sum += __ldg(&input[input_offset + i]);
            }
        }
    }

    // Store the partial sum in shared memory
    shared[threadIdx.x] = sum;
    __syncthreads();

    // Intra-block reduction with manual unrolling
    if (blockSize >= 512) {
        if (threadIdx.x < 256)
            shared[threadIdx.x] += shared[threadIdx.x + 256];
        __syncthreads();
    }
    if (blockSize >= 256) {
        if (threadIdx.x < 128)
            shared[threadIdx.x] += shared[threadIdx.x + 128];
        __syncthreads();
    }
    if (blockSize >= 128) {
        if (threadIdx.x < 64)
            shared[threadIdx.x] += shared[threadIdx.x + 64];
        __syncthreads();
    }

    if (threadIdx.x < 32) {
        volatile float* vsmem = shared;
        #pragma unroll
        for (int offset = 32; offset > 0; offset /= 2) {
            vsmem[threadIdx.x] += vsmem[threadIdx.x + offset];
        }
    }

    // Thread 0 computes the mean for this slice and atomically adds it to the global accumulator
    if (threadIdx.x == 0) {
        float slice_mean = shared[0] / static_cast<float>(num_elements);
        atomicAdd(global_accum, slice_mean);
    }
}

at::Tensor module_fn(
    at::Tensor x,
    int64_t stride,
    int64_t padding,
    int64_t output_padding,
    at::Tensor conv_transpose,
    at::Tensor conv_transpose_bias,
    double multiplier
) {
    // Perform transposed convolution using PyTorch's native function
    at::Tensor y = at::conv_transpose2d(
        x,
        conv_transpose,
        conv_transpose_bias,
        {stride, stride},
        {padding, padding},
        {output_padding, output_padding},
        1,
        {1, 1}
    );

    // Scale the output
    y = y * multiplier;

    // Get dimensions (N, C, H, W)
    auto dims = y.sizes();
    int N = dims[0];
    int C = dims[1];
    int H = dims[2];
    int W = dims[3];

    // Allocate a scalar accumulator on the device and initialize to zero
    auto options = torch::TensorOptions().device(y.device()).dtype(y.dtype());
    at::Tensor accum = torch::zeros({1}, options);

    // Launch one block per (batch, channel) slice
    // Experiment with different block sizes
    constexpr int blockSize = 128;  // Example block size, can be changed to 32, 64, 256, 512
    int numBlocks = N * C;
    size_t sharedMemSize = blockSize * sizeof(float);

    block_size_experimentation_kernel<blockSize><<<numBlocks, blockSize, sharedMemSize>>>(
        y.data_ptr<float>(),
        accum.data_ptr<float>(),
        H, W, C
    );

    // Compute the final overall mean: average the means of all (batch, channel) slices
    accum = accum / static_cast<float>(N * C);
    return accum;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &module_fn, "Module function");
}