← Back to Leaderboard

The AI CUDA Engineer 👷

49_ConvTranspose3d_Softmax_Sigmoidwarp_divergence_reduction_base_base

Level 2 • Task 49
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(
    x: torch.Tensor,
    stride: int,
    padding: int,
    output_padding: int,
    bias_flag: bool,
    conv_transpose: torch.Tensor,
    conv_transpose_bias: torch.Tensor,
) -> torch.Tensor:
    """
    Applies a 3D transposed convolution operation followed by softmax and sigmoid.

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, in_channels, D, H, W)
        stride (int): Stride of the transposed convolution
        padding (int): Padding of the transposed convolution
        output_padding (int): Additional size added to output shape
        bias_flag (bool): Whether to use bias in conv_transpose
        conv_transpose (torch.Tensor): Transposed convolution weight tensor
        conv_transpose_bias (torch.Tensor): Bias tensor for transposed convolution

    Returns:
        torch.Tensor: Output tensor after applying transposed convolution, softmax and sigmoid
    """
    bias = conv_transpose_bias if bias_flag else None
    x = F.conv_transpose3d(
        x,
        conv_transpose,
        bias=bias,
        stride=stride,
        padding=padding,
        output_padding=output_padding,
    )
    x = F.softmax(x, dim=1)
    x = torch.sigmoid(x)
    return x


class Model(nn.Module):
    """
    Model that performs a 3D transposed convolution, applies Softmax and Sigmoid.
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        output_padding,
        bias,
    ):
        super(Model, self).__init__()
        conv_transpose = nn.ConvTranspose3d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            output_padding=output_padding,
            bias=bias,
        )
        self.conv_transpose_parameter = nn.Parameter(conv_transpose.weight)
        self.conv_transpose_bias = (
            nn.Parameter(
                conv_transpose.bias
                + torch.randn(
                    conv_transpose.bias.shape,
                    device=conv_transpose.bias.device,
                    dtype=conv_transpose.bias.dtype,
                )
                * 0.02
            )
            if bias
            else None
        )

    def forward(self, x, stride, padding, output_padding, bias, fn=module_fn):
        return fn(
            x,
            stride,
            padding,
            output_padding,
            bias,
            self.conv_transpose_parameter,
            self.conv_transpose_bias,
        )


batch_size = 16
in_channels = 32
out_channels = 64
D, H, W = 16, 32, 32
kernel_size = 3
stride = 2
padding = 1
output_padding = 1
bias = True


def get_inputs():
    return [
        torch.randn(batch_size, in_channels, D, H, W),
        stride,
        padding,
        output_padding,
        bias,
    ]


def get_init_inputs():
    return [
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        output_padding,
        bias,
    ]
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Model that performs a 3D transposed convolution, applies Softmax and Sigmoid.
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, output_padding, bias=True):
        super(Model, self).__init__()
        self.conv_transpose = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=bias)
        self.conv_transpose.bias = nn.Parameter(self.conv_transpose.bias + torch.randn(self.conv_transpose.bias.shape, device=self.conv_transpose.bias.device, dtype=self.conv_transpose.bias.dtype) * 0.02) if bias else None
        self.softmax = nn.Softmax(dim=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_channels, D, H, W).

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, out_channels, D, H, W).
        """
        x = self.conv_transpose(x)
        x = self.softmax(x)
        x = self.sigmoid(x)
        return x

batch_size = 16
in_channels = 32
out_channels = 64
D, H, W = 16, 32, 32
kernel_size = 3
stride = 2
padding = 1
output_padding = 1

def get_inputs():
    return [torch.randn(batch_size, in_channels, D, H, W)]

def get_init_inputs():
    return [in_channels, out_channels, kernel_size, stride, padding, output_padding]

Kernel Information

#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include <math.h>

#define BLOCK_SIZE 256
#define UNROLL_FACTOR 4

// Kernel to reduce warp divergence by using predicated execution
// and minimizing conditional branches within warps

template <typename scalar_t, int CHANNELS>
__global__ void warp_divergence_reduced_softmax_sigmoid_kernel(
    const scalar_t* __restrict__ input,
    scalar_t* __restrict__ output,
    const int batch,
    const int depth,
    const int height,
    const int width) {

    const int spatial = depth * height * width;
    const int total_pixels = batch * spatial;
    const int idx = blockIdx.x * blockDim.x + threadIdx.x;

    if (idx < total_pixels) {
        const int b = idx / spatial;
        const int pixel_idx = idx % spatial;
        const int d = pixel_idx / (height * width);
        const int rem = pixel_idx % (height * width);
        const int h = rem / width;
        const int w = rem % width;

        const int base = (b * CHANNELS * spatial) + (d * height * width + h * width + w);
        const int stride = spatial;

        scalar_t max_val = input[base];
        scalar_t sum_exp = 1.0f;

        #pragma unroll
        for (int c = 0; c < CHANNELS; c += UNROLL_FACTOR) {
            scalar_t vals[UNROLL_FACTOR];
            #pragma unroll
            for (int u = 0; u < UNROLL_FACTOR && (c + u) < CHANNELS; ++u) {
                const int pos = base + (c + u) * stride;
                vals[u] = input[pos];
            }
            #pragma unroll
            for (int u = 0; u < UNROLL_FACTOR && (c + u) < CHANNELS; ++u) {
                scalar_t val = vals[u];
                bool update_max = val > max_val;
                max_val = update_max ? val : max_val;
                sum_exp = update_max ? sum_exp * exp(max_val - val) + 1.0f : sum_exp + exp(val - max_val);
            }
        }

        #pragma unroll
        for (int c = 0; c < CHANNELS; c += UNROLL_FACTOR) {
            #pragma unroll
            for (int u = 0; u < UNROLL_FACTOR && (c + u) < CHANNELS; ++u) {
                const int pos = base + (c + u) * stride;
                const scalar_t softmax_val = exp(input[pos] - max_val) / sum_exp;
                output[pos] = 1.0f / (1.0f + exp(-softmax_val));
            }
        }
    }
}

// Non-templated version for dynamic channel sizes
template <typename scalar_t>
__global__ void dynamic_warp_divergence_reduced_softmax_sigmoid_kernel(
    const scalar_t* __restrict__ input,
    scalar_t* __restrict__ output,
    const int batch,
    const int channels,
    const int depth,
    const int height,
    const int width) {

    const int spatial = depth * height * width;
    const int total_pixels = batch * spatial;
    const int idx = blockIdx.x * blockDim.x + threadIdx.x;

    if (idx < total_pixels) {
        const int b = idx / spatial;
        const int pixel_idx = idx % spatial;
        const int d = pixel_idx / (height * width);
        const int rem = pixel_idx % (height * width);
        const int h = rem / width;
        const int w = rem % width;

        const int base = (b * channels * spatial) + (d * height * width + h * width + w);
        const int stride = spatial;

        scalar_t max_val = input[base];
        scalar_t sum_exp = 1.0f;

        #pragma unroll 4
        for (int c = 0; c < channels; ++c) {
            const int pos = base + c * stride;
            const scalar_t val = input[pos];
            bool update_max = val > max_val;
            max_val = update_max ? val : max_val;
            sum_exp = update_max ? sum_exp * exp(max_val - val) + 1.0f : sum_exp + exp(val - max_val);
        }

        #pragma unroll 4
        for (int c = 0; c < channels; ++c) {
            const int pos = base + c * stride;
            const scalar_t softmax_val = exp(input[pos] - max_val) / sum_exp;
            output[pos] = 1.0f / (1.0f + exp(-softmax_val));
        }
    }
}

torch::Tensor forward(
    torch::Tensor input,
    int stride,
    int padding,
    int output_padding,
    bool bias_flag,
    torch::Tensor conv_transpose,
    torch::Tensor conv_transpose_bias) {

    auto x = torch::conv_transpose3d(
        input,
        conv_transpose,
        bias_flag ? conv_transpose_bias : torch::Tensor(),
        stride,
        padding,
        output_padding
    );

    const int batch = x.size(0);
    const int channels = x.size(1);
    const int depth = x.size(2);
    const int height = x.size(3);
    const int width = x.size(4);

    auto output = torch::empty_like(x);

    const int spatial = depth * height * width;
    const int total_pixels = batch * spatial;
    const int threads = BLOCK_SIZE;
    const int blocks = (total_pixels + threads - 1) / threads;

    AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "warp_divergence_reduced_softmax_sigmoid_kernel", ([&] {
        // Choose specialized kernel based on channel count
        if (channels == 32) {
            warp_divergence_reduced_softmax_sigmoid_kernel<scalar_t, 32><<<blocks, threads>>>(
                x.data_ptr<scalar_t>(),
                output.data_ptr<scalar_t>(),
                batch,
                depth,
                height,
                width);
        } else if (channels == 64) {
            warp_divergence_reduced_softmax_sigmoid_kernel<scalar_t, 64><<<blocks, threads>>>(
                x.data_ptr<scalar_t>(),
                output.data_ptr<scalar_t>(),
                batch,
                depth,
                height,
                width);
        } else {
            dynamic_warp_divergence_reduced_softmax_sigmoid_kernel<scalar_t><<<blocks, threads>>>(
                x.data_ptr<scalar_t>(),
                output.data_ptr<scalar_t>(),
                batch,
                channels,
                depth,
                height,
                width);
        }
    }));

    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Warp Divergence Reduced Fused ConvTranspose3d with Softmax and Sigmoid");
}
Performance Metrics
Metric Value Unit Variance Samples
Analysis Rules
Rule Description
Operation / Metric Value Unit
aten::conv_transpose3d
CPU Time 1867364.10 μs
Device Time 5102816.65 μs
Self CPU Time 7314.24 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::convolution
CPU Time 1860049.86 μs
Device Time 5102816.65 μs
Self CPU Time 10525.07 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_convolution
CPU Time 1849524.79 μs
Device Time 5102816.65 μs
Self CPU Time 20798.46 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::cudnn_convolution_transpose
CPU Time 536692.82 μs
Device Time 3143161.67 μs
Self CPU Time 143442.79 μs
Self Device Time 3143161.67 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaEventRecord
CPU Time 1857613.28 μs
Device Time 163474.62 μs
Self CPU Time 1857613.28 μs
Self Device Time 163474.62 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 4589055.64 μs
Device Time 84714.55 μs
Self CPU Time 4589055.64 μs
Self Device Time 84714.55 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::add_
CPU Time 1286288.72 μs
Device Time 1959654.98 μs
Self CPU Time 19109.21 μs
Self Device Time 1959654.98 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45295 warnings generated when compiling for host.
Suppressed 45327 warnings (45280 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_49/b7_s0_warp_divergence_reduction_base/base/base.cu:17:5 bugprone-easily-swappable-parameters
17 | const int batch,
| ^~~~~~~~~~~~~~~~
18 | const int depth,
| ~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_49/b7_s0_warp_divergence_reduction_base/base/base.cu:17:15: note: the first parameter in the range is 'batch'
17 | const int batch,
| ^~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_49/b7_s0_warp_divergence_reduction_base/base/base.cu:18:15: note: the last parameter in the range is 'depth'
18 | const int depth,
| ^~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_49/b7_s0_warp_divergence_reduction_base/base/base.cu:24:21: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
24 | const int idx = blockIdx.x * blockDim.x + threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_49/b7_s0_warp_divergence_reduction_base/base/base.cu:74:5: warning: 3 adjacent parameters of 'dynamic_warp_divergence_reduced_softmax_sigmoid_kernel' of similar type ('const int') are easily swapped by mistake [bugprone-easily-swappable-parameters]
74 | const int batch,
| ^~~~~~~~~~~~~~~~
75 | const int channels,
| ~~~~~~~~~~~~~~~~~~~
76 | const int depth,
| ~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_49/b7_s0_warp_divergence_reduction_base/base/base.cu:74:15: note: the first parameter in the range is 'batch'
74 | const int batch,
| ^~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_49/b7_s0_warp_divergence_reduction_base/base/base.cu:76:15: note: the last parameter in the range is 'depth'
76 | const int depth,
| ^~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_49/b7_s0_warp_divergence_reduction_base/base/base.cu:82:21: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
82 | const int idx = blockIdx.x * blockDim.x + threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_49/b7_s0_warp_divergence_reduction_base/base/base.cu:117:19: warning: the parameter 'input' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
117 | torch::Tensor input,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_49/b7_s0_warp_divergence_reduction_base/base/base.cu:122:19: warning: the parameter 'conv_transpose' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
122 | torch::Tensor conv_transpose,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_49/b7_s0_warp_divergence_reduction_base/base/base.cu:128:21: warning: parameter 'conv_transpose_bias' is passed by value and only copied once; consider moving it to avoid unnecessary copies [performance-unnecessary-value-param]
4 | bias_flag ? conv_transpose_bias : torch::Tensor(),
| ^
| std::move( )
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_49/b7_s0_warp_divergence_reduction_base/base/base.cu:134:23: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
134 | const int batch = x.size(0);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_49/b7_s0_warp_divergence_reduction_base/base/base.cu:135:26: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
135 | const int channels = x.size(1);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_49/b7_s0_warp_divergence_reduction_base/base/base.cu:136:23: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
136 | const int depth = x.size(2);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_49/b7_s0_warp_divergence_reduction_base/base/base.cu:137:24: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
137 | const int height = x.size(3);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_49/b7_s0_warp_divergence_reduction_base/base/base.cu:138:23: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
138 | const int width = x.size(4);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_49/b7_s0_warp_divergence_reduction_base/base/base.cu:147:5: warning: inside a lambda, '__func__' expands to the name of the function call operator; consider capturing the name of the enclosing function explicitly [bugprone-lambda-function-name]
147 | AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "warp_divergence_reduced_softmax_sigmoid_kernel", ([&] {
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:237:34: note: expanded from macro 'AT_DISPATCH_FLOATING_TYPES'
237 | AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:233:3: note: expanded from macro 'AT_DISPATCH_CASE_FLOATING_TYPES'
233 | AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:74:3: note: expanded from macro 'AT_DISPATCH_CASE'
74 | AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
| ^
note: (skipping 1 expansions in backtrace; use -fmacro-backtrace-limit=0 to see all)
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:58:7: note: expanded from macro 'AT_PRIVATE_CHECK_SELECTIVE_BUILD'
58 | AT_ERROR( \
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/c10/util/Exception.h:711:32: note: expanded from macro 'AT_ERROR'
711 | C10_EXPAND_MSVC_WORKAROUND(TORCH_CHECK(false, ::c10::str(__VA_ARGS__))); \
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/c10/util/Exception.h:536:9: note: expanded from macro 'TORCH_CHECK'
536 | __func__, \
| ^