← Back to Leaderboard

The AI CUDA Engineer 👷

19_ConvTranspose2d_GELU_GroupNormmodular_convtrans_gelu_gn_base

Level 2 • Task 19
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(
    x: torch.Tensor,
    stride: int,
    conv_transpose: torch.Tensor,
    conv_transpose_bias: torch.Tensor,
    group_norm_weight: torch.Tensor,
    group_norm_bias: torch.Tensor,
    num_groups: int,
) -> torch.Tensor:
    """
    Applies transposed convolution, GELU activation, and group normalization.

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width)
        stride (int): Stride of the transposed convolution
        conv_transpose (torch.Tensor): Transposed convolution weight tensor
        conv_transpose_bias (torch.Tensor): Bias tensor for transposed convolution
        group_norm_weight (torch.Tensor): Weight tensor for group normalization
        group_norm_bias (torch.Tensor): Bias tensor for group normalization
        num_groups (int): Number of groups for group normalization

    Returns:
        torch.Tensor: Output tensor after applying transposed convolution, GELU and group norm
    """
    x = F.conv_transpose2d(x, conv_transpose, bias=conv_transpose_bias, stride=stride)
    x = F.gelu(x)
    x = F.group_norm(
        x, num_groups=num_groups, weight=group_norm_weight, bias=group_norm_bias
    )
    return x


class Model(nn.Module):
    """
    Model that performs a transposed convolution, applies GELU, and normalizes with GroupNorm.
    """

    def __init__(
        self, in_channels, out_channels, kernel_size, stride, groups, num_groups
    ):
        super(Model, self).__init__()
        conv_transpose = nn.ConvTranspose2d(
            in_channels, out_channels, kernel_size, stride=stride
        )
        group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels)
        self.conv_transpose_parameter = conv_transpose.weight
        self.conv_transpose_bias = nn.Parameter(
            conv_transpose.bias + torch.ones_like(conv_transpose.bias) * 0.02
        )  # make sure its nonzero
        self.group_norm_weight = group_norm.weight
        self.group_norm_bias = nn.Parameter(
            group_norm.bias + torch.ones_like(group_norm.bias) * 0.02
        )  # make sure its nonzero

    def forward(self, x, stride, num_groups, fn=module_fn):
        return fn(
            x,
            stride,
            self.conv_transpose_parameter,
            self.conv_transpose_bias,
            self.group_norm_weight,
            self.group_norm_bias,
            num_groups,
        )


batch_size = 128
in_channels = 32
out_channels = 64
height, width = 32, 32
kernel_size = 4
stride = 2
groups = 8
num_groups = 8


def get_inputs():
    return [torch.randn(batch_size, in_channels, height, width), stride, num_groups]


def get_init_inputs():
    return [in_channels, out_channels, kernel_size, stride, groups, num_groups]
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Model that performs a transposed convolution, applies GELU, and normalizes with GroupNorm.
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, groups, num_groups):
        super(Model, self).__init__()
        self.conv_transpose = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride)
        self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels)
        # Add the same noise as in the functional implementation
        self.conv_transpose.bias = nn.Parameter(self.conv_transpose.bias + torch.ones_like(self.conv_transpose.bias) * 0.02)
        self.group_norm.bias = nn.Parameter(self.group_norm.bias + torch.ones_like(self.group_norm.bias) * 0.02)

    def forward(self, x):
        x = self.conv_transpose(x)
        x = torch.nn.functional.gelu(x)
        x = self.group_norm(x)
        return x

batch_size = 128
in_channels = 32
out_channels = 64
height, width = 32, 32
kernel_size = 4
stride = 2
groups = 8
num_groups = 8

def get_inputs():
    return [torch.randn(batch_size, in_channels, height, width)]

def get_init_inputs():
    return [in_channels, out_channels, kernel_size, stride, groups, num_groups]

Kernel Information

Related Kernels (Level 2, Task 19 • 19_ConvTranspose2d_GELU_GroupNorm)

#include <torch/extension.h>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <pybind11/pybind11.h>
#include <cmath>

// Device function for GELU activation
__device__ inline float gelu_activation(float x) {
    float x_cubed = x * x * x;
    return 0.5f * x * (1.0f + tanhf(0.7978845608f * (x + 0.044715f * x_cubed)));
}

// Device function for warp-level reduction of a float value
__device__ inline float warpReduceSum(float val) {
    for (int offset = warpSize / 2; offset > 0; offset /= 2) {
        val += __shfl_down_sync(0xffffffff, val, offset);
    }
    return val;
}

// Fused kernel that applies GELU activation and Group Normalization in two passes
// using modular device functions for better readability and reusability.
__global__ void fused_gelu_group_norm_kernel(
    const float* __restrict__ in,
    float* __restrict__ out,
    int N, int C, int H, int W,
    int num_groups,
    const float* __restrict__ gn_weight,
    const float* __restrict__ gn_bias,
    float eps) {

    // Each block processes one (sample, group) pair
    int group_id = blockIdx.x; // overall group id
    int n = group_id / num_groups;
    int g = group_id % num_groups;
    int channels_per_group = C / num_groups;
    int group_elems = channels_per_group * H * W;
    int base = n * C * H * W + g * channels_per_group * H * W;

    float local_sum = 0.0f;
    float local_sum_sq = 0.0f;
    
    // First pass: Apply GELU activation and accumulate local sum and squared sum
    for (int idx = threadIdx.x; idx < group_elems; idx += blockDim.x) {
        float val = in[base + idx];
        float gelu_val = gelu_activation(val);
        out[base + idx] = gelu_val;  // store activated value
        local_sum += gelu_val;
        local_sum_sq += gelu_val * gelu_val;
    }

    // Warp-level reduction for each thread's local sums
    local_sum = warpReduceSum(local_sum);
    local_sum_sq = warpReduceSum(local_sum_sq);

    __shared__ float shared_sum[32];
    __shared__ float shared_sum_sq[32];

    int lane = threadIdx.x & 31;
    int wid = threadIdx.x >> 5;
    if (lane == 0) {
        shared_sum[wid] = local_sum;
        shared_sum_sq[wid] = local_sum_sq;
    }
    __syncthreads();

    // Final reduction across warps
    float group_sum = 0.0f;
    float group_sum_sq = 0.0f;
    int num_warps = (blockDim.x + 31) / 32;
    if (threadIdx.x < num_warps) {
        group_sum = shared_sum[threadIdx.x];
        group_sum_sq = shared_sum_sq[threadIdx.x];
    }
    if (threadIdx.x < 32) {
        group_sum = warpReduceSum(group_sum);
        group_sum_sq = warpReduceSum(group_sum_sq);
    }
    float mean, var;
    if (threadIdx.x == 0) {
        mean = group_sum / group_elems;
        var = group_sum_sq / group_elems - mean * mean;
        shared_sum[0] = mean;      // broadcast mean
        shared_sum_sq[0] = var;    // broadcast variance
    }
    __syncthreads();
    mean = shared_sum[0];
    var = shared_sum_sq[0];

    // Second pass: Normalize and apply per-channel affine transformation
    for (int idx = threadIdx.x; idx < group_elems; idx += blockDim.x) {
        float gelu_val = out[base + idx];
        // Determine channel index within this group
        int ch_rel = idx / (H * W);
        int channel = g * channels_per_group + ch_rel;
        float w = gn_weight[channel];
        float b = gn_bias[channel];
        float norm_val = (gelu_val - mean) / sqrtf(var + eps);
        out[base + idx] = norm_val * w + b;
    }
}

// Host function that performs ConvTranspose2d followed by the fused GELU and GroupNorm operations
torch::Tensor forward(
    torch::Tensor x,
    int64_t stride,
    torch::Tensor conv_transpose_weight,
    torch::Tensor conv_transpose_bias,
    torch::Tensor group_norm_weight,
    torch::Tensor group_norm_bias,
    int64_t num_groups) {

    // Ensure tensors are contiguous and on the CUDA device
    x = x.contiguous();
    conv_transpose_weight = conv_transpose_weight.contiguous();
    conv_transpose_bias = conv_transpose_bias.contiguous();
    group_norm_weight = group_norm_weight.contiguous();
    group_norm_bias = group_norm_bias.contiguous();

    if (!x.is_cuda()) x = x.cuda();
    if (!conv_transpose_weight.is_cuda()) conv_transpose_weight = conv_transpose_weight.cuda();
    if (!conv_transpose_bias.is_cuda()) conv_transpose_bias = conv_transpose_bias.cuda();
    if (!group_norm_weight.is_cuda()) group_norm_weight = group_norm_weight.cuda();
    if (!group_norm_bias.is_cuda()) group_norm_bias = group_norm_bias.cuda();

    // Execute transposed convolution using PyTorch's optimized implementation
    auto conv_out = at::conv_transpose2d(x, conv_transpose_weight, conv_transpose_bias, {stride});
    auto output = at::empty_like(conv_out);

    int N = conv_out.size(0);
    int C = conv_out.size(1);
    int H = conv_out.size(2);
    int W = conv_out.size(3);

    int total_groups = N * num_groups;  // One block per (sample, group) pair
    int block = 256; // Threads per block
    float eps = 1e-5;

    const float* conv_ptr = conv_out.data_ptr<float>();
    float* out_ptr = output.data_ptr<float>();
    const float* gn_weight_ptr = group_norm_weight.data_ptr<float>();
    const float* gn_bias_ptr = group_norm_bias.data_ptr<float>();

    fused_gelu_group_norm_kernel<<<total_groups, block>>>(
        conv_ptr, out_ptr, N, C, H, W, num_groups,
        gn_weight_ptr, gn_bias_ptr, eps);
    cudaDeviceSynchronize();
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Modular Fused ConvTranspose2d with GELU and GroupNorm (CUDA)");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 1.376 inst/cycle 0.000 5
Executed Ipc Elapsed 1.328 inst/cycle 0.000 5
Issue Slots Busy 34.434 % 0.011 5
Issued Ipc Active 1.376 inst/cycle 0.000 5
SM Busy 34.434 % 0.011 5
Memory Throughput 2220695951331.162 byte/second 33844656467401486336.000 5
Mem Busy 35.834 % 0.009 5
Max Bandwidth 66.254 % 0.030 5
L1/TEX Hit Rate 33.084 % 0.000 5
L2 Hit Rate 50.008 % 0.000 5
Mem Pipes Busy 13.204 % 0.001 5
Warp Cycles Per Issued Instruction 43.552 cycle 0.014 5
Warp Cycles Per Executed Instruction 43.558 cycle 0.013 5
Avg. Active Threads Per Warp 31.890 0.000 5
Avg. Not Predicated Off Threads Per Warp 29.770 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 8.000 block 0.000 5
Block Limit Shared Mem 25.000 block 0.000 5
Block Limit Warps 8.000 block 0.000 5
Theoretical Active Warps per SM 64.000 warp 0.000 5
Theoretical Occupancy 100.000 % 0.000 5
Achieved Occupancy 93.754 % 0.000 5
Achieved Active Warps Per SM 60.002 warp 0.000 5
Analysis Rules
Rule Description
INF HighPipeUtilization ALU is the highest-utilized pipeline (21.9%) based on active cycles, taking into account the rates of its different instructions. It executes integer and logic operations. It is well-utilized, but should not be a bottleneck.
INF CPIStall Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason.
INF Occupancy This kernel's theoretical occupancy is not impacted by any block limit.
Operation / Metric Value Unit
aten::conv_transpose2d
CPU Time 510680.18 μs
Device Time 2344556.21 μs
Self CPU Time 12344.85 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::convolution
CPU Time 498335.33 μs
Device Time 2344556.21 μs
Self CPU Time 16040.80 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_convolution
CPU Time 482294.54 μs
Device Time 2344556.21 μs
Self CPU Time 32298.29 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::cudnn_convolution_transpose
CPU Time 372075.06 μs
Device Time 1469166.40 μs
Self CPU Time 190676.91 μs
Self Device Time 1469166.40 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaDeviceSynchronize
CPU Time 4191179.32 μs
Device Time 37648.25 μs
Self CPU Time 4191179.32 μs
Self Device Time 37648.25 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
fused_gelu_group_norm_kernel(float const*, float*, int, int, int, int, int, float const*, float const*, float)
CPU Time 0.00 μs
Device Time 1728841.13 μs
Self CPU Time 0.00 μs
Self Device Time 1728841.13 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45299 warnings generated when compiling for host.
Suppressed 45327 warnings (45280 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b3_s1_modular_convtrans_gelu_gn/base/base.cu:27:5 bugprone-easily-swappable-parameters
27 | int N, int C, int H, int W,
| ^~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b3_s1_modular_convtrans_gelu_gn/base/base.cu:27:9: note: the first parameter in the range is 'N'
27 | int N, int C, int H, int W,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b3_s1_modular_convtrans_gelu_gn/base/base.cu:27:16: note: the last parameter in the range is 'C'
27 | int N, int C, int H, int W,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b3_s1_modular_convtrans_gelu_gn/base/base.cu:27:26: warning: 2 adjacent parameters of 'fused_gelu_group_norm_kernel' of similar type ('int') are easily swapped by mistake [bugprone-easily-swappable-parameters]
27 | int N, int C, int H, int W,
| ^~~~~~
28 | int num_groups,
| ~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b3_s1_modular_convtrans_gelu_gn/base/base.cu:27:30: note: the first parameter in the range is 'W'
27 | int N, int C, int H, int W,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b3_s1_modular_convtrans_gelu_gn/base/base.cu:28:9: note: the last parameter in the range is 'num_groups'
28 | int num_groups,
| ^~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b3_s1_modular_convtrans_gelu_gn/base/base.cu:29:5: warning: 2 adjacent parameters of 'fused_gelu_group_norm_kernel' of similar type ('const float *__restrict') are easily swapped by mistake [bugprone-easily-swappable-parameters]
29 | const float* __restrict__ gn_weight,
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
30 | const float* __restrict__ gn_bias,
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b3_s1_modular_convtrans_gelu_gn/base/base.cu:29:31: note: the first parameter in the range is 'gn_weight'
29 | const float* __restrict__ gn_weight,
| ^~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b3_s1_modular_convtrans_gelu_gn/base/base.cu:30:31: note: the last parameter in the range is 'gn_bias'
30 | const float* __restrict__ gn_bias,
| ^~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b3_s1_modular_convtrans_gelu_gn/base/base.cu:34:20: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
34 | int group_id = blockIdx.x; // overall group id
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b3_s1_modular_convtrans_gelu_gn/base/base.cu:45:20: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
45 | for (int idx = threadIdx.x; idx < group_elems; idx += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b3_s1_modular_convtrans_gelu_gn/base/base.cu:45:59: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
45 | for (int idx = threadIdx.x; idx < group_elems; idx += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b3_s1_modular_convtrans_gelu_gn/base/base.cu:60:16: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
60 | int lane = threadIdx.x & 31;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b3_s1_modular_convtrans_gelu_gn/base/base.cu:61:15: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
61 | int wid = threadIdx.x >> 5;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b3_s1_modular_convtrans_gelu_gn/base/base.cu:71:21: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
71 | int num_warps = (blockDim.x + 31) / 32;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b3_s1_modular_convtrans_gelu_gn/base/base.cu:82:28: warning: narrowing conversion from 'int' to 'float' [bugprone-narrowing-conversions]
82 | mean = group_sum / group_elems;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b3_s1_modular_convtrans_gelu_gn/base/base.cu:83:30: warning: narrowing conversion from 'int' to 'float' [bugprone-narrowing-conversions]
83 | var = group_sum_sq / group_elems - mean * mean;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b3_s1_modular_convtrans_gelu_gn/base/base.cu:92:20: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
92 | for (int idx = threadIdx.x; idx < group_elems; idx += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b3_s1_modular_convtrans_gelu_gn/base/base.cu:92:59: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
92 | for (int idx = threadIdx.x; idx < group_elems; idx += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b3_s1_modular_convtrans_gelu_gn/base/base.cu:131:13: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
131 | int N = conv_out.size(0);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b3_s1_modular_convtrans_gelu_gn/base/base.cu:132:13: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
132 | int C = conv_out.size(1);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b3_s1_modular_convtrans_gelu_gn/base/base.cu:133:13: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
133 | int H = conv_out.size(2);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b3_s1_modular_convtrans_gelu_gn/base/base.cu:134:13: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
134 | int W = conv_out.size(3);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b3_s1_modular_convtrans_gelu_gn/base/base.cu:136:24: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
136 | int total_groups = N * num_groups; // One block per (sample, group) pair
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b3_s1_modular_convtrans_gelu_gn/base/base.cu:146:40: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
146 | conv_ptr, out_ptr, N, C, H, W, num_groups,
| ^