← Back to Leaderboard

The AI CUDA Engineer 👷

78_ConvTranspose3d_Max_Max_Sumbalanced_workload_distribution_base

Level 2 • Task 78
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(
    x: torch.Tensor,
    stride: int,
    padding: int,
    conv_transpose: torch.Tensor,
    conv_transpose_bias: torch.Tensor,
) -> torch.Tensor:
    """
    Applies a 3D transposed convolution operation followed by two max pooling layers and a sum operation.

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, in_channels, depth, height, width)
        stride (int): Stride of the transposed convolution
        padding (int): Padding of the transposed convolution
        conv_transpose (torch.Tensor): Transposed convolution weight tensor
        conv_transpose_bias (torch.Tensor): Bias tensor for transposed convolution

    Returns:
        torch.Tensor: Output tensor after applying transposed convolution, max pooling and sum reduction
    """
    x = F.conv_transpose3d(
        x, conv_transpose, bias=conv_transpose_bias, stride=stride, padding=padding
    )
    x = F.max_pool3d(x, kernel_size=2)
    x = F.max_pool3d(x, kernel_size=3)
    x = torch.sum(x, dim=1, keepdim=True)
    return x


class Model(nn.Module):
    """
    Model that performs a 3D transposed convolution, followed by two max pooling layers and a sum operation.
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(Model, self).__init__()
        conv = nn.ConvTranspose3d(in_channels, out_channels, kernel_size)
        self.conv_transpose_parameter = nn.Parameter(conv.weight)
        self.conv_transpose_bias = nn.Parameter(conv.bias)

    def forward(self, x, stride, padding, fn=module_fn):
        return fn(
            x, stride, padding, self.conv_transpose_parameter, self.conv_transpose_bias
        )


batch_size = 16
in_channels = 8
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = 3
stride = 2
padding = 1


def get_inputs():
    return [torch.randn(batch_size, in_channels, depth, height, width), stride, padding]


def get_init_inputs():
    return [in_channels, out_channels, kernel_size, stride, padding]
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Model that performs a 3D transposed convolution, followed by two max pooling layers and a sum operation.
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(Model, self).__init__()
        self.conv_transpose = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)
        self.max_pool1 = nn.MaxPool3d(kernel_size=2)
        self.max_pool2 = nn.MaxPool3d(kernel_size=3)

    def forward(self, x):
        x = self.conv_transpose(x)
        x = self.max_pool1(x)
        x = self.max_pool2(x)
        x = torch.sum(x, dim=1, keepdim=True) 
        return x

batch_size = 16
in_channels = 8
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = 3
stride = 2
padding = 1

def get_inputs():
    return [torch.randn(batch_size, in_channels, depth, height, width)]

def get_init_inputs():
    return [in_channels, out_channels, kernel_size, stride, padding]

Kernel Information

Related Kernels (Level 2, Task 78 • 78_ConvTranspose3d_Max_Max_Sum)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 optimized_maxpool_kernel_base 0.58 1.05 1.21
🥇 adaptive_blocksize_maxpool_opt_base 0.58 1.05 1.21
🥉 minimized_divergence_maxpool_base_base 0.58 1.05 1.21
4 unrolled_78_convtranspose3d_optimized_base 0.59 1.03 1.19
4 modular_maxpool_kernel_base 0.59 1.03 1.19
6 fully_unrolled_maxpool_base_base 0.59 1.03 1.19
7 balanced_load_distribution_maxpool_base 0.59 1.03 1.19
8 manual_unroll_maxpool_base_base 0.59 1.03 1.19
9 coalesced_maxpool_shared_mem_base 0.60 1.02 1.18
10 unrolled_78_convtranspose3d_base 0.61 1.01 1.16
11 78_ConvTranspose3d_Max_Max_Sum 0.61 1.00 1.16
12 unroll_conv3d_max_sum_base 0.61 1.00 1.15
13 modular_conv3d_max_sum_edit_1 0.61 1.00 1.15
13 modular_conv3d_max_sum_base 0.61 1.00 1.15
13 shared_mem_reduction_max_sum_base 0.61 1.00 1.15
13 unroll_conv3d_max_sum_edit_1 0.61 1.00 1.15
17 optimized_stride_max_pool_base 0.61 1.00 1.15
17 shared_mem_reduction_max_sum_edit_1 0.61 1.00 1.15
19 constant_memory_optimization_base_edit_1 0.62 0.99 1.14
19 balanced_workload_distribution_base 0.62 0.99 1.14
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <cuda_runtime.h>

#define WARP_SIZE 32

// Optimized kernel with balanced workload distribution
// Grid mapping: grid.x = N * D, grid.y = H, grid.z = W
// Each block computes the sum over channels for one spatial location of one (n,d) pair.

template <int BLOCK_SIZE>
__global__ void sum_channels_kernel(const float* __restrict__ input, float* __restrict__ output,
                                      int N, int C, int D, int H, int W) {
    // Compute indices from grid: combine N and D into grid.x
    int nd = blockIdx.x;            // nd ranges from 0 to (N*D - 1)
    int n = nd / D;                 // n index
    int d = nd % D;                 // d index
    int h = blockIdx.y;             // h index
    int w = blockIdx.z;             // w index

    // Compute the base offset for (n, d, h, w) in a contiguous tensor [N, C, D, H, W]
    // Offset for element [n, c, d, h, w] = n * (C*D*H*W) + c * (D*H*W) + d * (H*W) + h * W + w
    int base = n * (C * D * H * W) + d * (H * W) + h * W + w;

    float thread_sum = 0.0f;

    // Loop over channels with stride equal to BLOCK_SIZE
    for (int c = threadIdx.x; c < C; c += BLOCK_SIZE) {
        thread_sum += input[base + c * (D * H * W)];
    }

    // Warp-level reduction using shuffle intrinsics
    unsigned int mask = 0xffffffff;
    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
        thread_sum += __shfl_down_sync(mask, thread_sum, offset);
    }

    // Use shared memory to accumulate results from different warps if needed
    __shared__ float warp_sums[BLOCK_SIZE / WARP_SIZE];
    int lane = threadIdx.x % WARP_SIZE;
    int warp_id = threadIdx.x / WARP_SIZE;
    if (lane == 0) {
        warp_sums[warp_id] = thread_sum;
    }
    __syncthreads();

    // Final reduction from warp sums: let the first warp perform the reduction
    float block_sum = 0.0f;
    if (threadIdx.x < (BLOCK_SIZE / WARP_SIZE)) {
        block_sum = warp_sums[threadIdx.x];
    }
    if (threadIdx.x < WARP_SIZE) {
        for (int offset = (BLOCK_SIZE / WARP_SIZE) / 2; offset > 0; offset /= 2) {
            block_sum += __shfl_down_sync(mask, block_sum, offset);
        }
    }
    
    if (threadIdx.x == 0) {
        // Output tensor has shape [N, 1, D, H, W] so index accordingly
        int out_index = n * (D * H * W) + d * (H * W) + h * W + w;
        output[out_index] = block_sum;
    }
}

// Forward function that applies conv_transpose3d, pooling, then channel sum reduction
torch::Tensor forward(
    torch::Tensor x,
    int64_t stride,
    int64_t padding,
    torch::Tensor conv_transpose,
    torch::Tensor conv_transpose_bias) {

    // Ensure inputs are contiguous
    x = x.contiguous();
    conv_transpose = conv_transpose.contiguous();
    conv_transpose_bias = conv_transpose_bias.contiguous();

    TORCH_CHECK(x.is_cuda(), "Input x must be a CUDA tensor");
    TORCH_CHECK(conv_transpose.is_cuda(), "conv_transpose must be a CUDA tensor");
    TORCH_CHECK(conv_transpose_bias.is_cuda(), "conv_transpose_bias must be a CUDA tensor");

    // Apply transposed convolution
    x = at::conv_transpose3d(
        x,
        conv_transpose,
        conv_transpose_bias,
        {stride, stride, stride},
        {padding, padding, padding}
    );

    // Apply max pooling operations
    x = at::max_pool3d(x, {2, 2, 2});
    x = at::max_pool3d(x, {3, 3, 3});

    // Prepare output tensor with summed channels (result shape: [N, 1, D, H, W])
    const auto sizes = x.sizes(); // sizes: [N, C, D, H, W]
    auto output = torch::zeros({sizes[0], 1, sizes[2], sizes[3], sizes[4]}, x.options());

    // Launch kernel: use grid dimensions mapping (N*D, H, W) to cover the output spatial domain
    const int BLOCK_SIZE = 256;
    dim3 grid(sizes[0] * sizes[2], sizes[3], sizes[4]);
    sum_channels_kernel<BLOCK_SIZE><<<grid, BLOCK_SIZE>>>(
        x.data_ptr<float>(),
        output.data_ptr<float>(),
        sizes[0],
        sizes[1],
        sizes[2],
        sizes[3],
        sizes[4]
    );

    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Optimized 78_ConvTranspose3d_Max_Max_Sum kernel with balanced workload distribution");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 2.532 inst/cycle 0.000 5
Executed Ipc Elapsed 2.042 inst/cycle 0.000 5
Issue Slots Busy 63.502 % 0.025 5
Issued Ipc Active 2.542 inst/cycle 0.000 5
SM Busy 63.502 % 0.025 5
Memory Throughput 46082311619.352 byte/second 74727649242934016.000 5
Mem Busy 27.298 % 0.026 5
Max Bandwidth 23.554 % 0.019 5
L1/TEX Hit Rate 4.150 % 0.009 5
L2 Hit Rate 91.188 % 0.037 5
Mem Pipes Busy 44.162 % 0.067 5
Warp Cycles Per Issued Instruction 18.718 cycle 0.007 5
Warp Cycles Per Executed Instruction 18.778 cycle 0.007 5
Avg. Active Threads Per Warp 30.440 0.000 5
Avg. Not Predicated Off Threads Per Warp 23.350 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 8.000 block 0.000 5
Block Limit Shared Mem 28.000 block 0.000 5
Block Limit Warps 8.000 block 0.000 5
Theoretical Active Warps per SM 64.000 warp 0.000 5
Theoretical Occupancy 100.000 % 0.000 5
Achieved Occupancy 75.132 % 0.001 5
Achieved Active Warps Per SM 48.082 warp 0.000 5
Analysis Rules
Rule Description
INF HighPipeUtilization ALU is the highest-utilized pipeline (55.1%) based on active cycles, taking into account the rates of its different instructions. It executes integer and logic operations. It is well-utilized, but should not be a bottleneck.
INF CPIStall Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason.
WRN ThreadDivergence Instructions are executed in warps, which are groups of 32 threads. Optimal instruction throughput is achieved if all 32 threads of a warp execute the same instruction. The chosen launch configuration, early thread completion, and divergent flow control can significantly lower the number of active threads in a warp per cycle. This kernel achieves an average of 30.4 threads being active per cycle. This is further reduced to 23.4 threads per warp due to predication. The compiler may use predication to avoid an actual branch. Instead, all instructions are scheduled, but a per-thread condition code or predicate controls which threads execute the instructions. Try to avoid different execution paths within a warp when possible. In addition, ensure your kernel makes use of Independent Thread Scheduling, which allows a warp to reconverge after a data-dependent conditional block by explicitly calling __syncwarp().
WRN Occupancy This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (75.2%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy.
Operation / Metric Value Unit
aten::conv_transpose3d
CPU Time 3442611.49 μs
Device Time 6963737.43 μs
Self CPU Time 27816.73 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::convolution
CPU Time 3414794.76 μs
Device Time 6963737.43 μs
Self CPU Time 36624.07 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_convolution
CPU Time 3378170.69 μs
Device Time 6963737.43 μs
Self CPU Time 75475.08 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::cudnn_convolution_transpose
CPU Time 3154805.23 μs
Device Time 5442574.73 μs
Self CPU Time 335465.62 μs
Self Device Time 5442574.73 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 6461365.07 μs
Device Time 680.06 μs
Self CPU Time 6461365.07 μs
Self Device Time 680.06 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
sm90_xmma_dgrad_implicit_gemm_indexed_f32f32_tf32f32_f32_nhwckrsc_nhwc_tilesize256x64x32_warpgroupsize1x1x1_g1_strided_execute_kernel__5x_cudnn
CPU Time 0.00 μs
Device Time 3641286.64 μs
Self CPU Time 0.00 μs
Self Device Time 3641286.64 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 3298299.22 μs
Device Time 1089443.57 μs
Self CPU Time 60398.26 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45290 warnings generated when compiling for host.
Suppressed 45325 warnings (45278 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_78/b3_s1_balanced_workload_distribution/base/base.cu:13:39 bugprone-easily-swappable-parameters
13 | int N, int C, int D, int H, int W) {
| ^~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_78/b3_s1_balanced_workload_distribution/base/base.cu:13:43: note: the first parameter in the range is 'N'
13 | int N, int C, int D, int H, int W) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_78/b3_s1_balanced_workload_distribution/base/base.cu:13:50: note: the last parameter in the range is 'C'
13 | int N, int C, int D, int H, int W) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_78/b3_s1_balanced_workload_distribution/base/base.cu:15:14: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
15 | int nd = blockIdx.x; // nd ranges from 0 to (N*D - 1)
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_78/b3_s1_balanced_workload_distribution/base/base.cu:18:13: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
18 | int h = blockIdx.y; // h index
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_78/b3_s1_balanced_workload_distribution/base/base.cu:19:13: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
19 | int w = blockIdx.z; // w index
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_78/b3_s1_balanced_workload_distribution/base/base.cu:28:18: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
28 | for (int c = threadIdx.x; c < C; c += BLOCK_SIZE) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_78/b3_s1_balanced_workload_distribution/base/base.cu:40:16: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
40 | int lane = threadIdx.x % WARP_SIZE;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_78/b3_s1_balanced_workload_distribution/base/base.cu:41:19: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
41 | int warp_id = threadIdx.x / WARP_SIZE;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_78/b3_s1_balanced_workload_distribution/base/base.cu:105:9: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
105 | sizes[0],
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_78/b3_s1_balanced_workload_distribution/base/base.cu:106:9: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
106 | sizes[1],
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_78/b3_s1_balanced_workload_distribution/base/base.cu:107:9: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
107 | sizes[2],
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_78/b3_s1_balanced_workload_distribution/base/base.cu:108:9: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
108 | sizes[3],
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_2/task_78/b3_s1_balanced_workload_distribution/base/base.cu:109:9: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
109 | sizes[4]
| ^