51_Gemm_Subtract_GlobalAvgPool_LogSumExp_GELU_ResidualAdd
• fused_forward_base
import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
subtract: torch.Tensor,
) -> torch.Tensor:
"""
Performs a series of operations: Gemm, Subtract, GlobalAvgPool, LogSumExp, GELU, and ResidualAdd.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features)
weight (torch.Tensor): Weight matrix for linear layer of shape (out_features, in_features)
bias (torch.Tensor): Bias vector for linear layer of shape (out_features)
subtract (torch.Tensor): Vector to subtract of shape (out_features)
Returns:
torch.Tensor: Output tensor after applying all operations
"""
original_x = x.clone().detach()
# Gemm
x = F.linear(x, weight, bias)
# Subtract
x = x - subtract
# GlobalAvgPool
x = torch.mean(x, dim=1, keepdim=True)
# LogSumExp
x = torch.logsumexp(x, dim=1, keepdim=True)
# GELU
x = F.gelu(x)
# ResidualAdd
x = x + original_x
return x
class Model(nn.Module):
"""
Model that performs a series of operations: Gemm, Subtract, GlobalAvgPool, LogSumExp, GELU, and ResidualAdd.
"""
def __init__(self, in_features, out_features):
super(Model, self).__init__()
gemm = nn.Linear(in_features, out_features)
self.weight = nn.Parameter(gemm.weight)
self.bias = nn.Parameter(gemm.bias)
self.subtract = nn.Parameter(torch.randn(out_features) * 0.02)
def forward(self, x, fn=module_fn):
return fn(x, self.weight, self.bias, self.subtract)
batch_size = 128
in_features = 1024
out_features = 512
def get_inputs():
return [torch.randn(batch_size, in_features)]
def get_init_inputs():
return [in_features, out_features]
import torch
import torch.nn as nn
class Model(nn.Module):
"""
Model that performs a series of operations: Gemm, Subtract, GlobalAvgPool, LogSumExp, GELU, and ResidualAdd.
"""
def __init__(self, in_features, out_features, bias=True):
super(Model, self).__init__()
self.gemm = nn.Linear(in_features, out_features, bias=bias)
self.subtract = nn.Parameter(torch.randn(out_features) * 0.02)
def forward(self, x):
original_x = x.clone().detach()
# Gemm
x = self.gemm(x)
# Subtract
x = x - self.subtract
# GlobalAvgPool
x = torch.mean(x, dim=1, keepdim=True)
# LogSumExp
x = torch.logsumexp(x, dim=1, keepdim=True)
# GELU
x = torch.nn.functional.gelu(x)
# ResidualAdd
x = x + original_x
return x
batch_size = 128
in_features = 1024
out_features = 512
def get_inputs():
return [torch.randn(batch_size, in_features)]
def get_init_inputs():
return [in_features, out_features]
/*
Fused Forward CUDA Kernel
This kernel fuses the series of operations from the original implementations:
1. GEMM (matrix multiplication with bias)
2. Subtraction of a per-column constant
3. Global average pooling
4. LogSumExp (which is mathematically the identity in this case)
5. GELU activation
6. Residual addition with the original input
Observation:
The original sequence computes, for each row i and each column j:
gemm_out[i,j] = dot(x[i,:], weight[j,:]) + bias[j] - subtract[j]
pool[i] = (1/out_features) * sum_j gemm_out[i,j]
pool[i] = gelu(pool[i])
out[i,k] = original_x[i,k] + pool[i]
Notice that the sum over j can be re-ordered as:
pool[i] = (1/out_features) * ( dot(x[i,:], sum_{j} weight[j,:]) + sum_{j}(bias[j]-subtract[j]) )
= ( dot(x[i,:], weight_sum) + constant ) / out_features
where:
weight_sum[k] = sum_{j=0}^{out_features-1} weight[j * in_features + k]
constant = sum_{j=0}^{out_features-1} (bias[j] - subtract[j])
This transformation allows us to replace the heavy GEMM over (batch_size x out_features) with
a fast dot product per row over in_features elements. Then, after applying GELU on the pooled
scalar and adding back via a residual connection, we obtain the same overall result as the original.
This implementation precomputes weight_sum and constant (using PyTorch tensor operations which run on GPU),
and then launches a fused CUDA kernel that, for each row, computes the dot product x[i] * weight_sum,
applies the necessary normalization, GELU activation, and broadcasts the result as a residual add to x[i].
The fused kernel uses one block per row and a shared memory reduction for computing the dot product.
*/
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cmath>
//------------------------------------------------------------------------------
// GELU approximation function
__device__ float gelu_approx(float val) {
const float kAlpha = 0.044715f;
const float kBeta = 0.7978845608f; // sqrt(2/M_PI)
float inner = kBeta * (val + kAlpha * val * val * val);
float cdf = 0.5f * (1.0f + tanhf(inner));
return val * cdf;
}
//------------------------------------------------------------------------------
// Fused kernel: Computes the dot product of x[i] and weight_sum with a reduction,
// applies normalization using out_features and constant, then applies GELU,
// and finally performs a residual add with x to produce the final output.
// Each block processes one row.
__global__ void fused_forward_kernel(
const float* __restrict__ x, // Input x: shape (batch_size, in_features)
const float* __restrict__ weight_sum, // Precomputed weight_sum: shape (in_features)
float constant, // Precomputed constant: sum(bias - subtract)
float* __restrict__ out, // Output: shape (batch_size, in_features)
int batch_size,
int in_features,
int out_features // Needed for normalization
) {
int row = blockIdx.x;
if (row >= batch_size) return;
extern __shared__ float sdata[]; // Shared memory for reduction
float sum_val = 0.0f;
// Each thread processes a subset of the in_features dimension
for (int k = threadIdx.x; k < in_features; k += blockDim.x) {
float x_val = x[row * in_features + k];
float ws = weight_sum[k];
sum_val += x_val * ws;
}
sdata[threadIdx.x] = sum_val;
__syncthreads();
// Reduction in shared memory to compute the dot product
for (int stride = blockDim.x / 2; stride > 0; stride /= 2) {
if (threadIdx.x < stride)
sdata[threadIdx.x] += sdata[threadIdx.x + stride];
__syncthreads();
}
float pool_val = sdata[0];
// Thread 0 normalizes the sum, applies GELU, and writes back to shared memory
if (threadIdx.x == 0) {
pool_val = (pool_val + constant) / static_cast<float>(out_features);
pool_val = gelu_approx(pool_val);
sdata[0] = pool_val; // Broadcast the result
}
__syncthreads();
pool_val = sdata[0];
// Broadcast residual addition: each thread adds pool_val to the corresponding
// element of the original input x to produce out.
for (int k = threadIdx.x; k < in_features; k += blockDim.x) {
out[row * in_features + k] = x[row * in_features + k] + pool_val;
}
}
//------------------------------------------------------------------------------
// Forward function for the fused kernel
// Precomputes the necessary reductions (weight_sum and constant) and launches the fused kernel.
torch::Tensor forward_cuda_fused(
const torch::Tensor& x,
const torch::Tensor& weight,
const torch::Tensor& bias,
const torch::Tensor& subtract
) {
TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor");
TORCH_CHECK(weight.is_cuda(), "weight must be a CUDA tensor");
TORCH_CHECK(bias.is_cuda(), "bias must be a CUDA tensor");
TORCH_CHECK(subtract.is_cuda(), "subtract must be a CUDA tensor");
TORCH_CHECK(x.dim() == 2, "x must be 2D (batch_size x in_features)");
TORCH_CHECK(weight.dim() == 2, "weight must be 2D (out_features x in_features)");
TORCH_CHECK(bias.dim() == 1, "bias must be 1D (out_features)");
TORCH_CHECK(subtract.dim() == 1, "subtract must be 1D (out_features)");
int64_t batch_size = x.size(0);
int64_t in_features = x.size(1);
int64_t out_features = weight.size(0);
TORCH_CHECK(weight.size(1) == in_features, "weight.shape[1] must match x.shape[1]");
TORCH_CHECK(bias.size(0) == out_features, "bias.shape[0] must match weight.shape[0]");
TORCH_CHECK(subtract.size(0) == out_features, "subtract.shape[0] must match weight.shape[0]");
auto x_contig = x.contiguous();
auto weight_contig = weight.contiguous();
auto bias_contig = bias.contiguous();
auto subtract_contig = subtract.contiguous();
// Precompute weight_sum: sum over rows of weight (weight is out_features x in_features)
// weight_sum will have shape (in_features,)
auto weight_sum = torch::sum(weight_contig, 0);
// Precompute constant = sum(bias - subtract) [a scalar]
auto constant_tensor = torch::sum(bias_contig - subtract_contig);
float constant = constant_tensor.item<float>();
// Allocate output tensor (same shape as x)
auto out = torch::empty({batch_size, in_features}, x.options());
int threads = 256;
int blocks = batch_size; // One block per row in x
size_t shared_mem_bytes = threads * sizeof(float);
fused_forward_kernel<<<blocks, threads, shared_mem_bytes>>>(
x_contig.data_ptr<float>(),
weight_sum.data_ptr<float>(),
constant,
out.data_ptr<float>(),
batch_size,
in_features,
out_features
);
return out;
}
//------------------------------------------------------------------------------
// PyBind11 module definition
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward_cuda_fused, "Fused Forward CUDA Kernel");
}
Metric | Value | Unit | Variance | Samples |
---|---|---|---|---|
Executed Ipc Active | 0.314 | inst/cycle | 0.000 | 5 |
Executed Ipc Elapsed | 0.174 | inst/cycle | 0.000 | 5 |
Issue Slots Busy | 7.920 | % | 0.012 | 5 |
Issued Ipc Active | 0.316 | inst/cycle | 0.000 | 5 |
SM Busy | 7.920 | % | 0.012 | 5 |
Memory Throughput | 95206188918.680 | byte/second | 1108962367567673344.000 | 5 |
Mem Busy | 7.528 | % | 0.003 | 5 |
Max Bandwidth | 5.786 | % | 0.002 | 5 |
L1/TEX Hit Rate | 25.000 | % | 0.000 | 5 |
L2 Hit Rate | 76.076 | % | 0.225 | 5 |
Mem Pipes Busy | 4.266 | % | 0.001 | 5 |
Warp Cycles Per Issued Instruction | 25.374 | cycle | 0.175 | 5 |
Warp Cycles Per Executed Instruction | 25.720 | cycle | 0.177 | 5 |
Avg. Active Threads Per Warp | 31.430 | 0.000 | 5 | |
Avg. Not Predicated Off Threads Per Warp | 24.510 | 0.000 | 5 | |
Max Active Clusters | 0.000 | cluster | 0.000 | 5 |
Max Cluster Size | 8.000 | block | 0.000 | 5 |
Overall GPU Occupancy | 0.000 | % | 0.000 | 5 |
Cluster Occupancy | 0.000 | % | 0.000 | 5 |
Block Limit SM | 32.000 | block | 0.000 | 5 |
Block Limit Registers | 10.000 | block | 0.000 | 5 |
Block Limit Shared Mem | 16.000 | block | 0.000 | 5 |
Block Limit Warps | 8.000 | block | 0.000 | 5 |
Theoretical Active Warps per SM | 64.000 | warp | 0.000 | 5 |
Theoretical Occupancy | 100.000 | % | 0.000 | 5 |
Achieved Occupancy | 12.460 | % | 0.000 | 5 |
Achieved Active Warps Per SM | 7.972 | warp | 0.000 | 5 |
Rule | Description |
---|---|
WRN HighPipeUtilization | All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details. |
INF CPIStall | Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason. |
WRN Occupancy | This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (12.5%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy. |
Operation / Metric | Value | Unit |
---|---|---|
cudaStreamSynchronize | ||
CPU Time | 833923.32 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 833923.32 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::sum | ||
CPU Time | 264990.49 | μs |
Device Time | 252833.55 | μs |
Self CPU Time | 160598.46 | μs |
Self Device Time | 252833.55 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaLaunchKernel | ||
CPU Time | 226274.27 | μs |
Device Time | 14861.13 | μs |
Self CPU Time | 226274.27 | μs |
Self Device Time | 14861.13 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void at::native::reduce_kernel<128, 4, at::native::ReduceOp<float, at::native::func_wrapper_t<float, at::native::sum_functor<float, float, float>::operator()(at::TensorIterator&)::{lambda(float, float)#1}>, unsigned int, float, 4> >(at::native::ReduceOp<float, at::native::func_wrapper_t<float, at::native::sum_functor<float, float, float>::operator()(at::TensorIterator&)::{lambda(float, float)#1}>, unsigned int, float, 4>) | ||
CPU Time | 0.00 | μs |
Device Time | 220042.45 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 220042.45 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::item | ||
CPU Time | 893076.14 | μs |
Device Time | 21993.44 | μs |
Self CPU Time | 8800.60 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_local_scalar_dense | ||
CPU Time | 884275.54 | μs |
Device Time | 21993.44 | μs |
Self CPU Time | 25476.96 | μs |
Self Device Time | 21993.44 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::zero_ | ||
CPU Time | 88416.97 | μs |
Device Time | 884106.50 | μs |
Self CPU Time | 19708.34 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::fill_ | ||
CPU Time | 68710.28 | μs |
Device Time | 884106.50 | μs |
Self CPU Time | 25791.27 | μs |
Self Device Time | 884106.50 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>) | ||
CPU Time | 0.00 | μs |
Device Time | 884184.93 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 884184.93 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45289 warnings generated when compiling for host. Suppressed 45324 warnings (45277 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.