← Back to Leaderboard

The AI CUDA Engineer 👷

78_ConvTranspose3d_Max_Max_Sumfully_unrolled_maxpool_base_base

Level 2 • Task 78

Kernel Information

Related Kernels (Level 2, Task 78 • 78_ConvTranspose3d_Max_Max_Sum)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 optimized_maxpool_kernel_base 0.58 1.05 1.21
🥇 adaptive_blocksize_maxpool_opt_base 0.58 1.05 1.21
🥉 minimized_divergence_maxpool_base_base 0.58 1.05 1.21
4 unrolled_78_convtranspose3d_optimized_base 0.59 1.03 1.19
4 modular_maxpool_kernel_base 0.59 1.03 1.19
6 fully_unrolled_maxpool_base_base 0.59 1.03 1.19
7 balanced_load_distribution_maxpool_base 0.59 1.03 1.19
8 manual_unroll_maxpool_base_base 0.59 1.03 1.19
9 coalesced_maxpool_shared_mem_base 0.60 1.02 1.18
10 unrolled_78_convtranspose3d_base 0.61 1.01 1.16
11 78_ConvTranspose3d_Max_Max_Sum 0.61 1.00 1.16
12 unroll_conv3d_max_sum_base 0.61 1.00 1.15
13 modular_conv3d_max_sum_edit_1 0.61 1.00 1.15
13 modular_conv3d_max_sum_base 0.61 1.00 1.15
13 shared_mem_reduction_max_sum_base 0.61 1.00 1.15
13 unroll_conv3d_max_sum_edit_1 0.61 1.00 1.15
17 optimized_stride_max_pool_base 0.61 1.00 1.15
17 shared_mem_reduction_max_sum_edit_1 0.61 1.00 1.15
19 constant_memory_optimization_base_edit_1 0.62 0.99 1.14
19 balanced_workload_distribution_base 0.62 0.99 1.14
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>

__global__ void fully_unrolled_double_maxpool_kernel(
    const float* __restrict__ input,
    float* __restrict__ output,
    const int N, const int C,
    const int D1, const int H1, const int W1,  // Dimensions after conv_transpose
    const int D2, const int H2, const int W2,  // Dimensions after first maxpool
    const int D3, const int H3, const int W3)  // Final dimensions
{
    const int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= N * C * D3 * H3 * W3) return;

    // Decode output index
    const int w3 = idx % W3;
    const int h3 = (idx / W3) % H3;
    const int d3 = (idx / (W3 * H3)) % D3;
    const int c = (idx / (W3 * H3 * D3)) % C;
    const int n = idx / (W3 * H3 * D3 * C);

    // Calculate starting indices for the 3x3x3 window in the first maxpool output
    const int start_d2 = d3 * 3;
    const int start_h2 = h3 * 3;
    const int start_w2 = w3 * 3;

    float final_max = -FLT_MAX;

    // Fully unrolled 3x3x3 maxpool over the 2x2x2 maxpool results
    #pragma unroll
    for (int d2_offset = 0; d2_offset < 3; d2_offset++) {
        const int d2 = start_d2 + d2_offset;
        if (d2 >= D2) continue;

        #pragma unroll
        for (int h2_offset = 0; h2_offset < 3; h2_offset++) {
            const int h2 = start_h2 + h2_offset;
            if (h2 >= H2) continue;

            #pragma unroll
            for (int w2_offset = 0; w2_offset < 3; w2_offset++) {
                const int w2 = start_w2 + w2_offset;
                if (w2 >= W2) continue;

                // For each position in the 3x3x3 window, compute 2x2x2 maxpool
                float local_max = -FLT_MAX;

                // Starting indices for the 2x2x2 window in the original input
                const int start_d1 = d2 * 2;
                const int start_h1 = h2 * 2;
                const int start_w1 = w2 * 2;

                // Fully unrolled 2x2x2 maxpool
                #pragma unroll
                for (int d1_offset = 0; d1_offset < 2; d1_offset++) {
                    const int d1 = start_d1 + d1_offset;
                    if (d1 >= D1) continue;

                    #pragma unroll
                    for (int h1_offset = 0; h1_offset < 2; h1_offset++) {
                        const int h1 = start_h1 + h1_offset;
                        if (h1 >= H1) continue;

                        #pragma unroll
                        for (int w1_offset = 0; w1_offset < 2; w1_offset++) {
                            const int w1 = start_w1 + w1_offset;
                            if (w1 >= W1) continue;

                            const int input_idx = ((n * C + c) * D1 + d1) * H1 * W1 + h1 * W1 + w1;
                            local_max = max(local_max, input[input_idx]);
                        }
                    }
                }

                final_max = max(final_max, local_max);
            }
        }
    }

    output[idx] = final_max;
}

torch::Tensor forward(
    torch::Tensor x,
    int64_t stride,
    int64_t padding,
    torch::Tensor conv_transpose,
    torch::Tensor conv_transpose_bias) {

    x = x.contiguous();
    conv_transpose = conv_transpose.contiguous();
    conv_transpose_bias = conv_transpose_bias.contiguous();

    TORCH_CHECK(x.is_cuda(), "Input x must be a CUDA tensor");
    TORCH_CHECK(conv_transpose.is_cuda(), "conv_transpose must be a CUDA tensor");
    TORCH_CHECK(conv_transpose_bias.is_cuda(), "conv_transpose_bias must be a CUDA tensor");

    // Apply transposed convolution using ATen op
    x = at::conv_transpose3d(
        x,
        conv_transpose,
        conv_transpose_bias,
        {stride, stride, stride},
        {padding, padding, padding}
    );

    // Get dimensions after conv_transpose
    auto sizes = x.sizes();
    const int N = sizes[0];
    const int C = sizes[1];
    const int D1 = sizes[2];
    const int H1 = sizes[3];
    const int W1 = sizes[4];

    // Calculate dimensions after first maxpool (2x2x2)
    const int D2 = D1 / 2;
    const int H2 = H1 / 2;
    const int W2 = W1 / 2;

    // Calculate final dimensions after second maxpool (3x3x3)
    const int D3 = D2 / 3;
    const int H3 = H2 / 3;
    const int W3 = W2 / 3;

    // Allocate output tensor
    auto output = torch::empty({N, C, D3, H3, W3}, x.options());

    // Launch kernel
    const int total_elements = N * C * D3 * H3 * W3;
    const int threads = 256;
    const int blocks = (total_elements + threads - 1) / threads;

    fully_unrolled_double_maxpool_kernel<<<blocks, threads>>>(
        x.data_ptr<float>(),
        output.data_ptr<float>(),
        N, C, D1, H1, W1, D2, H2, W2, D3, H3, W3
    );

    // Sum over channels
    return output.sum(1, /*keepdim=*/true);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Forward pass with fully unrolled maxpool operations");
}