← Back to Leaderboard

The AI CUDA Engineer 👷

47_Conv3d_Mish_Tanhaligned_ldg_mish_tanh_base

Level 2 • Task 47
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(
    x: torch.Tensor,
    stride: int,
    padding: int,
    conv_weight: torch.Tensor,
    conv_bias: torch.Tensor,
) -> torch.Tensor:
    """
    Applies 3D convolution followed by Mish and Tanh activations.

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, in_channels, D, H, W)
        stride (int): Stride of the convolution
        padding (int): Padding of the convolution
        conv_weight (torch.Tensor): Convolution weight tensor of shape
            (out_channels, in_channels, kernel_size, kernel_size, kernel_size)
        conv_bias (torch.Tensor): Bias tensor for convolution of shape (out_channels)

    Returns:
        torch.Tensor: Output tensor after applying convolution, Mish and Tanh activations
    """
    x = F.conv3d(x, conv_weight, bias=conv_bias, stride=stride, padding=padding)
    x = F.mish(x)
    x = torch.tanh(x)
    return x


class Model(nn.Module):
    """
    Model that performs a 3D convolution, applies Mish activation, and then applies Tanh activation.
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(Model, self).__init__()
        conv = nn.Conv3d(
            in_channels, out_channels, kernel_size, stride=stride, padding=padding
        )
        self.conv_weight = nn.Parameter(conv.weight)
        self.conv_bias = nn.Parameter(
            conv.bias
            + torch.randn(
                conv.bias.shape, device=conv.bias.device, dtype=conv.bias.dtype
            )
            * 0.02
        )

    def forward(self, x, stride, padding, fn=module_fn):
        return fn(x, stride, padding, self.conv_weight, self.conv_bias)


batch_size = 16
in_channels = 3
out_channels = 16
D, H, W = 16, 32, 32
kernel_size = 3
stride = 1
padding = 0


def get_inputs():
    return [torch.randn(batch_size, in_channels, D, H, W), stride, padding]


def get_init_inputs():
    return [in_channels, out_channels, kernel_size, stride, padding]
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Model that performs a 3D convolution, applies Mish activation, and then applies Tanh activation.
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(Model, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)
        self.conv.bias = nn.Parameter(self.conv.bias + torch.randn(self.conv.bias.shape, device=self.conv.bias.device, dtype=self.conv.bias.dtype) * 0.02)

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_channels, D, H, W).

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, out_channels, D', H', W').
        """
        x = self.conv(x)
        x = torch.nn.functional.mish(x)
        x = torch.tanh(x)
        return x

batch_size = 16
in_channels = 3
out_channels = 16
D, H, W = 16, 32, 32
kernel_size = 3

def get_inputs():
    return [torch.randn(batch_size, in_channels, D, H, W)]

def get_init_inputs():
    return [in_channels, out_channels, kernel_size]

Kernel Information

Related Kernels (Level 2, Task 47 • 47_Conv3d_Mish_Tanh)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 optimized_fused_mish_tanh_base 0.10 1.09 0.95
🥇 shared_mem_mish_tanh_base_base 0.10 1.09 0.95
🥇 modular_fused_mish_tanh_base 0.10 1.09 0.95
4 aligned_ldg_mish_tanh_base 0.10 1.08 0.94
4 warp_optimized_mish_tanh_base_base 0.10 1.08 0.94
4 vec_nosync_mish_tanh_base 0.10 1.08 0.94
4 block_size_optimization_mish_tanh_base 0.10 1.08 0.94
4 efficient_mish_tanh_shared_memory_base 0.10 1.08 0.94
4 fused_shared_unrolled_base 0.10 1.08 0.94
4 optimized_block_mish_tanh_base 0.10 1.08 0.94
4 manual_loop_unroll_fused_mish_tanh_base 0.10 1.08 0.94
4 stride_loop_mish_tanh_base 0.10 1.08 0.94
13 warp_level_no_shared_base 0.10 1.06 0.93
13 coalesced_memory_access_optimized_base 0.10 1.06 0.93
13 47_conv3d_mish_tanh_inplace_vec4_edit_1 0.10 1.06 0.93
13 47_conv3d_mish_tanh_shared_mem_base 0.10 1.06 0.93
13 47_Conv3d_Mish_Tanh_aligned_edit_1 0.10 1.06 0.93
13 47_Conv3d_Mish_Tanh_modular_base 0.10 1.06 0.93
19 47_conv3d_mish_tanh_unrolled_base 0.10 1.05 0.92
19 47_conv3d_mish_tanh_shared_mem_edit_1 0.10 1.05 0.92
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>

// Device function: fused Mish and Tanh activation
// Computes: tanh(x * tanh(log(1 + exp(x))))
__device__ __forceinline__ float fused_mish_tanh_activation(float x) {
    float softplus = logf(1.0f + expf(x));
    float mish = x * tanhf(softplus);
    return tanhf(mish);
}

// Kernel that uses __ldg() for read-only global memory loads and vectorized (128-bit) accesses
// Processes most of the data in chunks of 4 floats and handles remaining elements separately.
__global__ void aligned_ldg_mish_tanh_kernel(
    float* __restrict__ output,
    const float* __restrict__ input,
    const int total_elements,
    const int total_vec4  // Number of 4-float groups
) {
    int global_threads = blockDim.x * gridDim.x;
    
    // Process vectorized loads and stores (128-bit aligned, i.e., 4 floats at a time)
    for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < total_vec4; i += global_threads) {
        // Use __ldg() to load a float4 from aligned global memory
        float4 in_val = __ldg(reinterpret_cast<const float4*>(input) + i);
        float4 out_val;
        out_val.x = fused_mish_tanh_activation(in_val.x);
        out_val.y = fused_mish_tanh_activation(in_val.y);
        out_val.z = fused_mish_tanh_activation(in_val.z);
        out_val.w = fused_mish_tanh_activation(in_val.w);
        
        // Store the result back to global memory (assumed to be 16-byte aligned)
        reinterpret_cast<float4*>(output)[i] = out_val;
    }

    // Process remaining elements that don't fit into a vectorized load
    int vec4_total = total_vec4 * 4;
    for (int i = vec4_total + threadIdx.x + blockIdx.x * blockDim.x; i < total_elements; i += global_threads) {
        float in_val = __ldg(input + i);
        output[i] = fused_mish_tanh_activation(in_val);
    }
}

// Forward function: Perform 3D convolution then apply the fused activation via custom CUDA kernel
// This implementation leverages __ldg() for optimized read-only loads and 128-bit aligned vectorized memory accesses
torch::Tensor module_fn_forward(
    torch::Tensor x,
    int64_t stride,
    int64_t padding,
    torch::Tensor conv_weight,
    torch::Tensor conv_bias
) {
    TORCH_CHECK(x.is_cuda(), "Input tensor x must be a CUDA tensor");
    TORCH_CHECK(conv_weight.is_cuda(), "Convolution weight must be a CUDA tensor");
    TORCH_CHECK(conv_bias.is_cuda(), "Convolution bias must be a CUDA tensor");

    // Perform 3D convolution using PyTorch's highly optimized conv3d
    auto x_conv = at::conv3d(
        x, 
        conv_weight, 
        conv_bias, 
        {stride, stride, stride},
        {padding, padding, padding}
    );

    // Prepare output tensor for activation result
    auto output = torch::empty_like(x_conv);
    const int total_elements = x_conv.numel();
    
    // Determine how many complete groups of 4 floats can be processed
    const int total_vec4 = total_elements / 4;  // 4 floats per 128-bit vectorized load

    // Set execution configuration
    const int block_size = 256;
    // Launch enough blocks to cover the vectorized part; leftover elements are handled within the same kernel
    const int num_blocks = (total_vec4 + block_size - 1) / block_size;

    aligned_ldg_mish_tanh_kernel<<<num_blocks, block_size>>>(
        output.data_ptr<float>(),
        x_conv.data_ptr<float>(),
        total_elements,
        total_vec4
    );

    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &module_fn_forward, "Fused convolution with Mish and Tanh activations using aligned __ldg() (CUDA)");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 2.838 inst/cycle 0.002 5
Executed Ipc Elapsed 2.334 inst/cycle 0.000 5
Issue Slots Busy 71.668 % 1.094 5
Issued Ipc Active 2.868 inst/cycle 0.002 5
SM Busy 71.668 % 1.094 5
Memory Throughput 1079819606458.638 byte/second 9887681800948221952.000 5
Mem Busy 27.666 % 0.016 5
Max Bandwidth 32.288 % 0.016 5
L1/TEX Hit Rate 0.000 % 0.000 5
L2 Hit Rate 50.978 % 0.019 5
Mem Pipes Busy 6.602 % 0.001 5
Warp Cycles Per Issued Instruction 18.028 cycle 0.001 5
Warp Cycles Per Executed Instruction 18.206 cycle 0.001 5
Avg. Active Threads Per Warp 21.370 0.000 5
Avg. Not Predicated Off Threads Per Warp 20.690 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 10.000 block 0.000 5
Block Limit Shared Mem 32.000 block 0.000 5
Block Limit Warps 8.000 block 0.000 5
Theoretical Active Warps per SM 64.000 warp 0.000 5
Theoretical Occupancy 100.000 % 0.000 5
Achieved Occupancy 82.168 % 0.060 5
Achieved Active Warps Per SM 52.588 warp 0.025 5
Analysis Rules
Rule Description
INF HighPipeUtilization ALU is the highest-utilized pipeline (30.6%) based on active cycles, taking into account the rates of its different instructions. It executes integer and logic operations. It is well-utilized, but should not be a bottleneck.
WRN ThreadDivergence Instructions are executed in warps, which are groups of 32 threads. Optimal instruction throughput is achieved if all 32 threads of a warp execute the same instruction. The chosen launch configuration, early thread completion, and divergent flow control can significantly lower the number of active threads in a warp per cycle. This kernel achieves an average of 21.4 threads being active per cycle. This is further reduced to 20.7 threads per warp due to predication. The compiler may use predication to avoid an actual branch. Instead, all instructions are scheduled, but a per-thread condition code or predicate controls which threads execute the instructions. Try to avoid different execution paths within a warp when possible. In addition, ensure your kernel makes use of Independent Thread Scheduling, which allows a warp to reconverge after a data-dependent conditional block by explicitly calling __syncwarp().
WRN Occupancy This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (82.3%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy.
Operation / Metric Value Unit
aten::conv3d
CPU Time 1009759.07 μs
Device Time 960325.98 μs
Self CPU Time 20321.81 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::convolution
CPU Time 989437.25 μs
Device Time 960325.98 μs
Self CPU Time 26016.34 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_convolution
CPU Time 963420.92 μs
Device Time 960325.98 μs
Self CPU Time 51420.67 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::cudnn_convolution
CPU Time 793440.72 μs
Device Time 835787.22 μs
Self CPU Time 241717.96 μs
Self Device Time 835787.22 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
sm80_xmma_fprop_implicit_gemm_indexed_f32f32_f32f32_f32_nchwkcrs_nchw_tilesize32x32x8_stage3_warpsize1x2x1_g1_ffma_aligna4_alignc4_execute_kernel__5x_cudnn
CPU Time 0.00 μs
Device Time 835784.27 μs
Self CPU Time 0.00 μs
Self Device Time 835784.27 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 761603.23 μs
Device Time 33382.19 μs
Self CPU Time 761603.23 μs
Self Device Time 33382.19 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45285 warnings generated when compiling for host.
Suppressed 45325 warnings (45278 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_47/b6_s2_aligned_ldg_mish_tanh/base/base.cu:19:5 bugprone-easily-swappable-parameters
19 | const int total_elements,
| ^~~~~~~~~~~~~~~~~~~~~~~~~
20 | const int total_vec4 // Number of 4-float groups
| ~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_47/b6_s2_aligned_ldg_mish_tanh/base/base.cu:19:15: note: the first parameter in the range is 'total_elements'
19 | const int total_elements,
| ^~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_47/b6_s2_aligned_ldg_mish_tanh/base/base.cu:20:15: note: the last parameter in the range is 'total_vec4'
20 | const int total_vec4 // Number of 4-float groups
| ^~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_47/b6_s2_aligned_ldg_mish_tanh/base/base.cu:22:26: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
22 | int global_threads = blockDim.x * gridDim.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_47/b6_s2_aligned_ldg_mish_tanh/base/base.cu:25:18: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
25 | for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < total_vec4; i += global_threads) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_47/b6_s2_aligned_ldg_mish_tanh/base/base.cu:40:18: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
40 | for (int i = vec4_total + threadIdx.x + blockIdx.x * blockDim.x; i < total_elements; i += global_threads) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_47/b6_s2_aligned_ldg_mish_tanh/base/base.cu:49:19: warning: the parameter 'x' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
49 | torch::Tensor x,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_47/b6_s2_aligned_ldg_mish_tanh/base/base.cu:52:19: warning: the parameter 'conv_weight' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
52 | torch::Tensor conv_weight,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_2/task_47/b6_s2_aligned_ldg_mish_tanh/base/base.cu:70:32: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
70 | const int total_elements = x_conv.numel();
| ^