99_Matmul_GELU_Softmax
• warp_optimized_fused_kernel_base_base
import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
"""
Applies linear transformation, GELU activation, and softmax.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features)
weight (torch.Tensor): Weight matrix of shape (out_features, in_features)
bias (torch.Tensor): Bias vector of shape (out_features)
Returns:
torch.Tensor: Output tensor after applying linear, GELU and softmax,
with shape (batch_size, out_features)
"""
x = F.linear(x, weight, bias)
x = F.gelu(x)
x = F.softmax(x, dim=1)
return x
class Model(nn.Module):
"""
Simple model that performs a matrix multiplication, applies GELU, and then applies Softmax.
"""
def __init__(self, in_features, out_features):
super(Model, self).__init__()
gemm = nn.Linear(in_features, out_features)
self.weight = gemm.weight
self.bias = gemm.bias
def forward(self, x, fn=module_fn):
return fn(x, self.weight, self.bias)
batch_size = 128
in_features = 100
out_features = 10
def get_inputs():
return [torch.randn(batch_size, in_features)]
def get_init_inputs():
return [in_features, out_features]
import torch
import torch.nn as nn
class Model(nn.Module):
"""
Simple model that performs a matrix multiplication, applies GELU, and then applies Softmax.
"""
def __init__(self, in_features, out_features):
super(Model, self).__init__()
self.linear = nn.Linear(in_features, out_features)
def forward(self, x):
x = self.linear(x)
x = torch.nn.functional.gelu(x)
x = torch.nn.functional.softmax(x, dim=1)
return x
batch_size = 128
in_features = 100
out_features = 10
def get_inputs():
return [torch.randn(batch_size, in_features)]
def get_init_inputs():
return [in_features, out_features]
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>
#include <float.h>
// GELU activation function (approximation used in PyTorch)
__device__ float gelu(float x) {
const float sqrt_2_over_pi = 0.7978845608028654f;
const float coef = 0.044715f;
float cdf = 0.5f * (1.0f + tanhf(sqrt_2_over_pi * x * (1.0f + coef * x * x)));
return x * cdf;
}
// Warp-level reduction to compute maximum using __shfl_down_sync
__inline__ __device__ float warpReduceMax(float val) {
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
val = max(val, __shfl_down_sync(0xffffffff, val, offset));
}
return val;
}
// Warp-level reduction to compute sum using __shfl_down_sync
__inline__ __device__ float warpReduceSum(float val) {
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
val += __shfl_down_sync(0xffffffff, val, offset);
}
return val;
}
// Fused kernel: Performs linear transformation, applies GELU activation, and softmax normalization
// Utilizes warp-level primitives to optimize reduction operations, minimizing shared memory usage
__global__ void warp_optimized_fused_kernel(
const float* __restrict__ x,
const float* __restrict__ weight,
const float* __restrict__ bias,
float* __restrict__ output,
int batch_size,
int in_features,
int out_features
) {
int row = blockIdx.x; // Each block processes one row of the batch
int tid = threadIdx.x;
// 1. Compute the dot product for the linear transformation for each valid output feature
float act = 0.0f;
if (tid < out_features) {
float sum = 0.0f;
for (int k = 0; k < in_features; k++) {
sum += x[row * in_features + k] * weight[tid * in_features + k];
}
sum += bias[tid];
act = gelu(sum);
}
// 2. Compute maximum value using warp-level reduction
float max_val = (tid < out_features) ? act : -FLT_MAX;
max_val = warpReduceMax(max_val);
// Use shared memory to store warp-level maxima
__shared__ float smax[32]; // supports up to 32 warps per block (1024 threads/block)
int lane = tid & (warpSize - 1);
int warpId = tid / warpSize;
if (lane == 0) {
smax[warpId] = max_val;
}
__syncthreads();
// First warp reduces the warp maxima
if (tid < warpSize) {
// Only threads corresponding to the number of warps participate
int numWarps = (blockDim.x + warpSize - 1) / warpSize;
float tmp = (tid < numWarps) ? smax[tid] : -FLT_MAX;
tmp = warpReduceMax(tmp);
smax[tid] = tmp;
}
__syncthreads();
float row_max = smax[0];
// 3. Compute the exponentials and sum using warp-level reduction
float exp_val = (tid < out_features) ? expf(act - row_max) : 0.0f;
float sum_val = warpReduceSum(exp_val);
__shared__ float ssum[32];
if (lane == 0) {
ssum[warpId] = sum_val;
}
__syncthreads();
if (tid < warpSize) {
int numWarps = (blockDim.x + warpSize - 1) / warpSize;
float tmp = (tid < numWarps) ? ssum[tid] : 0.0f;
tmp = warpReduceSum(tmp);
ssum[tid] = tmp;
}
__syncthreads();
float sum_exp = ssum[0];
// 4. Write the normalized softmax result for valid output features
if (tid < out_features) {
output[row * out_features + tid] = exp_val / sum_exp;
}
}
// Forward function that wraps the kernel launch
// Sets up the padded thread count and allocates shared memory for warp-level reductions
torch::Tensor forward(
torch::Tensor x,
torch::Tensor weight,
torch::Tensor bias
) {
const int batch_size = x.size(0);
const int in_features = x.size(1);
const int out_features = weight.size(0);
auto options = torch::TensorOptions().dtype(x.dtype()).device(x.device());
auto output = torch::empty({batch_size, out_features}, options);
// Determine padded thread count: round up out_features to the next multiple of 32
int threads = ((out_features + 31) / 32) * 32;
dim3 blocks(batch_size);
dim3 threadBlock(threads);
// Shared memory size: space for warp-level maxima and sums
int shared_mem_size = 2 * 32 * sizeof(float);
warp_optimized_fused_kernel<<<blocks, threadBlock, shared_mem_size>>>(
x.data_ptr<float>(),
weight.data_ptr<float>(),
bias.data_ptr<float>(),
output.data_ptr<float>(),
batch_size,
in_features,
out_features
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused Linear + GELU + Softmax forward with warp-level optimizations");
}
Metric | Value | Unit | Variance | Samples |
---|---|---|---|---|
Executed Ipc Active | 0.060 | inst/cycle | 0.000 | 5 |
Executed Ipc Elapsed | 0.050 | inst/cycle | 0.000 | 5 |
Issue Slots Busy | 1.568 | % | 0.000 | 5 |
Issued Ipc Active | 0.060 | inst/cycle | 0.000 | 5 |
SM Busy | 1.568 | % | 0.000 | 5 |
Memory Throughput | 6694981612.050 | byte/second | 8837190495254405.000 | 5 |
Mem Busy | 3.948 | % | 0.010 | 5 |
Max Bandwidth | 3.126 | % | 0.003 | 5 |
L1/TEX Hit Rate | 87.140 | % | 0.000 | 5 |
L2 Hit Rate | 99.330 | % | 0.943 | 5 |
Mem Pipes Busy | 1.492 | % | 0.000 | 5 |
Warp Cycles Per Issued Instruction | 15.822 | cycle | 0.038 | 5 |
Warp Cycles Per Executed Instruction | 15.908 | cycle | 0.039 | 5 |
Avg. Active Threads Per Warp | 17.850 | 0.000 | 5 | |
Avg. Not Predicated Off Threads Per Warp | 16.230 | 0.000 | 5 | |
Max Active Clusters | 0.000 | cluster | 0.000 | 5 |
Max Cluster Size | 8.000 | block | 0.000 | 5 |
Overall GPU Occupancy | 0.000 | % | 0.000 | 5 |
Cluster Occupancy | 0.000 | % | 0.000 | 5 |
Block Limit SM | 32.000 | block | 0.000 | 5 |
Block Limit Registers | 64.000 | block | 0.000 | 5 |
Block Limit Shared Mem | 42.000 | block | 0.000 | 5 |
Block Limit Warps | 64.000 | block | 0.000 | 5 |
Theoretical Active Warps per SM | 32.000 | warp | 0.000 | 5 |
Theoretical Occupancy | 50.000 | % | 0.000 | 5 |
Achieved Occupancy | 1.560 | % | 0.000 | 5 |
Achieved Active Warps Per SM | 1.000 | warp | 0.000 | 5 |
Rule | Description |
---|---|
WRN HighPipeUtilization | All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details. |
INF CPIStall | Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason. |
WRN ThreadDivergence | Instructions are executed in warps, which are groups of 32 threads. Optimal instruction throughput is achieved if all 32 threads of a warp execute the same instruction. The chosen launch configuration, early thread completion, and divergent flow control can significantly lower the number of active threads in a warp per cycle. This kernel achieves an average of 17.8 threads being active per cycle. This is further reduced to 16.2 threads per warp due to predication. The compiler may use predication to avoid an actual branch. Instead, all instructions are scheduled, but a per-thread condition code or predicate controls which threads execute the instructions. Try to avoid different execution paths within a warp when possible. In addition, ensure your kernel makes use of Independent Thread Scheduling, which allows a warp to reconverge after a data-dependent conditional block by explicitly calling __syncwarp(). |
WRN Occupancy | This kernel's theoretical occupancy (50.0%) is limited by the number of blocks that can fit on the SM. The difference between calculated theoretical (50.0%) and measured achieved occupancy (1.6%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy. |
Operation / Metric | Value | Unit |
---|---|---|
aten::to | ||
CPU Time | 312433.68 | μs |
Device Time | 9.09 | μs |
Self CPU Time | 41.88 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_to_copy | ||
CPU Time | 312391.79 | μs |
Device Time | 9.09 | μs |
Self CPU Time | 91.66 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::empty_strided | ||
CPU Time | 312162.53 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 89.70 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaDeviceGetStreamPriorityRange | ||
CPU Time | 311896.43 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 311896.43 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaLaunchKernel | ||
CPU Time | 512931.79 | μs |
Device Time | 21566.74 | μs |
Self CPU Time | 512931.79 | μs |
Self Device Time | 21566.74 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
warp_optimized_fused_kernel(float const*, float const*, float const*, float*, int, int, int) | ||
CPU Time | 0.00 | μs |
Device Time | 71710.07 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 71710.07 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaEventRecord | ||
CPU Time | 37978.57 | μs |
Device Time | 42920.04 | μs |
Self CPU Time | 37978.57 | μs |
Self Device Time | 42920.04 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::zero_ | ||
CPU Time | 69148.22 | μs |
Device Time | 640754.20 | μs |
Self CPU Time | 14406.23 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::fill_ | ||
CPU Time | 54743.76 | μs |
Device Time | 640754.20 | μs |
Self CPU Time | 17596.85 | μs |
Self Device Time | 640754.20 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>) | ||
CPU Time | 0.00 | μs |
Device Time | 640754.20 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 640754.20 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45289 warnings generated when compiling for host. Suppressed 45323 warnings (45276 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.