← Back to Leaderboard

The AI CUDA Engineer 👷

44_ConvTranspose2d_Multiply_GlobalAvgPool_GlobalAvgPool_Meanfused_atomic_reduction_base

Level 2 • Task 44
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(
    x: torch.Tensor,
    stride: int,
    padding: int,
    output_padding: int,
    conv_transpose: torch.Tensor,
    conv_transpose_bias: torch.Tensor,
    multiplier: float,
) -> torch.Tensor:
    """
    Applies transposed convolution, scalar multiplication, and multiple global average pooling operations.

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width)
        stride (int): Stride of the transposed convolution
        padding (int): Padding of the transposed convolution
        output_padding (int): Additional size added to output shape
        conv_transpose (torch.Tensor): Transposed convolution weight tensor
        conv_transpose_bias (torch.Tensor): Bias tensor for transposed convolution
        multiplier (float): Scalar multiplier value

    Returns:
        torch.Tensor: Scalar output after applying operations
    """
    x = F.conv_transpose2d(
        x,
        conv_transpose,
        bias=conv_transpose_bias,
        stride=stride,
        padding=padding,
        output_padding=output_padding,
    )
    x = x * multiplier
    x = torch.mean(x, dim=[2, 3], keepdim=True)
    x = torch.mean(x, dim=[2, 3], keepdim=True)
    x = torch.mean(x)
    return x


class Model(nn.Module):
    """
    Model that performs a transposed convolution, multiplies by a scalar, applies global average pooling,
    another global average pooling, and then calculates the mean.
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        output_padding,
        multiplier,
    ):
        super(Model, self).__init__()
        conv = nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            output_padding=output_padding,
        )
        self.conv_transpose_parameter = nn.Parameter(conv.weight)
        self.conv_transpose_bias = nn.Parameter(
            conv.bias
            + torch.randn(
                conv.bias.shape, device=conv.bias.device, dtype=conv.bias.dtype
            )
            * 0.02
        )
        self.multiplier = multiplier

    def forward(self, x, stride, padding, output_padding, fn=module_fn):
        return fn(
            x,
            stride,
            padding,
            output_padding,
            self.conv_transpose_parameter,
            self.conv_transpose_bias,
            self.multiplier,
        )


batch_size = 128
in_channels = 3
out_channels = 16
height, width = 32, 32
kernel_size = 3
stride = 2
padding = 1
output_padding = 1
multiplier = 0.5


def get_inputs():
    return [
        torch.randn(batch_size, in_channels, height, width),
        stride,
        padding,
        output_padding,
    ]


def get_init_inputs():
    return [
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        output_padding,
        multiplier,
    ]
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Model that performs a transposed convolution, multiplies by a scalar, applies global average pooling, 
    another global average pooling, and then calculates the mean.
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, output_padding, multiplier):
        super(Model, self).__init__()
        self.conv_transpose = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding)
        self.conv_transpose.bias = nn.Parameter(self.conv_transpose.bias + torch.randn(self.conv_transpose.bias.shape, device=self.conv_transpose.bias.device, dtype=self.conv_transpose.bias.dtype) * 0.02)
        self.multiplier = multiplier

    def forward(self, x):
        x = self.conv_transpose(x)
        x = x * self.multiplier
        x = torch.mean(x, dim=[2, 3], keepdim=True)  # First global average pooling
        x = torch.mean(x, dim=[2, 3], keepdim=True)  # Second global average pooling
        x = torch.mean(x)
        return x

batch_size = 128
in_channels = 3
out_channels = 16
height, width = 32, 32
kernel_size = 3
stride = 2
padding = 1
output_padding = 1
multiplier = 0.5

def get_inputs():
    return [torch.randn(batch_size, in_channels, height, width)]

def get_init_inputs():
    return [in_channels, out_channels, kernel_size, stride, padding, output_padding, multiplier]

Kernel Information

Related Kernels (Level 2, Task 44 • 44_ConvTranspose2d_Multiply_GlobalAvgPool_GlobalAvgPool_Mean)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 optimized_spatial_reduction_base 0.18 1.21 0.72
🥇 optimized_spatial_reduction_edit_1 0.18 1.21 0.72
🥉 minimal_sync_reduction_edit_1 0.19 1.19 0.71
4 shared_memory_tiled_reduction_base 0.19 1.17 0.70
5 fused_global_avg_base 0.20 1.13 0.67
6 block_size_experimentation_base 0.20 1.11 0.66
7 optimized_strided_avg_pooling_edit_1 0.20 1.10 0.66
7 aligned_ldg_optimized_kernel_base 0.20 1.10 0.66
9 combined_optimized_mean_kernel_base 0.20 1.10 0.65
9 vectorized_ldg_mean_kernel_base 0.20 1.10 0.65
9 optimized_mean_kernel_base 0.20 1.10 0.65
9 warp_uniform_mean_kernel_base_base 0.20 1.10 0.65
9 unrolled_vectorized_mean_kernel_base 0.20 1.10 0.65
9 atomic_final_reduction_base 0.20 1.10 0.65
15 optimized_sync_reduction_base 0.20 1.09 0.65
15 shared_mem_reduction_optimized_base 0.20 1.09 0.65
15 modular_shared_warp_mean_base_base 0.20 1.09 0.65
15 coalesced_vectorized_mean_kernel_base 0.20 1.09 0.65
15 reduced_sync_shared_memory_base 0.20 1.09 0.65
15 fused_atomic_reduction_base 0.20 1.09 0.65
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <cuda_runtime.h>

// Fused kernel: performs elementwise multiplication then reduction over spatial dims
// Each block processes one (batch, channel) slice of the tensor

template <int BLOCK_SIZE>
__global__ void fused_reduction_kernel(
    const float* __restrict__ input,
    float multiplier,
    int N,
    int C,
    int H,
    int W,
    float* __restrict__ global_sum
) {
    extern __shared__ float sdata[];
    const int tid = threadIdx.x;

    // Each block handles one (batch, channel) pair
    const int bid = blockIdx.x; // range: 0 to (N*C - 1)
    const int batch_idx = bid / C;
    const int channel_idx = bid % C;
    const int spatial_size = H * W;

    // Get pointer to the beginning of the (batch, channel) slice
    const float* input_ptr = input + (batch_idx * C * spatial_size) + (channel_idx * spatial_size);

    // Each thread processes multiple elements along the spatial dimension
    float sum = 0.0f;
    for (int i = tid; i < spatial_size; i += BLOCK_SIZE) {
        // Fuse multiplication by 'multiplier' with accumulation
        sum += input_ptr[i] * multiplier;
    }

    // Store the local sum in shared memory
    sdata[tid] = sum;
    __syncthreads();

    // Perform reduction in shared memory
    if (BLOCK_SIZE >= 512) { 
        if (tid < 256) { sdata[tid] += sdata[tid + 256]; } 
        __syncthreads(); 
    }
    if (BLOCK_SIZE >= 256) { 
        if (tid < 128) { sdata[tid] += sdata[tid + 128]; } 
        __syncthreads(); 
    }
    if (BLOCK_SIZE >= 128) { 
        if (tid < 64) { sdata[tid] += sdata[tid + 64]; } 
        __syncthreads(); 
    }

    // Warp-level reduction (no __syncthreads needed within a warp)
    if (tid < 32) {
        volatile float* vsmem = sdata;
        if (BLOCK_SIZE >= 64) { vsmem[tid] += vsmem[tid + 32]; }
        if (BLOCK_SIZE >= 32) { vsmem[tid] += vsmem[tid + 16]; }
        if (BLOCK_SIZE >= 16) { vsmem[tid] += vsmem[tid + 8]; }
        if (BLOCK_SIZE >= 8)  { vsmem[tid] += vsmem[tid + 4]; }
        if (BLOCK_SIZE >= 4)  { vsmem[tid] += vsmem[tid + 2]; }
        if (BLOCK_SIZE >= 2)  { vsmem[tid] += vsmem[tid + 1]; }
    }

    // Thread 0 atomically adds the block's result to the global accumulator
    if (tid == 0) {
        atomicAdd(global_sum, sdata[0]);
    }
}

// Module function: performs transposed convolution, then fuses elementwise multiplication and reduction
// to compute the overall mean of the output tensor
at::Tensor module_fn(
    at::Tensor x,
    int64_t stride,
    int64_t padding,
    int64_t output_padding,
    at::Tensor conv_transpose,
    at::Tensor conv_transpose_bias,
    double multiplier
) {
    // Apply transposed convolution (optimized via cuDNN)
    at::Tensor y = at::conv_transpose2d(
        x,
        conv_transpose,
        conv_transpose_bias,
        {stride, stride},
        {padding, padding},
        {output_padding, output_padding},
        1,
        {1, 1}
    );

    // Retrieve dimensions of y: (N, C, H, W)
    auto dims = y.sizes();
    int N = dims[0];
    int C = dims[1];
    int H = dims[2];
    int W = dims[3];

    // Allocate a device tensor to hold the global sum (initialized to 0)
    auto options = torch::TensorOptions().device(y.device()).dtype(torch::kFloat32);
    at::Tensor global_sum_tensor = torch::zeros({1}, options);

    // Launch the fused reduction kernel
    constexpr int BLOCK_SIZE = 256;
    int grid = N * C;  // one block per (batch, channel) slice
    int shared_mem_size = BLOCK_SIZE * sizeof(float);
    
    fused_reduction_kernel<BLOCK_SIZE><<<grid, BLOCK_SIZE, shared_mem_size>>>(
        y.data_ptr<float>(),
        static_cast<float>(multiplier),
        N, C, H, W,
        global_sum_tensor.data_ptr<float>()
    );

    // Synchronize to ensure kernel completion
    cudaDeviceSynchronize();

    // Compute the final mean: global sum divided by total number of elements
    float total_elements = static_cast<float>(N * C * H * W);
    float final_sum = global_sum_tensor.item<float>();
    float final_mean = final_sum / total_elements;

    // Return the mean as a scalar tensor
    return torch::tensor(final_mean, options);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &module_fn, "Fused elementwise multiplication and reduction");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 0.942 inst/cycle 0.000 5
Executed Ipc Elapsed 0.746 inst/cycle 0.000 5
Issue Slots Busy 23.782 % 0.012 5
Issued Ipc Active 0.952 inst/cycle 0.000 5
SM Busy 23.782 % 0.012 5
Memory Throughput 2350516163070.786 byte/second 3463603365578768121856.000 5
Mem Busy 39.442 % 0.996 5
Max Bandwidth 70.286 % 2.932 5
L1/TEX Hit Rate 0.000 % 0.000 5
L2 Hit Rate 2.912 % 0.000 5
Mem Pipes Busy 14.028 % 0.123 5
Warp Cycles Per Issued Instruction 54.058 cycle 0.053 5
Warp Cycles Per Executed Instruction 54.470 cycle 0.054 5
Avg. Active Threads Per Warp 31.890 0.000 5
Avg. Not Predicated Off Threads Per Warp 28.910 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 8.000 block 0.000 5
Block Limit Shared Mem 16.000 block 0.000 5
Block Limit Warps 8.000 block 0.000 5
Theoretical Active Warps per SM 64.000 warp 0.000 5
Theoretical Occupancy 100.000 % 0.000 5
Achieved Occupancy 80.188 % 0.067 5
Achieved Active Warps Per SM 51.318 warp 0.027 5
Analysis Rules
Rule Description
WRN HighPipeUtilization All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details.
INF CPIStall Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason.
WRN Occupancy This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (80.1%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy.
Operation / Metric Value Unit
aten::conv_transpose2d
CPU Time 2202916.34 μs
Device Time 4973883.76 μs
Self CPU Time 59398.28 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::convolution
CPU Time 2143518.06 μs
Device Time 4973883.76 μs
Self CPU Time 71182.64 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_convolution
CPU Time 2072335.42 μs
Device Time 4973883.76 μs
Self CPU Time 145669.81 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::cudnn_convolution_transpose
CPU Time 1626771.35 μs
Device Time 4051138.48 μs
Self CPU Time 669479.52 μs
Self Device Time 4051138.48 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::fill_
CPU Time 561845.51 μs
Device Time 2477233.74 μs
Self CPU Time 226280.73 μs
Self Device Time 2477233.74 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaDeviceSynchronize
CPU Time 5329459.31 μs
Device Time 109182.92 μs
Self CPU Time 5329459.31 μs
Self Device Time 109182.92 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45291 warnings generated when compiling for host.
Suppressed 45327 warnings (45280 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s0_fused_atomic_reduction/base/base.cu:11:5 bugprone-easily-swappable-parameters
11 | float multiplier,
| ^~~~~~~~~~~~~~~~~
12 | int N,
| ~~~~~~
13 | int C,
| ~~~~~~
14 | int H,
| ~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s0_fused_atomic_reduction/base/base.cu:11:11: note: the first parameter in the range is 'multiplier'
11 | float multiplier,
| ^~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s0_fused_atomic_reduction/base/base.cu:14:9: note: the last parameter in the range is 'H'
14 | int H,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s0_fused_atomic_reduction/base/base.cu:12:5: note: 'float' and 'int' may be implicitly converted
12 | int N,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s0_fused_atomic_reduction/base/base.cu:19:21: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
19 | const int tid = threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s0_fused_atomic_reduction/base/base.cu:22:21: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
22 | const int bid = blockIdx.x; // range: 0 to (N*C - 1)
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s0_fused_atomic_reduction/base/base.cu:28:30: warning: result of multiplication in type 'int' is used as a pointer offset after an implicit widening conversion to type 'ptrdiff_t' [bugprone-implicit-widening-of-multiplication-result]
28 | const float* input_ptr = input + (batch_idx * C * spatial_size) + (channel_idx * spatial_size);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s0_fused_atomic_reduction/base/base.cu:28:72: note: make conversion explicit to silence this warning
4 | const float* input_ptr = input + (batch_idx * C * spatial_size) + (channel_idx * spatial_size);
| ^~~~~~~~~~~~~~~~~~~~~~~~~~
| static_cast<ptrdiff_t>( )
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s0_fused_atomic_reduction/base/base.cu:28:72: note: perform multiplication in a wider type
28 | const float* input_ptr = input + (batch_idx * C * spatial_size) + (channel_idx * spatial_size);
| ^~~~~~~~~~~
| static_cast<ptrdiff_t>( )
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s0_fused_atomic_reduction/base/base.cu:75:16: warning: the parameter 'x' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
75 | at::Tensor x,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s0_fused_atomic_reduction/base/base.cu:79:16: warning: the parameter 'conv_transpose' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
79 | at::Tensor conv_transpose,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s0_fused_atomic_reduction/base/base.cu:97:13: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
97 | int N = dims[0];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s0_fused_atomic_reduction/base/base.cu:98:13: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
98 | int C = dims[1];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s0_fused_atomic_reduction/base/base.cu:99:13: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
99 | int H = dims[2];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s0_fused_atomic_reduction/base/base.cu:100:13: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
100 | int W = dims[3];
| ^