← Back to Leaderboard

The AI CUDA Engineer 👷

99_Matmul_GELU_Softmaxfused_shared_mem_kernel_base

Level 2 • Task 99
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(
    x: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
) -> torch.Tensor:
    """
    Applies linear transformation, GELU activation, and softmax.

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, in_features)
        weight (torch.Tensor): Weight matrix of shape (out_features, in_features)
        bias (torch.Tensor): Bias vector of shape (out_features)

    Returns:
        torch.Tensor: Output tensor after applying linear, GELU and softmax,
            with shape (batch_size, out_features)
    """
    x = F.linear(x, weight, bias)
    x = F.gelu(x)
    x = F.softmax(x, dim=1)
    return x


class Model(nn.Module):
    """
    Simple model that performs a matrix multiplication, applies GELU, and then applies Softmax.
    """

    def __init__(self, in_features, out_features):
        super(Model, self).__init__()
        gemm = nn.Linear(in_features, out_features)
        self.weight = gemm.weight
        self.bias = gemm.bias

    def forward(self, x, fn=module_fn):
        return fn(x, self.weight, self.bias)


batch_size = 128
in_features = 100
out_features = 10


def get_inputs():
    return [torch.randn(batch_size, in_features)]


def get_init_inputs():
    return [in_features, out_features]
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Simple model that performs a matrix multiplication, applies GELU, and then applies Softmax.
    """
    def __init__(self, in_features, out_features):
        super(Model, self).__init__()
        self.linear = nn.Linear(in_features, out_features)

    def forward(self, x):
        x = self.linear(x)
        x = torch.nn.functional.gelu(x)
        x = torch.nn.functional.softmax(x, dim=1)
        return x

batch_size = 128
in_features = 100
out_features = 10

def get_inputs():
    return [torch.randn(batch_size, in_features)]

def get_init_inputs():
    return [in_features, out_features]

Kernel Information

Related Kernels (Level 2, Task 99 • 99_Matmul_GELU_Softmax)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 fused_opt_base 0.01 2.78 2.26
🥇 warp_divergence_minimized_kernel_base 0.01 2.78 2.26
🥇 block_size_experiment_fused_kernel_base 0.01 2.78 2.26
🥇 optimized_fused_kernel_base 0.01 2.78 2.26
🥇 fused_shared_mem_kernel_base 0.01 2.78 2.26
🥇 balanced_workload_fused_kernel_base 0.01 2.78 2.26
7 reduced_sync_matmul_gelu_softmax_base 0.01 2.53 2.06
7 aligned_ldg_fused_kernel_base_base 0.01 2.53 2.06
7 warp_optimized_fused_kernel_base_base 0.01 2.53 2.06
7 fused_ldg_vec_kernel_base 0.01 2.53 2.06
7 unrolled_fused_matmul_gelu_softmax_base_base 0.01 2.53 2.06
7 fused_optim_base 0.01 2.53 2.06
7 fused_linear_gelu_softmax_optimized_base 0.01 2.53 2.06
7 warp_reduced_fused_kernel_base 0.01 2.53 2.06
7 fused_nodivergence_kernel_base 0.01 2.53 2.06
7 modular_device_functions_base 0.01 2.53 2.06
7 optimized_linear_gelu_softmax_base 0.01 2.53 2.06
7 optimized_linear_gelu_softmax_edit_1 0.01 2.53 2.06
7 optimized_linear_gelu_softmax_base 0.01 2.53 2.06
7 optimized_linear_gelu_softmax_edit_1 0.01 2.53 2.06
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>
#include <float.h>

// GELU activation function (approximation used in PyTorch)
__device__ float gelu(float x) {
    const float sqrt_2_over_pi = 0.7978845608028654f;
    const float coef = 0.044715f;
    float cdf = 0.5f * (1.0f + tanhf(sqrt_2_over_pi * x * (1.0f + coef * x * x)));
    return x * cdf;
}

// Fused kernel: Performs linear transformation, applies GELU activation, and softmax normalization
// Leverages shared memory to store the input row, which is reused for all dot-product computations,
// and uses shared memory for softmax reduction to minimize global memory latency.
__global__ void fused_shared_mem_kernel(
    const float* __restrict__ x,
    const float* __restrict__ weight,
    const float* __restrict__ bias,
    float* __restrict__ output,
    int batch_size,
    int in_features,
    int out_features
) {
    // We expect the number of threads per block to be a padded value (multiple of 32) >= out_features
    int padded = blockDim.x;  

    // Shared memory layout: first region for storing the input row, second for softmax reduction
    // Total shared memory allocated: (in_features + padded) * sizeof(float)
    extern __shared__ float shared_mem[];
    float* s_x = shared_mem;             // Size: in_features (to store one row of input x)
    float* s_softmax = shared_mem + in_features; // Size: padded (for softmax reduction)

    int row = blockIdx.x;   // Each block processes one row of the batch
    int tid = threadIdx.x;

    // 1. Load the input row from global memory into shared memory
    for (int i = tid; i < in_features; i += padded) {
        s_x[i] = x[row * in_features + i];
    }
    __syncthreads();

    // 2. Compute the dot product for the linear transformation for each valid output feature
    float act = 0.0f;
    if (tid < out_features) {
        float sum = 0.0f;
        for (int k = 0; k < in_features; k++) {
            sum += s_x[k] * weight[tid * in_features + k];
        }
        sum += bias[tid];
        act = gelu(sum);
        s_softmax[tid] = act;  // Store the activated value for softmax reduction
    } else {
        // For padded threads, use a sentinel value for max reduction
        s_softmax[tid] = -FLT_MAX;
    }
    __syncthreads();

    // 3. Reduction to compute the maximum activated value across the outputs (for softmax numerical stability)
    for (int stride = padded / 2; stride > 0; stride /= 2) {
        if (tid < stride) {
            float other = s_softmax[tid + stride];
            s_softmax[tid] = (other > s_softmax[tid]) ? other : s_softmax[tid];
        }
        __syncthreads();
    }
    float row_max = s_softmax[0];
    __syncthreads();
    
    // 4. Compute the exponentials; invalid threads (tid >= out_features) produce 0
    float exp_val = 0.0f;
    if (tid < out_features) {
        exp_val = expf(act - row_max);
        s_softmax[tid] = exp_val;
    } else {
        s_softmax[tid] = 0.0f;
    }
    __syncthreads();

    // 5. Reduction to compute the sum of exponentials
    for (int stride = padded / 2; stride > 0; stride /= 2) {
        if (tid < stride) {
            s_softmax[tid] += s_softmax[tid + stride];
        }
        __syncthreads();
    }
    float sum_exp = s_softmax[0];
    __syncthreads();

    // 6. Write the normalized softmax result for valid output features
    if (tid < out_features) {
        output[row * out_features + tid] = exp_val / sum_exp;
    }
}

// Forward function that wraps the kernel launch
// It sets up the padded thread count and allocates shared memory for both the input row and softmax reduction buffer
torch::Tensor forward(
    torch::Tensor x,
    torch::Tensor weight,
    torch::Tensor bias
) {
    const int batch_size = x.size(0);
    const int in_features = x.size(1);
    const int out_features = weight.size(0);

    auto options = torch::TensorOptions().dtype(x.dtype()).device(x.device());
    auto output = torch::empty({batch_size, out_features}, options);

    // Determine padded thread count: round up out_features to the next multiple of 32
    int threads = ((out_features + 31) / 32) * 32;
    dim3 blocks(batch_size);
    dim3 threadBlock(threads);

    // Shared memory size: space for one input row (in_features floats) + softmax buffer (threads floats)
    int shared_mem_size = (in_features + threads) * sizeof(float);

    fused_shared_mem_kernel<<<blocks, threadBlock, shared_mem_size>>>(
        x.data_ptr<float>(),
        weight.data_ptr<float>(),
        bias.data_ptr<float>(),
        output.data_ptr<float>(),
        batch_size,
        in_features,
        out_features
    );

    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused Linear + GELU + Softmax forward with shared memory");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 0.060 inst/cycle 0.000 5
Executed Ipc Elapsed 0.042 inst/cycle 0.000 5
Issue Slots Busy 1.598 % 0.001 5
Issued Ipc Active 0.062 inst/cycle 0.000 5
SM Busy 1.598 % 0.001 5
Memory Throughput 7503425666.734 byte/second 17213957605964890.000 5
Mem Busy 4.946 % 0.012 5
Max Bandwidth 3.184 % 0.003 5
L1/TEX Hit Rate 86.080 % 0.000 5
L2 Hit Rate 98.754 % 0.466 5
Mem Pipes Busy 1.294 % 0.001 5
Warp Cycles Per Issued Instruction 15.916 cycle 0.058 5
Warp Cycles Per Executed Instruction 16.044 cycle 0.059 5
Avg. Active Threads Per Warp 17.610 0.000 5
Avg. Not Predicated Off Threads Per Warp 14.410 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 64.000 block 0.000 5
Block Limit Shared Mem 39.000 block 0.000 5
Block Limit Warps 64.000 block 0.000 5
Theoretical Active Warps per SM 32.000 warp 0.000 5
Theoretical Occupancy 50.000 % 0.000 5
Achieved Occupancy 1.560 % 0.000 5
Achieved Active Warps Per SM 1.000 warp 0.000 5
Analysis Rules
Rule Description
WRN HighPipeUtilization All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details.
INF CPIStall Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason.
WRN ThreadDivergence Instructions are executed in warps, which are groups of 32 threads. Optimal instruction throughput is achieved if all 32 threads of a warp execute the same instruction. The chosen launch configuration, early thread completion, and divergent flow control can significantly lower the number of active threads in a warp per cycle. This kernel achieves an average of 17.6 threads being active per cycle. This is further reduced to 14.4 threads per warp due to predication. The compiler may use predication to avoid an actual branch. Instead, all instructions are scheduled, but a per-thread condition code or predicate controls which threads execute the instructions. Try to avoid different execution paths within a warp when possible. In addition, ensure your kernel makes use of Independent Thread Scheduling, which allows a warp to reconverge after a data-dependent conditional block by explicitly calling __syncwarp().
WRN Occupancy This kernel's theoretical occupancy (50.0%) is limited by the number of blocks that can fit on the SM. The difference between calculated theoretical (50.0%) and measured achieved occupancy (1.6%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy.
Operation / Metric Value Unit
aten::to
CPU Time 550971.24 μs
Device Time 8.54 μs
Self CPU Time 51.91 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_to_copy
CPU Time 550919.33 μs
Device Time 8.54 μs
Self CPU Time 95.04 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::empty_strided
CPU Time 550675.46 μs
Device Time 0.00 μs
Self CPU Time 95.68 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaDeviceGetStreamPriorityRange
CPU Time 550183.26 μs
Device Time 0.00 μs
Self CPU Time 550183.26 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 514842.98 μs
Device Time 22448.48 μs
Self CPU Time 514842.98 μs
Self Device Time 22448.48 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
fused_shared_mem_kernel(float const*, float const*, float const*, float*, int, int, int)
CPU Time 0.00 μs
Device Time 60898.98 μs
Self CPU Time 0.00 μs
Self Device Time 60898.98 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaEventRecord
CPU Time 19026.77 μs
Device Time 41506.17 μs
Self CPU Time 19026.77 μs
Self Device Time 41506.17 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 63066.88 μs
Device Time 620266.75 μs
Self CPU Time 14031.19 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::fill_
CPU Time 49039.52 μs
Device Time 620266.75 μs
Self CPU Time 15815.15 μs
Self Device Time 620266.75 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>)
CPU Time 0.00 μs
Device Time 620266.75 μs
Self CPU Time 0.00 μs
Self Device Time 620266.75 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45288 warnings generated when compiling for host.
Suppressed 45323 warnings (45276 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_99/b5_s1_fused_shared_mem_kernel/base/base.cu:19:5 bugprone-easily-swappable-parameters
19 | const float* __restrict__ x,
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~
20 | const float* __restrict__ weight,
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
21 | const float* __restrict__ bias,
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_99/b5_s1_fused_shared_mem_kernel/base/base.cu:19:31: note: the first parameter in the range is 'x'
19 | const float* __restrict__ x,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_99/b5_s1_fused_shared_mem_kernel/base/base.cu:21:31: note: the last parameter in the range is 'bias'
21 | const float* __restrict__ bias,
| ^~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_99/b5_s1_fused_shared_mem_kernel/base/base.cu:23:5: warning: 3 adjacent parameters of 'fused_shared_mem_kernel' of similar type ('int') are easily swapped by mistake [bugprone-easily-swappable-parameters]
23 | int batch_size,
| ^~~~~~~~~~~~~~~
24 | int in_features,
| ~~~~~~~~~~~~~~~~
25 | int out_features
| ~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_99/b5_s1_fused_shared_mem_kernel/base/base.cu:23:9: note: the first parameter in the range is 'batch_size'
23 | int batch_size,
| ^~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_99/b5_s1_fused_shared_mem_kernel/base/base.cu:25:9: note: the last parameter in the range is 'out_features'
25 | int out_features
| ^~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_99/b5_s1_fused_shared_mem_kernel/base/base.cu:28:18: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
28 | int padded = blockDim.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_99/b5_s1_fused_shared_mem_kernel/base/base.cu:36:15: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
36 | int row = blockIdx.x; // Each block processes one row of the batch
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_99/b5_s1_fused_shared_mem_kernel/base/base.cu:37:15: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
37 | int tid = threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_99/b5_s1_fused_shared_mem_kernel/base/base.cu:101:19: warning: the parameter 'x' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
101 | torch::Tensor x,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_99/b5_s1_fused_shared_mem_kernel/base/base.cu:102:19: warning: the parameter 'weight' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
102 | torch::Tensor weight,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_99/b5_s1_fused_shared_mem_kernel/base/base.cu:103:19: warning: the parameter 'bias' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
103 | torch::Tensor bias
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_99/b5_s1_fused_shared_mem_kernel/base/base.cu:105:28: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
105 | const int batch_size = x.size(0);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_99/b5_s1_fused_shared_mem_kernel/base/base.cu:106:29: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
106 | const int in_features = x.size(1);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_99/b5_s1_fused_shared_mem_kernel/base/base.cu:107:30: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
107 | const int out_features = weight.size(0);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_99/b5_s1_fused_shared_mem_kernel/base/base.cu:118:27: warning: narrowing conversion from 'unsigned long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
118 | int shared_mem_size = (in_features + threads) * sizeof(float);
| ^