← Back to Leaderboard

The AI CUDA Engineer 👷

43_Conv3d_Max_LogSumExp_ReLUoptimized_fused_3d_kernel_base

Level 2 • Task 43

Kernel Information

Related Kernels (Level 2, Task 43 • 43_Conv3d_Max_LogSumExp_ReLU)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 strided_conv3d_max_logsumexp_relu_base 0.79 1.05 1.01
🥇 optimized_fused_kernel_base 0.79 1.05 1.01
🥇 fused_optimized_base 0.79 1.05 1.01
🥇 optimized_fused_3d_kernel_base 0.79 1.05 1.01
🥇 coalesced_memory_access_kernel_base 0.79 1.05 1.01
🥇 optimized_fused_3d_kernel_base 0.79 1.05 1.01
🥇 coalesced_memory_access_kernel_base 0.79 1.05 1.01
🥇 block_tuned_fused_kernel_base_base 0.79 1.05 1.01
🥇 unroll_fused_kernel_base_base 0.79 1.05 1.01
🥇 unroll_optimized_kernel_base_base 0.79 1.05 1.01
🥇 warp_uniform_kernel_base_base 0.79 1.05 1.01
🥇 minimal_sync_fused_kernel_base 0.79 1.05 1.01
13 optimized_thread_block_indexing_base 0.79 1.05 1.00
14 coalesced_fused_kernel_base 0.80 1.05 1.00
15 atomic_reduction_43_conv3d_base 0.81 1.03 0.99
16 fused_gridstride_logsumexp_relu_base 0.81 1.03 0.99
17 stride_loop_43_conv3d_base_base 0.81 1.03 0.99
18 43_Conv3d_Max_LogSumExp_ReLU 0.83 1.00 0.96
19 43_conv3d_max_logsumexp_relu_unrolled_optimized_base 0.83 1.00 0.95
19 43_conv3d_max_logsumexp_relu_tuned_blocksize_edit_1 0.83 1.00 0.95
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include <cfloat>

// Shared memory kernel for better performance
__global__ void optimized_fused_kernel(
    const float* __restrict__ input,
    float* __restrict__ output,
    const int N, const int C, const int D, const int H, const int W) {
    
    extern __shared__ float shared_data[];
    
    const int tid = threadIdx.x;
    const int bid = blockIdx.x;
    const int num_blocks = gridDim.x;
    const int stride = D * H * W;
    const int total_elements = N * D * H * W;
    
    // Process multiple elements per thread using grid-stride loop
    for (int idx = bid * blockDim.x + tid; idx < total_elements; idx += blockDim.x * num_blocks) {
        // Decode indices
        const int w = idx % W;
        int temp = idx / W;
        const int h = temp % H;
        temp /= H;
        const int d = temp % D;
        const int n = temp / D;
        
        // Use shared memory for intermediate calculations
        float max_val = -FLT_MAX;
        float local_sum = 0.0f;
        
        // First pass: find maximum (coalesced memory access)
        #pragma unroll 4
        for (int c = 0; c < C; ++c) {
            const int input_idx = n * (C * stride) + c * stride + d * (H * W) + h * W + w;
            max_val = fmaxf(max_val, input[input_idx]);
        }
        
        // Store max_val in shared memory
        shared_data[tid] = max_val;
        __syncthreads();
        
        // Second pass: compute sum of exponentials
        #pragma unroll 4
        for (int c = 0; c < C; ++c) {
            const int input_idx = n * (C * stride) + c * stride + d * (H * W) + h * W + w;
            local_sum += __expf(input[input_idx] - shared_data[tid]);
        }
        
        // Compute final result with ReLU using intrinsics for better performance
        float result = shared_data[tid] + __logf(local_sum);
        result = fmaxf(0.0f, result);
        
        // Write to output (coalesced write)
        output[idx] = result;
    }
}

torch::Tensor forward(
    torch::Tensor x,
    int64_t stride,
    int64_t padding,
    torch::Tensor conv_weight,
    torch::Tensor conv_bias) {

    // Ensure input tensors are contiguous
    x = x.contiguous();
    conv_weight = conv_weight.contiguous();
    conv_bias = conv_bias.contiguous();

    // Perform 3D convolution using PyTorch
    auto conv_result = torch::conv3d(x, conv_weight, conv_bias, 
                                   {stride, stride, stride}, 
                                   {padding, padding, padding});

    // Perform max pooling using PyTorch
    auto pool_result = torch::max_pool3d(conv_result, {2, 2, 2}, {2, 2, 2});

    const int N = pool_result.size(0);
    const int C = pool_result.size(1);
    const int D = pool_result.size(2);
    const int H = pool_result.size(3);
    const int W = pool_result.size(4);

    auto output = torch::empty({N, 1, D, H, W}, pool_result.options());

    // Optimize kernel launch configuration
    const int block_size = 256;
    const int num_blocks = std::min(65535, (N * D * H * W + block_size - 1) / block_size);
    
    optimized_fused_kernel<<<num_blocks, block_size, block_size * sizeof(float)>>>(
        pool_result.data_ptr<float>(),
        output.data_ptr<float>(),
        N, C, D, H, W
    );

    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Optimized fused 3D operations");
}