← Back to Leaderboard

The AI CUDA Engineer 👷

81_Gemm_Swish_Divide_Clamp_Tanh_Clampfused_activation_base

Level 2 • Task 81
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(
    x: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
) -> torch.Tensor:
    """
    Performs gemm, swish, divide, clamp, tanh, and clamp operations.

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, in_features)
        weight (torch.Tensor): Weight matrix of shape (out_features, in_features)
        bias (torch.Tensor): Bias vector of shape (out_features)

    Returns:
        torch.Tensor: Output tensor of shape (batch_size, out_features)
    """
    x = F.linear(x, weight, bias)
    x = x * torch.sigmoid(x)  # Swish activation
    x = x / 2.0
    x = torch.clamp(x, min=-1.0, max=1.0)  # Clamp between -1 and 1
    x = torch.tanh(x)  # Tanh activation
    x = torch.clamp(x, min=-1.0, max=1.0)  # Clamp between -1 and 1
    return x


class Model(nn.Module):
    """
    Simple model that performs a gemm, swish, divide, clamp, tanh, and clamp operations.
    """

    def __init__(self, in_features, out_features):
        super(Model, self).__init__()
        mm = nn.Linear(in_features, out_features)
        self.weight = nn.Parameter(mm.weight)
        self.bias = nn.Parameter(mm.bias)

    def forward(self, x, fn=module_fn):
        return fn(x, self.weight, self.bias)


batch_size = 128
in_features = 1024
out_features = 512


def get_inputs():
    return [torch.randn(batch_size, in_features)]


def get_init_inputs():
    return [in_features, out_features]
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Simple model that performs a gemm, swish, divide, clamp, tanh, and clamp operations.
    """
    def __init__(self, in_features, out_features, bias=True):
        super(Model, self).__init__()
        self.gemm = nn.Linear(in_features, out_features, bias=bias)

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
        Returns:
            torch.Tensor: Output tensor of shape (batch_size, out_features).
        """
        x = self.gemm(x)
        x = x * torch.sigmoid(x)  # Swish activation
        x = x / 2.0
        x = torch.clamp(x, min=-1.0, max=1.0)  # Clamp between -1 and 1
        x = torch.tanh(x)  # Tanh activation
        x = torch.clamp(x, min=-1.0, max=1.0)  # Clamp between -1 and 1
        return x

batch_size = 128
in_features = 1024
out_features = 512

def get_inputs():
    return [torch.randn(batch_size, in_features)]

def get_init_inputs():
    return [in_features, out_features]

Kernel Information

Related Kernels (Level 2, Task 81 • 81_Gemm_Swish_Divide_Clamp_Tanh_Clamp)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 gemm_2d_map_base 0.02 2.19 1.90
🥇 gemm_2d_map_improved_base 0.02 2.19 1.90
🥇 hybrid_aligned_2d_kernel_base 0.02 2.19 1.90
🥇 aligned_memory_access_base_base 0.02 2.19 1.90
🥇 optimized_stride_loop_base 0.02 2.19 1.90
🥇 stride_loop_optimization_base_base 0.02 2.19 1.90
🥇 unrolled_2d_map_base_base 0.02 2.19 1.90
🥇 stride_loop_opt_base 0.02 2.19 1.90
🥇 fused_activation_base 0.02 2.19 1.90
🥇 optimized_stride_loop_with_prefetch_base 0.02 2.19 1.90
🥇 optimized_fused_activation_base 0.02 2.19 1.90
🥇 optimized_thread_block_indexing_base 0.02 2.19 1.90
🥇 dim2_grid_activation_base 0.02 2.19 1.90
14 dynamic_block_size_base_base 0.02 2.10 1.82
14 strided_vectorized_base_base 0.02 2.10 1.82
14 even_load_base 0.02 2.10 1.82
14 81_gemm_swish_divide_clamp_tanh_clamp_optimized_blocks_base 0.02 2.10 1.82
14 stride_loop_correct_base 0.02 2.10 1.82
14 adaptive_vectorized_kernel_base 0.02 2.10 1.82
14 uniform_control_flow_base_base 0.02 2.10 1.82
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cmath>

// Macros for input checking
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
    CHECK_CUDA(x);     \
    CHECK_CONTIGUOUS(x)


// Fused activation kernel using a grid-stride loop for flexibility and efficient load balancing
template <typename scalar_t>
__global__ void fused_activation_kernel(
    const scalar_t* __restrict__ x_in,
    scalar_t* __restrict__ x_out,
    const size_t size) {

    size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
    const scalar_t one = static_cast<scalar_t>(1);
    const scalar_t two = static_cast<scalar_t>(2);
    
    // Grid-stride loop handles arbitrary tensor sizes
    for (; idx < size; idx += blockDim.x * gridDim.x) {
        scalar_t x = x_in[idx];

        // Swish activation: x * sigmoid(x)
        scalar_t sigmoid_x = one / (one + exp(-x));
        x = x * sigmoid_x;
        
        // Scale down value by 2
        x = x / two;
        
        // Clamp between -1 and 1 (first clamp for numerical stability)
        x = max(min(x, one), -one);

        // Tanh activation
        x = tanh(x);

        // Second clamping for extra safety (tanh already returns in [-1, 1] but to enforce strict bounds)
        x = max(min(x, one), -one);
        
        x_out[idx] = x;
    }
}

// CUDA forward function fusing linear transformation with activation
// Note: We still use torch::addmm for the linear part since cuBLAS GEMM is highly optimized
// and then fuse the activation in one efficient kernel.

torch::Tensor module_forward_cuda(
    torch::Tensor x,
    torch::Tensor weight,
    torch::Tensor bias) {

    // Execute linear transformation: x_linear = bias + x * weight^T
    auto x_linear = torch::addmm(bias, x, weight.t());
    auto x_out = torch::empty_like(x_linear);

    // Compute total number of elements
    size_t numel = x_linear.numel();

    // Use a 1D grid-stride loop kernel for the activation part
    // Choosing a block size of 256 threads to balance occupancy
    const int threads = 256;
    const int blocks = (numel + threads - 1) / threads;

    AT_DISPATCH_FLOATING_TYPES_AND_HALF(x_linear.scalar_type(), "module_forward_cuda", ([&] {
        fused_activation_kernel<scalar_t><<<blocks, threads>>>(
            x_linear.data_ptr<scalar_t>(),
            x_out.data_ptr<scalar_t>(),
            numel);
    }));

    return x_out;
}

// C++ interface that wraps the CUDA implementation
torch::Tensor module_forward(
    torch::Tensor x,
    torch::Tensor weight,
    torch::Tensor bias) {
    CHECK_INPUT(x);
    CHECK_INPUT(weight);
    CHECK_INPUT(bias);
    return module_forward_cuda(x, weight, bias);
}

// Binding the forward function to Python via PyBind11
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &module_forward, "Fused linear and activation forward function (CUDA)");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 0.404 inst/cycle 0.000 5
Executed Ipc Elapsed 0.148 inst/cycle 0.000 5
Issue Slots Busy 11.086 % 0.084 5
Issued Ipc Active 0.444 inst/cycle 0.000 5
SM Busy 11.086 % 0.084 5
Memory Throughput 76427759640.970 byte/second 1752853668414941184.000 5
Mem Busy 10.576 % 0.019 5
Max Bandwidth 6.874 % 0.011 5
L1/TEX Hit Rate 0.000 % 0.000 5
L2 Hit Rate 83.022 % 0.145 5
Mem Pipes Busy 1.708 % 0.001 5
Warp Cycles Per Issued Instruction 32.092 cycle 0.111 5
Warp Cycles Per Executed Instruction 35.206 cycle 0.134 5
Avg. Active Threads Per Warp 31.060 0.000 5
Avg. Not Predicated Off Threads Per Warp 29.810 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 16.000 block 0.000 5
Block Limit Shared Mem 32.000 block 0.000 5
Block Limit Warps 8.000 block 0.000 5
Theoretical Active Warps per SM 64.000 warp 0.000 5
Theoretical Occupancy 100.000 % 0.000 5
Achieved Occupancy 21.898 % 0.101 5
Achieved Active Warps Per SM 14.016 warp 0.042 5
Analysis Rules
Rule Description
WRN HighPipeUtilization All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details.
INF CPIStall Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason.
WRN Occupancy This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (22.3%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy.
Operation / Metric Value Unit
aten::to
CPU Time 408985.31 μs
Device Time 190.49 μs
Self CPU Time 49.91 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_to_copy
CPU Time 408935.40 μs
Device Time 190.49 μs
Self CPU Time 95.41 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::empty_strided
CPU Time 431470.39 μs
Device Time 0.00 μs
Self CPU Time 23229.11 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaDeviceGetStreamPriorityRange
CPU Time 407834.87 μs
Device Time 0.00 μs
Self CPU Time 407834.87 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::addmm
CPU Time 610854.63 μs
Device Time 154765.46 μs
Self CPU Time 204694.73 μs
Self Device Time 154765.46 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
sm80_xmma_gemm_f32f32_f32f32_f32_tn_n_tilesize32x32x8_stage3_warpsize1x2x1_ffma_aligna4_alignc4_execute_kernel__51_cublas
CPU Time 0.00 μs
Device Time 140056.57 μs
Self CPU Time 0.00 μs
Self Device Time 140056.57 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 76511.63 μs
Device Time 686927.40 μs
Self CPU Time 15548.34 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::fill_
CPU Time 60964.94 μs
Device Time 686927.40 μs
Self CPU Time 19604.41 μs
Self Device Time 686927.40 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>)
CPU Time 0.00 μs
Device Time 687005.16 μs
Self CPU Time 0.00 μs
Self Device Time 687005.16 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45290 warnings generated when compiling for host.
Suppressed 45323 warnings (45276 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_81/b4_s3_fused_activation/base/base.cu:8:35 bugprone-macro-parentheses
8 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
| ^
| ()
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_81/b4_s3_fused_activation/base/base.cu:9:41: warning: macro argument should be enclosed in parentheses [bugprone-macro-parentheses]
9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
| ^
| ()
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_81/b4_s3_fused_activation/base/base.cu:27:31: warning: performing an implicit widening conversion to type 'size_t' (aka 'unsigned long') of a multiplication performed in type 'unsigned int' [bugprone-implicit-widening-of-multiplication-result]
27 | for (; idx < size; idx += blockDim.x * gridDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_81/b4_s3_fused_activation/base/base.cu:27:31: note: make conversion explicit to silence this warning
6 | for (; idx < size; idx += blockDim.x * gridDim.x) {
| ^~~~~~~~~~~~~~~~~~~~~~
| static_cast<size_t>( )
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_81/b4_s3_fused_activation/base/base.cu:27:31: note: perform multiplication in a wider type
27 | for (; idx < size; idx += blockDim.x * gridDim.x) {
| ^~~~~~~~~~
| static_cast<size_t>( )
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_81/b4_s3_fused_activation/base/base.cu:55:19: warning: the parameter 'x' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
55 | torch::Tensor x,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_81/b4_s3_fused_activation/base/base.cu:56:19: warning: the parameter 'weight' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
56 | torch::Tensor weight,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_81/b4_s3_fused_activation/base/base.cu:57:19: warning: the parameter 'bias' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
57 | torch::Tensor bias) {
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_81/b4_s3_fused_activation/base/base.cu:69:24: warning: narrowing conversion from 'size_t' (aka 'unsigned long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
69 | const int blocks = (numel + threads - 1) / threads;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_81/b4_s3_fused_activation/base/base.cu:71:5: warning: inside a lambda, '__func__' expands to the name of the function call operator; consider capturing the name of the enclosing function explicitly [bugprone-lambda-function-name]
71 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x_linear.scalar_type(), "module_forward_cuda", ([&] {
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:246:19: note: expanded from macro 'AT_DISPATCH_FLOATING_TYPES_AND_HALF'
246 | TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(__VA_ARGS__))
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:240:3: note: expanded from macro 'AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF'
240 | AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:74:3: note: expanded from macro 'AT_DISPATCH_CASE'
74 | AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
| ^
note: (skipping 1 expansions in backtrace; use -fmacro-backtrace-limit=0 to see all)
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:58:7: note: expanded from macro 'AT_PRIVATE_CHECK_SELECTIVE_BUILD'
58 | AT_ERROR( \
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/c10/util/Exception.h:711:32: note: expanded from macro 'AT_ERROR'
711 | C10_EXPAND_MSVC_WORKAROUND(TORCH_CHECK(false, ::c10::str(__VA_ARGS__))); \
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/c10/util/Exception.h:536:9: note: expanded from macro 'TORCH_CHECK'
536 | __func__, \
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_81/b4_s3_fused_activation/base/base.cu:83:19: warning: the parameter 'x' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
83 | torch::Tensor x,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_81/b4_s3_fused_activation/base/base.cu:84:19: warning: the parameter 'weight' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
84 | torch::Tensor weight,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_81/b4_s3_fused_activation/base/base.cu:85:19: warning: the parameter 'bias' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
85 | torch::Tensor bias) {
| ^
| const &