← Back to Leaderboard

The AI CUDA Engineer 👷

44_ConvTranspose2d_Multiply_GlobalAvgPool_GlobalAvgPool_Meanmodular_shared_warp_mean_base_base

Level 2 • Task 44
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(
    x: torch.Tensor,
    stride: int,
    padding: int,
    output_padding: int,
    conv_transpose: torch.Tensor,
    conv_transpose_bias: torch.Tensor,
    multiplier: float,
) -> torch.Tensor:
    """
    Applies transposed convolution, scalar multiplication, and multiple global average pooling operations.

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width)
        stride (int): Stride of the transposed convolution
        padding (int): Padding of the transposed convolution
        output_padding (int): Additional size added to output shape
        conv_transpose (torch.Tensor): Transposed convolution weight tensor
        conv_transpose_bias (torch.Tensor): Bias tensor for transposed convolution
        multiplier (float): Scalar multiplier value

    Returns:
        torch.Tensor: Scalar output after applying operations
    """
    x = F.conv_transpose2d(
        x,
        conv_transpose,
        bias=conv_transpose_bias,
        stride=stride,
        padding=padding,
        output_padding=output_padding,
    )
    x = x * multiplier
    x = torch.mean(x, dim=[2, 3], keepdim=True)
    x = torch.mean(x, dim=[2, 3], keepdim=True)
    x = torch.mean(x)
    return x


class Model(nn.Module):
    """
    Model that performs a transposed convolution, multiplies by a scalar, applies global average pooling,
    another global average pooling, and then calculates the mean.
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        output_padding,
        multiplier,
    ):
        super(Model, self).__init__()
        conv = nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            output_padding=output_padding,
        )
        self.conv_transpose_parameter = nn.Parameter(conv.weight)
        self.conv_transpose_bias = nn.Parameter(
            conv.bias
            + torch.randn(
                conv.bias.shape, device=conv.bias.device, dtype=conv.bias.dtype
            )
            * 0.02
        )
        self.multiplier = multiplier

    def forward(self, x, stride, padding, output_padding, fn=module_fn):
        return fn(
            x,
            stride,
            padding,
            output_padding,
            self.conv_transpose_parameter,
            self.conv_transpose_bias,
            self.multiplier,
        )


batch_size = 128
in_channels = 3
out_channels = 16
height, width = 32, 32
kernel_size = 3
stride = 2
padding = 1
output_padding = 1
multiplier = 0.5


def get_inputs():
    return [
        torch.randn(batch_size, in_channels, height, width),
        stride,
        padding,
        output_padding,
    ]


def get_init_inputs():
    return [
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        output_padding,
        multiplier,
    ]
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Model that performs a transposed convolution, multiplies by a scalar, applies global average pooling, 
    another global average pooling, and then calculates the mean.
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, output_padding, multiplier):
        super(Model, self).__init__()
        self.conv_transpose = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding)
        self.conv_transpose.bias = nn.Parameter(self.conv_transpose.bias + torch.randn(self.conv_transpose.bias.shape, device=self.conv_transpose.bias.device, dtype=self.conv_transpose.bias.dtype) * 0.02)
        self.multiplier = multiplier

    def forward(self, x):
        x = self.conv_transpose(x)
        x = x * self.multiplier
        x = torch.mean(x, dim=[2, 3], keepdim=True)  # First global average pooling
        x = torch.mean(x, dim=[2, 3], keepdim=True)  # Second global average pooling
        x = torch.mean(x)
        return x

batch_size = 128
in_channels = 3
out_channels = 16
height, width = 32, 32
kernel_size = 3
stride = 2
padding = 1
output_padding = 1
multiplier = 0.5

def get_inputs():
    return [torch.randn(batch_size, in_channels, height, width)]

def get_init_inputs():
    return [in_channels, out_channels, kernel_size, stride, padding, output_padding, multiplier]

Kernel Information

Related Kernels (Level 2, Task 44 • 44_ConvTranspose2d_Multiply_GlobalAvgPool_GlobalAvgPool_Mean)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 optimized_spatial_reduction_base 0.18 1.21 0.72
🥇 optimized_spatial_reduction_edit_1 0.18 1.21 0.72
🥉 minimal_sync_reduction_edit_1 0.19 1.19 0.71
4 shared_memory_tiled_reduction_base 0.19 1.17 0.70
5 fused_global_avg_base 0.20 1.13 0.67
6 block_size_experimentation_base 0.20 1.11 0.66
7 optimized_strided_avg_pooling_edit_1 0.20 1.10 0.66
7 aligned_ldg_optimized_kernel_base 0.20 1.10 0.66
9 combined_optimized_mean_kernel_base 0.20 1.10 0.65
9 vectorized_ldg_mean_kernel_base 0.20 1.10 0.65
9 optimized_mean_kernel_base 0.20 1.10 0.65
9 warp_uniform_mean_kernel_base_base 0.20 1.10 0.65
9 unrolled_vectorized_mean_kernel_base 0.20 1.10 0.65
9 atomic_final_reduction_base 0.20 1.10 0.65
15 optimized_sync_reduction_base 0.20 1.09 0.65
15 shared_mem_reduction_optimized_base 0.20 1.09 0.65
15 modular_shared_warp_mean_base_base 0.20 1.09 0.65
15 coalesced_vectorized_mean_kernel_base 0.20 1.09 0.65
15 reduced_sync_shared_memory_base 0.20 1.09 0.65
15 fused_atomic_reduction_base 0.20 1.09 0.65
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>

// Warp-level reduction using shuffle instructions
__device__ __forceinline__ float warpReduceSum(float val) {
    #pragma unroll
    for (int offset = 16; offset > 0; offset /= 2)
        val += __shfl_down_sync(0xffffffff, val, offset);
    return val;
}

// Block-level reduction using shared memory
template<unsigned int blockSize>
__device__ __forceinline__ float blockReduceSum(float val, float* shared) {
    int lane = threadIdx.x % warpSize;
    int wid = threadIdx.x / warpSize;

    // Warp reduction first
    val = warpReduceSum(val);

    // Write reduced warp values to shared memory
    if (lane == 0) shared[wid] = val;
    __syncthreads();

    // Final warp reduces all other warp results
    if (wid == 0) {
        val = (threadIdx.x < blockSize / warpSize) ? shared[lane] : 0.0f;
        val = warpReduceSum(val);
    }
    return val;
}

// Global memory load with grid stride loop
template<typename T>
__device__ __forceinline__ float computeLocalSum(
    const T* input,
    const int offset,
    const int num_elements,
    const int thread_stride
) {
    float sum = 0.0f;
    for (int idx = threadIdx.x; idx < num_elements; idx += thread_stride) {
        sum += static_cast<float>(input[offset + idx]);
    }
    return sum;
}

template <unsigned int blockSize>
__global__ void modular_mean_kernel(
    const float* __restrict__ input,
    float* __restrict__ output,
    const int H,
    const int W,
    const int C
) {
    extern __shared__ float shared[];
    
    const int num_elements = H * W;
    const int batch_idx = blockIdx.x / C;
    const int channel_idx = blockIdx.x % C;
    const int input_offset = (batch_idx * C + channel_idx) * num_elements;
    
    // Load and sum values from global memory
    float sum = computeLocalSum<float>(
        input,
        input_offset,
        num_elements,
        blockSize
    );
    
    // Perform block reduction
    sum = blockReduceSum<blockSize>(sum, shared);
    
    // Write result
    if (threadIdx.x == 0) {
        output[blockIdx.x] = sum / static_cast<float>(num_elements);
    }
}

at::Tensor module_fn(
    at::Tensor x,
    int64_t stride,
    int64_t padding,
    int64_t output_padding,
    at::Tensor conv_transpose,
    at::Tensor conv_transpose_bias,
    double multiplier
) {
    at::Tensor y = at::conv_transpose2d(
        x,
        conv_transpose,
        conv_transpose_bias,
        {stride, stride},
        {padding, padding},
        {output_padding, output_padding},
        1,
        {1, 1}
    );
    
    y = y * multiplier;
    
    auto dims = y.sizes();
    int N = dims[0];
    int C = dims[1];
    int H = dims[2];
    int W = dims[3];
    
    auto options = torch::TensorOptions().device(y.device()).dtype(y.dtype());
    auto temp = torch::empty({N * C}, options);
    
    constexpr int blockSize = 256;
    const int numBlocks = N * C;
    const size_t sharedMemSize = (blockSize / 32) * sizeof(float);
    
    modular_mean_kernel<blockSize><<<numBlocks, blockSize, sharedMemSize>>>(
        y.data_ptr<float>(),
        temp.data_ptr<float>(),
        H, W, C
    );
    
    return temp.mean();
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &module_fn, "Module function");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 0.840 inst/cycle 0.000 5
Executed Ipc Elapsed 0.668 inst/cycle 0.000 5
Issue Slots Busy 21.098 % 0.111 5
Issued Ipc Active 0.844 inst/cycle 0.000 5
SM Busy 21.098 % 0.111 5
Memory Throughput 2364552373312.274 byte/second 4066705649370971242496.000 5
Mem Busy 39.830 % 1.129 5
Max Bandwidth 70.654 % 3.400 5
L1/TEX Hit Rate 0.024 % 0.000 5
L2 Hit Rate 2.956 % 0.000 5
Mem Pipes Busy 12.734 % 0.112 5
Warp Cycles Per Issued Instruction 61.234 cycle 0.024 5
Warp Cycles Per Executed Instruction 61.582 cycle 0.025 5
Avg. Active Threads Per Warp 31.780 0.000 5
Avg. Not Predicated Off Threads Per Warp 27.840 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 8.000 block 0.000 5
Block Limit Shared Mem 28.000 block 0.000 5
Block Limit Warps 8.000 block 0.000 5
Theoretical Active Warps per SM 64.000 warp 0.000 5
Theoretical Occupancy 100.000 % 0.000 5
Achieved Occupancy 81.802 % 0.099 5
Achieved Active Warps Per SM 52.352 warp 0.040 5
Analysis Rules
Rule Description
WRN HighPipeUtilization All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details.
INF CPIStall Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason.
WRN Occupancy This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (81.4%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy.
Operation / Metric Value Unit
aten::conv_transpose2d
CPU Time 2807328.93 μs
Device Time 1936752.52 μs
Self CPU Time 24664.31 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::convolution
CPU Time 2782664.63 μs
Device Time 1936752.52 μs
Self CPU Time 32052.84 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_convolution
CPU Time 2750611.79 μs
Device Time 1936752.52 μs
Self CPU Time 65649.03 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::cudnn_convolution_transpose
CPU Time 1909827.91 μs
Device Time 1563725.31 μs
Self CPU Time 454415.94 μs
Self Device Time 1563725.31 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 1265364.59 μs
Device Time 24157.48 μs
Self CPU Time 1265364.59 μs
Self Device Time 24157.48 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 106851.21 μs
Device Time 944930.88 μs
Self CPU Time 26441.15 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45294 warnings generated when compiling for host.
Suppressed 45327 warnings (45280 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b3_s1_modular_shared_warp_mean_base/base/base.cu:17:16 bugprone-narrowing-conversions
17 | int lane = threadIdx.x % warpSize;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b3_s1_modular_shared_warp_mean_base/base/base.cu:18:15: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
18 | int wid = threadIdx.x / warpSize;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b3_s1_modular_shared_warp_mean_base/base/base.cu:39:5: warning: 3 adjacent parameters of 'computeLocalSum' of similar type ('const int') are easily swapped by mistake [bugprone-easily-swappable-parameters]
39 | const int offset,
| ^~~~~~~~~~~~~~~~~
40 | const int num_elements,
| ~~~~~~~~~~~~~~~~~~~~~~~
41 | const int thread_stride
| ~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b3_s1_modular_shared_warp_mean_base/base/base.cu:39:15: note: the first parameter in the range is 'offset'
39 | const int offset,
| ^~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b3_s1_modular_shared_warp_mean_base/base/base.cu:41:15: note: the last parameter in the range is 'thread_stride'
41 | const int thread_stride
| ^~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b3_s1_modular_shared_warp_mean_base/base/base.cu:44:20: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
44 | for (int idx = threadIdx.x; idx < num_elements; idx += thread_stride) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b3_s1_modular_shared_warp_mean_base/base/base.cu:55:5: warning: 2 adjacent parameters of 'modular_mean_kernel' of similar type ('const int') are easily swapped by mistake [bugprone-easily-swappable-parameters]
55 | const int W,
| ^~~~~~~~~~~~
56 | const int C
| ~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b3_s1_modular_shared_warp_mean_base/base/base.cu:55:15: note: the first parameter in the range is 'W'
55 | const int W,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b3_s1_modular_shared_warp_mean_base/base/base.cu:56:15: note: the last parameter in the range is 'C'
56 | const int C
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b3_s1_modular_shared_warp_mean_base/base/base.cu:61:27: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
61 | const int batch_idx = blockIdx.x / C;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b3_s1_modular_shared_warp_mean_base/base/base.cu:62:29: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
62 | const int channel_idx = blockIdx.x % C;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b3_s1_modular_shared_warp_mean_base/base/base.cu:83:16: warning: the parameter 'x' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
83 | at::Tensor x,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b3_s1_modular_shared_warp_mean_base/base/base.cu:87:16: warning: the parameter 'conv_transpose' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
87 | at::Tensor conv_transpose,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b3_s1_modular_shared_warp_mean_base/base/base.cu:105:13: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
105 | int N = dims[0];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b3_s1_modular_shared_warp_mean_base/base/base.cu:106:13: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
106 | int C = dims[1];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b3_s1_modular_shared_warp_mean_base/base/base.cu:107:13: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
107 | int H = dims[2];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b3_s1_modular_shared_warp_mean_base/base/base.cu:108:13: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
108 | int W = dims[3];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b3_s1_modular_shared_warp_mean_base/base/base.cu:111:31: warning: performing an implicit widening conversion to type 'const long' of a multiplication performed in type 'int' [bugprone-implicit-widening-of-multiplication-result]
111 | auto temp = torch::empty({N * C}, options);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b3_s1_modular_shared_warp_mean_base/base/base.cu:111:31: note: make conversion explicit to silence this warning
5 | auto temp = torch::empty({N * C}, options);
| ^~~~~
| static_cast<const long>( )
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_44/b3_s1_modular_shared_warp_mean_base/base/base.cu:111:31: note: perform multiplication in a wider type
111 | auto temp = torch::empty({N * C}, options);
| ^
| static_cast<const long>( )