← Back to Leaderboard

The AI CUDA Engineer 👷

8_Conv3d_Divide_Max_GlobalAvgPool_BiasAdd_Sumfused_pooling_warp_uniform_base

Level 2 • Task 8
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(
    x: torch.Tensor,
    divisor: float,
    pool_size: tuple,
    sum_dim: int,
    conv_weight: torch.Tensor,
    conv_bias: torch.Tensor,
    bias: torch.Tensor,
) -> torch.Tensor:
    """
    Applies 3D convolution, division, max pooling, global average pooling, bias addition and sum.

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, in_channels, depth, height, width)
        divisor (float): Constant to divide by
        pool_size (tuple): Size for max pooling (depth, height, width)
        sum_dim (int): Dimension to sum over
        conv_weight (torch.Tensor): 3D convolution weights
        conv_bias (torch.Tensor): 3D convolution bias
        bias (torch.Tensor): Bias tensor for addition

    Returns:
        torch.Tensor: Output tensor after applying all operations
    """
    x = F.conv3d(x, conv_weight, bias=conv_bias)
    x = x / divisor
    x = F.max_pool3d(x, pool_size)
    x = F.adaptive_avg_pool3d(x, (1, 1, 1))
    x = x + bias
    x = torch.sum(x, dim=sum_dim)
    return x


class Model(nn.Module):
    """
    Model that performs a 3D convolution, divides by a constant, applies max pooling,
    global average pooling, adds a bias term, and sums along a specific dimension.
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        divisor,
        pool_size,
        bias_shape,
        sum_dim,
    ):
        super(Model, self).__init__()
        conv_shape = (out_channels, in_channels, *kernel_size)
        conv = nn.Conv3d(in_channels, out_channels, kernel_size)
        self.bias = nn.Parameter(torch.randn(bias_shape) * 0.02)

        self.conv_weight = conv.weight
        self.conv_bias = conv.bias
        self.bias = self.bias

    def forward(self, x, fn=module_fn):
        return fn(
            x, divisor, pool_size, sum_dim, self.conv_weight, self.conv_bias, self.bias
        )


batch_size = 128
in_channels = 3
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = (3, 3, 3)
divisor = 2.0
pool_size = (2, 2, 2)
bias_shape = (out_channels, 1, 1, 1)
sum_dim = 1


def get_inputs():
    return [torch.randn(batch_size, in_channels, depth, height, width)]


def get_init_inputs():
    return [
        in_channels,
        out_channels,
        kernel_size,
        divisor,
        pool_size,
        bias_shape,
        sum_dim,
    ]
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Model that performs a 3D convolution, divides by a constant, applies max pooling,
    global average pooling, adds a bias term, and sums along a specific dimension.
    """
    def __init__(self, in_channels, out_channels, kernel_size, divisor, pool_size, bias_shape, sum_dim):
        super(Model, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size)
        self.divisor = divisor
        self.max_pool = nn.MaxPool3d(pool_size)
        self.global_avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.bias = nn.Parameter(torch.randn(bias_shape) * 0.02)
        self.sum_dim = sum_dim

    def forward(self, x):
        x = self.conv(x)
        x = x / self.divisor
        x = self.max_pool(x)
        x = self.global_avg_pool(x)
        x = x + self.bias
        x = torch.sum(x, dim=self.sum_dim)
        return x

batch_size = 128
in_channels = 3
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = (3, 3, 3)
divisor = 2.0
pool_size = (2, 2, 2)
bias_shape = (out_channels, 1, 1, 1)
sum_dim = 1

def get_inputs():
    return [torch.randn(batch_size, in_channels, depth, height, width)]

def get_init_inputs():
    return [in_channels, out_channels, kernel_size, divisor, pool_size, bias_shape, sum_dim]

Kernel Information

Related Kernels (Level 2, Task 8 • 8_Conv3d_Divide_Max_GlobalAvgPool_BiasAdd_Sum)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 fused_stride_loops_base 0.75 1.21 0.91
🥇 fused_pooling_warp_uniform_base 0.75 1.21 0.91
🥇 fused_divide_maxpool_avg_base 0.75 1.21 0.91
🥇 fused_divide_maxpool_avg_edit_1 0.75 1.21 0.91
5 fused_stride_loops_edit_1 0.75 1.21 0.91
5 fused_pooling_warp_uniform_edit_1 0.75 1.21 0.91
5 fused_stride_loops_base 0.75 1.21 0.91
8 block_size_tuned_base_base 0.75 1.20 0.91
8 fused_stride_loops_edit_1 0.75 1.20 0.91
10 fused_pooling_shared_memory_base 0.75 1.20 0.91
10 optimized_stride_boundary_base_base 0.75 1.20 0.91
10 fused_pooling_min_sync_base 0.75 1.20 0.91
13 fused_pooling_min_sync_opt_base 0.75 1.20 0.91
14 fused_pooling_opt_sync_edit_1 0.76 1.19 0.90
15 fused_pooling_uniform_edit_1 0.76 1.19 0.90
15 fused_pooling_uniform_base 0.76 1.19 0.90
15 fused_pooling_opt_sync_base 0.76 1.19 0.90
18 fused_pooling_stride_boundaries_base 0.76 1.19 0.90
18 fused_pooling_unroll_edit_1 0.76 1.19 0.90
20 fused_pooling_unroll_base 0.76 1.19 0.90
#include <torch/extension.h>
#include <vector>
#include <cfloat>
#include <cuda_runtime.h>

__global__ void fusedPoolingAndReductionKernel(
    const float* __restrict__ conv_out,
    int N, int C, int D, int H, int W,
    int poolD, int poolH, int poolW,
    const float* __restrict__ bias,
    int sum_dim,
    float* __restrict__ output,
    float divisor
) {
    const int nc = blockIdx.x;
    const int n = nc / C;
    const int c = nc % C;
    
    const int outD = D / poolD;
    const int outH = H / poolH;
    const int outW = W / poolW;
    const int numPools = outD * outH * outW;
    const float invDiv = 1.0f / divisor;
    
    // Ensure uniform workload distribution across threads
    float thread_sum = 0.0f;
    const int tid = threadIdx.x;
    const int stride = blockDim.x;
    
    #pragma unroll 1
    for (int idx = tid; idx < numPools; idx += stride) {
        const int d = idx / (outH * outW);
        const int rem = idx % (outH * outW);
        const int h = rem / outW;
        const int w = rem % outW;
        
        const int startD = d * poolD;
        const int startH = h * poolH;
        const int startW = w * poolW;
        
        float max_val = -FLT_MAX;
        
        #pragma unroll
        for (int pd = 0; pd < poolD; pd++) {
            const int curD = startD + pd;
            #pragma unroll
            for (int ph = 0; ph < poolH; ph++) {
                const int curH = startH + ph;
                #pragma unroll
                for (int pw = 0; pw < poolW; pw++) {
                    const int curW = startW + pw;
                    const int index = (((n * C + c) * D + curD) * H + curH) * W + curW;
                    const float val = __ldg(&conv_out[index]) * invDiv;
                    max_val = fmaxf(max_val, val);
                }
            }
        }
        thread_sum += max_val;
    }
    
    // Warp-level reduction with uniform control flow
    const unsigned int lane = threadIdx.x & 31;
    const unsigned int warp = threadIdx.x >> 5;
    
    // Use predicated assignments instead of if statements
    float warp_sum = thread_sum;
    
    #pragma unroll
    for (int offset = 16; offset > 0; offset >>= 1) {
        warp_sum += __shfl_down_sync(0xffffffff, warp_sum, offset);
    }
    
    __shared__ float warp_results[32];  // Max 32 warps per block
    
    bool is_warp_leader = (lane == 0);
    if (is_warp_leader) {
        warp_results[warp] = warp_sum;
    }
    __syncthreads();
    
    // Final reduction using first warp only
    if (threadIdx.x < 32) {
        float final_sum = (threadIdx.x < (blockDim.x + 31) / 32) ? warp_results[threadIdx.x] : 0.0f;
        
        #pragma unroll
        for (int offset = 16; offset > 0; offset >>= 1) {
            final_sum += __shfl_down_sync(0xffffffff, final_sum, offset);
        }
        
        // Only thread 0 writes the result
        bool is_thread_zero = (threadIdx.x == 0);
        if (is_thread_zero) {
            const float avg = final_sum / numPools + bias[c];
            // Use predicated atomic operation
            if (sum_dim == 1) {
                atomicAdd(&output[n], avg);
            } else {
                output[n * C + c] = avg;
            }
        }
    }
}

torch::Tensor forward_cuda(
    torch::Tensor x,
    double divisor,
    std::vector<int64_t> pool_size,
    int64_t sum_dim,
    torch::Tensor conv_weight,
    torch::Tensor conv_bias,
    torch::Tensor bias
) {
    TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor");
    TORCH_CHECK(conv_weight.is_cuda(), "conv_weight must be a CUDA tensor");
    TORCH_CHECK(conv_bias.is_cuda(), "conv_bias must be a CUDA tensor");
    TORCH_CHECK(bias.is_cuda(), "bias must be a CUDA tensor");

    auto conv_out = at::conv3d(x, conv_weight, conv_bias);
    
    const int N = conv_out.size(0);
    const int C = conv_out.size(1);
    const int D = conv_out.size(2);
    const int H = conv_out.size(3);
    const int W = conv_out.size(4);

    const int threadsPerBlock = 256;  // Increased thread count for better occupancy
    const int totalBlocks = N * C;
    
    torch::Tensor output = (sum_dim == 1) ? 
        torch::zeros({N}, conv_out.options()) :
        torch::empty({N, C}, conv_out.options());

    fusedPoolingAndReductionKernel<<<totalBlocks, threadsPerBlock, sizeof(float) * 32>>>(
        conv_out.data_ptr<float>(),
        N, C, D, H, W,
        pool_size[0], pool_size[1], pool_size[2],
        bias.data_ptr<float>(),
        sum_dim,
        output.data_ptr<float>(),
        static_cast<float>(divisor)
    );

    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward_cuda, "Fused 3D Conv, Pooling, GlobalAvgPool, BiasAdd, and Sum (CUDA)");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 2.380 inst/cycle 0.000 5
Executed Ipc Elapsed 2.150 inst/cycle 0.000 5
Issue Slots Busy 59.520 % 0.039 5
Issued Ipc Active 2.380 inst/cycle 0.000 5
SM Busy 59.520 % 0.039 5
Memory Throughput 1883734319208.674 byte/second 21128943110373679104.000 5
Mem Busy 34.674 % 0.006 5
Max Bandwidth 56.226 % 0.018 5
L1/TEX Hit Rate 61.038 % 0.000 5
L2 Hit Rate 12.874 % 0.001 5
Mem Pipes Busy 22.582 % 0.003 5
Warp Cycles Per Issued Instruction 24.420 cycle 0.012 5
Warp Cycles Per Executed Instruction 24.432 cycle 0.013 5
Avg. Active Threads Per Warp 31.500 0.000 5
Avg. Not Predicated Off Threads Per Warp 24.890 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 8.000 block 0.000 5
Block Limit Shared Mem 25.000 block 0.000 5
Block Limit Warps 8.000 block 0.000 5
Theoretical Active Warps per SM 64.000 warp 0.000 5
Theoretical Occupancy 100.000 % 0.000 5
Achieved Occupancy 90.736 % 0.015 5
Achieved Active Warps Per SM 58.070 warp 0.006 5
Analysis Rules
Rule Description
INF HighPipeUtilization ALU is the highest-utilized pipeline (42.6%) based on active cycles, taking into account the rates of its different instructions. It executes integer and logic operations. It is well-utilized, but should not be a bottleneck.
INF CPIStall Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason.
INF Occupancy This kernel's theoretical occupancy is not impacted by any block limit.
Operation / Metric Value Unit
aten::conv3d
CPU Time 7963454.66 μs
Device Time 8103648.03 μs
Self CPU Time 21889.05 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::convolution
CPU Time 7941565.60 μs
Device Time 8103648.03 μs
Self CPU Time 30080.81 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_convolution
CPU Time 7911484.80 μs
Device Time 8103648.03 μs
Self CPU Time 66359.11 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::cudnn_convolution
CPU Time 6815726.12 μs
Device Time 7033711.13 μs
Self CPU Time 251114.31 μs
Self Device Time 7033711.13 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernelExC
CPU Time 6529370.76 μs
Device Time 0.00 μs
Self CPU Time 6529370.76 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
sm80_xmma_fprop_implicit_gemm_indexed_f32f32_f32f32_f32_nchwkcrs_nchw_tilesize32x32x8_stage3_warpsize1x2x1_g1_ffma_aligna4_alignc4_execute_kernel__5x_cudnn
CPU Time 0.00 μs
Device Time 7033709.82 μs
Self CPU Time 0.00 μs
Self Device Time 7033709.82 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45298 warnings generated when compiling for host.
Suppressed 45327 warnings (45280 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_8/b4_s1_fused_pooling_warp_uniform/base/base.cu:8:5 bugprone-easily-swappable-parameters
8 | int N, int C, int D, int H, int W,
| ^~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_8/b4_s1_fused_pooling_warp_uniform/base/base.cu:8:9: note: the first parameter in the range is 'N'
8 | int N, int C, int D, int H, int W,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_8/b4_s1_fused_pooling_warp_uniform/base/base.cu:8:16: note: the last parameter in the range is 'C'
8 | int N, int C, int D, int H, int W,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_8/b4_s1_fused_pooling_warp_uniform/base/base.cu:8:33: warning: 2 adjacent parameters of 'fusedPoolingAndReductionKernel' of similar type ('int') are easily swapped by mistake [bugprone-easily-swappable-parameters]
8 | int N, int C, int D, int H, int W,
| ^~~~~~
9 | int poolD, int poolH, int poolW,
| ~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_8/b4_s1_fused_pooling_warp_uniform/base/base.cu:8:37: note: the first parameter in the range is 'W'
8 | int N, int C, int D, int H, int W,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_8/b4_s1_fused_pooling_warp_uniform/base/base.cu:9:9: note: the last parameter in the range is 'poolD'
9 | int poolD, int poolH, int poolW,
| ^~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_8/b4_s1_fused_pooling_warp_uniform/base/base.cu:15:20: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
15 | const int nc = blockIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_8/b4_s1_fused_pooling_warp_uniform/base/base.cu:27:21: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
27 | const int tid = threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_8/b4_s1_fused_pooling_warp_uniform/base/base.cu:28:24: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
28 | const int stride = blockDim.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_8/b4_s1_fused_pooling_warp_uniform/base/base.cu:93:43: warning: narrowing conversion from 'int' to 'float' [bugprone-narrowing-conversions]
93 | const float avg = final_sum / numPools + bias[c];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_8/b4_s1_fused_pooling_warp_uniform/base/base.cu:105:19: warning: the parameter 'x' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
105 | torch::Tensor x,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_8/b4_s1_fused_pooling_warp_uniform/base/base.cu:109:19: warning: the parameter 'conv_weight' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
109 | torch::Tensor conv_weight,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_8/b4_s1_fused_pooling_warp_uniform/base/base.cu:111:19: warning: the parameter 'bias' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
111 | torch::Tensor bias
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_8/b4_s1_fused_pooling_warp_uniform/base/base.cu:120:19: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
120 | const int N = conv_out.size(0);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_8/b4_s1_fused_pooling_warp_uniform/base/base.cu:121:19: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
121 | const int C = conv_out.size(1);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_8/b4_s1_fused_pooling_warp_uniform/base/base.cu:122:19: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
122 | const int D = conv_out.size(2);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_8/b4_s1_fused_pooling_warp_uniform/base/base.cu:123:19: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
123 | const int H = conv_out.size(3);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_8/b4_s1_fused_pooling_warp_uniform/base/base.cu:124:19: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
124 | const int W = conv_out.size(4);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_8/b4_s1_fused_pooling_warp_uniform/base/base.cu:136:9: warning: narrowing conversion from 'value_type' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
136 | pool_size[0], pool_size[1], pool_size[2],
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_8/b4_s1_fused_pooling_warp_uniform/base/base.cu:136:23: warning: narrowing conversion from 'value_type' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
136 | pool_size[0], pool_size[1], pool_size[2],
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_8/b4_s1_fused_pooling_warp_uniform/base/base.cu:136:37: warning: narrowing conversion from 'value_type' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
136 | pool_size[0], pool_size[1], pool_size[2],
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_8/b4_s1_fused_pooling_warp_uniform/base/base.cu:138:9: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
138 | sum_dim,
| ^