import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
"""
Applies linear transformation, GELU activation, and softmax.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features)
weight (torch.Tensor): Weight matrix of shape (out_features, in_features)
bias (torch.Tensor): Bias vector of shape (out_features)
Returns:
torch.Tensor: Output tensor after applying linear, GELU and softmax,
with shape (batch_size, out_features)
"""
x = F.linear(x, weight, bias)
x = F.gelu(x)
x = F.softmax(x, dim=1)
return x
class Model(nn.Module):
"""
Simple model that performs a matrix multiplication, applies GELU, and then applies Softmax.
"""
def __init__(self, in_features, out_features):
super(Model, self).__init__()
gemm = nn.Linear(in_features, out_features)
self.weight = gemm.weight
self.bias = gemm.bias
def forward(self, x, fn=module_fn):
return fn(x, self.weight, self.bias)
batch_size = 128
in_features = 100
out_features = 10
def get_inputs():
return [torch.randn(batch_size, in_features)]
def get_init_inputs():
return [in_features, out_features]
import torch
import torch.nn as nn
class Model(nn.Module):
"""
Simple model that performs a matrix multiplication, applies GELU, and then applies Softmax.
"""
def __init__(self, in_features, out_features):
super(Model, self).__init__()
self.linear = nn.Linear(in_features, out_features)
def forward(self, x):
x = self.linear(x)
x = torch.nn.functional.gelu(x)
x = torch.nn.functional.softmax(x, dim=1)
return x
batch_size = 128
in_features = 100
out_features = 10
def get_inputs():
return [torch.randn(batch_size, in_features)]
def get_init_inputs():
return [in_features, out_features]
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <float.h>
// Approximate GELU activation function
__device__ float gelu(float x) {
const float sqrt_2_over_pi = 0.7978845608028654f;
const float coef = 0.044715f;
float tanh_val = tanhf(sqrt_2_over_pi * x * (1.0f + coef * x * x));
return 0.5f * x * (1.0f + tanh_val);
}
// Fused kernel: Performs linear transformation, applies GELU activation, and softmax normalization.
// It preloads the input row into shared memory for re-use in the dot product, and then reuses a shared buffer
// for both max and sum reduction for softmax normalization. The use of __restrict__ and padded thread counts
// (multiple of 32) ensures coalesced memory accesses and efficient reductions.
__global__ void fused_opt_kernel(
const float* __restrict__ x,
const float* __restrict__ weight,
const float* __restrict__ bias,
float* __restrict__ output,
int batch_size,
int in_features,
int out_features
) {
// The blockDim.x is assumed to be a padded value (multiple of 32) covering at least out_features threads
int padded = blockDim.x;
// Allocate shared memory: first region for the input row, second region for storing activation values
// for reduction (max and sum). Total size = (in_features + padded) floats.
extern __shared__ float shared_mem[];
float* s_x = shared_mem; // Holds one row of input x (size: in_features)
float* s_buffer = shared_mem + in_features; // Used for storing dot-product results & softmax reduction (size: padded)
int row = blockIdx.x; // Each block processes one row in the batch
int tid = threadIdx.x;
// 1. Load the input row from global memory into shared memory to reduce redundant accesses
for (int i = tid; i < in_features; i += padded) {
s_x[i] = x[row * in_features + i];
}
__syncthreads();
// 2. Each thread computes the dot-product for its assigned output feature, applies bias, and GELU activation.
float act = -FLT_MAX;
if (tid < out_features) {
float sum = 0.0f;
for (int k = 0; k < in_features; k++) {
sum += s_x[k] * weight[tid * in_features + k];
}
sum += bias[tid];
act = gelu(sum);
}
// Store the result (or a sentinel for invalid threads) in the reduction buffer
s_buffer[tid] = (tid < out_features) ? act : -FLT_MAX;
__syncthreads();
// 3. Reduction to compute the maximum activated value for numerical stability in softmax
for (int stride = padded / 2; stride > 0; stride /= 2) {
if (tid < stride) {
s_buffer[tid] = max(s_buffer[tid], s_buffer[tid + stride]);
}
__syncthreads();
}
float row_max = s_buffer[0];
__syncthreads();
// 4. Compute exponentials with the stabilized activation result and store back in the buffer
float exp_val = 0.0f;
if (tid < out_features) {
exp_val = expf(act - row_max);
s_buffer[tid] = exp_val;
} else {
s_buffer[tid] = 0.0f;
}
__syncthreads();
// 5. Reduction to compute the sum of exponentials
for (int stride = padded / 2; stride > 0; stride /= 2) {
if (tid < stride) {
s_buffer[tid] += s_buffer[tid + stride];
}
__syncthreads();
}
float sum_exp = s_buffer[0];
__syncthreads();
// 6. Normalize the exponential to produce the softmax output
if (tid < out_features) {
output[row * out_features + tid] = exp_val / sum_exp;
}
}
// Host-side forward function for launching the kernel
// It sets up a padded thread count that is a multiple of 32 and allocates shared memory for both the input row
// and the softmax reduction buffer.
torch::Tensor forward(
torch::Tensor x,
torch::Tensor weight,
torch::Tensor bias
) {
int batch_size = x.size(0);
int in_features = x.size(1);
int out_features = weight.size(0);
auto options = torch::TensorOptions().dtype(x.dtype()).device(x.device());
auto output = torch::empty({batch_size, out_features}, options);
// Determine padded thread count: round up out_features to the next multiple of 32
int threads = ((out_features + 31) / 32) * 32;
dim3 blocks(batch_size);
dim3 threadBlock(threads);
// Shared memory size: space for one input row (in_features floats) + softmax buffer (threads floats)
int shared_mem_size = (in_features + threads) * sizeof(float);
fused_opt_kernel<<<blocks, threadBlock, shared_mem_size>>>(
x.data_ptr<float>(),
weight.data_ptr<float>(),
bias.data_ptr<float>(),
output.data_ptr<float>(),
batch_size,
in_features,
out_features
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused Linear + GELU + Softmax forward (optimized)");
}
Metric | Value | Unit | Variance | Samples |
---|---|---|---|---|
Executed Ipc Active | 0.060 | inst/cycle | 0.000 | 5 |
Executed Ipc Elapsed | 0.040 | inst/cycle | 0.000 | 5 |
Issue Slots Busy | 1.552 | % | 0.001 | 5 |
Issued Ipc Active | 0.060 | inst/cycle | 0.000 | 5 |
SM Busy | 1.552 | % | 0.001 | 5 |
Memory Throughput | 7549785234.974 | byte/second | 4776057634224563.000 | 5 |
Mem Busy | 4.954 | % | 0.008 | 5 |
Max Bandwidth | 3.196 | % | 0.001 | 5 |
L1/TEX Hit Rate | 86.080 | % | 0.000 | 5 |
L2 Hit Rate | 98.388 | % | 1.308 | 5 |
Mem Pipes Busy | 1.300 | % | 0.000 | 5 |
Warp Cycles Per Issued Instruction | 15.930 | cycle | 0.054 | 5 |
Warp Cycles Per Executed Instruction | 16.056 | cycle | 0.054 | 5 |
Avg. Active Threads Per Warp | 17.530 | 0.000 | 5 | |
Avg. Not Predicated Off Threads Per Warp | 14.520 | 0.000 | 5 | |
Max Active Clusters | 0.000 | cluster | 0.000 | 5 |
Max Cluster Size | 8.000 | block | 0.000 | 5 |
Overall GPU Occupancy | 0.000 | % | 0.000 | 5 |
Cluster Occupancy | 0.000 | % | 0.000 | 5 |
Block Limit SM | 32.000 | block | 0.000 | 5 |
Block Limit Registers | 64.000 | block | 0.000 | 5 |
Block Limit Shared Mem | 39.000 | block | 0.000 | 5 |
Block Limit Warps | 64.000 | block | 0.000 | 5 |
Theoretical Active Warps per SM | 32.000 | warp | 0.000 | 5 |
Theoretical Occupancy | 50.000 | % | 0.000 | 5 |
Achieved Occupancy | 1.560 | % | 0.000 | 5 |
Achieved Active Warps Per SM | 1.000 | warp | 0.000 | 5 |
Rule | Description |
---|---|
WRN HighPipeUtilization | All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details. |
INF CPIStall | Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason. |
WRN ThreadDivergence | Instructions are executed in warps, which are groups of 32 threads. Optimal instruction throughput is achieved if all 32 threads of a warp execute the same instruction. The chosen launch configuration, early thread completion, and divergent flow control can significantly lower the number of active threads in a warp per cycle. This kernel achieves an average of 17.5 threads being active per cycle. This is further reduced to 14.5 threads per warp due to predication. The compiler may use predication to avoid an actual branch. Instead, all instructions are scheduled, but a per-thread condition code or predicate controls which threads execute the instructions. Try to avoid different execution paths within a warp when possible. In addition, ensure your kernel makes use of Independent Thread Scheduling, which allows a warp to reconverge after a data-dependent conditional block by explicitly calling __syncwarp(). |
WRN Occupancy | This kernel's theoretical occupancy (50.0%) is limited by the number of blocks that can fit on the SM. The difference between calculated theoretical (50.0%) and measured achieved occupancy (1.6%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy. |
Operation / Metric | Value | Unit |
---|---|---|
aten::to | ||
CPU Time | 759949.28 | μs |
Device Time | 8.48 | μs |
Self CPU Time | 50.51 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_to_copy | ||
CPU Time | 759898.77 | μs |
Device Time | 8.48 | μs |
Self CPU Time | 104.26 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::empty_strided | ||
CPU Time | 759646.91 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 111.33 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaDeviceGetStreamPriorityRange | ||
CPU Time | 755641.01 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 755641.01 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaLaunchKernel | ||
CPU Time | 506738.47 | μs |
Device Time | 22099.92 | μs |
Self CPU Time | 506738.47 | μs |
Self Device Time | 22099.92 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
fused_opt_kernel(float const*, float const*, float const*, float*, int, int, int) | ||
CPU Time | 0.00 | μs |
Device Time | 59753.64 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 59753.64 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaEventRecord | ||
CPU Time | 18891.50 | μs |
Device Time | 40985.82 | μs |
Self CPU Time | 18891.50 | μs |
Self Device Time | 40985.82 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::zero_ | ||
CPU Time | 63054.85 | μs |
Device Time | 611992.73 | μs |
Self CPU Time | 14138.22 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::fill_ | ||
CPU Time | 48917.73 | μs |
Device Time | 611992.73 | μs |
Self CPU Time | 15292.25 | μs |
Self Device Time | 611992.73 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>) | ||
CPU Time | 0.00 | μs |
Device Time | 611992.73 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 611992.73 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45288 warnings generated when compiling for host. Suppressed 45323 warnings (45276 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.