The AI CUDA Engineer 👷


Level 2 • Task 89
import torch
import torch.nn as nn
import torch.nn.functional as F

def module_fn(
    x: torch.Tensor,
    stride: int,
    padding: int,
    output_padding: int,
    pool_kernel_size: int,
    pool_stride: int,
    pool_padding: int,
    conv_transpose: torch.Tensor,
    conv_transpose_bias: torch.Tensor,
    subtract: torch.Tensor,
) -> torch.Tensor:
    Applies sequence of operations:
        - ConvTranspose3d
        - MaxPool3d
        - Softmax
        - Subtract
        - Swish
        - Max

        x (torch.Tensor): Input tensor of shape (batch_size, in_channels, depth, height, width)
        stride (int): Stride for conv transpose
        padding (int): Padding for conv transpose
        output_padding (int): Output padding for conv transpose
        pool_kernel_size (int): Kernel size for max pooling
        pool_stride (int): Stride for max pooling
        pool_padding (int): Padding for max pooling
        conv_transpose (torch.Tensor): Weight tensor for transposed convolution
        conv_transpose_bias (torch.Tensor): Bias tensor for transposed convolution
        subtract (torch.Tensor): Subtraction parameter tensor
    x = F.conv_transpose3d(
    x = F.max_pool3d(
        x, kernel_size=pool_kernel_size, stride=pool_stride, padding=pool_padding
    x = F.softmax(x, dim=1)
    x = x - subtract.view(1, -1, 1, 1, 1)
    x = torch.sigmoid(x) * x  # Swish
    x = torch.max(x, dim=1)[0]
    return x

class Model(nn.Module):
    A model that performs a sequence of operations:
        - ConvTranspose3d
        - MaxPool3d
        - Softmax
        - Subtract
        - Swish
        - Max

    def __init__(
        super(Model, self).__init__()
        conv_transpose = nn.ConvTranspose3d(in_channels, out_channels, kernel_size)
        self.conv_transpose_parameter = conv_transpose.weight
        self.conv_transpose_bias = conv_transpose.bias
        self.subtract_parameter = nn.Parameter(torch.randn(out_channels) * 0.02)

    def forward(
        return fn(

batch_size = 128
in_channels = 3
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = 3
stride = 2
padding = 1
output_padding = 1
pool_kernel_size = 2
pool_stride = 2
pool_padding = 0

def get_inputs():
    return [
        torch.randn(batch_size, in_channels, depth, height, width),

def get_init_inputs():
    return [
import torch
import torch.nn as nn

class Model(nn.Module):
    A model that performs a sequence of operations:
        - ConvTranspose3d
        - MaxPool3d
        - Softmax
        - Subtract
        - Swish
        - Max
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, output_padding, pool_kernel_size, pool_stride, pool_padding):
        super(Model, self).__init__()
        self.conv_transpose = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding)
        self.max_pool = nn.MaxPool3d(kernel_size=pool_kernel_size, stride=pool_stride, padding=pool_padding)
        self.subtract = nn.Parameter(torch.randn(out_channels)*0.02) # Assuming subtraction is element-wise across channels

    def forward(self, x):
        x = self.conv_transpose(x)
        x = self.max_pool(x)
        x = torch.softmax(x, dim=1) # Apply softmax across channels (dim=1)
        x = x - self.subtract.view(1, -1, 1, 1, 1) # Subtract across channels
        x = torch.sigmoid(x) * x # Swish activation
        x = torch.max(x, dim=1)[0] # Max pooling across channels
        return x

batch_size = 128
in_channels = 3
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = 3
stride = 2
padding = 1
output_padding = 1
pool_kernel_size = 2
pool_stride = 2
pool_padding = 0

def get_inputs():
    return [torch.randn(batch_size, in_channels, depth, height, width)]

def get_init_inputs():
    return [in_channels, out_channels, kernel_size, stride, padding, output_padding, pool_kernel_size, pool_stride, pool_padding]

Kernel Information

#include <torch/extension.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cfloat>
#include <cmath>
#include <pybind11/pybind11.h>

namespace py = pybind11;

// This CUDA kernel fuses the softmax (along channel dimension), subtraction, swish activation, 
// and channel-wise max reduction into a single kernel. The loops over the channel dimension 
// are unrolled with #pragma unroll to reduce loop overhead. It operates on the tensor produced 
// by the ConvTranspose3d and MaxPool3d operations. 

// Assumptions:
// 1. Input tensor is in NCDHW layout.
// 2. 'subtract' tensor has size [C] and is broadcast along spatial dimensions.
// 3. The output tensor is of shape [N, D, H, W] representing the maximum over channels after the 
//    fused operations.

__global__ void fused_softmax_subtract_swish_max_kernel(
    const float* __restrict__ input,      // [N, C, D, H, W]
    const float* __restrict__ subtract,     // [C]
    float* __restrict__ output,             // [N, D, H, W]
    int N, int C, int D, int H, int W
) {
    // Compute a linear index over the output spatial locations
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int total = N * D * H * W;
    if (idx >= total) return;

    // Decode the index into (n, d, h, w) assuming contiguous layout for D, H, W
    int spatialSize = D * H * W;
    int n = idx / spatialSize;
    int rem = idx % spatialSize;
    int d = rem / (H * W);
    rem = rem % (H * W);
    int h = rem / W;
    int w = rem % W;

    // In a tensor with shape [N, C, D, H, W] in NCDHW format, the stride for the channel dimension
    // is (D*H*W). For a given (n, d, h, w), the element at channel c is at:
    // index = n*(C*D*H*W) + c*(D*H*W) + (d*H*W + h*W + w)
    int spatialOffset = d * (H * W) + h * W + w;
    int base = n * C * spatialSize + spatialOffset;

    // Step 1: Compute the maximum value over the channel dimension for numerical stability.
    float max_val = -FLT_MAX;
    #pragma unroll
    for (int c = 0; c < 64; c++) { // unrolling hint; actual loop will run for c < C
        if (c < C) {
            int in_index = base + c * spatialSize;
            float val = input[in_index];
            if (val > max_val)
                max_val = val;

    // Step 2: Compute the sum of exponentials for softmax.
    float sum_exp = 0.0f;
    #pragma unroll
    for (int c = 0; c < 64; c++) {
        if (c < C) {
            int in_index = base + c * spatialSize;
            float tmp = input[in_index] - max_val;
            float exp_val = expf(tmp);
            sum_exp += exp_val;

    // Step 3: For each channel, compute the softmax value, subtract the channel-specific bias,
    // apply the swish activation, and reduce to find the maximum activation across channels.
    float max_swish = -FLT_MAX;
    #pragma unroll
    for (int c = 0; c < 64; c++) {
        if (c < C) {
            int in_index = base + c * spatialSize;
            float exp_val = expf(input[in_index] - max_val);
            float softmax_val = exp_val / sum_exp;  // Softmax output
            float subtracted = softmax_val - subtract[c];
            float sigmoid = 1.0f / (1.0f + expf(-subtracted));
            float swish = subtracted * sigmoid;
            if (swish > max_swish)
                max_swish = swish;

    output[idx] = max_swish;

// Wrapper to launch the fused kernel. This kernel fuses four operations: softmax along the channel dimension,
// subtract a broadcasted tensor, swish activation (x * sigmoid(x)), and then a max reduction over channels.
// The kernel uses #pragma unroll to attempt to unroll the channel loops (assuming C does not exceed 64).
// If C > 64, only the first 64 iterations will be unrolled in each pragma unroll; the "if (c < C)"
// covers cases where C is less. For best performance, C should be known and small at compile time.

torch::Tensor fused_forward(
    torch::Tensor input,          // [N, C, D, H, W]
    torch::Tensor subtract_tensor // [C]
) {
    // Ensure the tensors are contiguous
    input = input.contiguous();
    subtract_tensor = subtract_tensor.contiguous();

    auto sizes = input.sizes();
    int N = sizes[0];
    int C = sizes[1];
    int D = sizes[2];
    int H = sizes[3];
    int W = sizes[4];

    auto output = torch::empty({N, D, H, W}, input.options());
    int total_spatial = N * D * H * W;
    int threads = 256;
    int blocks = (total_spatial + threads - 1) / threads;

    fused_softmax_subtract_swish_max_kernel<<<blocks, threads>>>(
        N, C, D, H, W

    return output;

// The complete forward function performs the following operations in sequence:
// 1. ConvTranspose3d
// 2. MaxPool3d
// 3. Fused operations (softmax along channel, subtract, swish, max over channel) using a custom CUDA kernel

torch::Tensor forward(
    torch::Tensor x,
    int64_t stride,
    int64_t padding,
    int64_t output_padding,
    int64_t pool_kernel_size,
    int64_t pool_stride,
    int64_t pool_padding,
    torch::Tensor conv_transpose_weight,
    torch::Tensor conv_transpose_bias,
    torch::Tensor subtract_tensor
) {
    // 1. Transposed convolution
    auto out = at::conv_transpose3d(
        {stride, stride, stride},
        {padding, padding, padding},
        {output_padding, output_padding, output_padding},
        1,             // groups
        {1, 1, 1}      // dilation

    // 2. MaxPool3d
    out = at::max_pool3d(
        {pool_kernel_size, pool_kernel_size, pool_kernel_size},
        {pool_stride, pool_stride, pool_stride},
        {pool_padding, pool_padding, pool_padding}

    // 3. Fused kernel: softmax (along channels) -> subtract -> swish -> max over channels
    // The subtract_tensor is reshaped to a 1D tensor of size [C]
    auto result = fused_forward(out, subtract_tensor.view({-1}));
    return result;

    m.def("forward", &forward, "Fused CUDA forward pass for 89_ConvTranspose3d_MaxPool_Softmax_Subtract_Swish_Max with loop unrolling");
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 0.890 inst/cycle 0.000 5
Executed Ipc Elapsed 0.870 inst/cycle 0.000 5
Issue Slots Busy 22.342 % 0.000 5
Issued Ipc Active 0.890 inst/cycle 0.000 5
SM Busy 30.422 % 0.000 5
Memory Throughput 207956348308.540 byte/second 23633605407968464.000 5
Mem Busy 11.762 % 0.000 5
Max Bandwidth 11.762 % 0.000 5
L1/TEX Hit Rate 67.920 % 0.000 5
L2 Hit Rate 6.890 % 0.001 5
Mem Pipes Busy 11.762 % 0.000 5
Warp Cycles Per Issued Instruction 8.790 cycle 0.000 5
Warp Cycles Per Executed Instruction 8.790 cycle 0.000 5
Avg. Active Threads Per Warp 32.000 0.000 5
Avg. Not Predicated Off Threads Per Warp 14.190 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 1.000 block 0.000 5
Block Limit Shared Mem 8.000 block 0.000 5
Block Limit Warps 8.000 block 0.000 5
Theoretical Active Warps per SM 8.000 warp 0.000 5
Theoretical Occupancy 12.500 % 0.000 5
Achieved Occupancy 12.290 % 0.000 5
Achieved Active Warps Per SM 7.860 warp 0.000 5
Analysis Rules
Rule Description
WRN HighPipeUtilization All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details.
INF CPIStall Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason.
WRN ThreadDivergence Instructions are executed in warps, which are groups of 32 threads. Optimal instruction throughput is achieved if all 32 threads of a warp execute the same instruction. The chosen launch configuration, early thread completion, and divergent flow control can significantly lower the number of active threads in a warp per cycle. This kernel achieves an average of 32.0 threads being active per cycle. This is further reduced to 14.2 threads per warp due to predication. The compiler may use predication to avoid an actual branch. Instead, all instructions are scheduled, but a per-thread condition code or predicate controls which threads execute the instructions. Try to avoid different execution paths within a warp when possible. In addition, ensure your kernel makes use of Independent Thread Scheduling, which allows a warp to reconverge after a data-dependent conditional block by explicitly calling __syncwarp().
WRN Occupancy This kernel's theoretical occupancy (12.5%) is limited by the number of required registers. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy.
Operation / Metric Value Unit
CPU Time 8198072.10 μs
Device Time 6719391.82 μs
Self CPU Time 4144.28 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
CPU Time 8193927.83 μs
Device Time 6719391.82 μs
Self CPU Time 6032.21 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
CPU Time 8187895.62 μs
Device Time 6719391.82 μs
Self CPU Time 14193.21 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
CPU Time 8139328.92 μs
Device Time 5321011.62 μs
Self CPU Time 114369.33 μs
Self Device Time 5321011.62 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
CPU Time 5648566.07 μs
Device Time 0.00 μs
Self CPU Time 5648566.07 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
CPU Time 0.00 μs
Device Time 3755539.27 μs
Self CPU Time 0.00 μs
Self Device Time 3755539.27 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45355 warnings generated when compiling for host.
Suppressed 45389 warnings (45342 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_89/b4_s3_fused_unroll/edit_1/edit_1.cu:23:5 bugprone-easily-swappable-parameters
23 | const float* __restrict__ input, // [N, C, D, H, W]
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
24 | const float* __restrict__ subtract, // [C]
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_89/b4_s3_fused_unroll/edit_1/edit_1.cu:23:31: note: the first parameter in the range is 'input'
23 | const float* __restrict__ input, // [N, C, D, H, W]
| ^~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_89/b4_s3_fused_unroll/edit_1/edit_1.cu:24:31: note: the last parameter in the range is 'subtract'
24 | const float* __restrict__ subtract, // [C]
| ^~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_89/b4_s3_fused_unroll/edit_1/edit_1.cu:26:5: warning: 2 adjacent parameters of 'fused_softmax_subtract_swish_max_kernel' of similar type ('int') are easily swapped by mistake [bugprone-easily-swappable-parameters]
26 | int N, int C, int D, int H, int W
| ^~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_89/b4_s3_fused_unroll/edit_1/edit_1.cu:26:9: note: the first parameter in the range is 'N'
26 | int N, int C, int D, int H, int W
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_89/b4_s3_fused_unroll/edit_1/edit_1.cu:26:16: note: the last parameter in the range is 'C'
26 | int N, int C, int D, int H, int W
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_89/b4_s3_fused_unroll/edit_1/edit_1.cu:29:15: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
29 | int idx = blockIdx.x * blockDim.x + threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_89/b4_s3_fused_unroll/edit_1/edit_1.cu:107:13: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
107 | int N = sizes[0];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_89/b4_s3_fused_unroll/edit_1/edit_1.cu:108:13: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
108 | int C = sizes[1];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_89/b4_s3_fused_unroll/edit_1/edit_1.cu:109:13: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
109 | int D = sizes[2];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_89/b4_s3_fused_unroll/edit_1/edit_1.cu:110:13: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
110 | int H = sizes[3];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_89/b4_s3_fused_unroll/edit_1/edit_1.cu:111:13: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
111 | int W = sizes[4];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_89/b4_s3_fused_unroll/edit_1/edit_1.cu:134:19: warning: the parameter 'x' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
134 | torch::Tensor x,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_89/b4_s3_fused_unroll/edit_1/edit_1.cu:137:5: warning: 2 adjacent parameters of 'forward' of similar type ('int64_t') are easily swapped by mistake [bugprone-easily-swappable-parameters]
137 | int64_t output_padding,
| ^~~~~~~~~~~~~~~~~~~~~~~
138 | int64_t pool_kernel_size,
| ~~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_89/b4_s3_fused_unroll/edit_1/edit_1.cu:137:13: note: the first parameter in the range is 'output_padding'
137 | int64_t output_padding,
| ^~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_89/b4_s3_fused_unroll/edit_1/edit_1.cu:138:13: note: the last parameter in the range is 'pool_kernel_size'
138 | int64_t pool_kernel_size,
| ^~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_89/b4_s3_fused_unroll/edit_1/edit_1.cu:141:19: warning: the parameter 'conv_transpose_weight' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
141 | torch::Tensor conv_transpose_weight,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_89/b4_s3_fused_unroll/edit_1/edit_1.cu:142:5: warning: 2 adjacent parameters of 'forward' of similar type ('torch::Tensor') are easily swapped by mistake [bugprone-easily-swappable-parameters]
142 | torch::Tensor conv_transpose_bias,
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
143 | torch::Tensor subtract_tensor
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_89/b4_s3_fused_unroll/edit_1/edit_1.cu:142:19: note: the first parameter in the range is 'conv_transpose_bias'
142 | torch::Tensor conv_transpose_bias,
| ^~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_89/b4_s3_fused_unroll/edit_1/edit_1.cu:143:19: note: the last parameter in the range is 'subtract_tensor'
143 | torch::Tensor subtract_tensor
| ^~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_89/b4_s3_fused_unroll/edit_1/edit_1.cu:143:19: warning: the parameter 'subtract_tensor' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
143 | torch::Tensor subtract_tensor
| ^
| const &