← Back to Leaderboard

The AI CUDA Engineer 👷

19_ConvTranspose2d_GELU_GroupNormopt_convtrans_gelu_gn_even_distribution_base

Level 2 • Task 19
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(
    x: torch.Tensor,
    stride: int,
    conv_transpose: torch.Tensor,
    conv_transpose_bias: torch.Tensor,
    group_norm_weight: torch.Tensor,
    group_norm_bias: torch.Tensor,
    num_groups: int,
) -> torch.Tensor:
    """
    Applies transposed convolution, GELU activation, and group normalization.

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width)
        stride (int): Stride of the transposed convolution
        conv_transpose (torch.Tensor): Transposed convolution weight tensor
        conv_transpose_bias (torch.Tensor): Bias tensor for transposed convolution
        group_norm_weight (torch.Tensor): Weight tensor for group normalization
        group_norm_bias (torch.Tensor): Bias tensor for group normalization
        num_groups (int): Number of groups for group normalization

    Returns:
        torch.Tensor: Output tensor after applying transposed convolution, GELU and group norm
    """
    x = F.conv_transpose2d(x, conv_transpose, bias=conv_transpose_bias, stride=stride)
    x = F.gelu(x)
    x = F.group_norm(
        x, num_groups=num_groups, weight=group_norm_weight, bias=group_norm_bias
    )
    return x


class Model(nn.Module):
    """
    Model that performs a transposed convolution, applies GELU, and normalizes with GroupNorm.
    """

    def __init__(
        self, in_channels, out_channels, kernel_size, stride, groups, num_groups
    ):
        super(Model, self).__init__()
        conv_transpose = nn.ConvTranspose2d(
            in_channels, out_channels, kernel_size, stride=stride
        )
        group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels)
        self.conv_transpose_parameter = conv_transpose.weight
        self.conv_transpose_bias = nn.Parameter(
            conv_transpose.bias + torch.ones_like(conv_transpose.bias) * 0.02
        )  # make sure its nonzero
        self.group_norm_weight = group_norm.weight
        self.group_norm_bias = nn.Parameter(
            group_norm.bias + torch.ones_like(group_norm.bias) * 0.02
        )  # make sure its nonzero

    def forward(self, x, stride, num_groups, fn=module_fn):
        return fn(
            x,
            stride,
            self.conv_transpose_parameter,
            self.conv_transpose_bias,
            self.group_norm_weight,
            self.group_norm_bias,
            num_groups,
        )


batch_size = 128
in_channels = 32
out_channels = 64
height, width = 32, 32
kernel_size = 4
stride = 2
groups = 8
num_groups = 8


def get_inputs():
    return [torch.randn(batch_size, in_channels, height, width), stride, num_groups]


def get_init_inputs():
    return [in_channels, out_channels, kernel_size, stride, groups, num_groups]
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Model that performs a transposed convolution, applies GELU, and normalizes with GroupNorm.
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, groups, num_groups):
        super(Model, self).__init__()
        self.conv_transpose = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride)
        self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels)
        # Add the same noise as in the functional implementation
        self.conv_transpose.bias = nn.Parameter(self.conv_transpose.bias + torch.ones_like(self.conv_transpose.bias) * 0.02)
        self.group_norm.bias = nn.Parameter(self.group_norm.bias + torch.ones_like(self.group_norm.bias) * 0.02)

    def forward(self, x):
        x = self.conv_transpose(x)
        x = torch.nn.functional.gelu(x)
        x = self.group_norm(x)
        return x

batch_size = 128
in_channels = 32
out_channels = 64
height, width = 32, 32
kernel_size = 4
stride = 2
groups = 8
num_groups = 8

def get_inputs():
    return [torch.randn(batch_size, in_channels, height, width)]

def get_init_inputs():
    return [in_channels, out_channels, kernel_size, stride, groups, num_groups]

Kernel Information

Related Kernels (Level 2, Task 19 • 19_ConvTranspose2d_GELU_GroupNorm)

#include <torch/extension.h>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <pybind11/pybind11.h>
#include <cmath>

// Kernel: Each block processes one group from the fused convTranspose output.
// Workload is distributed evenly by dynamically choosing the number of threads per block
// based on the group size. Grid-stride loops and optional vectorized loads ensure balanced work.

__global__ void fused_gelu_group_norm_kernel(
    const float* __restrict__ in,
    float* __restrict__ out,
    int group_size,       // = channels_per_group * (H*W)
    int hw,               // H * W
    int channels_per_group,
    int C,                // Total channels
    int num_groups,
    float eps,
    const float* __restrict__ gn_weight,
    const float* __restrict__ gn_bias) {

    // Each block processes one group. Calculate group indices.
    int group_global = blockIdx.x; // global group index
    int n = group_global / num_groups;  // batch index
    int g = group_global % num_groups;  // group index
    int base = n * C * hw + g * channels_per_group * hw;  // starting offset for this group

    float local_sum = 0.0f;
    float local_sum_sq = 0.0f;

    int tid = threadIdx.x;
    int block_stride = blockDim.x;

    // Check if group_size is vectorizable: process 4 elements at a time if group_size is divisible by 4
    bool use_vector = (group_size % 4 == 0);
    if (use_vector) {
        const float4* in_vec = reinterpret_cast<const float4*>(in + base);
        float4* out_vec = reinterpret_cast<float4*>(out + base);
        int vec_count = group_size / 4;
        for (int idx = tid; idx < vec_count; idx += block_stride) {
            float4 vals = in_vec[idx];
            float4 gelu_vals;
            #pragma unroll
            for (int j = 0; j < 4; j++) {
                float v = ((float*)&vals)[j];
                float gelu = 0.5f * v * (1.0f + tanhf(0.7978845608f * (v + 0.044715f * v * v * v)));
                ((float*)&gelu_vals)[j] = gelu;
                local_sum += gelu;
                local_sum_sq += gelu * gelu;
            }
            out_vec[idx] = gelu_vals;
        }
    } else {
        // Scalar processing if vector load is not applicable
        for (int idx = tid; idx < group_size; idx += block_stride) {
            float v = in[base + idx];
            float gelu = 0.5f * v * (1.0f + tanhf(0.7978845608f * (v + 0.044715f * v * v * v)));
            out[base + idx] = gelu;
            local_sum += gelu;
            local_sum_sq += gelu * gelu;
        }
    }

    // Warp-level reduction using shuffle for sum and sum of squares
    int lane = tid & 31;
    for (int offset = 16; offset > 0; offset /= 2) {
        local_sum += __shfl_down_sync(0xffffffff, local_sum, offset);
        local_sum_sq += __shfl_down_sync(0xffffffff, local_sum_sq, offset);
    }

    // Shared memory to hold per-warp partial sums (reserve space for up to 32 warps)
    __shared__ float smem_sum[32];
    __shared__ float smem_sum_sq[32];
    int warp_id = tid / 32;
    if (lane == 0) {
        smem_sum[warp_id] = local_sum;
        smem_sum_sq[warp_id] = local_sum_sq;
    }
    __syncthreads();

    // Final reduction from warp sums done by thread 0
    float group_mean = 0.0f;
    float group_inv_std = 0.0f;
    if (tid == 0) {
        int num_warps = (blockDim.x + 31) / 32;
        float sum_tot = 0.0f;
        float sum_sq_tot = 0.0f;
        for (int i = 0; i < num_warps; i++) {
            sum_tot += smem_sum[i];
            sum_sq_tot += smem_sum_sq[i];
        }
        group_mean = sum_tot / group_size;
        float variance = sum_sq_tot / group_size - group_mean * group_mean;
        group_inv_std = rsqrtf(variance + eps);
        smem_sum[0] = group_mean;   // reuse shared memory to broadcast
        smem_sum[1] = group_inv_std;
    }
    __syncthreads();

    group_mean = smem_sum[0];
    group_inv_std = smem_sum[1];

    // Normalize and apply affine transformation with grid-stride loop
    if (use_vector) {
        float4* out_vec = reinterpret_cast<float4*>(out + base);
        int vec_count = group_size / 4;
        for (int idx = tid; idx < vec_count; idx += block_stride) {
            float4 vals = out_vec[idx];
            #pragma unroll
            for (int j = 0; j < 4; j++) {
                float gelu = ((float*)&vals)[j];
                float norm = (gelu - group_mean) * group_inv_std;
                // Compute channel index: each channel has 'hw' elements
                int k = idx * 4 + j; // overall element index within the group
                int ch = k / hw;  // channel index within the group
                int global_ch = g * channels_per_group + ch;  // global channel index for group norm params
                float alpha = gn_weight[global_ch];
                float beta = gn_bias[global_ch];
                ((float*)&vals)[j] = norm * alpha + beta;
            }
            out_vec[idx] = vals;
        }
    } else {
        for (int idx = tid; idx < group_size; idx += block_stride) {
            float gelu = out[base + idx];
            float norm = (gelu - group_mean) * group_inv_std;
            int ch = idx / hw;
            int global_ch = g * channels_per_group + ch;
            out[base + idx] = norm * gn_weight[global_ch] + gn_bias[global_ch];
        }
    }
}


torch::Tensor forward(
    torch::Tensor x,
    int64_t stride,
    torch::Tensor conv_transpose_weight,
    torch::Tensor conv_transpose_bias,
    torch::Tensor group_norm_weight,
    torch::Tensor group_norm_bias,
    int64_t num_groups) {

    // Ensure tensors are contiguous and on CUDA
    x = x.contiguous();
    conv_transpose_weight = conv_transpose_weight.contiguous();
    conv_transpose_bias = conv_transpose_bias.contiguous();
    group_norm_weight = group_norm_weight.contiguous();
    group_norm_bias = group_norm_bias.contiguous();

    if (!x.is_cuda()) x = x.cuda();
    if (!conv_transpose_weight.is_cuda()) conv_transpose_weight = conv_transpose_weight.cuda();
    if (!conv_transpose_bias.is_cuda()) conv_transpose_bias = conv_transpose_bias.cuda();
    if (!group_norm_weight.is_cuda()) group_norm_weight = group_norm_weight.cuda();
    if (!group_norm_bias.is_cuda()) group_norm_bias = group_norm_bias.cuda();

    // Perform transposed convolution
    auto conv_out = at::conv_transpose2d(x, conv_transpose_weight, conv_transpose_bias, {stride});
    auto output = at::empty_like(conv_out);

    int N = conv_out.size(0);
    int C = conv_out.size(1);
    int H = conv_out.size(2);
    int W = conv_out.size(3);
    int hw = H * W;
    int channels_per_group = C / num_groups;
    int group_size = channels_per_group * hw;

    // Dynamically determine block size to evenly distribute the workload for each group
    int threads = (group_size < 256) ? ((group_size < 32) ? 32 : group_size) : 256;
    int total_groups = N * num_groups;

    int shared_mem_size = 64 * sizeof(float); // Allocate enough shared memory for warp reductions

    // Launch one block per group
    fused_gelu_group_norm_kernel<<<total_groups, threads, shared_mem_size>>>(
        conv_out.data_ptr<float>(),
        output.data_ptr<float>(),
        group_size,
        hw,
        channels_per_group,
        C,
        num_groups,
        1e-5f,
        group_norm_weight.data_ptr<float>(),
        group_norm_bias.data_ptr<float>()
    );

    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused ConvTranspose2d with GELU+GroupNorm with Even Workload Distribution (CUDA)");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 1.374 inst/cycle 0.000 5
Executed Ipc Elapsed 1.272 inst/cycle 0.000 5
Issue Slots Busy 34.336 % 0.008 5
Issued Ipc Active 1.374 inst/cycle 0.000 5
SM Busy 34.336 % 0.008 5
Memory Throughput 2658103412848.900 byte/second 372477460704455753728.000 5
Mem Busy 42.836 % 0.094 5
Max Bandwidth 79.304 % 0.326 5
L1/TEX Hit Rate 30.132 % 0.000 5
L2 Hit Rate 50.558 % 0.003 5
Mem Pipes Busy 8.070 % 0.004 5
Warp Cycles Per Issued Instruction 43.126 cycle 0.035 5
Warp Cycles Per Executed Instruction 43.150 cycle 0.036 5
Avg. Active Threads Per Warp 31.730 0.000 5
Avg. Not Predicated Off Threads Per Warp 29.060 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 8.000 block 0.000 5
Block Limit Shared Mem 21.000 block 0.000 5
Block Limit Warps 8.000 block 0.000 5
Theoretical Active Warps per SM 64.000 warp 0.000 5
Theoretical Occupancy 100.000 % 0.000 5
Achieved Occupancy 92.710 % 0.013 5
Achieved Active Warps Per SM 59.336 warp 0.006 5
Analysis Rules
Rule Description
INF HighPipeUtilization ALU is the highest-utilized pipeline (22.8%) based on active cycles, taking into account the rates of its different instructions. It executes integer and logic operations. It is well-utilized, but should not be a bottleneck.
INF CPIStall Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason.
INF Occupancy This kernel's theoretical occupancy is not impacted by any block limit.
Operation / Metric Value Unit
aten::fill_
CPU Time 1564465.90 μs
Device Time 588051.33 μs
Self CPU Time 26668.75 μs
Self Device Time 588051.33 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 1580993.09 μs
Device Time 588051.33 μs
Self CPU Time 16553.39 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::conv_transpose2d
CPU Time 1469355.53 μs
Device Time 2522924.26 μs
Self CPU Time 13515.91 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::convolution
CPU Time 1455839.63 μs
Device Time 2522924.26 μs
Self CPU Time 18250.09 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_convolution
CPU Time 1437589.53 μs
Device Time 2522924.26 μs
Self CPU Time 36478.23 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::cudnn_convolution_transpose
CPU Time 793976.38 μs
Device Time 1586775.50 μs
Self CPU Time 206000.26 μs
Self Device Time 1586775.50 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 3240858.36 μs
Device Time 40781.29 μs
Self CPU Time 3240858.36 μs
Self Device Time 40781.29 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
fused_gelu_group_norm_kernel(float const*, float*, int, int, int, int, int, float, float const*, float const*)
CPU Time 0.00 μs
Device Time 1573543.82 μs
Self CPU Time 0.00 μs
Self Device Time 1573543.82 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45295 warnings generated when compiling for host.
Suppressed 45327 warnings (45280 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b5_s2_opt_convtrans_gelu_gn_even_distribution/base/base.cu:15:5 bugprone-easily-swappable-parameters
15 | int group_size, // = channels_per_group * (H*W)
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
16 | int hw, // H * W
| ~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b5_s2_opt_convtrans_gelu_gn_even_distribution/base/base.cu:15:9: note: the first parameter in the range is 'group_size'
15 | int group_size, // = channels_per_group * (H*W)
| ^~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b5_s2_opt_convtrans_gelu_gn_even_distribution/base/base.cu:16:9: note: the last parameter in the range is 'hw'
16 | int hw, // H * W
| ^~
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b5_s2_opt_convtrans_gelu_gn_even_distribution/base/base.cu:18:5: warning: 3 adjacent parameters of 'fused_gelu_group_norm_kernel' of convertible types are easily swapped by mistake [bugprone-easily-swappable-parameters]
18 | int C, // Total channels
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
19 | int num_groups,
| ~~~~~~~~~~~~~~~
20 | float eps,
| ~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b5_s2_opt_convtrans_gelu_gn_even_distribution/base/base.cu:18:9: note: the first parameter in the range is 'C'
18 | int C, // Total channels
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b5_s2_opt_convtrans_gelu_gn_even_distribution/base/base.cu:20:11: note: the last parameter in the range is 'eps'
20 | float eps,
| ^~~
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b5_s2_opt_convtrans_gelu_gn_even_distribution/base/base.cu:20:5: note: 'int' and 'float' may be implicitly converted
20 | float eps,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b5_s2_opt_convtrans_gelu_gn_even_distribution/base/base.cu:25:24: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
25 | int group_global = blockIdx.x; // global group index
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b5_s2_opt_convtrans_gelu_gn_even_distribution/base/base.cu:33:15: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
33 | int tid = threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b5_s2_opt_convtrans_gelu_gn_even_distribution/base/base.cu:34:24: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
34 | int block_stride = blockDim.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b5_s2_opt_convtrans_gelu_gn_even_distribution/base/base.cu:87:25: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
87 | int num_warps = (blockDim.x + 31) / 32;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b5_s2_opt_convtrans_gelu_gn_even_distribution/base/base.cu:94:32: warning: narrowing conversion from 'int' to 'float' [bugprone-narrowing-conversions]
94 | group_mean = sum_tot / group_size;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b5_s2_opt_convtrans_gelu_gn_even_distribution/base/base.cu:95:39: warning: narrowing conversion from 'int' to 'float' [bugprone-narrowing-conversions]
95 | float variance = sum_sq_tot / group_size - group_mean * group_mean;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b5_s2_opt_convtrans_gelu_gn_even_distribution/base/base.cu:163:13: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
163 | int N = conv_out.size(0);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b5_s2_opt_convtrans_gelu_gn_even_distribution/base/base.cu:164:13: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
164 | int C = conv_out.size(1);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b5_s2_opt_convtrans_gelu_gn_even_distribution/base/base.cu:165:13: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
165 | int H = conv_out.size(2);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b5_s2_opt_convtrans_gelu_gn_even_distribution/base/base.cu:166:13: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
166 | int W = conv_out.size(3);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b5_s2_opt_convtrans_gelu_gn_even_distribution/base/base.cu:168:30: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
168 | int channels_per_group = C / num_groups;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b5_s2_opt_convtrans_gelu_gn_even_distribution/base/base.cu:173:24: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
173 | int total_groups = N * num_groups;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_19/b5_s2_opt_convtrans_gelu_gn_even_distribution/base/base.cu:185:9: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
185 | num_groups,
| ^