← Back to Leaderboard

The AI CUDA Engineer 👷

20_ConvTranspose3d_Sum_ResidualAdd_Multiply_ResidualAddreduced_sync_kernel_base_base

Level 2 • Task 20
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(
    x: torch.Tensor,
    stride: int,
    padding: int,
    output_padding: int,
    conv_transpose: torch.Tensor,
    conv_transpose_bias: torch.Tensor,
    bias: torch.Tensor,
) -> torch.Tensor:
    """
    Applies a 3D transposed convolution followed by bias addition and residual operations.

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, in_channels, depth, height, width)
        stride (int): Stride of the transposed convolution
        padding (int): Padding of the transposed convolution
        output_padding (int): Additional size added to output shape
        conv_transpose (torch.Tensor): Transposed convolution weight tensor
        conv_transpose_bias (torch.Tensor): Bias tensor for transposed convolution
        bias (torch.Tensor): Bias tensor for addition

    Returns:
        torch.Tensor: Output tensor after applying operations
    """
    x = F.conv_transpose3d(
        x,
        conv_transpose,
        bias=conv_transpose_bias,
        stride=stride,
        padding=padding,
        output_padding=output_padding,
    )
    original_x = x.clone().detach()
    x = x + bias
    x = x + original_x
    x = x * original_x
    x = x + original_x
    return x


class Model(nn.Module):
    """
    Model that performs a 3D transposed convolution, followed by a sum,
    a residual add, a multiplication, and another residual add.
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        output_padding,
        bias_shape,
    ):
        super(Model, self).__init__()
        conv_transpose = nn.ConvTranspose3d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            output_padding=output_padding,
        )
        self.conv_transpose_parameter = conv_transpose.weight
        self.conv_transpose_bias = nn.Parameter(
            conv_transpose.bias + torch.ones_like(conv_transpose.bias) * 0.02
        )  # make sure its nonzero
        self.bias_parameter = nn.Parameter(torch.randn(bias_shape) * 0.02)

    def forward(self, x, stride, padding, output_padding, fn=module_fn):
        return fn(
            x,
            stride,
            padding,
            output_padding,
            self.conv_transpose_parameter,
            self.conv_transpose_bias,
            self.bias_parameter,
        )


batch_size = 16
in_channels = 32
out_channels = 64
depth, height, width = 16, 32, 32
kernel_size = 3
stride = 2
padding = 1
output_padding = 1
bias_shape = (out_channels, 1, 1, 1)


def get_inputs():
    return [
        torch.randn(batch_size, in_channels, depth, height, width),
        stride,
        padding,
        output_padding,
    ]


def get_init_inputs():
    return [
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        output_padding,
        bias_shape,
    ]
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Model that performs a 3D transposed convolution, followed by a sum, 
    a residual add, a multiplication, and another residual add.
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, output_padding, bias_shape):
        super(Model, self).__init__()
        self.conv_transpose = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding)
        self.conv_transpose.bias = nn.Parameter(self.conv_transpose.bias + torch.ones_like(self.conv_transpose.bias) * 0.02)
        self.bias = nn.Parameter(torch.randn(bias_shape)*0.02)

    def forward(self, x):
        x = self.conv_transpose(x)
        original_x = x.clone().detach()
        x = x + self.bias
        x = x + original_x
        x = x * original_x
        x = x + original_x
        return x

batch_size = 16
in_channels = 32
out_channels = 64
depth, height, width = 16, 32, 32
kernel_size = 3
stride = 2
padding = 1
output_padding = 1
bias_shape = (out_channels, 1, 1, 1)

def get_inputs():
    return [torch.randn(batch_size, in_channels, depth, height, width)]

def get_init_inputs():
    return [in_channels, out_channels, kernel_size, stride, padding, output_padding, bias_shape]

Kernel Information

Related Kernels (Level 2, Task 20 • 20_ConvTranspose3d_Sum_ResidualAdd_Multiply_ResidualAdd)

#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

__global__ void reduced_sync_kernel(
    const float* __restrict__ conv_output,
    const float* __restrict__ element_bias,
    float* __restrict__ output,
    int num_elements,
    int channels,
    int spatial_size
) {
    extern __shared__ float shared_bias[];
    int tid = blockIdx.x * blockDim.x + threadIdx.x;
    int stride = blockDim.x * gridDim.x;
    int lane = threadIdx.x;

    // Load bias into shared memory
    for (int c = lane; c < channels; c += blockDim.x) {
        shared_bias[c] = element_bias[c];
    }
    __syncthreads();  // Ensure all threads have loaded the bias

    // Process elements in groups of 4 (128-bit aligned)
    int num_vec = num_elements / 4;

    for (int i = tid; i < num_vec; i += stride) {
        int base_idx = i * 4;
        float4 in_val = __ldg(reinterpret_cast<const float4*>(conv_output) + i);
        float4 result;

        int c0 = (base_idx / spatial_size) % channels;
        int c1 = ((base_idx + 1) / spatial_size) % channels;
        int c2 = ((base_idx + 2) / spatial_size) % channels;
        int c3 = ((base_idx + 3) / spatial_size) % channels;

        float b0 = shared_bias[c0];
        float b1 = shared_bias[c1];
        float b2 = shared_bias[c2];
        float b3 = shared_bias[c3];

        result.x = in_val.x * (2.0f * in_val.x + b0 + 1.0f);
        result.y = in_val.y * (2.0f * in_val.y + b1 + 1.0f);
        result.z = in_val.z * (2.0f * in_val.z + b2 + 1.0f);
        result.w = in_val.w * (2.0f * in_val.w + b3 + 1.0f);

        reinterpret_cast<float4*>(output)[i] = result;
    }

    // Handle remaining elements
    int remaining_start = num_vec * 4;
    for (int i = tid + remaining_start; i < num_elements; i += stride) {
        int c = (i / spatial_size) % channels;
        float orig = __ldg(conv_output + i);
        float b = shared_bias[c];
        output[i] = orig * (2.0f * orig + b + 1.0f);
    }
}

torch::Tensor forward(
    torch::Tensor x,
    int stride,
    int padding,
    int output_padding,
    torch::Tensor conv_transpose,
    torch::Tensor conv_transpose_bias,
    torch::Tensor bias
) {
    auto conv_result = torch::conv_transpose3d(
        x,
        conv_transpose,
        conv_transpose_bias,
        stride,
        padding,
        output_padding
    );

    auto sizes = conv_result.sizes();
    const int channels = sizes[1];
    const int spatial_size = sizes[2] * sizes[3] * sizes[4];
    const int num_elements = conv_result.numel();

    auto output = torch::empty_like(conv_result);

    const int threads_per_block = 512;  // Larger block size for better performance
    int max_blocks = (num_elements + threads_per_block - 1) / threads_per_block;
    int blocks = (max_blocks < 1024) ? max_blocks : 1024;

    size_t shared_mem_size = channels * sizeof(float);

    reduced_sync_kernel<<<blocks, threads_per_block, shared_mem_size>>>(
        conv_result.data_ptr<float>(),
        bias.data_ptr<float>(),
        output.data_ptr<float>(),
        num_elements,
        channels,
        spatial_size
    );

    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused ConvTranspose3D with Reduced Synchronization");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 2.248 inst/cycle 0.000 5
Executed Ipc Elapsed 2.096 inst/cycle 0.000 5
Issue Slots Busy 56.240 % 0.028 5
Issued Ipc Active 2.248 inst/cycle 0.000 5
SM Busy 56.240 % 0.028 5
Memory Throughput 2825144950767.830 byte/second 203650958220320145408.000 5
Mem Busy 44.820 % 0.053 5
Max Bandwidth 84.282 % 0.182 5
L1/TEX Hit Rate 0.020 % 0.000 5
L2 Hit Rate 50.806 % 0.003 5
Mem Pipes Busy 8.350 % 0.002 5
Warp Cycles Per Issued Instruction 25.894 cycle 0.009 5
Warp Cycles Per Executed Instruction 25.896 cycle 0.009 5
Avg. Active Threads Per Warp 32.000 0.000 5
Avg. Not Predicated Off Threads Per Warp 26.270 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 4.000 block 0.000 5
Block Limit Shared Mem 12.000 block 0.000 5
Block Limit Warps 4.000 block 0.000 5
Theoretical Active Warps per SM 64.000 warp 0.000 5
Theoretical Occupancy 100.000 % 0.000 5
Achieved Occupancy 91.166 % 0.031 5
Achieved Active Warps Per SM 58.348 warp 0.014 5
Analysis Rules
Rule Description
INF HighPipeUtilization ALU is the highest-utilized pipeline (49.1%) based on active cycles, taking into account the rates of its different instructions. It executes integer and logic operations. It is well-utilized, but should not be a bottleneck.
INF CPIStall Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason.
INF Occupancy This kernel's theoretical occupancy is not impacted by any block limit.
Operation / Metric Value Unit
aten::conv_transpose3d
CPU Time 1883665.82 μs
Device Time 5183833.69 μs
Self CPU Time 8293.02 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::convolution
CPU Time 1875372.80 μs
Device Time 5183833.69 μs
Self CPU Time 10672.20 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_convolution
CPU Time 1864700.60 μs
Device Time 5183833.69 μs
Self CPU Time 21641.27 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::cudnn_convolution_transpose
CPU Time 535331.59 μs
Device Time 3187567.20 μs
Self CPU Time 142057.24 μs
Self Device Time 3187567.20 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaEventRecord
CPU Time 1890733.08 μs
Device Time 165971.86 μs
Self CPU Time 1890733.08 μs
Self Device Time 165971.86 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 4678320.50 μs
Device Time 84171.51 μs
Self CPU Time 4678320.50 μs
Self Device Time 84171.51 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::add_
CPU Time 1301790.90 μs
Device Time 1996266.49 μs
Self CPU Time 18952.88 μs
Self Device Time 1996266.49 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45293 warnings generated when compiling for host.
Suppressed 45327 warnings (45280 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_20/b10_s1_reduced_sync_kernel_base/base/base.cu:6:5 bugprone-easily-swappable-parameters
6 | const float* __restrict__ conv_output,
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
7 | const float* __restrict__ element_bias,
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_20/b10_s1_reduced_sync_kernel_base/base/base.cu:6:31: note: the first parameter in the range is 'conv_output'
6 | const float* __restrict__ conv_output,
| ^~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_20/b10_s1_reduced_sync_kernel_base/base/base.cu:7:31: note: the last parameter in the range is 'element_bias'
7 | const float* __restrict__ element_bias,
| ^~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_20/b10_s1_reduced_sync_kernel_base/base/base.cu:9:5: warning: 2 adjacent parameters of 'reduced_sync_kernel' of similar type ('int') are easily swapped by mistake [bugprone-easily-swappable-parameters]
9 | int num_elements,
| ^~~~~~~~~~~~~~~~~
10 | int channels,
| ~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_20/b10_s1_reduced_sync_kernel_base/base/base.cu:9:9: note: the first parameter in the range is 'num_elements'
9 | int num_elements,
| ^~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_20/b10_s1_reduced_sync_kernel_base/base/base.cu:10:9: note: the last parameter in the range is 'channels'
10 | int channels,
| ^~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_20/b10_s1_reduced_sync_kernel_base/base/base.cu:14:15: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
14 | int tid = blockIdx.x * blockDim.x + threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_20/b10_s1_reduced_sync_kernel_base/base/base.cu:15:18: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
15 | int stride = blockDim.x * gridDim.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_20/b10_s1_reduced_sync_kernel_base/base/base.cu:16:16: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
16 | int lane = threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_20/b10_s1_reduced_sync_kernel_base/base/base.cu:19:43: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
19 | for (int c = lane; c < channels; c += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_20/b10_s1_reduced_sync_kernel_base/base/base.cu:61:19: warning: the parameter 'x' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
61 | torch::Tensor x,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_20/b10_s1_reduced_sync_kernel_base/base/base.cu:65:19: warning: the parameter 'conv_transpose' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
65 | torch::Tensor conv_transpose,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_20/b10_s1_reduced_sync_kernel_base/base/base.cu:66:5: warning: 2 adjacent parameters of 'forward' of similar type ('torch::Tensor') are easily swapped by mistake [bugprone-easily-swappable-parameters]
66 | torch::Tensor conv_transpose_bias,
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
67 | torch::Tensor bias
| ~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_20/b10_s1_reduced_sync_kernel_base/base/base.cu:66:19: note: the first parameter in the range is 'conv_transpose_bias'
66 | torch::Tensor conv_transpose_bias,
| ^~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_20/b10_s1_reduced_sync_kernel_base/base/base.cu:67:19: note: the last parameter in the range is 'bias'
67 | torch::Tensor bias
| ^~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_20/b10_s1_reduced_sync_kernel_base/base/base.cu:67:19: warning: the parameter 'bias' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
67 | torch::Tensor bias
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_20/b10_s1_reduced_sync_kernel_base/base/base.cu:79:26: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
79 | const int channels = sizes[1];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_20/b10_s1_reduced_sync_kernel_base/base/base.cu:80:30: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
80 | const int spatial_size = sizes[2] * sizes[3] * sizes[4];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_20/b10_s1_reduced_sync_kernel_base/base/base.cu:81:30: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
81 | const int num_elements = conv_result.numel();
| ^