← Back to Leaderboard

The AI CUDA Engineer 👷

79_conv_transposed_1D_asymmetric_input_square_kernel___padded____strided____dilated__optimized_thread_block_mapping_base_base

Level 1 • Task 79
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(
    x: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
    stride: int,
    padding: int,
    dilation: int,
) -> torch.Tensor:
    """
    Performs a transposed 1D convolution operation with asymmetric input and square kernel. Supports padding, striding, and dilation.

    Args:
        x (torch.Tensor): Input tensor
        weight (torch.Tensor): Convolution weights
        bias (torch.Tensor): Bias tensor (optional)
        stride (int): Stride of the convolution
        padding (int): Padding applied to the input
        dilation (int): Spacing between kernel elements

    Returns:
        torch.Tensor: Output tensor
    """
    return F.conv_transpose1d(
        x, weight, bias=bias, stride=stride, padding=padding, dilation=dilation
    )


class Model(nn.Module):
    """
    Performs a transposed 1D convolution operation with asymmetric input and square kernel.
    Supports padding, striding, and dilation.

    Args:
        in_channels (int): Number of channels in the input tensor.
        out_channels (int): Number of channels produced by the convolution.
        kernel_size (int): Size of the square convolution kernel.
        stride (int): Stride of the convolution.
        padding (int): Padding applied to the input.
        dilation (int): Spacing between kernel elements.
        bias (bool): If `True`, adds a learnable bias to the output.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int,
        padding: int,
        dilation: int,
        bias: bool,
    ):
        super(Model, self).__init__()
        self.conv_transpose1d = nn.ConvTranspose1d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            bias=bias,
        )

        # Copy the initialized parameters
        self.weight = nn.Parameter(self.conv_transpose1d.weight.clone())
        self.bias = nn.Parameter(self.conv_transpose1d.bias.clone()) if bias else None

        self.stride = stride
        self.padding = padding
        self.dilation = dilation

    def forward(self, x: torch.Tensor, fn=module_fn) -> torch.Tensor:
        return fn(
            x,
            self.weight,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
        )


# Constants
batch_size = 16
in_channels = 32
out_channels = 64
kernel_size = 3
length = 128
stride = 2
padding = 1
dilation = 2
bias = False


def get_inputs():
    x = torch.randn(batch_size, in_channels, length)
    return [x]


def get_init_inputs():
    return [in_channels, out_channels, kernel_size, stride, padding, dilation, bias]
import torch
import torch.nn as nn


class Model(nn.Module):
    """
    Performs a transposed 1D convolution operation with asymmetric input and square kernel.
    Supports padding, striding, and dilation.

    Args:
        in_channels (int): Number of channels in the input tensor.
        out_channels (int): Number of channels produced by the convolution.
        kernel_size (int): Size of the square convolution kernel.
        stride (int, optional): Stride of the convolution. Defaults to 1.
        padding (int, optional): Padding applied to the input. Defaults to 0.
        dilation (int, optional): Spacing between kernel elements. Defaults to 1.
        bias (bool, optional): If `True`, adds a learnable bias to the output. Defaults to `False`.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        padding: int = 0,
        dilation: int = 1,
        bias: bool = False,
    ):
        super(Model, self).__init__()
        self.conv1d_transpose = nn.ConvTranspose1d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            bias=bias,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Performs the transposed 1D convolution.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_channels, length).

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, out_channels, length_out).
        """
        return self.conv1d_transpose(x)


# Constants
batch_size = 16
in_channels = 32
out_channels = 64
kernel_size = 3
length = 128
stride = 2
padding = 1
dilation = 2
bias = False


def get_inputs():
    x = torch.randn(batch_size, in_channels, length)
    return [x]


def get_init_inputs():
    return [in_channels, out_channels, kernel_size, stride, padding, dilation, bias]

Kernel Information

Related Kernels (Level 1, Task 79 • 79_conv_transposed_1D_asymmetric_input_square_kernel___padded____strided____dilated__)

#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

inline int compute_output_length(int input_length, int stride, int padding, int dilation, int kernel_size) {
    return (input_length - 1) * stride - 2 * padding + dilation * (kernel_size - 1) + 1;
}

__global__ void conv_transpose1d_kernel(
    const float* x_ptr,
    const float* weight_ptr,
    const float* bias_ptr,
    float* output_ptr,
    int batch_size,
    int in_channels,
    int out_channels,
    int input_length,
    int output_length,
    int kernel_size,
    int stride,
    int padding,
    int dilation
) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= batch_size * out_channels * output_length) return;

    int o = idx % output_length;
    int oc = (idx / output_length) % out_channels;
    int b = idx / (out_channels * output_length);

    float sum = 0.0f;

    for (int k = 0; k < kernel_size; ++k) {
        int i_pos = o + padding - k * dilation;
        if (i_pos % stride != 0) continue;
        int i = i_pos / stride;
        if (i < 0 || i >= input_length) continue;

        for (int ic = 0; ic < in_channels; ++ic) {
            int x_idx = b * in_channels * input_length + ic * input_length + i;
            int weight_idx = ic * out_channels * kernel_size + oc * kernel_size + k;
            sum += x_ptr[x_idx] * weight_ptr[weight_idx];
        }
    }

    if (bias_ptr) {
        sum += bias_ptr[oc];
    }

    int output_idx = b * out_channels * output_length + oc * output_length + o;
    output_ptr[output_idx] = sum;
}

torch::Tensor forward_cuda(
    torch::Tensor x,
    torch::Tensor weight,
    torch::optional<torch::Tensor> bias,
    int stride,
    int padding,
    int dilation
) {
    TORCH_CHECK(x.device().is_cuda(), "x must be a CUDA tensor");
    TORCH_CHECK(weight.device().is_cuda(), "weight must be a CUDA tensor");
    TORCH_CHECK(x.dim() == 3, "x must be 3D (batch, in_channels, input_length)");
    TORCH_CHECK(weight.dim() == 3, "weight must be 3D (in_channels, out_channels, kernel_size)");

    x = x.contiguous();
    weight = weight.contiguous();
    torch::Tensor bias_contig;
    const float* bias_ptr = nullptr;

    if (bias.has_value()) {
        bias_contig = bias->contiguous();
        TORCH_CHECK(bias_contig.device().is_cuda(), "bias must be a CUDA tensor");
        TORCH_CHECK(bias_contig.dim() == 1, "bias must be 1D");
        bias_ptr = bias_contig.data_ptr<float>();
    }

    int batch_size = x.size(0);
    int in_channels = x.size(1);
    int input_length = x.size(2);
    int out_channels = weight.size(1);
    int kernel_size = weight.size(2);

    TORCH_CHECK(weight.size(0) == in_channels, "weight's in_channels must match x's in_channels");
    if (bias.has_value()) {
        TORCH_CHECK(bias_contig.size(0) == out_channels, "bias size must match out_channels");
    }

    int output_length = compute_output_length(input_length, stride, padding, dilation, kernel_size);
    auto output = torch::zeros({batch_size, out_channels, output_length}, x.options());

    int num_output_elements = batch_size * out_channels * output_length;
    int threads_per_block = 256;
    int num_blocks = (num_output_elements + threads_per_block - 1) / threads_per_block;

    conv_transpose1d_kernel<<<num_blocks, threads_per_block>>>(
        x.data_ptr<float>(),
        weight.data_ptr<float>(),
        bias_ptr,
        output.data_ptr<float>(),
        batch_size,
        in_channels,
        out_channels,
        input_length,
        output_length,
        kernel_size,
        stride,
        padding,
        dilation
    );

    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward_cuda, "ConvTranspose1D forward (CUDA)",
          py::arg("x"), py::arg("weight"), py::arg("bias") = py::none(),
          py::arg("stride"), py::arg("padding"), py::arg("dilation"));
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 2.206 inst/cycle 0.000 5
Executed Ipc Elapsed 1.822 inst/cycle 0.000 5
Issue Slots Busy 55.396 % 0.022 5
Issued Ipc Active 2.216 inst/cycle 0.000 5
SM Busy 59.922 % 0.027 5
Memory Throughput 14663432159.262 byte/second 15736992556671076.000 5
Mem Busy 39.542 % 0.077 5
Max Bandwidth 38.994 % 0.109 5
L1/TEX Hit Rate 80.076 % 0.001 5
L2 Hit Rate 96.158 % 0.039 5
Mem Pipes Busy 38.850 % 0.093 5
Warp Cycles Per Issued Instruction 20.742 cycle 0.004 5
Warp Cycles Per Executed Instruction 20.814 cycle 0.003 5
Avg. Active Threads Per Warp 19.760 0.000 5
Avg. Not Predicated Off Threads Per Warp 18.560 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 8.000 block 0.000 5
Block Limit Shared Mem 32.000 block 0.000 5
Block Limit Warps 8.000 block 0.000 5
Theoretical Active Warps per SM 64.000 warp 0.000 5
Theoretical Occupancy 100.000 % 0.000 5
Achieved Occupancy 71.950 % 0.018 5
Achieved Active Warps Per SM 46.046 warp 0.008 5
Analysis Rules
Rule Description
INF HighPipeUtilization FMA is the highest-utilized pipeline (35.4%) based on active cycles, taking into account the rates of its different instructions. It executes 32-bit floating point (FADD, FMUL, FMAD, ...) and integer (IMUL, IMAD) operations. It is well-utilized, but should not be a bottleneck.
WRN ThreadDivergence Instructions are executed in warps, which are groups of 32 threads. Optimal instruction throughput is achieved if all 32 threads of a warp execute the same instruction. The chosen launch configuration, early thread completion, and divergent flow control can significantly lower the number of active threads in a warp per cycle. This kernel achieves an average of 19.8 threads being active per cycle. This is further reduced to 18.6 threads per warp due to predication. The compiler may use predication to avoid an actual branch. Instead, all instructions are scheduled, but a per-thread condition code or predicate controls which threads execute the instructions. Try to avoid different execution paths within a warp when possible. In addition, ensure your kernel makes use of Independent Thread Scheduling, which allows a warp to reconverge after a data-dependent conditional block by explicitly calling __syncwarp().
WRN Occupancy This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (71.9%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy.
Operation / Metric Value Unit
aten::to
CPU Time 452473.17 μs
Device Time 16.58 μs
Self CPU Time 53.10 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zeros
CPU Time 6254899.82 μs
Device Time 178996.50 μs
Self CPU Time 141060.05 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 6842727.57 μs
Device Time 7101725.28 μs
Self CPU Time 306661.27 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::fill_
CPU Time 6536067.22 μs
Device Time 7101725.28 μs
Self CPU Time 385772.95 μs
Self Device Time 7101725.28 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 6471914.63 μs
Device Time 2755.95 μs
Self CPU Time 6471914.63 μs
Self Device Time 2755.95 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
conv_transpose1d_kernel(float const*, float const*, float const*, float*, int, int, int, int, int, int, int, int, int)
CPU Time 0.00 μs
Device Time 1455658.04 μs
Self CPU Time 0.00 μs
Self Device Time 1455658.04 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaEventRecord
CPU Time 193743.00 μs
Device Time 1147935.98 μs
Self CPU Time 193743.00 μs
Self Device Time 1147935.98 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>)
CPU Time 0.00 μs
Device Time 6922728.79 μs
Self CPU Time 0.00 μs
Self Device Time 6922728.79 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45289 warnings generated when compiling for host.
Suppressed 45326 warnings (45279 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_79/b2_s3_optimized_thread_block_mapping_base/base/base.cu:11:5 bugprone-easily-swappable-parameters
11 | const float* weight_ptr,
| ^~~~~~~~~~~~~~~~~~~~~~~~
12 | const float* bias_ptr,
| ~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_79/b2_s3_optimized_thread_block_mapping_base/base/base.cu:11:18: note: the first parameter in the range is 'weight_ptr'
11 | const float* weight_ptr,
| ^~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_79/b2_s3_optimized_thread_block_mapping_base/base/base.cu:12:18: note: the last parameter in the range is 'bias_ptr'
12 | const float* bias_ptr,
| ^~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_79/b2_s3_optimized_thread_block_mapping_base/base/base.cu:14:5: warning: 2 adjacent parameters of 'conv_transpose1d_kernel' of similar type ('int') are easily swapped by mistake [bugprone-easily-swappable-parameters]
14 | int batch_size,
| ^~~~~~~~~~~~~~~
15 | int in_channels,
| ~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_79/b2_s3_optimized_thread_block_mapping_base/base/base.cu:14:9: note: the first parameter in the range is 'batch_size'
14 | int batch_size,
| ^~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_79/b2_s3_optimized_thread_block_mapping_base/base/base.cu:15:9: note: the last parameter in the range is 'in_channels'
15 | int in_channels,
| ^~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_79/b2_s3_optimized_thread_block_mapping_base/base/base.cu:16:5: warning: 2 adjacent parameters of 'conv_transpose1d_kernel' of similar type ('int') are easily swapped by mistake [bugprone-easily-swappable-parameters]
16 | int out_channels,
| ^~~~~~~~~~~~~~~~~
17 | int input_length,
| ~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_79/b2_s3_optimized_thread_block_mapping_base/base/base.cu:16:9: note: the first parameter in the range is 'out_channels'
16 | int out_channels,
| ^~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_79/b2_s3_optimized_thread_block_mapping_base/base/base.cu:17:9: note: the last parameter in the range is 'input_length'
17 | int input_length,
| ^~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_79/b2_s3_optimized_thread_block_mapping_base/base/base.cu:18:5: warning: 4 adjacent parameters of 'conv_transpose1d_kernel' of similar type ('int') are easily swapped by mistake [bugprone-easily-swappable-parameters]
18 | int output_length,
| ^~~~~~~~~~~~~~~~~~
19 | int kernel_size,
| ~~~~~~~~~~~~~~~~
20 | int stride,
| ~~~~~~~~~~~
21 | int padding,
| ~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_79/b2_s3_optimized_thread_block_mapping_base/base/base.cu:18:9: note: the first parameter in the range is 'output_length'
18 | int output_length,
| ^~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_79/b2_s3_optimized_thread_block_mapping_base/base/base.cu:21:9: note: the last parameter in the range is 'padding'
21 | int padding,
| ^~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_79/b2_s3_optimized_thread_block_mapping_base/base/base.cu:24:15: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
24 | int idx = blockIdx.x * blockDim.x + threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_79/b2_s3_optimized_thread_block_mapping_base/base/base.cu:79:22: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
79 | int batch_size = x.size(0);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_79/b2_s3_optimized_thread_block_mapping_base/base/base.cu:80:23: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
80 | int in_channels = x.size(1);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_79/b2_s3_optimized_thread_block_mapping_base/base/base.cu:81:24: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
81 | int input_length = x.size(2);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_79/b2_s3_optimized_thread_block_mapping_base/base/base.cu:82:24: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
82 | int out_channels = weight.size(1);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_79/b2_s3_optimized_thread_block_mapping_base/base/base.cu:83:23: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
83 | int kernel_size = weight.size(2);
| ^