← Back to Leaderboard

The AI CUDA Engineer 👷

44_ConvTranspose2d_Multiply_GlobalAvgPool_GlobalAvgPool_Meanoptimized_spatial_reduction_base

Level 2 • Task 44
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(
    x: torch.Tensor,
    stride: int,
    padding: int,
    output_padding: int,
    conv_transpose: torch.Tensor,
    conv_transpose_bias: torch.Tensor,
    multiplier: float,
) -> torch.Tensor:
    """
    Applies transposed convolution, scalar multiplication, and multiple global average pooling operations.

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width)
        stride (int): Stride of the transposed convolution
        padding (int): Padding of the transposed convolution
        output_padding (int): Additional size added to output shape
        conv_transpose (torch.Tensor): Transposed convolution weight tensor
        conv_transpose_bias (torch.Tensor): Bias tensor for transposed convolution
        multiplier (float): Scalar multiplier value

    Returns:
        torch.Tensor: Scalar output after applying operations
    """
    x = F.conv_transpose2d(
        x,
        conv_transpose,
        bias=conv_transpose_bias,
        stride=stride,
        padding=padding,
        output_padding=output_padding,
    )
    x = x * multiplier
    x = torch.mean(x, dim=[2, 3], keepdim=True)
    x = torch.mean(x, dim=[2, 3], keepdim=True)
    x = torch.mean(x)
    return x


class Model(nn.Module):
    """
    Model that performs a transposed convolution, multiplies by a scalar, applies global average pooling,
    another global average pooling, and then calculates the mean.
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        output_padding,
        multiplier,
    ):
        super(Model, self).__init__()
        conv = nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            output_padding=output_padding,
        )
        self.conv_transpose_parameter = nn.Parameter(conv.weight)
        self.conv_transpose_bias = nn.Parameter(
            conv.bias
            + torch.randn(
                conv.bias.shape, device=conv.bias.device, dtype=conv.bias.dtype
            )
            * 0.02
        )
        self.multiplier = multiplier

    def forward(self, x, stride, padding, output_padding, fn=module_fn):
        return fn(
            x,
            stride,
            padding,
            output_padding,
            self.conv_transpose_parameter,
            self.conv_transpose_bias,
            self.multiplier,
        )


batch_size = 128
in_channels = 3
out_channels = 16
height, width = 32, 32
kernel_size = 3
stride = 2
padding = 1
output_padding = 1
multiplier = 0.5


def get_inputs():
    return [
        torch.randn(batch_size, in_channels, height, width),
        stride,
        padding,
        output_padding,
    ]


def get_init_inputs():
    return [
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        output_padding,
        multiplier,
    ]
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Model that performs a transposed convolution, multiplies by a scalar, applies global average pooling, 
    another global average pooling, and then calculates the mean.
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, output_padding, multiplier):
        super(Model, self).__init__()
        self.conv_transpose = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding)
        self.conv_transpose.bias = nn.Parameter(self.conv_transpose.bias + torch.randn(self.conv_transpose.bias.shape, device=self.conv_transpose.bias.device, dtype=self.conv_transpose.bias.dtype) * 0.02)
        self.multiplier = multiplier

    def forward(self, x):
        x = self.conv_transpose(x)
        x = x * self.multiplier
        x = torch.mean(x, dim=[2, 3], keepdim=True)  # First global average pooling
        x = torch.mean(x, dim=[2, 3], keepdim=True)  # Second global average pooling
        x = torch.mean(x)
        return x

batch_size = 128
in_channels = 3
out_channels = 16
height, width = 32, 32
kernel_size = 3
stride = 2
padding = 1
output_padding = 1
multiplier = 0.5

def get_inputs():
    return [torch.randn(batch_size, in_channels, height, width)]

def get_init_inputs():
    return [in_channels, out_channels, kernel_size, stride, padding, output_padding, multiplier]

Kernel Information

Related Kernels (Level 2, Task 44 • 44_ConvTranspose2d_Multiply_GlobalAvgPool_GlobalAvgPool_Mean)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 optimized_spatial_reduction_base 0.18 1.21 0.72
🥇 optimized_spatial_reduction_edit_1 0.18 1.21 0.72
🥉 minimal_sync_reduction_edit_1 0.19 1.19 0.71
4 shared_memory_tiled_reduction_base 0.19 1.17 0.70
5 fused_global_avg_base 0.20 1.13 0.67
6 block_size_experimentation_base 0.20 1.11 0.66
7 optimized_strided_avg_pooling_edit_1 0.20 1.10 0.66
7 aligned_ldg_optimized_kernel_base 0.20 1.10 0.66
9 combined_optimized_mean_kernel_base 0.20 1.10 0.65
9 vectorized_ldg_mean_kernel_base 0.20 1.10 0.65
9 optimized_mean_kernel_base 0.20 1.10 0.65
9 warp_uniform_mean_kernel_base_base 0.20 1.10 0.65
9 unrolled_vectorized_mean_kernel_base 0.20 1.10 0.65
9 atomic_final_reduction_base 0.20 1.10 0.65
15 optimized_sync_reduction_base 0.20 1.09 0.65
15 shared_mem_reduction_optimized_base 0.20 1.09 0.65
15 modular_shared_warp_mean_base_base 0.20 1.09 0.65
15 coalesced_vectorized_mean_kernel_base 0.20 1.09 0.65
15 reduced_sync_shared_memory_base 0.20 1.09 0.65
15 fused_atomic_reduction_base 0.20 1.09 0.65
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <cuda_runtime.h>

template<int BLOCK_SIZE>
__global__ void optimized_reduction_kernel(
    const float* __restrict__ input,
    float* __restrict__ output,
    const int N,
    const int C,
    const int H,
    const int W,
    const float multiplier
) {
    extern __shared__ float sdata[];
    
    const int tid = threadIdx.x;
    const int bid = blockIdx.x;
    const int batch_idx = bid / C;
    const int channel_idx = bid % C;
    const int spatial_size = H * W;
    
    // Calculate input offset for this block
    const float* block_input = input + (batch_idx * C * spatial_size) + (channel_idx * spatial_size);
    
    // Initialize accumulator
    float sum = 0.0f;
    
    // Process multiple elements per thread with stride pattern and apply multiplier
    #pragma unroll 8  // Increased unroll factor
    for (int i = tid; i < spatial_size; i += BLOCK_SIZE) {
        sum += block_input[i] * multiplier;
    }
    
    // Store in shared memory
    sdata[tid] = sum;
    __syncthreads();
    
    // Two-phase reduction for better performance
    if (BLOCK_SIZE >= 1024) {
        if (tid < 512) sdata[tid] += sdata[tid + 512];
        __syncthreads();
    }
    if (BLOCK_SIZE >= 512) {
        if (tid < 256) sdata[tid] += sdata[tid + 256];
        __syncthreads();
    }
    if (BLOCK_SIZE >= 256) {
        if (tid < 128) sdata[tid] += sdata[tid + 128];
        __syncthreads();
    }
    if (BLOCK_SIZE >= 128) {
        if (tid < 64) sdata[tid] += sdata[tid + 64];
        __syncthreads();
    }
    
    // Warp-level reduction (no sync needed within a warp)
    if (tid < 32) {
        volatile float* vmem = sdata;
        if (BLOCK_SIZE >= 64) vmem[tid] += vmem[tid + 32];
        if (BLOCK_SIZE >= 32) vmem[tid] += vmem[tid + 16];
        if (BLOCK_SIZE >= 16) vmem[tid] += vmem[tid + 8];
        if (BLOCK_SIZE >= 8)  vmem[tid] += vmem[tid + 4];
        if (BLOCK_SIZE >= 4)  vmem[tid] += vmem[tid + 2];
        if (BLOCK_SIZE >= 2)  vmem[tid] += vmem[tid + 1];
    }
    
    // First thread writes result
    if (tid == 0) {
        output[bid] = sdata[0] / (spatial_size);  // Normalize during reduction
    }
}

at::Tensor module_fn(
    at::Tensor x,
    int64_t stride,
    int64_t padding,
    int64_t output_padding,
    at::Tensor conv_transpose,
    at::Tensor conv_transpose_bias,
    double multiplier
) {
    // Apply transposed convolution
    at::Tensor y = at::conv_transpose2d(
        x,
        conv_transpose,
        conv_transpose_bias,
        {stride, stride},
        {padding, padding},
        {output_padding, output_padding},
        1,
        {1, 1}
    );
    
    // Prepare output tensor
    auto options = torch::TensorOptions().device(y.device()).dtype(y.dtype());
    auto dims = y.sizes();
    at::Tensor output = torch::zeros({dims[0], dims[1]}, options);
    
    // Launch kernel with optimized configuration
    constexpr int BLOCK_SIZE = 512;  // Optimized block size
    const int blocks = dims[0] * dims[1];
    const int shared_mem_size = BLOCK_SIZE * sizeof(float);
    
    optimized_reduction_kernel<BLOCK_SIZE><<<blocks, BLOCK_SIZE, shared_mem_size>>>(
        y.data_ptr<float>(),
        output.data_ptr<float>(),
        dims[0], dims[1], dims[2], dims[3],
        static_cast<float>(multiplier)
    );
    
    // Compute final mean
    return output.mean();
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &module_fn, "Module function");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 1.540 inst/cycle 0.000 5
Executed Ipc Elapsed 1.208 inst/cycle 0.001 5
Issue Slots Busy 38.652 % 0.031 5
Issued Ipc Active 1.546 inst/cycle 0.000 5
SM Busy 38.652 % 0.031 5
Memory Throughput 2350054044898.164 byte/second 3791160778347151622144.000 5
Mem Busy 39.506 % 1.114 5
Max Bandwidth 70.234 % 3.497 5
L1/TEX Hit Rate 0.010 % 0.000 5
L2 Hit Rate 2.926 % 0.000 5
Mem Pipes Busy 21.460 % 0.347 5
Warp Cycles Per Issued Instruction 34.664 cycle 0.004 5
Warp Cycles Per Executed Instruction 34.820 cycle 0.004 5
Avg. Active Threads Per Warp 31.860 0.000 5
Avg. Not Predicated Off Threads Per Warp 27.510 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 5.000 block 0.000 5
Block Limit Shared Mem 10.000 block 0.000 5
Block Limit Warps 4.000 block 0.000 5
Theoretical Active Warps per SM 64.000 warp 0.000 5
Theoretical Occupancy 100.000 % 0.000 5
Achieved Occupancy 84.340 % 0.019 5
Achieved Active Warps Per SM 53.978 warp 0.007 5
Analysis Rules
Rule Description
INF HighPipeUtilization ALU is the highest-utilized pipeline (22.3%) based on active cycles, taking into account the rates of its different instructions. It executes integer and logic operations. It is well-utilized, but should not be a bottleneck.
INF CPIStall Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason.
WRN Occupancy This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (84.5%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy.
Operation / Metric Value Unit
aten::conv_transpose2d
CPU Time 6596246.32 μs
Device Time 5333382.66 μs
Self CPU Time 54605.28 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::convolution
CPU Time 6541641.04 μs
Device Time 5333382.66 μs
Self CPU Time 72271.19 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_convolution
CPU Time 6469369.85 μs
Device Time 5333382.66 μs
Self CPU Time 151283.79 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::cudnn_convolution_transpose
CPU Time 3965701.33 μs
Device Time 4338178.22 μs
Self CPU Time 739830.54 μs
Self Device Time 4338178.22 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 3929087.15 μs
Device Time 1244.28 μs
Self CPU Time 3929087.15 μs
Self Device Time 1244.28 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 510290.40 μs
Device Time 2619305.19 μs
Self CPU Time 110567.20 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45294 warnings generated when compiling for host.
Suppressed 45327 warnings (45280 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s2_optimized_spatial_reduction/base/base.cu:9:5 bugprone-easily-swappable-parameters
9 | const int N,
| ^~~~~~~~~~~~
10 | const int C,
| ~~~~~~~~~~~~
11 | const int H,
| ~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s2_optimized_spatial_reduction/base/base.cu:9:15: note: the first parameter in the range is 'N'
9 | const int N,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s2_optimized_spatial_reduction/base/base.cu:11:15: note: the last parameter in the range is 'H'
11 | const int H,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s2_optimized_spatial_reduction/base/base.cu:12:5: warning: 2 adjacent parameters of 'optimized_reduction_kernel' of convertible types are easily swapped by mistake [bugprone-easily-swappable-parameters]
12 | const int W,
| ^~~~~~~~~~~~
13 | const float multiplier
| ~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s2_optimized_spatial_reduction/base/base.cu:12:15: note: the first parameter in the range is 'W'
12 | const int W,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s2_optimized_spatial_reduction/base/base.cu:13:17: note: the last parameter in the range is 'multiplier'
13 | const float multiplier
| ^~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s2_optimized_spatial_reduction/base/base.cu:13:5: note: 'const int' and 'const float' may be implicitly converted: 'const int' (as 'int') -> 'const float' (as 'float'), 'const float' (as 'float') -> 'const int' (as 'int')
13 | const float multiplier
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s2_optimized_spatial_reduction/base/base.cu:17:21: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
17 | const int tid = threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s2_optimized_spatial_reduction/base/base.cu:18:21: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
18 | const int bid = blockIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s2_optimized_spatial_reduction/base/base.cu:24:32: warning: result of multiplication in type 'int' is used as a pointer offset after an implicit widening conversion to type 'ptrdiff_t' [bugprone-implicit-widening-of-multiplication-result]
24 | const float* block_input = input + (batch_idx * C * spatial_size) + (channel_idx * spatial_size);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s2_optimized_spatial_reduction/base/base.cu:24:74: note: make conversion explicit to silence this warning
4 | const float* block_input = input + (batch_idx * C * spatial_size) + (channel_idx * spatial_size);
| ^~~~~~~~~~~~~~~~~~~~~~~~~~
| static_cast<ptrdiff_t>( )
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s2_optimized_spatial_reduction/base/base.cu:24:74: note: perform multiplication in a wider type
24 | const float* block_input = input + (batch_idx * C * spatial_size) + (channel_idx * spatial_size);
| ^~~~~~~~~~~
| static_cast<ptrdiff_t>( )
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s2_optimized_spatial_reduction/base/base.cu:70:34: warning: narrowing conversion from 'int' to 'float' [bugprone-narrowing-conversions]
70 | output[bid] = sdata[0] / (spatial_size); // Normalize during reduction
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s2_optimized_spatial_reduction/base/base.cu:75:16: warning: the parameter 'x' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
75 | at::Tensor x,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s2_optimized_spatial_reduction/base/base.cu:79:16: warning: the parameter 'conv_transpose' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
79 | at::Tensor conv_transpose,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s2_optimized_spatial_reduction/base/base.cu:102:24: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
102 | const int blocks = dims[0] * dims[1];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s2_optimized_spatial_reduction/base/base.cu:108:9: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
108 | dims[0], dims[1], dims[2], dims[3],
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s2_optimized_spatial_reduction/base/base.cu:108:18: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
108 | dims[0], dims[1], dims[2], dims[3],
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s2_optimized_spatial_reduction/base/base.cu:108:27: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
108 | dims[0], dims[1], dims[2], dims[3],
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b4_s2_optimized_spatial_reduction/base/base.cu:108:36: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
108 | dims[0], dims[1], dims[2], dims[3],
| ^