← Back to Leaderboard

The AI CUDA Engineer 👷

44_ConvTranspose2d_Multiply_GlobalAvgPool_GlobalAvgPool_Meanvectorized_ldg_mean_kernel_base

Level 2 • Task 44
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(
    x: torch.Tensor,
    stride: int,
    padding: int,
    output_padding: int,
    conv_transpose: torch.Tensor,
    conv_transpose_bias: torch.Tensor,
    multiplier: float,
) -> torch.Tensor:
    """
    Applies transposed convolution, scalar multiplication, and multiple global average pooling operations.

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width)
        stride (int): Stride of the transposed convolution
        padding (int): Padding of the transposed convolution
        output_padding (int): Additional size added to output shape
        conv_transpose (torch.Tensor): Transposed convolution weight tensor
        conv_transpose_bias (torch.Tensor): Bias tensor for transposed convolution
        multiplier (float): Scalar multiplier value

    Returns:
        torch.Tensor: Scalar output after applying operations
    """
    x = F.conv_transpose2d(
        x,
        conv_transpose,
        bias=conv_transpose_bias,
        stride=stride,
        padding=padding,
        output_padding=output_padding,
    )
    x = x * multiplier
    x = torch.mean(x, dim=[2, 3], keepdim=True)
    x = torch.mean(x, dim=[2, 3], keepdim=True)
    x = torch.mean(x)
    return x


class Model(nn.Module):
    """
    Model that performs a transposed convolution, multiplies by a scalar, applies global average pooling,
    another global average pooling, and then calculates the mean.
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        output_padding,
        multiplier,
    ):
        super(Model, self).__init__()
        conv = nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            output_padding=output_padding,
        )
        self.conv_transpose_parameter = nn.Parameter(conv.weight)
        self.conv_transpose_bias = nn.Parameter(
            conv.bias
            + torch.randn(
                conv.bias.shape, device=conv.bias.device, dtype=conv.bias.dtype
            )
            * 0.02
        )
        self.multiplier = multiplier

    def forward(self, x, stride, padding, output_padding, fn=module_fn):
        return fn(
            x,
            stride,
            padding,
            output_padding,
            self.conv_transpose_parameter,
            self.conv_transpose_bias,
            self.multiplier,
        )


batch_size = 128
in_channels = 3
out_channels = 16
height, width = 32, 32
kernel_size = 3
stride = 2
padding = 1
output_padding = 1
multiplier = 0.5


def get_inputs():
    return [
        torch.randn(batch_size, in_channels, height, width),
        stride,
        padding,
        output_padding,
    ]


def get_init_inputs():
    return [
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        output_padding,
        multiplier,
    ]
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Model that performs a transposed convolution, multiplies by a scalar, applies global average pooling, 
    another global average pooling, and then calculates the mean.
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, output_padding, multiplier):
        super(Model, self).__init__()
        self.conv_transpose = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding)
        self.conv_transpose.bias = nn.Parameter(self.conv_transpose.bias + torch.randn(self.conv_transpose.bias.shape, device=self.conv_transpose.bias.device, dtype=self.conv_transpose.bias.dtype) * 0.02)
        self.multiplier = multiplier

    def forward(self, x):
        x = self.conv_transpose(x)
        x = x * self.multiplier
        x = torch.mean(x, dim=[2, 3], keepdim=True)  # First global average pooling
        x = torch.mean(x, dim=[2, 3], keepdim=True)  # Second global average pooling
        x = torch.mean(x)
        return x

batch_size = 128
in_channels = 3
out_channels = 16
height, width = 32, 32
kernel_size = 3
stride = 2
padding = 1
output_padding = 1
multiplier = 0.5

def get_inputs():
    return [torch.randn(batch_size, in_channels, height, width)]

def get_init_inputs():
    return [in_channels, out_channels, kernel_size, stride, padding, output_padding, multiplier]

Kernel Information

Related Kernels (Level 2, Task 44 • 44_ConvTranspose2d_Multiply_GlobalAvgPool_GlobalAvgPool_Mean)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 optimized_spatial_reduction_base 0.18 1.21 0.72
🥇 optimized_spatial_reduction_edit_1 0.18 1.21 0.72
🥉 minimal_sync_reduction_edit_1 0.19 1.19 0.71
4 shared_memory_tiled_reduction_base 0.19 1.17 0.70
5 fused_global_avg_base 0.20 1.13 0.67
6 block_size_experimentation_base 0.20 1.11 0.66
7 optimized_strided_avg_pooling_edit_1 0.20 1.10 0.66
7 aligned_ldg_optimized_kernel_base 0.20 1.10 0.66
9 combined_optimized_mean_kernel_base 0.20 1.10 0.65
9 vectorized_ldg_mean_kernel_base 0.20 1.10 0.65
9 optimized_mean_kernel_base 0.20 1.10 0.65
9 warp_uniform_mean_kernel_base_base 0.20 1.10 0.65
9 unrolled_vectorized_mean_kernel_base 0.20 1.10 0.65
9 atomic_final_reduction_base 0.20 1.10 0.65
15 optimized_sync_reduction_base 0.20 1.09 0.65
15 shared_mem_reduction_optimized_base 0.20 1.09 0.65
15 modular_shared_warp_mean_base_base 0.20 1.09 0.65
15 coalesced_vectorized_mean_kernel_base 0.20 1.09 0.65
15 reduced_sync_shared_memory_base 0.20 1.09 0.65
15 fused_atomic_reduction_base 0.20 1.09 0.65
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>

// Kernel to compute the mean of each (batch, channel) slice using vectorized loads with __ldg()
// and then atomically accumulate the slice means into a global accumulator.

// The kernel uses 128-bit (float4) vectorized loads when the number of elements is divisible by 4,
// otherwise it falls back to scalar loads with __ldg().

// Each block is responsible for one (batch, channel) slice.
// The final result is computed as the average over all (batch, channel) means.

template <unsigned int blockSize>
__global__ void vectorized_ldg_mean_kernel(
    const float* __restrict__ input,
    float* __restrict__ global_accum,
    int H,
    int W,
    int C
) {
    extern __shared__ float shared[]; // shared memory for block reduction
    int num_elements = H * W;
    int batch = blockIdx.x / C;
    int channel = blockIdx.x % C;
    int input_offset = (batch * C + channel) * num_elements;
    float sum = 0.0f;

    // Check if the slice size is divisible by 4 to use vectorized loads
    if ((num_elements & 3) == 0) {
        int num_vec = num_elements >> 2; // equivalent to num_elements / 4
        // Cast the input pointer to float4 pointer for 128-bit aligned loads
        const float4* in_vec = reinterpret_cast<const float4*>(input + input_offset);
        for (int i = threadIdx.x; i < num_vec; i += blockDim.x) {
            // Use __ldg for read-only global memory access
            float4 v = __ldg(&in_vec[i]);
            sum += v.x + v.y + v.z + v.w;
        }
    } else {
        for (int i = threadIdx.x; i < num_elements; i += blockDim.x) {
            sum += __ldg(&input[input_offset + i]);
        }
    }

    // Write the partial sum to shared memory
    shared[threadIdx.x] = sum;
    __syncthreads();

    // Intra-block reduction in shared memory
    if (blockSize >= 512) {
        if (threadIdx.x < 256) {
            shared[threadIdx.x] += shared[threadIdx.x + 256];
        }
        __syncthreads();
    }
    if (blockSize >= 256) {
        if (threadIdx.x < 128) {
            shared[threadIdx.x] += shared[threadIdx.x + 128];
        }
        __syncthreads();
    }
    if (blockSize >= 128) {
        if (threadIdx.x < 64) {
            shared[threadIdx.x] += shared[threadIdx.x + 64];
        }
        __syncthreads();
    }
    if (threadIdx.x < 32) {
        volatile float* vsmem = shared;
        vsmem[threadIdx.x] += vsmem[threadIdx.x + 32];
        vsmem[threadIdx.x] += vsmem[threadIdx.x + 16];
        vsmem[threadIdx.x] += vsmem[threadIdx.x + 8];
        vsmem[threadIdx.x] += vsmem[threadIdx.x + 4];
        vsmem[threadIdx.x] += vsmem[threadIdx.x + 2];
        vsmem[threadIdx.x] += vsmem[threadIdx.x + 1];
    }

    // Thread 0 computes the mean for this (batch, channel) slice and atomically adds it to the global accumulator
    if (threadIdx.x == 0) {
        float slice_mean = shared[0] / static_cast<float>(num_elements);
        atomicAdd(global_accum, slice_mean);
    }
}

at::Tensor module_fn(
    at::Tensor x,
    int64_t stride,
    int64_t padding,
    int64_t output_padding,
    at::Tensor conv_transpose,
    at::Tensor conv_transpose_bias,
    double multiplier
) {
    // Perform transposed convolution
    at::Tensor y = at::conv_transpose2d(
        x,
        conv_transpose,
        conv_transpose_bias,
        {stride, stride},
        {padding, padding},
        {output_padding, output_padding},
        1,
        {1, 1}
    );
    
    // Scale the output
    y = y * multiplier;
    
    // Get dimensions (N, C, H, W)
    auto dims = y.sizes();
    int N = dims[0];
    int C = dims[1];
    int H = dims[2];
    int W = dims[3];
    
    // Allocate a scalar accumulator on the device initialized to zero
    auto options = torch::TensorOptions().device(y.device()).dtype(y.dtype());
    at::Tensor accum = torch::zeros({1}, options);
    
    // Launch one block per (batch, channel) slice
    constexpr int blockSize = 256;
    int numBlocks = N * C;
    size_t sharedMemSize = blockSize * sizeof(float);

    vectorized_ldg_mean_kernel<blockSize><<<numBlocks, blockSize, sharedMemSize>>>(
        y.data_ptr<float>(),
        accum.data_ptr<float>(),
        H, W, C
    );
    
    // Compute the final overall mean: each block contributed its slice mean, so average over N*C slices
    accum = accum / static_cast<float>(N * C);
    return accum;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &module_fn, "Module function");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 0.534 inst/cycle 0.000 5
Executed Ipc Elapsed 0.428 inst/cycle 0.000 5
Issue Slots Busy 13.460 % 0.116 5
Issued Ipc Active 0.538 inst/cycle 0.000 5
SM Busy 13.460 % 0.116 5
Memory Throughput 2264764848966.874 byte/second 1598666630662445858816.000 5
Mem Busy 38.138 % 0.486 5
Max Bandwidth 67.746 % 1.437 5
L1/TEX Hit Rate 0.000 % 0.000 5
L2 Hit Rate 2.912 % 0.000 5
Mem Pipes Busy 9.658 % 0.030 5
Warp Cycles Per Issued Instruction 104.392 cycle 0.954 5
Warp Cycles Per Executed Instruction 105.080 cycle 0.966 5
Avg. Active Threads Per Warp 31.680 0.000 5
Avg. Not Predicated Off Threads Per Warp 28.960 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 16.000 block 0.000 5
Block Limit Shared Mem 16.000 block 0.000 5
Block Limit Warps 8.000 block 0.000 5
Theoretical Active Warps per SM 64.000 warp 0.000 5
Theoretical Occupancy 100.000 % 0.000 5
Achieved Occupancy 88.494 % 0.022 5
Achieved Active Warps Per SM 56.638 warp 0.009 5
Analysis Rules
Rule Description
WRN HighPipeUtilization All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details.
INF CPIStall Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason.
WRN Occupancy This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (88.7%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy.
Operation / Metric Value Unit
aten::conv_transpose2d
CPU Time 5147576.24 μs
Device Time 5046688.03 μs
Self CPU Time 54774.43 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::convolution
CPU Time 5092801.81 μs
Device Time 5046688.03 μs
Self CPU Time 72554.45 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_convolution
CPU Time 5020247.35 μs
Device Time 5046688.03 μs
Self CPU Time 144147.42 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::cudnn_convolution_transpose
CPU Time 3992244.02 μs
Device Time 4080959.33 μs
Self CPU Time 858081.93 μs
Self Device Time 4080959.33 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 5550015.48 μs
Device Time 51489.24 μs
Self CPU Time 5550015.48 μs
Self Device Time 51489.24 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 2175584.90 μs
Device Time 2514181.89 μs
Self CPU Time 108362.45 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45293 warnings generated when compiling for host.
Suppressed 45327 warnings (45280 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b7_s1_vectorized_ldg_mean_kernel/base/base.cu:20:5 bugprone-easily-swappable-parameters
20 | int W,
| ^~~~~~
21 | int C
| ~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b7_s1_vectorized_ldg_mean_kernel/base/base.cu:20:9: note: the first parameter in the range is 'W'
20 | int W,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b7_s1_vectorized_ldg_mean_kernel/base/base.cu:21:9: note: the last parameter in the range is 'C'
21 | int C
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b7_s1_vectorized_ldg_mean_kernel/base/base.cu:25:17: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
25 | int batch = blockIdx.x / C;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b7_s1_vectorized_ldg_mean_kernel/base/base.cu:26:19: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
26 | int channel = blockIdx.x % C;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b7_s1_vectorized_ldg_mean_kernel/base/base.cu:35:22: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
35 | for (int i = threadIdx.x; i < num_vec; i += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b7_s1_vectorized_ldg_mean_kernel/base/base.cu:35:53: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
35 | for (int i = threadIdx.x; i < num_vec; i += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b7_s1_vectorized_ldg_mean_kernel/base/base.cu:41:22: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
41 | for (int i = threadIdx.x; i < num_elements; i += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b7_s1_vectorized_ldg_mean_kernel/base/base.cu:41:58: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
41 | for (int i = threadIdx.x; i < num_elements; i += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b7_s1_vectorized_ldg_mean_kernel/base/base.cu:87:16: warning: the parameter 'x' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
87 | at::Tensor x,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b7_s1_vectorized_ldg_mean_kernel/base/base.cu:91:16: warning: the parameter 'conv_transpose' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
91 | at::Tensor conv_transpose,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b7_s1_vectorized_ldg_mean_kernel/base/base.cu:112:13: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
112 | int N = dims[0];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b7_s1_vectorized_ldg_mean_kernel/base/base.cu:113:13: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
113 | int C = dims[1];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b7_s1_vectorized_ldg_mean_kernel/base/base.cu:114:13: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
114 | int H = dims[2];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b7_s1_vectorized_ldg_mean_kernel/base/base.cu:115:13: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
115 | int W = dims[3];
| ^