22_Matmul_Scale_ResidualAdd_Clamp_LogSumExp_Mish
• fused_kernel_base
import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(
x: torch.Tensor,
scale_factor: float,
clamp_min: float,
clamp_max: float,
weight: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
"""
Applies matrix multiplication, scaling, residual connection, clamping, LogSumExp and Mish activation.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, input_size)
scale_factor (float): Factor to scale the output by
clamp_min (float): Minimum value for clamping
clamp_max (float): Maximum value for clamping
weight (torch.Tensor): Weight matrix of shape (hidden_size, input_size)
bias (torch.Tensor): Bias vector of shape (hidden_size)
Returns:
torch.Tensor: Output tensor of shape (batch_size, hidden_size)
"""
x = F.linear(x, weight, bias)
x = x * scale_factor
x = x + x
x = torch.clamp(x, clamp_min, clamp_max)
x = torch.logsumexp(x, dim=1, keepdim=True)
x = x * F.mish(x)
return x
class Model(nn.Module):
"""
Model that performs a matrix multiplication, scales the result, adds a residual connection, clamps the output,
applies LogSumExp, and finally applies the Mish activation function.
"""
def __init__(self, input_size, hidden_size, scale_factor, clamp_min, clamp_max):
super(Model, self).__init__()
matmul = nn.Linear(input_size, hidden_size)
self.weight = matmul.weight
self.bias = nn.Parameter(
matmul.bias + torch.ones_like(matmul.bias) * 0.02
) # make sure its nonzero
def forward(self, x, scale_factor, clamp_min, clamp_max, fn=module_fn):
return fn(x, scale_factor, clamp_min, clamp_max, self.weight, self.bias)
batch_size = 128
input_size = 512
hidden_size = 1024
scale_factor = 2.0
clamp_min = -10.0
clamp_max = 10.0
def get_inputs():
return [torch.randn(batch_size, input_size), scale_factor, clamp_min, clamp_max]
def get_init_inputs():
return [input_size, hidden_size, scale_factor, clamp_min, clamp_max]
import torch
import torch.nn as nn
class Model(nn.Module):
"""
Model that performs a matrix multiplication, scales the result, adds a residual connection, clamps the output,
applies LogSumExp, and finally applies the Mish activation function.
"""
def __init__(self, input_size, hidden_size, scale_factor, clamp_min, clamp_max):
super(Model, self).__init__()
self.matmul = nn.Linear(input_size, hidden_size)
self.matmul.bias = nn.Parameter(self.matmul.bias + torch.ones_like(self.matmul.bias) * 0.02)
self.scale_factor = scale_factor
self.clamp_min = clamp_min
self.clamp_max = clamp_max
def forward(self, x):
"""
Args:
x: Input tensor of shape (batch_size, input_size).
Returns:
Output tensor of shape (batch_size, hidden_size).
"""
x = self.matmul(x)
x = x * self.scale_factor
x = x + x
x = torch.clamp(x, self.clamp_min, self.clamp_max)
x = torch.logsumexp(x, dim=1, keepdim=True)
x = x * torch.nn.functional.mish(x) # Mish activation
return x
batch_size = 128
input_size = 512
hidden_size = 1024
scale_factor = 2.0
clamp_min = -10.0
clamp_max = 10.0
def get_inputs():
return [torch.randn(batch_size, input_size)]
def get_init_inputs():
return [input_size, hidden_size, scale_factor, clamp_min, clamp_max]
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <cuda_runtime.h>
#include <vector>
#include <math.h>
// Define tile size for matrix multiplication tiling
#define TILE_SIZE 16
// Kernel 1: Perform tiled matrix multiplication and apply per-element transformations
// (add bias, scale, residual addition, and clamping) and write the intermediate result.
__global__ void gemm_post_kernel(
const float* __restrict__ input,
const float* __restrict__ weight,
const float* __restrict__ bias,
float* __restrict__ intermediate,
int batch_size,
int input_size,
int hidden_size,
float scale_factor,
float clamp_min,
float clamp_max
) {
// Shared memory tiles for input and weight
__shared__ float input_shared[TILE_SIZE][TILE_SIZE];
__shared__ float weight_shared[TILE_SIZE][TILE_SIZE];
// Calculate row and column indices for the output matrix
int row = blockIdx.y * TILE_SIZE + threadIdx.y;
int col = blockIdx.x * TILE_SIZE + threadIdx.x;
float sum = 0.0f;
int num_tiles = (input_size + TILE_SIZE - 1) / TILE_SIZE;
// Loop over tiles
for (int t = 0; t < num_tiles; t++) {
int input_col = t * TILE_SIZE + threadIdx.x;
if (row < batch_size && input_col < input_size) {
input_shared[threadIdx.y][threadIdx.x] = input[row * input_size + input_col];
} else {
input_shared[threadIdx.y][threadIdx.x] = 0.0f;
}
int weight_row = t * TILE_SIZE + threadIdx.y;
if (col < hidden_size && weight_row < input_size) {
// Note: weight is stored as [hidden_size, input_size], and we need its transpose for multiplication
weight_shared[threadIdx.y][threadIdx.x] = weight[col * input_size + weight_row];
} else {
weight_shared[threadIdx.y][threadIdx.x] = 0.0f;
}
__syncthreads();
// Multiply the two tiles together
#pragma unroll
for (int i = 0; i < TILE_SIZE; i++) {
sum += input_shared[threadIdx.y][i] * weight_shared[i][threadIdx.x];
}
__syncthreads();
}
// Write back the computed value with post-operations
if (row < batch_size && col < hidden_size) {
sum += bias[col]; // Add bias
sum *= scale_factor; // Scale
sum += sum; // Residual addition (doubling the value)
sum = fmaxf(fminf(sum, clamp_max), clamp_min); // Clamp between clamp_min and clamp_max
intermediate[row * hidden_size + col] = sum;
}
}
// Kernel 2: For each row, compute logsumexp followed by a Mish activation and final multiplication.
// This kernel performs a two-phase reduction: first to compute the maximum (for numerical stability)
// and then to compute the sum of exponentials. Finally, it computes the logsumexp and applies Mish.
__global__ void row_logsumexp_mish_kernel(
const float* __restrict__ intermediate,
float* __restrict__ output,
int hidden_size,
int batch_size
) {
int row = blockIdx.x; // one block per row
if (row >= batch_size) return;
extern __shared__ float shared_data[]; // shared memory for reduction; size = blockDim.x
int tid = threadIdx.x;
// First reduction: compute the maximum value in this row for numerical stability
float max_val = -INFINITY;
for (int i = tid; i < hidden_size; i += blockDim.x) {
float val = intermediate[row * hidden_size + i];
if (val > max_val)
max_val = val;
}
shared_data[tid] = max_val;
__syncthreads();
// Reduce to get the row maximum
for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
float other = shared_data[tid + s];
if (other > shared_data[tid])
shared_data[tid] = other;
}
__syncthreads();
}
max_val = shared_data[0];
__syncthreads();
// Second reduction: compute the sum of exp(val - max_val) for the row
float sum_exp = 0.0f;
for (int i = tid; i < hidden_size; i += blockDim.x) {
float val = intermediate[row * hidden_size + i];
sum_exp += expf(val - max_val);
}
shared_data[tid] = sum_exp;
__syncthreads();
// Reduce to get the total sum of exponentials
for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_data[tid] += shared_data[tid + s];
}
__syncthreads();
}
sum_exp = shared_data[0];
// Thread 0 computes the logsumexp and applies Mish activation
if (tid == 0) {
float lse = max_val + logf(sum_exp);
// Compute Mish: mish(x) = x * tanh(softplus(x)) where softplus(x) = log(1+exp(x))
float softplus = logf(1.0f + expf(lse));
float mish_val = lse * tanhf(softplus);
// Final output: multiply lse with mish value
output[row] = lse * mish_val;
}
}
// Host function that fuses the overall computation on the GPU
// 1. Launch gemm_post_kernel to perform matrix multiplication with elementwise ops.
// 2. Launch row_logsumexp_mish_kernel to compute per-row reductions (logsumexp) and apply Mish.
torch::Tensor fused_forward(
torch::Tensor x,
float scale_factor,
float clamp_min,
float clamp_max,
torch::Tensor weight,
torch::Tensor bias
) {
// Dimensions
int batch_size = x.size(0);
int input_size = x.size(1);
int hidden_size = weight.size(0); // weight shape: [hidden_size, input_size]
// Allocate intermediate tensor for the output of GEMM (size [batch_size, hidden_size])
auto options = torch::TensorOptions().dtype(x.dtype()).device(x.device());
auto intermediate = torch::empty({batch_size, hidden_size}, options);
// Launch gemm_post_kernel with a 2D grid and 2D block for tiled matrix multiplication
dim3 block(TILE_SIZE, TILE_SIZE);
int grid_x = (hidden_size + TILE_SIZE - 1) / TILE_SIZE;
int grid_y = (batch_size + TILE_SIZE - 1) / TILE_SIZE;
dim3 grid(grid_x, grid_y);
gemm_post_kernel<<<grid, block>>>(
x.data_ptr<float>(),
weight.data_ptr<float>(),
bias.data_ptr<float>(),
intermediate.data_ptr<float>(),
batch_size,
input_size,
hidden_size,
scale_factor,
clamp_min,
clamp_max
);
// Allocate output tensor to hold one value per batch (final shape: [batch_size, 1])
auto output = torch::empty({batch_size}, options);
// Launch row_logsumexp_mish_kernel with one block per row.
// Use a blockDim of 256 threads and allocate shared memory accordingly.
int threads = 256;
int shared_mem = threads * sizeof(float);
row_logsumexp_mish_kernel<<<batch_size, threads, shared_mem>>>(
intermediate.data_ptr<float>(),
output.data_ptr<float>(),
hidden_size,
batch_size
);
// Reshape the output to [batch_size, 1] to match the expected final dimensions
return output.reshape({batch_size, 1});
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &fused_forward, "Fused forward pass for module_fn (CUDA)");
}
Metric | Value | Unit | Variance | Samples |
---|---|---|---|---|
Executed Ipc Active | 0.338 | inst/cycle | 0.000 | 5 |
Executed Ipc Elapsed | 0.200 | inst/cycle | 0.000 | 5 |
Issue Slots Busy | 8.538 | % | 0.026 | 5 |
Issued Ipc Active | 0.342 | inst/cycle | 0.000 | 5 |
SM Busy | 8.538 | % | 0.026 | 5 |
Memory Throughput | 80676714385.724 | byte/second | 383907429681214144.000 | 5 |
Mem Busy | 5.546 | % | 0.002 | 5 |
Max Bandwidth | 3.056 | % | 0.001 | 5 |
L1/TEX Hit Rate | 49.810 | % | 0.000 | 5 |
L2 Hit Rate | 65.128 | % | 0.064 | 5 |
Mem Pipes Busy | 4.226 | % | 0.001 | 5 |
Warp Cycles Per Issued Instruction | 22.222 | cycle | 0.008 | 5 |
Warp Cycles Per Executed Instruction | 22.430 | cycle | 0.008 | 5 |
Avg. Active Threads Per Warp | 31.130 | 0.000 | 5 | |
Avg. Not Predicated Off Threads Per Warp | 25.580 | 0.000 | 5 | |
Max Active Clusters | 0.000 | cluster | 0.000 | 5 |
Max Cluster Size | 8.000 | block | 0.000 | 5 |
Overall GPU Occupancy | 0.000 | % | 0.000 | 5 |
Cluster Occupancy | 0.000 | % | 0.000 | 5 |
Block Limit SM | 32.000 | block | 0.000 | 5 |
Block Limit Registers | 16.000 | block | 0.000 | 5 |
Block Limit Shared Mem | 16.000 | block | 0.000 | 5 |
Block Limit Warps | 8.000 | block | 0.000 | 5 |
Theoretical Active Warps per SM | 64.000 | warp | 0.000 | 5 |
Theoretical Occupancy | 100.000 | % | 0.000 | 5 |
Achieved Occupancy | 11.994 | % | 0.000 | 5 |
Achieved Active Warps Per SM | 7.676 | warp | 0.000 | 5 |
Rule | Description |
---|---|
WRN HighPipeUtilization | All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details. |
INF CPIStall | Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason. |
WRN Occupancy | This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (12.0%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy. |
Operation / Metric | Value | Unit |
---|---|---|
aten::to | ||
CPU Time | 325027.76 | μs |
Device Time | 170.69 | μs |
Self CPU Time | 60.31 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_to_copy | ||
CPU Time | 324967.45 | μs |
Device Time | 170.69 | μs |
Self CPU Time | 130.41 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::empty_strided | ||
CPU Time | 324388.20 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 134.71 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaDeviceGetStreamPriorityRange | ||
CPU Time | 284395.27 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 284395.27 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::fill_ | ||
CPU Time | 73013.61 | μs |
Device Time | 623129.70 | μs |
Self CPU Time | 16048.95 | μs |
Self Device Time | 623129.70 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaLaunchKernel | ||
CPU Time | 569897.14 | μs |
Device Time | 35242.64 | μs |
Self CPU Time | 569897.14 | μs |
Self Device Time | 35242.64 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
gemm_post_kernel(float const*, float const*, float const*, float*, int, int, int, float, float, float) | ||
CPU Time | 0.00 | μs |
Device Time | 244704.63 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 244704.63 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
row_logsumexp_mish_kernel(float const*, float*, int, int) | ||
CPU Time | 0.00 | μs |
Device Time | 39144.75 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 39144.75 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::zero_ | ||
CPU Time | 85182.56 | μs |
Device Time | 623129.70 | μs |
Self CPU Time | 12183.88 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>) | ||
CPU Time | 0.00 | μs |
Device Time | 623208.14 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 623208.14 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45297 warnings generated when compiling for host. Suppressed 45326 warnings (45279 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.