← Back to Leaderboard

The AI CUDA Engineer 👷

81_Gemm_Swish_Divide_Clamp_Tanh_Clampdim2_grid_activation_base

Level 2 • Task 81
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(
    x: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
) -> torch.Tensor:
    """
    Performs gemm, swish, divide, clamp, tanh, and clamp operations.

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, in_features)
        weight (torch.Tensor): Weight matrix of shape (out_features, in_features)
        bias (torch.Tensor): Bias vector of shape (out_features)

    Returns:
        torch.Tensor: Output tensor of shape (batch_size, out_features)
    """
    x = F.linear(x, weight, bias)
    x = x * torch.sigmoid(x)  # Swish activation
    x = x / 2.0
    x = torch.clamp(x, min=-1.0, max=1.0)  # Clamp between -1 and 1
    x = torch.tanh(x)  # Tanh activation
    x = torch.clamp(x, min=-1.0, max=1.0)  # Clamp between -1 and 1
    return x


class Model(nn.Module):
    """
    Simple model that performs a gemm, swish, divide, clamp, tanh, and clamp operations.
    """

    def __init__(self, in_features, out_features):
        super(Model, self).__init__()
        mm = nn.Linear(in_features, out_features)
        self.weight = nn.Parameter(mm.weight)
        self.bias = nn.Parameter(mm.bias)

    def forward(self, x, fn=module_fn):
        return fn(x, self.weight, self.bias)


batch_size = 128
in_features = 1024
out_features = 512


def get_inputs():
    return [torch.randn(batch_size, in_features)]


def get_init_inputs():
    return [in_features, out_features]
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Simple model that performs a gemm, swish, divide, clamp, tanh, and clamp operations.
    """
    def __init__(self, in_features, out_features, bias=True):
        super(Model, self).__init__()
        self.gemm = nn.Linear(in_features, out_features, bias=bias)

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
        Returns:
            torch.Tensor: Output tensor of shape (batch_size, out_features).
        """
        x = self.gemm(x)
        x = x * torch.sigmoid(x)  # Swish activation
        x = x / 2.0
        x = torch.clamp(x, min=-1.0, max=1.0)  # Clamp between -1 and 1
        x = torch.tanh(x)  # Tanh activation
        x = torch.clamp(x, min=-1.0, max=1.0)  # Clamp between -1 and 1
        return x

batch_size = 128
in_features = 1024
out_features = 512

def get_inputs():
    return [torch.randn(batch_size, in_features)]

def get_init_inputs():
    return [in_features, out_features]

Kernel Information

Related Kernels (Level 2, Task 81 • 81_Gemm_Swish_Divide_Clamp_Tanh_Clamp)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 gemm_2d_map_base 0.02 2.19 1.90
🥇 gemm_2d_map_improved_base 0.02 2.19 1.90
🥇 hybrid_aligned_2d_kernel_base 0.02 2.19 1.90
🥇 aligned_memory_access_base_base 0.02 2.19 1.90
🥇 optimized_stride_loop_base 0.02 2.19 1.90
🥇 stride_loop_optimization_base_base 0.02 2.19 1.90
🥇 unrolled_2d_map_base_base 0.02 2.19 1.90
🥇 stride_loop_opt_base 0.02 2.19 1.90
🥇 fused_activation_base 0.02 2.19 1.90
🥇 optimized_stride_loop_with_prefetch_base 0.02 2.19 1.90
🥇 optimized_fused_activation_base 0.02 2.19 1.90
🥇 optimized_thread_block_indexing_base 0.02 2.19 1.90
🥇 dim2_grid_activation_base 0.02 2.19 1.90
14 dynamic_block_size_base_base 0.02 2.10 1.82
14 strided_vectorized_base_base 0.02 2.10 1.82
14 even_load_base 0.02 2.10 1.82
14 81_gemm_swish_divide_clamp_tanh_clamp_optimized_blocks_base 0.02 2.10 1.82
14 stride_loop_correct_base 0.02 2.10 1.82
14 adaptive_vectorized_kernel_base 0.02 2.10 1.82
14 uniform_control_flow_base_base 0.02 2.10 1.82
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cmath>

// Macros for input checking
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
    CHECK_CUDA(x);     \
    CHECK_CONTIGUOUS(x)

// This kernel uses 2D grid and block indexing to directly map the 2D output
// of the linear transformation. Grid-stride loops in both row and column dimensions
// ensure that all elements are processed efficiently, even for larger tensors.
template <typename scalar_t>
__global__ void twoD_activation_kernel(
    const scalar_t* __restrict__ x_in,
    scalar_t* __restrict__ x_out,
    const int m,
    const int n) {

    // Calculate the starting row and column for this thread
    int col = blockIdx.x * blockDim.x + threadIdx.x;
    int row = blockIdx.y * blockDim.y + threadIdx.y;

    const scalar_t one = static_cast<scalar_t>(1);
    const scalar_t half = static_cast<scalar_t>(0.5);

    // Grid-stride loop in row and column directions
    for (int r = row; r < m; r += blockDim.y * gridDim.y) {
        for (int c = col; c < n; c += blockDim.x * gridDim.x) {
            int idx = r * n + c;
            scalar_t x = x_in[idx];
            
            // Swish activation: x * sigmoid(x)
            scalar_t sigmoid = one / (one + exp(-x));
            x = x * sigmoid;
            
            // Divide by 2
            x = x * half;
            
            // First clamp between -1 and 1
            x = max(min(x, one), -one);
            
            // Tanh activation
            x = tanh(x);
            
            // Second clamp between -1 and 1
            x = max(min(x, one), -one);
            
            x_out[idx] = x;
        }
    }
}

// CUDA forward function that fuses the linear transformation (using torch::addmm)
// with the activation operations using 2D grid indexing.
torch::Tensor module_forward_cuda(
    torch::Tensor x,
    torch::Tensor weight,
    torch::Tensor bias) {

    // Perform linear operation: x_linear = bias + x * weight^T
    auto x_linear = torch::addmm(bias, x, weight.t());
    auto x_out = torch::empty_like(x_linear);

    int m = x_linear.size(0);
    int n = x_linear.size(1);

    // Define 2D block and grid dimensions to match the tensor's 2D structure
    const dim3 block(32, 8);  // 256 threads per block
    const dim3 grid((n + block.x - 1) / block.x, (m + block.y - 1) / block.y);

    AT_DISPATCH_FLOATING_TYPES_AND_HALF(x_linear.scalar_type(), "module_forward_cuda", ([&] {
        twoD_activation_kernel<scalar_t><<<grid, block>>>(
            x_linear.data_ptr<scalar_t>(),
            x_out.data_ptr<scalar_t>(),
            m,
            n);
    }));

    return x_out;
}

// C++ interface
torch::Tensor module_forward(
    torch::Tensor x,
    torch::Tensor weight,
    torch::Tensor bias) {
    CHECK_INPUT(x);
    CHECK_INPUT(weight);
    CHECK_INPUT(bias);
    return module_forward_cuda(x, weight, bias);
}

// PyBind11 module definition
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &module_forward, "2D grid based fused activation forward function (CUDA)");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 0.454 inst/cycle 0.000 5
Executed Ipc Elapsed 0.174 inst/cycle 0.000 5
Issue Slots Busy 13.028 % 0.265 5
Issued Ipc Active 0.520 inst/cycle 0.000 5
SM Busy 13.028 % 0.265 5
Memory Throughput 75069238986.818 byte/second 3438714706811368960.000 5
Mem Busy 10.322 % 0.051 5
Max Bandwidth 6.744 % 0.022 5
L1/TEX Hit Rate 0.000 % 0.000 5
L2 Hit Rate 83.048 % 0.056 5
Mem Pipes Busy 3.150 % 0.005 5
Warp Cycles Per Issued Instruction 26.816 cycle 1.161 5
Warp Cycles Per Executed Instruction 30.656 cycle 1.516 5
Avg. Active Threads Per Warp 31.160 0.000 5
Avg. Not Predicated Off Threads Per Warp 29.120 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 16.000 block 0.000 5
Block Limit Shared Mem 32.000 block 0.000 5
Block Limit Warps 8.000 block 0.000 5
Theoretical Active Warps per SM 64.000 warp 0.000 5
Theoretical Occupancy 100.000 % 0.000 5
Achieved Occupancy 21.868 % 0.160 5
Achieved Active Warps Per SM 13.998 warp 0.064 5
Analysis Rules
Rule Description
WRN HighPipeUtilization All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details.
INF CPIStall Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason.
WRN Occupancy This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (21.4%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy.
Operation / Metric Value Unit
aten::to
CPU Time 676471.35 μs
Device Time 191.58 μs
Self CPU Time 66.56 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_to_copy
CPU Time 676404.78 μs
Device Time 191.58 μs
Self CPU Time 125.22 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::empty_strided
CPU Time 694895.90 μs
Device Time 0.00 μs
Self CPU Time 19285.25 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaDeviceGetStreamPriorityRange
CPU Time 671030.15 μs
Device Time 0.00 μs
Self CPU Time 671030.15 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::addmm
CPU Time 507797.38 μs
Device Time 128403.29 μs
Self CPU Time 168224.60 μs
Self Device Time 128403.29 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
sm80_xmma_gemm_f32f32_f32f32_f32_tn_n_tilesize32x32x8_stage3_warpsize1x2x1_ffma_aligna4_alignc4_execute_kernel__51_cublas
CPU Time 0.00 μs
Device Time 116192.66 μs
Self CPU Time 0.00 μs
Self Device Time 116192.66 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 62960.46 μs
Device Time 571063.95 μs
Self CPU Time 11183.78 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::fill_
CPU Time 51777.79 μs
Device Time 571063.95 μs
Self CPU Time 17492.31 μs
Self Device Time 571063.95 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>)
CPU Time 0.00 μs
Device Time 571063.95 μs
Self CPU Time 0.00 μs
Self Device Time 571063.95 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45295 warnings generated when compiling for host.
Suppressed 45323 warnings (45276 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_81/b10_s1_dim2_grid_activation/base/base.cu:8:35 bugprone-macro-parentheses
8 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
| ^
| ()
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_81/b10_s1_dim2_grid_activation/base/base.cu:9:41: warning: macro argument should be enclosed in parentheses [bugprone-macro-parentheses]
9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
| ^
| ()
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_81/b10_s1_dim2_grid_activation/base/base.cu:21:5: warning: 2 adjacent parameters of 'twoD_activation_kernel' of similar type ('const int') are easily swapped by mistake [bugprone-easily-swappable-parameters]
21 | const int m,
| ^~~~~~~~~~~~
22 | const int n) {
| ~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_81/b10_s1_dim2_grid_activation/base/base.cu:21:15: note: the first parameter in the range is 'm'
21 | const int m,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_81/b10_s1_dim2_grid_activation/base/base.cu:22:15: note: the last parameter in the range is 'n'
22 | const int n) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_81/b10_s1_dim2_grid_activation/base/base.cu:25:15: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
25 | int col = blockIdx.x * blockDim.x + threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_81/b10_s1_dim2_grid_activation/base/base.cu:26:15: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
26 | int row = blockIdx.y * blockDim.y + threadIdx.y;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_81/b10_s1_dim2_grid_activation/base/base.cu:32:35: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
32 | for (int r = row; r < m; r += blockDim.y * gridDim.y) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_81/b10_s1_dim2_grid_activation/base/base.cu:33:39: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
33 | for (int c = col; c < n; c += blockDim.x * gridDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_81/b10_s1_dim2_grid_activation/base/base.cu:61:19: warning: the parameter 'x' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
61 | torch::Tensor x,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_81/b10_s1_dim2_grid_activation/base/base.cu:62:19: warning: the parameter 'weight' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
62 | torch::Tensor weight,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_81/b10_s1_dim2_grid_activation/base/base.cu:63:19: warning: the parameter 'bias' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
63 | torch::Tensor bias) {
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_81/b10_s1_dim2_grid_activation/base/base.cu:69:13: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
69 | int m = x_linear.size(0);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_81/b10_s1_dim2_grid_activation/base/base.cu:70:13: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
70 | int n = x_linear.size(1);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_81/b10_s1_dim2_grid_activation/base/base.cu:76:5: warning: inside a lambda, '__func__' expands to the name of the function call operator; consider capturing the name of the enclosing function explicitly [bugprone-lambda-function-name]
76 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x_linear.scalar_type(), "module_forward_cuda", ([&] {
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:246:19: note: expanded from macro 'AT_DISPATCH_FLOATING_TYPES_AND_HALF'
246 | TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(__VA_ARGS__))
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:240:3: note: expanded from macro 'AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF'
240 | AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:74:3: note: expanded from macro 'AT_DISPATCH_CASE'
74 | AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
| ^
note: (skipping 1 expansions in backtrace; use -fmacro-backtrace-limit=0 to see all)
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:58:7: note: expanded from macro 'AT_PRIVATE_CHECK_SELECTIVE_BUILD'
58 | AT_ERROR( \
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/c10/util/Exception.h:711:32: note: expanded from macro 'AT_ERROR'
711 | C10_EXPAND_MSVC_WORKAROUND(TORCH_CHECK(false, ::c10::str(__VA_ARGS__))); \
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/c10/util/Exception.h:536:9: note: expanded from macro 'TORCH_CHECK'
536 | __func__, \
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_81/b10_s1_dim2_grid_activation/base/base.cu:89:19: warning: the parameter 'x' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
89 | torch::Tensor x,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_81/b10_s1_dim2_grid_activation/base/base.cu:90:19: warning: the parameter 'weight' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
90 | torch::Tensor weight,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_81/b10_s1_dim2_grid_activation/base/base.cu:91:19: warning: the parameter 'bias' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
91 | torch::Tensor bias) {
| ^
| const &