← Back to Leaderboard

The AI CUDA Engineer 👷

44_ConvTranspose2d_Multiply_GlobalAvgPool_GlobalAvgPool_Meanshared_memory_tiled_reduction_base

Level 2 • Task 44
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(
    x: torch.Tensor,
    stride: int,
    padding: int,
    output_padding: int,
    conv_transpose: torch.Tensor,
    conv_transpose_bias: torch.Tensor,
    multiplier: float,
) -> torch.Tensor:
    """
    Applies transposed convolution, scalar multiplication, and multiple global average pooling operations.

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width)
        stride (int): Stride of the transposed convolution
        padding (int): Padding of the transposed convolution
        output_padding (int): Additional size added to output shape
        conv_transpose (torch.Tensor): Transposed convolution weight tensor
        conv_transpose_bias (torch.Tensor): Bias tensor for transposed convolution
        multiplier (float): Scalar multiplier value

    Returns:
        torch.Tensor: Scalar output after applying operations
    """
    x = F.conv_transpose2d(
        x,
        conv_transpose,
        bias=conv_transpose_bias,
        stride=stride,
        padding=padding,
        output_padding=output_padding,
    )
    x = x * multiplier
    x = torch.mean(x, dim=[2, 3], keepdim=True)
    x = torch.mean(x, dim=[2, 3], keepdim=True)
    x = torch.mean(x)
    return x


class Model(nn.Module):
    """
    Model that performs a transposed convolution, multiplies by a scalar, applies global average pooling,
    another global average pooling, and then calculates the mean.
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        output_padding,
        multiplier,
    ):
        super(Model, self).__init__()
        conv = nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            output_padding=output_padding,
        )
        self.conv_transpose_parameter = nn.Parameter(conv.weight)
        self.conv_transpose_bias = nn.Parameter(
            conv.bias
            + torch.randn(
                conv.bias.shape, device=conv.bias.device, dtype=conv.bias.dtype
            )
            * 0.02
        )
        self.multiplier = multiplier

    def forward(self, x, stride, padding, output_padding, fn=module_fn):
        return fn(
            x,
            stride,
            padding,
            output_padding,
            self.conv_transpose_parameter,
            self.conv_transpose_bias,
            self.multiplier,
        )


batch_size = 128
in_channels = 3
out_channels = 16
height, width = 32, 32
kernel_size = 3
stride = 2
padding = 1
output_padding = 1
multiplier = 0.5


def get_inputs():
    return [
        torch.randn(batch_size, in_channels, height, width),
        stride,
        padding,
        output_padding,
    ]


def get_init_inputs():
    return [
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        output_padding,
        multiplier,
    ]
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Model that performs a transposed convolution, multiplies by a scalar, applies global average pooling, 
    another global average pooling, and then calculates the mean.
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, output_padding, multiplier):
        super(Model, self).__init__()
        self.conv_transpose = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding)
        self.conv_transpose.bias = nn.Parameter(self.conv_transpose.bias + torch.randn(self.conv_transpose.bias.shape, device=self.conv_transpose.bias.device, dtype=self.conv_transpose.bias.dtype) * 0.02)
        self.multiplier = multiplier

    def forward(self, x):
        x = self.conv_transpose(x)
        x = x * self.multiplier
        x = torch.mean(x, dim=[2, 3], keepdim=True)  # First global average pooling
        x = torch.mean(x, dim=[2, 3], keepdim=True)  # Second global average pooling
        x = torch.mean(x)
        return x

batch_size = 128
in_channels = 3
out_channels = 16
height, width = 32, 32
kernel_size = 3
stride = 2
padding = 1
output_padding = 1
multiplier = 0.5

def get_inputs():
    return [torch.randn(batch_size, in_channels, height, width)]

def get_init_inputs():
    return [in_channels, out_channels, kernel_size, stride, padding, output_padding, multiplier]

Kernel Information

Related Kernels (Level 2, Task 44 • 44_ConvTranspose2d_Multiply_GlobalAvgPool_GlobalAvgPool_Mean)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 optimized_spatial_reduction_base 0.18 1.21 0.72
🥇 optimized_spatial_reduction_edit_1 0.18 1.21 0.72
🥉 minimal_sync_reduction_edit_1 0.19 1.19 0.71
4 shared_memory_tiled_reduction_base 0.19 1.17 0.70
5 fused_global_avg_base 0.20 1.13 0.67
6 block_size_experimentation_base 0.20 1.11 0.66
7 optimized_strided_avg_pooling_edit_1 0.20 1.10 0.66
7 aligned_ldg_optimized_kernel_base 0.20 1.10 0.66
9 combined_optimized_mean_kernel_base 0.20 1.10 0.65
9 vectorized_ldg_mean_kernel_base 0.20 1.10 0.65
9 optimized_mean_kernel_base 0.20 1.10 0.65
9 warp_uniform_mean_kernel_base_base 0.20 1.10 0.65
9 unrolled_vectorized_mean_kernel_base 0.20 1.10 0.65
9 atomic_final_reduction_base 0.20 1.10 0.65
15 optimized_sync_reduction_base 0.20 1.09 0.65
15 shared_mem_reduction_optimized_base 0.20 1.09 0.65
15 modular_shared_warp_mean_base_base 0.20 1.09 0.65
15 coalesced_vectorized_mean_kernel_base 0.20 1.09 0.65
15 reduced_sync_shared_memory_base 0.20 1.09 0.65
15 fused_atomic_reduction_base 0.20 1.09 0.65
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <cuda_runtime.h>

// Define tile dimensions for better memory access patterns
#define TILE_DIM 32
#define BLOCK_ROWS 8

template<int BLOCK_SIZE>
__global__ void tiled_reduction_kernel(
    const float* __restrict__ input,
    float* __restrict__ output,
    const int N,
    const int C,
    const int H,
    const int W,
    const float multiplier
) {
    __shared__ float tile[TILE_DIM][TILE_DIM + 1];  // +1 to avoid bank conflicts
    __shared__ float row_sums[TILE_DIM];
    
    const int tid = threadIdx.x;
    const int bid = blockIdx.x;
    const int batch_idx = bid / C;
    const int channel_idx = bid % C;
    const int spatial_size = H * W;
    
    // Calculate base input pointer for this block
    const float* block_input = input + (batch_idx * C * spatial_size) + (channel_idx * spatial_size);
    
    float thread_sum = 0.0f;
    
    // Process input in tiles
    #pragma unroll 2
    for (int by = 0; by < H; by += TILE_DIM) {
        #pragma unroll 2
        for (int bx = 0; bx < W; bx += TILE_DIM) {
            // Load tile into shared memory
            #pragma unroll
            for (int y = 0; y < TILE_DIM; y += BLOCK_ROWS) {
                int gy = by + y + (tid / TILE_DIM);
                int gx = bx + (tid % TILE_DIM);
                
                if (gy < H && gx < W) {
                    tile[y + (tid / TILE_DIM)][tid % TILE_DIM] = 
                        block_input[gy * W + gx] * multiplier;
                } else {
                    tile[y + (tid / TILE_DIM)][tid % TILE_DIM] = 0.0f;
                }
            }
            __syncthreads();
            
            // Reduce tile
            if (tid < TILE_DIM) {
                float sum = 0.0f;
                #pragma unroll
                for (int i = 0; i < TILE_DIM; i++) {
                    sum += tile[tid][i];
                }
                row_sums[tid] = sum;
            }
            __syncthreads();
            
            // Final reduction within the tile
            if (tid < (TILE_DIM/2)) {
                row_sums[tid] += row_sums[tid + (TILE_DIM/2)];
            }
            __syncthreads();
            
            if (tid < (TILE_DIM/4)) {
                row_sums[tid] += row_sums[tid + (TILE_DIM/4)];
            }
            __syncthreads();
            
            if (tid < (TILE_DIM/8)) {
                row_sums[tid] += row_sums[tid + (TILE_DIM/8)];
            }
            __syncthreads();
            
            if (tid < (TILE_DIM/16)) {
                row_sums[tid] += row_sums[tid + (TILE_DIM/16)];
            }
            __syncthreads();
            
            if (tid == 0) {
                thread_sum += (row_sums[0] + row_sums[1]);
            }
            __syncthreads();
        }
    }
    
    // Write final result
    if (tid == 0) {
        output[bid] = thread_sum / spatial_size;
    }
}

at::Tensor module_fn(
    at::Tensor x,
    int64_t stride,
    int64_t padding,
    int64_t output_padding,
    at::Tensor conv_transpose,
    at::Tensor conv_transpose_bias,
    double multiplier
) {
    // Apply transposed convolution
    at::Tensor y = at::conv_transpose2d(
        x,
        conv_transpose,
        conv_transpose_bias,
        {stride, stride},
        {padding, padding},
        {output_padding, output_padding},
        1,
        {1, 1}
    );
    
    // Prepare output tensor
    auto options = torch::TensorOptions().device(y.device()).dtype(y.dtype());
    auto dims = y.sizes();
    at::Tensor output = torch::zeros({dims[0], dims[1]}, options);
    
    // Launch kernel with optimized configuration
    constexpr int BLOCK_SIZE = TILE_DIM * BLOCK_ROWS;  // Optimized for tile-based processing
    const int blocks = dims[0] * dims[1];
    const int shared_mem_size = (TILE_DIM * (TILE_DIM + 1) + TILE_DIM) * sizeof(float);
    
    tiled_reduction_kernel<BLOCK_SIZE><<<blocks, BLOCK_SIZE, shared_mem_size>>>(
        y.data_ptr<float>(),
        output.data_ptr<float>(),
        dims[0], dims[1], dims[2], dims[3],
        static_cast<float>(multiplier)
    );
    
    // Compute final mean
    return output.mean();
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &module_fn, "Module function");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 1.746 inst/cycle 0.000 5
Executed Ipc Elapsed 1.510 inst/cycle 0.000 5
Issue Slots Busy 43.734 % 0.040 5
Issued Ipc Active 1.748 inst/cycle 0.000 5
SM Busy 43.734 % 0.040 5
Memory Throughput 1389356844142.478 byte/second 44822143570008506368.000 5
Mem Busy 33.792 % 0.029 5
Max Bandwidth 41.502 % 0.040 5
L1/TEX Hit Rate 0.010 % 0.000 5
L2 Hit Rate 3.232 % 0.000 5
Mem Pipes Busy 31.682 % 0.026 5
Warp Cycles Per Issued Instruction 33.412 cycle 0.029 5
Warp Cycles Per Executed Instruction 33.482 cycle 0.029 5
Avg. Active Threads Per Warp 31.940 0.000 5
Avg. Not Predicated Off Threads Per Warp 26.820 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 8.000 block 0.000 5
Block Limit Shared Mem 13.000 block 0.000 5
Block Limit Warps 8.000 block 0.000 5
Theoretical Active Warps per SM 64.000 warp 0.000 5
Theoretical Occupancy 100.000 % 0.000 5
Achieved Occupancy 91.426 % 0.155 5
Achieved Active Warps Per SM 58.516 warp 0.063 5
Analysis Rules
Rule Description
INF HighPipeUtilization ALU is the highest-utilized pipeline (31.7%) based on active cycles, taking into account the rates of its different instructions. It executes integer and logic operations. It is well-utilized, but should not be a bottleneck.
INF CPIStall Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason.
INF Occupancy This kernel's theoretical occupancy is not impacted by any block limit.
Operation / Metric Value Unit
aten::conv_transpose2d
CPU Time 6709422.79 μs
Device Time 5256732.58 μs
Self CPU Time 55812.37 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::convolution
CPU Time 6653610.42 μs
Device Time 5256732.58 μs
Self CPU Time 72315.71 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_convolution
CPU Time 6581294.70 μs
Device Time 5256732.58 μs
Self CPU Time 141325.87 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::cudnn_convolution_transpose
CPU Time 4101348.40 μs
Device Time 4277449.34 μs
Self CPU Time 681505.57 μs
Self Device Time 4277317.21 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 4149781.17 μs
Device Time 357824.81 μs
Self CPU Time 4149781.17 μs
Self Device Time 357824.81 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 487342.22 μs
Device Time 2580221.31 μs
Self CPU Time 105599.37 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45293 warnings generated when compiling for host.
Suppressed 45327 warnings (45280 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b5_s1_shared_memory_tiled_reduction/base/base.cu:13:5 bugprone-easily-swappable-parameters
13 | const int N,
| ^~~~~~~~~~~~
14 | const int C,
| ~~~~~~~~~~~~
15 | const int H,
| ~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b5_s1_shared_memory_tiled_reduction/base/base.cu:13:15: note: the first parameter in the range is 'N'
13 | const int N,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b5_s1_shared_memory_tiled_reduction/base/base.cu:15:15: note: the last parameter in the range is 'H'
15 | const int H,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b5_s1_shared_memory_tiled_reduction/base/base.cu:22:21: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
22 | const int tid = threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b5_s1_shared_memory_tiled_reduction/base/base.cu:23:21: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
23 | const int bid = blockIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b5_s1_shared_memory_tiled_reduction/base/base.cu:29:32: warning: result of multiplication in type 'int' is used as a pointer offset after an implicit widening conversion to type 'ptrdiff_t' [bugprone-implicit-widening-of-multiplication-result]
29 | const float* block_input = input + (batch_idx * C * spatial_size) + (channel_idx * spatial_size);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b5_s1_shared_memory_tiled_reduction/base/base.cu:29:74: note: make conversion explicit to silence this warning
4 | const float* block_input = input + (batch_idx * C * spatial_size) + (channel_idx * spatial_size);
| ^~~~~~~~~~~~~~~~~~~~~~~~~~
| static_cast<ptrdiff_t>( )
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b5_s1_shared_memory_tiled_reduction/base/base.cu:29:74: note: perform multiplication in a wider type
29 | const float* block_input = input + (batch_idx * C * spatial_size) + (channel_idx * spatial_size);
| ^~~~~~~~~~~
| static_cast<ptrdiff_t>( )
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b5_s1_shared_memory_tiled_reduction/base/base.cu:94:36: warning: narrowing conversion from 'int' to 'float' [bugprone-narrowing-conversions]
94 | output[bid] = thread_sum / spatial_size;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b5_s1_shared_memory_tiled_reduction/base/base.cu:99:16: warning: the parameter 'x' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
99 | at::Tensor x,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b5_s1_shared_memory_tiled_reduction/base/base.cu:103:16: warning: the parameter 'conv_transpose' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
103 | at::Tensor conv_transpose,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b5_s1_shared_memory_tiled_reduction/base/base.cu:126:24: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
126 | const int blocks = dims[0] * dims[1];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b5_s1_shared_memory_tiled_reduction/base/base.cu:132:9: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
132 | dims[0], dims[1], dims[2], dims[3],
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b5_s1_shared_memory_tiled_reduction/base/base.cu:132:18: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
132 | dims[0], dims[1], dims[2], dims[3],
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b5_s1_shared_memory_tiled_reduction/base/base.cu:132:27: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
132 | dims[0], dims[1], dims[2], dims[3],
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_44/b5_s1_shared_memory_tiled_reduction/base/base.cu:132:36: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
132 | dims[0], dims[1], dims[2], dims[3],
| ^