← Back to Leaderboard

The AI CUDA Engineer 👷

51_Gemm_Subtract_GlobalAvgPool_LogSumExp_GELU_ResidualAddfused_forward_edit_1

Level 2 • Task 51
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(
    x: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
    subtract: torch.Tensor,
) -> torch.Tensor:
    """
    Performs a series of operations: Gemm, Subtract, GlobalAvgPool, LogSumExp, GELU, and ResidualAdd.

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, in_features)
        weight (torch.Tensor): Weight matrix for linear layer of shape (out_features, in_features)
        bias (torch.Tensor): Bias vector for linear layer of shape (out_features)
        subtract (torch.Tensor): Vector to subtract of shape (out_features)

    Returns:
        torch.Tensor: Output tensor after applying all operations
    """
    original_x = x.clone().detach()

    # Gemm
    x = F.linear(x, weight, bias)

    # Subtract
    x = x - subtract

    # GlobalAvgPool
    x = torch.mean(x, dim=1, keepdim=True)

    # LogSumExp
    x = torch.logsumexp(x, dim=1, keepdim=True)

    # GELU
    x = F.gelu(x)

    # ResidualAdd
    x = x + original_x

    return x


class Model(nn.Module):
    """
    Model that performs a series of operations: Gemm, Subtract, GlobalAvgPool, LogSumExp, GELU, and ResidualAdd.
    """

    def __init__(self, in_features, out_features):
        super(Model, self).__init__()
        gemm = nn.Linear(in_features, out_features)
        self.weight = nn.Parameter(gemm.weight)
        self.bias = nn.Parameter(gemm.bias)
        self.subtract = nn.Parameter(torch.randn(out_features) * 0.02)

    def forward(self, x, fn=module_fn):
        return fn(x, self.weight, self.bias, self.subtract)


batch_size = 128
in_features = 1024
out_features = 512


def get_inputs():
    return [torch.randn(batch_size, in_features)]


def get_init_inputs():
    return [in_features, out_features]
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Model that performs a series of operations: Gemm, Subtract, GlobalAvgPool, LogSumExp, GELU, and ResidualAdd.
    """
    def __init__(self, in_features, out_features, bias=True):
        super(Model, self).__init__()
        self.gemm = nn.Linear(in_features, out_features, bias=bias)
        self.subtract = nn.Parameter(torch.randn(out_features) * 0.02)

    def forward(self, x):
        original_x = x.clone().detach()
        # Gemm
        x = self.gemm(x)

        # Subtract
        x = x - self.subtract

        # GlobalAvgPool
        x = torch.mean(x, dim=1, keepdim=True)

        # LogSumExp
        x = torch.logsumexp(x, dim=1, keepdim=True)

        # GELU
        x = torch.nn.functional.gelu(x)

        # ResidualAdd
        x = x + original_x

        return x

batch_size = 128
in_features = 1024
out_features = 512

def get_inputs():
    return [torch.randn(batch_size, in_features)]

def get_init_inputs():
    return [in_features, out_features]

Kernel Information

Related Kernels (Level 2, Task 51 • 51_Gemm_Subtract_GlobalAvgPool_LogSumExp_GELU_ResidualAdd)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 fused_forward_base 0.05 1.62 0.92
🥇 fused_forward_edit_1 0.05 1.62 0.92
🥉 fused_forward_coalesced_base 0.05 1.58 0.90
4 fused_forward_coalesced_edit_1 0.05 1.55 0.89
5 optimized_fused_kernel_base 0.06 1.32 0.76
6 fused_pipeline_base 0.06 1.28 0.73
6 threadblock_mapping_opt_base 0.06 1.28 0.73
8 atomic_optimized_pipeline_base 0.06 1.26 0.72
8 efficient_thread_block_mapping_base 0.06 1.26 0.72
8 aligned_memory_access_base_base 0.06 1.26 0.72
8 fused_pool_gelu_atomic_minimal_base 0.06 1.26 0.72
8 fused_pool_gelu_warp_edit_base 0.06 1.26 0.72
8 aligned_memory_access_base_base 0.06 1.26 0.72
14 constant_memory_optimization_base 0.07 1.24 0.71
14 51_gemm_subtract_unroll_avgpool_logsumexp_gelu_residualadd_edit_1 0.07 1.24 0.71
14 uniform_control_flow_base_base_base 0.07 1.24 0.71
17 modular_device_functions_optimized_base 0.07 1.22 0.70
17 modular_device_functions_base_base 0.07 1.22 0.70
19 experiment_block_sizes_base 0.07 1.19 0.68
19 tiled_gemm_shared_edit_2_base 0.07 1.19 0.68
/*
   Fused Forward CUDA Kernel
   This kernel fuses the series of operations from the original implementations:
   1. GEMM (matrix multiplication with bias) 
   2. Subtraction of a per-column constant
   3. Global average pooling
   4. LogSumExp (which is mathematically the identity in this case)
   5. GELU activation
   6. Residual addition with the original input

   Observation:
   The original sequence computes, for each row i and each column j:

       gemm_out[i,j] = dot(x[i,:], weight[j,:]) + bias[j] - subtract[j]
       pool[i] = (1/out_features) * sum_j gemm_out[i,j]
       pool[i] = gelu(pool[i])
       out[i,k] = original_x[i,k] + pool[i]

   Notice that the sum over j can be re-ordered as:

       pool[i] = (1/out_features) * ( dot(x[i,:], sum_{j} weight[j,:]) + sum_{j}(bias[j]-subtract[j]) )
                = ( dot(x[i,:], weight_sum) + constant ) / out_features

   where:
       weight_sum[k] = sum_{j=0}^{out_features-1} weight[j * in_features + k]
       constant = sum_{j=0}^{out_features-1} (bias[j] - subtract[j])

   This transformation allows us to replace the heavy GEMM over (batch_size x out_features) with
   a fast dot product per row over in_features elements. Then, after applying GELU on the pooled
   scalar and adding back via a residual connection, we obtain the same overall result as the original.

   This implementation precomputes weight_sum and constant (using PyTorch tensor operations which run on GPU),
   and then launches a fused CUDA kernel that, for each row, computes the dot product x[i] * weight_sum, 
   applies the necessary normalization, GELU activation, and broadcasts the result as a residual add to x[i].

   The fused kernel uses one block per row and a shared memory reduction for computing the dot product.
*/

#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cmath>

//------------------------------------------------------------------------------
// GELU approximation function
__device__ float gelu_approx(float val) {
    const float kAlpha = 0.044715f;
    const float kBeta  = 0.7978845608f; // sqrt(2/M_PI)
    float inner = kBeta * (val + kAlpha * val * val * val);
    float cdf   = 0.5f * (1.0f + tanhf(inner));
    return val * cdf;
}

//------------------------------------------------------------------------------
// Fused kernel: Computes the dot product of x[i] and weight_sum with a reduction,
// applies normalization using out_features and constant, then applies GELU,
// and finally performs a residual add with x to produce the final output.
// Each block processes one row.
__global__ void fused_forward_kernel(
    const float* __restrict__ x,            // Input x: shape (batch_size, in_features)
    const float* __restrict__ weight_sum,     // Precomputed weight_sum: shape (in_features)
    float constant,                           // Precomputed constant: sum(bias - subtract)
    float* __restrict__ out,                  // Output: shape (batch_size, in_features)
    int batch_size,
    int in_features,
    int out_features                        // Needed for normalization
) {
    int row = blockIdx.x;
    if (row >= batch_size) return;

    __shared__ float sdata[256]; // Shared memory for reduction
    float sum_val = 0.0f;
    
    // Each thread processes a subset of the in_features dimension
    for (int k = threadIdx.x; k < in_features; k += blockDim.x) {
         float x_val = x[row * in_features + k];
         float ws = weight_sum[k];
         sum_val += x_val * ws;
    }
    sdata[threadIdx.x] = sum_val;
    __syncthreads();

    // Reduction in shared memory to compute the dot product
    for (int stride = blockDim.x / 2; stride > 0; stride /= 2) {
         if (threadIdx.x < stride)
              sdata[threadIdx.x] += sdata[threadIdx.x + stride];
         __syncthreads();
    }
    float pool_val = sdata[0];

    // Thread 0 normalizes the sum, applies GELU, and writes back to shared memory
    if (threadIdx.x == 0) {
         pool_val = (pool_val + constant) / static_cast<float>(out_features);
         pool_val = gelu_approx(pool_val);
         sdata[0] = pool_val; // Broadcast the result
    }
    __syncthreads();
    pool_val = sdata[0];

    // Broadcast residual addition: each thread adds pool_val to the corresponding
    // element of the original input x to produce out.
    for (int k = threadIdx.x; k < in_features; k += blockDim.x) {
         out[row * in_features + k] = x[row * in_features + k] + pool_val;
    }
}

//------------------------------------------------------------------------------
// Forward function for the fused kernel
// Precomputes the necessary reductions (weight_sum and constant) and launches the fused kernel.

torch::Tensor forward_cuda_fused(
    const torch::Tensor& x,
    const torch::Tensor& weight,
    const torch::Tensor& bias,
    const torch::Tensor& subtract
) {
    TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor");
    TORCH_CHECK(weight.is_cuda(), "weight must be a CUDA tensor");
    TORCH_CHECK(bias.is_cuda(), "bias must be a CUDA tensor");
    TORCH_CHECK(subtract.is_cuda(), "subtract must be a CUDA tensor");
    TORCH_CHECK(x.dim() == 2, "x must be 2D (batch_size x in_features)");
    TORCH_CHECK(weight.dim() == 2, "weight must be 2D (out_features x in_features)");
    TORCH_CHECK(bias.dim() == 1, "bias must be 1D (out_features)");
    TORCH_CHECK(subtract.dim() == 1, "subtract must be 1D (out_features)");

    int64_t batch_size  = x.size(0);
    int64_t in_features = x.size(1);
    int64_t out_features = weight.size(0);

    TORCH_CHECK(weight.size(1) == in_features, "weight.shape[1] must match x.shape[1]");
    TORCH_CHECK(bias.size(0) == out_features, "bias.shape[0] must match weight.shape[0]");
    TORCH_CHECK(subtract.size(0) == out_features, "subtract.shape[0] must match weight.shape[0]");

    auto x_contig = x.contiguous();
    auto weight_contig = weight.contiguous();
    auto bias_contig = bias.contiguous();
    auto subtract_contig = subtract.contiguous();

    // Precompute weight_sum: sum over rows of weight (weight is out_features x in_features)
    // weight_sum will have shape (in_features,)
    auto weight_sum = torch::sum(weight_contig, 0);

    // Precompute constant = sum(bias - subtract) [a scalar]
    auto constant_tensor = torch::sum(bias_contig - subtract_contig);
    float constant = constant_tensor.item<float>();

    // Allocate output tensor (same shape as x)
    auto out = torch::empty({batch_size, in_features}, x.options());

    int threads = 256;
    int blocks = batch_size; // One block per row in x
    size_t shared_mem_bytes = threads * sizeof(float);
    
    fused_forward_kernel<<<blocks, threads, shared_mem_bytes>>>(
        x_contig.data_ptr<float>(),
        weight_sum.data_ptr<float>(),
        constant,
        out.data_ptr<float>(),
        batch_size,
        in_features,
        out_features
    );

    return out;
}

//------------------------------------------------------------------------------
// PyBind11 module definition
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward_cuda_fused, "Fused Forward CUDA Kernel");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 0.310 inst/cycle 0.000 5
Executed Ipc Elapsed 0.170 inst/cycle 0.000 5
Issue Slots Busy 7.834 % 0.031 5
Issued Ipc Active 0.312 inst/cycle 0.000 5
SM Busy 7.834 % 0.031 5
Memory Throughput 93916272945.768 byte/second 1025445480603973376.000 5
Mem Busy 7.392 % 0.012 5
Max Bandwidth 5.710 % 0.004 5
L1/TEX Hit Rate 25.000 % 0.000 5
L2 Hit Rate 75.956 % 0.018 5
Mem Pipes Busy 4.212 % 0.002 5
Warp Cycles Per Issued Instruction 25.458 cycle 0.331 5
Warp Cycles Per Executed Instruction 25.808 cycle 0.338 5
Avg. Active Threads Per Warp 31.430 0.000 5
Avg. Not Predicated Off Threads Per Warp 24.510 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 10.000 block 0.000 5
Block Limit Shared Mem 21.000 block 0.000 5
Block Limit Warps 8.000 block 0.000 5
Theoretical Active Warps per SM 64.000 warp 0.000 5
Theoretical Occupancy 100.000 % 0.000 5
Achieved Occupancy 12.460 % 0.000 5
Achieved Active Warps Per SM 7.976 warp 0.000 5
Analysis Rules
Rule Description
WRN HighPipeUtilization All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details.
INF CPIStall Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason.
WRN Occupancy This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (12.5%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy.
Operation / Metric Value Unit
cudaStreamSynchronize
CPU Time 830151.14 μs
Device Time 0.00 μs
Self CPU Time 830151.14 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::sum
CPU Time 265912.89 μs
Device Time 251697.45 μs
Self CPU Time 163729.87 μs
Self Device Time 251697.45 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 220204.56 μs
Device Time 14775.34 μs
Self CPU Time 220204.56 μs
Self Device Time 14775.34 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::reduce_kernel<128, 4, at::native::ReduceOp<float, at::native::func_wrapper_t<float, at::native::sum_functor<float, float, float>::operator()(at::TensorIterator&)::{lambda(float, float)#1}>, unsigned int, float, 4> >(at::native::ReduceOp<float, at::native::func_wrapper_t<float, at::native::sum_functor<float, float, float>::operator()(at::TensorIterator&)::{lambda(float, float)#1}>, unsigned int, float, 4>)
CPU Time 0.00 μs
Device Time 219035.09 μs
Self CPU Time 0.00 μs
Self Device Time 219035.09 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::item
CPU Time 890045.92 μs
Device Time 21882.56 μs
Self CPU Time 9128.06 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_local_scalar_dense
CPU Time 880917.86 μs
Device Time 21882.56 μs
Self CPU Time 25849.23 μs
Self Device Time 21882.56 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 90069.28 μs
Device Time 880097.94 μs
Self CPU Time 19841.78 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::fill_
CPU Time 70228.90 μs
Device Time 880097.94 μs
Self CPU Time 26042.24 μs
Self Device Time 880097.94 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>)
CPU Time 0.00 μs
Device Time 880176.22 μs
Self CPU Time 0.00 μs
Self Device Time 880176.22 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45289 warnings generated when compiling for host.
Suppressed 45324 warnings (45277 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b4_s3_fused_forward/edit_1/edit_1.cu:60:5 bugprone-easily-swappable-parameters
60 | const float* __restrict__ x, // Input x: shape (batch_size, in_features)
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
61 | const float* __restrict__ weight_sum, // Precomputed weight_sum: shape (in_features)
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b4_s3_fused_forward/edit_1/edit_1.cu:60:31: note: the first parameter in the range is 'x'
60 | const float* __restrict__ x, // Input x: shape (batch_size, in_features)
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b4_s3_fused_forward/edit_1/edit_1.cu:61:31: note: the last parameter in the range is 'weight_sum'
61 | const float* __restrict__ weight_sum, // Precomputed weight_sum: shape (in_features)
| ^~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b4_s3_fused_forward/edit_1/edit_1.cu:64:5: warning: 3 adjacent parameters of 'fused_forward_kernel' of similar type ('int') are easily swapped by mistake [bugprone-easily-swappable-parameters]
64 | int batch_size,
| ^~~~~~~~~~~~~~~
65 | int in_features,
| ~~~~~~~~~~~~~~~~
66 | int out_features // Needed for normalization
| ~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b4_s3_fused_forward/edit_1/edit_1.cu:64:9: note: the first parameter in the range is 'batch_size'
64 | int batch_size,
| ^~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b4_s3_fused_forward/edit_1/edit_1.cu:66:9: note: the last parameter in the range is 'out_features'
66 | int out_features // Needed for normalization
| ^~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b4_s3_fused_forward/edit_1/edit_1.cu:68:15: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
68 | int row = blockIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b4_s3_fused_forward/edit_1/edit_1.cu:75:18: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
75 | for (int k = threadIdx.x; k < in_features; k += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b4_s3_fused_forward/edit_1/edit_1.cu:75:53: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
75 | for (int k = threadIdx.x; k < in_features; k += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b4_s3_fused_forward/edit_1/edit_1.cu:84:23: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
84 | for (int stride = blockDim.x / 2; stride > 0; stride /= 2) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b4_s3_fused_forward/edit_1/edit_1.cu:102:18: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
102 | for (int k = threadIdx.x; k < in_features; k += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b4_s3_fused_forward/edit_1/edit_1.cu:102:53: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
102 | for (int k = threadIdx.x; k < in_features; k += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b4_s3_fused_forward/edit_1/edit_1.cu:151:18: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
151 | int blocks = batch_size; // One block per row in x
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b4_s3_fused_forward/edit_1/edit_1.cu:159:9: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
159 | batch_size,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b4_s3_fused_forward/edit_1/edit_1.cu:160:9: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
160 | in_features,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b4_s3_fused_forward/edit_1/edit_1.cu:161:9: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
161 | out_features
| ^