← Back to Leaderboard

The AI CUDA Engineer 👷

20_ConvTranspose3d_Sum_ResidualAdd_Multiply_ResidualAddcoalesced_atomic_selective_kernel_base

Level 2 • Task 20
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(
    x: torch.Tensor,
    stride: int,
    padding: int,
    output_padding: int,
    conv_transpose: torch.Tensor,
    conv_transpose_bias: torch.Tensor,
    bias: torch.Tensor,
) -> torch.Tensor:
    """
    Applies a 3D transposed convolution followed by bias addition and residual operations.

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, in_channels, depth, height, width)
        stride (int): Stride of the transposed convolution
        padding (int): Padding of the transposed convolution
        output_padding (int): Additional size added to output shape
        conv_transpose (torch.Tensor): Transposed convolution weight tensor
        conv_transpose_bias (torch.Tensor): Bias tensor for transposed convolution
        bias (torch.Tensor): Bias tensor for addition

    Returns:
        torch.Tensor: Output tensor after applying operations
    """
    x = F.conv_transpose3d(
        x,
        conv_transpose,
        bias=conv_transpose_bias,
        stride=stride,
        padding=padding,
        output_padding=output_padding,
    )
    original_x = x.clone().detach()
    x = x + bias
    x = x + original_x
    x = x * original_x
    x = x + original_x
    return x


class Model(nn.Module):
    """
    Model that performs a 3D transposed convolution, followed by a sum,
    a residual add, a multiplication, and another residual add.
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        output_padding,
        bias_shape,
    ):
        super(Model, self).__init__()
        conv_transpose = nn.ConvTranspose3d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            output_padding=output_padding,
        )
        self.conv_transpose_parameter = conv_transpose.weight
        self.conv_transpose_bias = nn.Parameter(
            conv_transpose.bias + torch.ones_like(conv_transpose.bias) * 0.02
        )  # make sure its nonzero
        self.bias_parameter = nn.Parameter(torch.randn(bias_shape) * 0.02)

    def forward(self, x, stride, padding, output_padding, fn=module_fn):
        return fn(
            x,
            stride,
            padding,
            output_padding,
            self.conv_transpose_parameter,
            self.conv_transpose_bias,
            self.bias_parameter,
        )


batch_size = 16
in_channels = 32
out_channels = 64
depth, height, width = 16, 32, 32
kernel_size = 3
stride = 2
padding = 1
output_padding = 1
bias_shape = (out_channels, 1, 1, 1)


def get_inputs():
    return [
        torch.randn(batch_size, in_channels, depth, height, width),
        stride,
        padding,
        output_padding,
    ]


def get_init_inputs():
    return [
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        output_padding,
        bias_shape,
    ]
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Model that performs a 3D transposed convolution, followed by a sum, 
    a residual add, a multiplication, and another residual add.
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, output_padding, bias_shape):
        super(Model, self).__init__()
        self.conv_transpose = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding)
        self.conv_transpose.bias = nn.Parameter(self.conv_transpose.bias + torch.ones_like(self.conv_transpose.bias) * 0.02)
        self.bias = nn.Parameter(torch.randn(bias_shape)*0.02)

    def forward(self, x):
        x = self.conv_transpose(x)
        original_x = x.clone().detach()
        x = x + self.bias
        x = x + original_x
        x = x * original_x
        x = x + original_x
        return x

batch_size = 16
in_channels = 32
out_channels = 64
depth, height, width = 16, 32, 32
kernel_size = 3
stride = 2
padding = 1
output_padding = 1
bias_shape = (out_channels, 1, 1, 1)

def get_inputs():
    return [torch.randn(batch_size, in_channels, depth, height, width)]

def get_init_inputs():
    return [in_channels, out_channels, kernel_size, stride, padding, output_padding, bias_shape]

Kernel Information

Related Kernels (Level 2, Task 20 • 20_ConvTranspose3d_Sum_ResidualAdd_Multiply_ResidualAdd)

#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

// Set USE_ATOMIC to 1 if atomic operations are required for residual accumulation
// In our case, each thread writes to a unique output element so race conditions do not occur normally.
// This flag allows minimal use of atomics only when necessary.
#define USE_ATOMIC 0

// Kernel computes: output[i] = conv_output[i] * (2.0f * conv_output[i] + bias[c] + 1.0f), where
// c = (i / spatial_size) % channels.
// It uses vectorized loads/stores (float4) for coalesced memory access and conditionally uses
// atomicAdd only if the compile-time flag USE_ATOMIC is enabled.

__global__ void coalesced_atomic_selective_kernel(
    const float* __restrict__ conv_output,
    const float* __restrict__ element_bias,
    float* output,
    int num_elements,
    int channels,
    int spatial_size
) {
    extern __shared__ float shared_bias[];

    // Cooperative loading of bias into shared memory
    for (int i = threadIdx.x; i < channels; i += blockDim.x) {
        shared_bias[i] = __ldg(&element_bias[i]);
    }
    __syncthreads();

    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int total_threads = gridDim.x * blockDim.x;

    // Process in vectorized manner (float4) for groups of 4 elements
    int total_vec = num_elements / 4;
    for (int i = idx; i < total_vec; i += total_threads) {
        int base = i * 4;
        float4 in_vec = reinterpret_cast<const float4*>(conv_output)[i];
        
        #if !USE_ATOMIC
        float4 out_vec;
        #endif
        
        #pragma unroll
        for (int j = 0; j < 4; j++) {
            int global_idx = base + j;
            int c = (global_idx / spatial_size) % channels;
            float val = ((float*)&in_vec)[j];
            float computed = val * (2.0f * val + shared_bias[c] + 1.0f);
            
            #if USE_ATOMIC
              atomicAdd(&output[global_idx], computed);
            #else
              ((float*)&out_vec)[j] = computed;
            #endif
        }
        
        #if !USE_ATOMIC
          reinterpret_cast<float4*>(output)[i] = out_vec;
        #endif
    }

    // Process remaining elements that don't fit into float4 groups
    int remainder = num_elements % 4;
    int start = total_vec * 4;
    for (int i = start + idx; i < num_elements; i += total_threads) {
        int c = (i / spatial_size) % channels;
        float val = conv_output[i];
        float computed = val * (2.0f * val + shared_bias[c] + 1.0f);
        
        #if USE_ATOMIC
          atomicAdd(&output[i], computed);
        #else
          output[i] = computed;
        #endif
    }
}

// The forward function performs conv_transpose3d using PyTorch's optimized implementation,
// then launches our fused kernel. Note that atomic operations are only used if needed
// (controlled via the USE_ATOMIC macro) to safely handle race conditions in case of overlapping
// writes. In the normal one-to-one mapping, atomics are avoided to reduce global memory contention.

torch::Tensor forward(
    torch::Tensor x,
    int stride,
    int padding,
    int output_padding,
    torch::Tensor conv_transpose,
    torch::Tensor conv_transpose_bias,
    torch::Tensor bias
) {
    auto conv_result = torch::conv_transpose3d(
        x,
        conv_transpose,
        conv_transpose_bias,
        stride,
        padding,
        output_padding
    );

    auto sizes = conv_result.sizes();
    int channels = sizes[1];
    int spatial_size = sizes[2] * sizes[3] * sizes[4];  // D * H * W
    int num_elements = conv_result.numel();

    auto output = torch::empty_like(conv_result);

    const int threads_per_block = 256;
    int total_vec = num_elements / 4;
    int blocks = (total_vec > 0) ? ((total_vec + threads_per_block - 1) / threads_per_block) : ((num_elements + threads_per_block - 1) / threads_per_block);
    blocks = (blocks < 1024) ? blocks : 1024;

    size_t shared_mem_size = channels * sizeof(float);

    coalesced_atomic_selective_kernel<<<blocks, threads_per_block, shared_mem_size>>>(
        conv_result.data_ptr<float>(),
        bias.data_ptr<float>(),
        output.data_ptr<float>(),
        num_elements,
        channels,
        spatial_size
    );

    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Coalesced Vectorized Fused ConvTranspose3D Kernel with Selective Atomic Residual Addition");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 2.020 inst/cycle 0.000 5
Executed Ipc Elapsed 1.896 inst/cycle 0.000 5
Issue Slots Busy 50.460 % 0.014 5
Issued Ipc Active 2.020 inst/cycle 0.000 5
SM Busy 50.460 % 0.014 5
Memory Throughput 2811453510615.038 byte/second 250621066563110830080.000 5
Mem Busy 44.674 % 0.065 5
Max Bandwidth 83.874 % 0.223 5
L1/TEX Hit Rate 0.020 % 0.000 5
L2 Hit Rate 50.846 % 0.006 5
Mem Pipes Busy 8.056 % 0.002 5
Warp Cycles Per Issued Instruction 29.058 cycle 0.014 5
Warp Cycles Per Executed Instruction 29.068 cycle 0.014 5
Avg. Active Threads Per Warp 32.000 0.000 5
Avg. Not Predicated Off Threads Per Warp 25.700 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 8.000 block 0.000 5
Block Limit Shared Mem 25.000 block 0.000 5
Block Limit Warps 8.000 block 0.000 5
Theoretical Active Warps per SM 64.000 warp 0.000 5
Theoretical Occupancy 100.000 % 0.000 5
Achieved Occupancy 92.196 % 0.003 5
Achieved Active Warps Per SM 59.006 warp 0.001 5
Analysis Rules
Rule Description
INF HighPipeUtilization ALU is the highest-utilized pipeline (45.0%) based on active cycles, taking into account the rates of its different instructions. It executes integer and logic operations. It is well-utilized, but should not be a bottleneck.
INF CPIStall Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason.
INF Occupancy This kernel's theoretical occupancy is not impacted by any block limit.
Operation / Metric Value Unit
aten::conv_transpose3d
CPU Time 1751352.00 μs
Device Time 4680503.06 μs
Self CPU Time 8302.34 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::convolution
CPU Time 1743049.66 μs
Device Time 4680503.06 μs
Self CPU Time 11773.66 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_convolution
CPU Time 1731276.00 μs
Device Time 4680503.06 μs
Self CPU Time 23506.13 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::cudnn_convolution_transpose
CPU Time 504272.67 μs
Device Time 2848111.79 μs
Self CPU Time 153556.87 μs
Self Device Time 2848111.79 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaEventRecord
CPU Time 1727252.88 μs
Device Time 126952.12 μs
Self CPU Time 1727252.88 μs
Self Device Time 126952.12 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 4207368.90 μs
Device Time 64218.04 μs
Self CPU Time 4207368.90 μs
Self Device Time 64218.04 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::add_
CPU Time 1196178.35 μs
Device Time 1832391.27 μs
Self CPU Time 20536.09 μs
Self Device Time 1832391.27 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45294 warnings generated when compiling for host.
Suppressed 45327 warnings (45280 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_20/b5_s1_coalesced_atomic_selective_kernel/base/base.cu:16:5 bugprone-easily-swappable-parameters
16 | const float* __restrict__ conv_output,
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
17 | const float* __restrict__ element_bias,
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_20/b5_s1_coalesced_atomic_selective_kernel/base/base.cu:16:31: note: the first parameter in the range is 'conv_output'
16 | const float* __restrict__ conv_output,
| ^~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_20/b5_s1_coalesced_atomic_selective_kernel/base/base.cu:17:31: note: the last parameter in the range is 'element_bias'
17 | const float* __restrict__ element_bias,
| ^~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_20/b5_s1_coalesced_atomic_selective_kernel/base/base.cu:19:5: warning: 2 adjacent parameters of 'coalesced_atomic_selective_kernel' of similar type ('int') are easily swapped by mistake [bugprone-easily-swappable-parameters]
19 | int num_elements,
| ^~~~~~~~~~~~~~~~~
20 | int channels,
| ~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_20/b5_s1_coalesced_atomic_selective_kernel/base/base.cu:19:9: note: the first parameter in the range is 'num_elements'
19 | int num_elements,
| ^~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_20/b5_s1_coalesced_atomic_selective_kernel/base/base.cu:20:9: note: the last parameter in the range is 'channels'
20 | int channels,
| ^~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_20/b5_s1_coalesced_atomic_selective_kernel/base/base.cu:26:18: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
26 | for (int i = threadIdx.x; i < channels; i += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_20/b5_s1_coalesced_atomic_selective_kernel/base/base.cu:26:50: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
26 | for (int i = threadIdx.x; i < channels; i += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_20/b5_s1_coalesced_atomic_selective_kernel/base/base.cu:31:15: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
31 | int idx = blockIdx.x * blockDim.x + threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_20/b5_s1_coalesced_atomic_selective_kernel/base/base.cu:32:25: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
32 | int total_threads = gridDim.x * blockDim.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_20/b5_s1_coalesced_atomic_selective_kernel/base/base.cu:64:9: warning: Value stored to 'remainder' during its initialization is never read [clang-analyzer-deadcode.DeadStores]
64 | int remainder = num_elements % 4;
| ^~~~~~~~~ ~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_20/b5_s1_coalesced_atomic_selective_kernel/base/base.cu:64:9: note: Value stored to 'remainder' during its initialization is never read
64 | int remainder = num_elements % 4;
| ^~~~~~~~~ ~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_20/b5_s1_coalesced_atomic_selective_kernel/base/base.cu:85:19: warning: the parameter 'x' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
85 | torch::Tensor x,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_20/b5_s1_coalesced_atomic_selective_kernel/base/base.cu:89:19: warning: the parameter 'conv_transpose' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
89 | torch::Tensor conv_transpose,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_20/b5_s1_coalesced_atomic_selective_kernel/base/base.cu:90:5: warning: 2 adjacent parameters of 'forward' of similar type ('torch::Tensor') are easily swapped by mistake [bugprone-easily-swappable-parameters]
90 | torch::Tensor conv_transpose_bias,
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
91 | torch::Tensor bias
| ~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_20/b5_s1_coalesced_atomic_selective_kernel/base/base.cu:90:19: note: the first parameter in the range is 'conv_transpose_bias'
90 | torch::Tensor conv_transpose_bias,
| ^~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_20/b5_s1_coalesced_atomic_selective_kernel/base/base.cu:91:19: note: the last parameter in the range is 'bias'
91 | torch::Tensor bias
| ^~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_20/b5_s1_coalesced_atomic_selective_kernel/base/base.cu:91:19: warning: the parameter 'bias' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
91 | torch::Tensor bias
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_20/b5_s1_coalesced_atomic_selective_kernel/base/base.cu:103:20: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
103 | int channels = sizes[1];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_20/b5_s1_coalesced_atomic_selective_kernel/base/base.cu:104:24: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
104 | int spatial_size = sizes[2] * sizes[3] * sizes[4]; // D * H * W
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_20/b5_s1_coalesced_atomic_selective_kernel/base/base.cu:105:24: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
105 | int num_elements = conv_result.numel();
| ^