← Back to Leaderboard

The AI CUDA Engineer 👷

42_Max_Pooling_2Dmax_pool2d_optimized_strided_base

Level 1 • Task 42

Kernel Information

Related Kernels (Level 1, Task 42 • 42_Max_Pooling_2D)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 textured_modular_unroll_base_base 0.02 1.43 3.04
🥇 max_pool2d_optimized_strided_base 0.02 1.43 3.04
🥇 tuned_block_size_maxpool2d_base_base 0.02 1.43 3.04
🥇 modularized_unroll_base_base 0.02 1.43 3.04
🥇 fused_unroll_constmem_pool_base 0.02 1.43 3.04
🥇 base_unrolled_combo_base 0.02 1.43 3.04
🥇 warp_divergence_optimized_unroll_base 0.02 1.43 3.04
8 max_pool2d_strided_base_base 0.02 1.37 2.91
8 streams_unroll_pipelined_batch_base_base 0.02 1.37 2.91
8 optimized_max_pool2d_base 0.02 1.37 2.91
8 42_Max_Pooling_2D_manual_unroll_base 0.02 1.37 2.91
8 coalesced_1d_unroll_base_base 0.02 1.37 2.91
8 max_pool2d_kernel_manually_unrolled_base_base 0.02 1.37 2.91
8 coalesced_ldg_unrolled_base_base 0.02 1.37 2.91
8 unrolled_coalesced_maxpool_base 0.02 1.37 2.91
8 max_pool2d_combined_base 0.02 1.37 2.91
8 max_pool2d_combined_optimized_base 0.02 1.37 2.91
8 unroll_base_optimized_base 0.02 1.37 2.91
8 max_pool2d_combined_base 0.02 1.37 2.91
8 fully_unrolled_maxpool2d_base_base 0.02 1.37 2.91
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

template <typename scalar_t>
__global__ void max_pool2d_optimized_kernel(
    const scalar_t* __restrict__ input,
    scalar_t* __restrict__ output,
    const int batch_channels,
    const int input_height,
    const int input_width,
    const int output_height,
    const int output_width,
    const int kernel_size,
    const int stride,
    const int padding,
    const int dilation
) {
    const int output_elements = output_height * output_width;
    const int total_elements = batch_channels * output_elements;
    
    for (int idx = blockIdx.x * blockDim.x + threadIdx.x; 
         idx < total_elements; 
         idx += gridDim.x * blockDim.x) {
        
        const int bc = idx / output_elements;
        const int oh = (idx % output_elements) / output_width;
        const int ow = idx % output_width;

        scalar_t max_val = -std::numeric_limits<scalar_t>::infinity();

        if (kernel_size == 2) {
            #pragma unroll
            for (int kh = 0; kh < 2; ++kh) {
                #pragma unroll
                for (int kw = 0; kw < 2; ++kw) {
                    const int ih = oh * stride - padding + kh * dilation;
                    const int iw = ow * stride - padding + kw * dilation;
                    
                    if (ih >= 0 && ih < input_height && iw >= 0 && iw < input_width) {
                        const int input_idx = (bc * input_height + ih) * input_width + iw;
                        max_val = max(max_val, input[input_idx]);
                    }
                }
            }
        } else {
            for (int kh = 0; kh < kernel_size; ++kh) {
                for (int kw = 0; kw < kernel_size; ++kw) {
                    const int ih = oh * stride - padding + kh * dilation;
                    const int iw = ow * stride - padding + kw * dilation;
                    
                    if (ih >= 0 && ih < input_height && iw >= 0 && iw < input_width) {
                        const int input_idx = (bc * input_height + ih) * input_width + iw;
                        max_val = max(max_val, input[input_idx]);
                    }
                }
            }
        }
        
        output[idx] = max_val;
    }
}

torch::Tensor max_pool2d_cuda_forward(
    torch::Tensor input,
    int kernel_size,
    int stride,
    int padding,
    int dilation
) {
    const int batch_size = input.size(0);
    const int channels = input.size(1);
    const int input_height = input.size(2);
    const int input_width = input.size(3);
    
    const int output_height = (input_height + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1;
    const int output_width = (input_width + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1;
    
    auto output = torch::empty({batch_size, channels, output_height, output_width}, input.options());
    
    const int batch_channels = batch_size * channels;
    const int threads = 256;
    const int blocks = (batch_channels * output_height * output_width + threads - 1) / threads;

    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "max_pool2d_forward", ([&] {
        max_pool2d_optimized_kernel<scalar_t><<<blocks, threads>>>(
            input.data_ptr<scalar_t>(),
            output.data_ptr<scalar_t>(),
            batch_channels,
            input_height,
            input_width,
            output_height,
            output_width,
            kernel_size,
            stride,
            padding,
            dilation
        );
    }));

    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &max_pool2d_cuda_forward, "Max Pool 2D optimized strided forward (CUDA)");
}