80_Gemm_Max_Subtract_GELU
• minimal_sync_optimized_kernel_base_base
import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(
x: torch.Tensor,
max_dim: int,
weight: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
"""
Performs a GEMM, followed by a max operation, subtraction, and GELU activation.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features)
max_dim (int): Dimension to perform max operation over
weight (torch.Tensor): Weight matrix of shape (out_features, in_features)
bias (torch.Tensor): Bias vector of shape (out_features)
Returns:
torch.Tensor: Output tensor of shape (batch_size, out_features)
"""
x = F.linear(x, weight, bias)
x = torch.max(x, dim=max_dim, keepdim=True).values
x = x - x.mean(dim=1, keepdim=True)
x = F.gelu(x)
return x
class Model(nn.Module):
"""
Model that performs a GEMM, followed by a max operation, subtraction, and GELU activation.
"""
def __init__(self, in_features, out_features, max_dim):
super(Model, self).__init__()
gemm = nn.Linear(in_features, out_features)
self.weight = nn.Parameter(gemm.weight)
self.bias = nn.Parameter(gemm.bias)
def forward(self, x, max_dim, fn=module_fn):
return fn(x, max_dim, self.weight, self.bias)
batch_size = 128
in_features = 512
out_features = 1024
max_dim = 1
def get_inputs():
return [torch.randn(batch_size, in_features), max_dim]
def get_init_inputs():
return [in_features, out_features, max_dim]
import torch
import torch.nn as nn
class Model(nn.Module):
"""
Model that performs a GEMM, followed by a max operation, subtraction, and GELU activation.
"""
def __init__(self, in_features, out_features, max_dim):
super(Model, self).__init__()
self.gemm = nn.Linear(in_features, out_features)
self.max_dim = max_dim
def forward(self, x):
"""
Args:
x: Input tensor of shape (batch_size, in_features)
Returns:
Output tensor of shape (batch_size, out_features)
"""
x = self.gemm(x)
x = torch.max(x, dim=self.max_dim, keepdim=True).values
x = x - x.mean(dim=1, keepdim=True)
x = torch.nn.functional.gelu(x)
return x
batch_size = 128
in_features = 512
out_features = 1024
max_dim = 1
def get_inputs():
return [torch.randn(batch_size, in_features)]
def get_init_inputs():
return [in_features, out_features, max_dim]
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>
#include <float.h>
#define GEMM_BLOCK_DIM 32
#define REDUCE_BLOCK_SIZE 512
#define WARP_SIZE 32
__device__ inline float gelu(float x) {
return 0.5f * x * (1.0f + erf(x * 0.70710678118654752440f));
}
__global__ void minimal_sync_gemm_kernel(const float* __restrict__ x,
const float* __restrict__ weight,
const float* __restrict__ bias,
float* __restrict__ y,
int batch, int in_features, int out_features) {
__shared__ float tile_x[GEMM_BLOCK_DIM][GEMM_BLOCK_DIM];
__shared__ float tile_w[GEMM_BLOCK_DIM][GEMM_BLOCK_DIM];
const int row = blockIdx.y * GEMM_BLOCK_DIM + threadIdx.y;
const int col = blockIdx.x * GEMM_BLOCK_DIM + threadIdx.x;
float sum = 0.0f;
#pragma unroll 4
for (int t = 0; t < (in_features + GEMM_BLOCK_DIM - 1) / GEMM_BLOCK_DIM; t++) {
const int idx = t * GEMM_BLOCK_DIM + threadIdx.x;
const int idy = t * GEMM_BLOCK_DIM + threadIdx.y;
// Load tiles with a single synchronization point
tile_x[threadIdx.y][threadIdx.x] = (row < batch && idx < in_features) ?
x[row * in_features + idx] : 0.0f;
tile_w[threadIdx.y][threadIdx.x] = (col < out_features && idy < in_features) ?
weight[col * in_features + idy] : 0.0f;
// Single sync after both loads
__syncthreads();
#pragma unroll
for (int k = 0; k < GEMM_BLOCK_DIM; k++) {
sum += tile_x[threadIdx.y][k] * tile_w[k][threadIdx.x];
}
// Single sync at end of tile processing
__syncthreads();
}
if (row < batch && col < out_features) {
y[row * out_features + col] = sum + bias[col];
}
}
__device__ inline float warp_reduce_max(float val) {
#pragma unroll
for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {
val = fmaxf(val, __shfl_down_sync(0xffffffff, val, offset));
}
return val;
}
__global__ void minimal_sync_reduce_gelu_kernel(float* data,
int batch,
int out_features,
int max_dim) {
__shared__ float sdata[REDUCE_BLOCK_SIZE];
const int tid = threadIdx.x;
const int lane_id = tid % WARP_SIZE;
const int warp_id = tid / WARP_SIZE;
float max_val = -FLT_MAX;
if (max_dim == 0) {
const int col = blockIdx.x;
// Grid-stride loop with minimal synchronization
#pragma unroll 2
for (int i = tid; i < batch; i += REDUCE_BLOCK_SIZE) {
max_val = fmaxf(max_val, data[i * out_features + col]);
}
} else {
const int row = blockIdx.x;
#pragma unroll 2
for (int j = tid; j < out_features; j += REDUCE_BLOCK_SIZE) {
max_val = fmaxf(max_val, data[row * out_features + j]);
}
}
// Warp-level reduction first (no sync needed within warp)
max_val = warp_reduce_max(max_val);
// Only the first thread in each warp writes to shared memory
if (lane_id == 0) {
sdata[warp_id] = max_val;
}
// Single sync before final reduction
__syncthreads();
// Final reduction using only the first warp
if (warp_id == 0) {
max_val = (tid < (REDUCE_BLOCK_SIZE / WARP_SIZE)) ? sdata[tid] : -FLT_MAX;
max_val = warp_reduce_max(max_val);
if (tid == 0) {
sdata[0] = max_val;
sdata[1] = max_val / (max_dim == 0 ? batch : out_features);
}
}
// Single sync before reading results
__syncthreads();
const float mean = sdata[1];
// Apply GELU without additional synchronization
if (max_dim == 0) {
const int col = blockIdx.x;
#pragma unroll 2
for (int i = tid; i < batch; i += REDUCE_BLOCK_SIZE) {
const int idx = i * out_features + col;
data[idx] = gelu(data[idx] - mean);
}
} else {
const int row = blockIdx.x;
#pragma unroll 2
for (int j = tid; j < out_features; j += REDUCE_BLOCK_SIZE) {
const int idx = row * out_features + j;
data[idx] = gelu(data[idx] - mean);
}
}
}
torch::Tensor forward(torch::Tensor x, int max_dim, torch::Tensor weight, torch::Tensor bias) {
const int batch = x.size(0);
const int in_features = x.size(1);
const int out_features = weight.size(0);
auto y = torch::empty({batch, out_features}, x.options());
dim3 blockDimGEMM(GEMM_BLOCK_DIM, GEMM_BLOCK_DIM);
dim3 gridDimGEMM((out_features + GEMM_BLOCK_DIM - 1) / GEMM_BLOCK_DIM,
(batch + GEMM_BLOCK_DIM - 1) / GEMM_BLOCK_DIM);
minimal_sync_gemm_kernel<<<gridDimGEMM, blockDimGEMM>>>(
x.data_ptr<float>(),
weight.data_ptr<float>(),
bias.data_ptr<float>(),
y.data_ptr<float>(),
batch, in_features, out_features
);
auto max_out = torch::empty({max_dim == 0 ? 1 : batch, max_dim == 0 ? out_features : 1}, y.options());
const int gridDim = max_dim == 0 ? out_features : batch;
const int sharedMem = REDUCE_BLOCK_SIZE * sizeof(float);
minimal_sync_reduce_gelu_kernel<<<gridDim, REDUCE_BLOCK_SIZE, sharedMem>>>(
max_out.data_ptr<float>(),
batch,
out_features,
max_dim
);
return max_out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Minimal sync optimized GEMM and reduction");
}
Metric | Value | Unit | Variance | Samples |
---|---|---|---|---|
Executed Ipc Active | 0.932 | inst/cycle | 0.001 | 5 |
Executed Ipc Elapsed | 0.402 | inst/cycle | 0.000 | 5 |
Issue Slots Busy | 23.828 | % | 0.880 | 5 |
Issued Ipc Active | 0.952 | inst/cycle | 0.001 | 5 |
SM Busy | 23.828 | % | 0.880 | 5 |
Memory Throughput | 129687694143.194 | byte/second | 4235837368560357376.000 | 5 |
Mem Busy | 10.154 | % | 0.022 | 5 |
Max Bandwidth | 7.738 | % | 0.013 | 5 |
L1/TEX Hit Rate | 66.670 | % | 0.000 | 5 |
L2 Hit Rate | 75.626 | % | 0.074 | 5 |
Mem Pipes Busy | 5.114 | % | 0.006 | 5 |
Warp Cycles Per Issued Instruction | 16.420 | cycle | 0.405 | 5 |
Warp Cycles Per Executed Instruction | 16.820 | cycle | 0.425 | 5 |
Avg. Active Threads Per Warp | 31.880 | 0.000 | 5 | |
Avg. Not Predicated Off Threads Per Warp | 27.690 | 0.000 | 5 | |
Max Active Clusters | 0.000 | cluster | 0.000 | 5 |
Max Cluster Size | 8.000 | block | 0.000 | 5 |
Overall GPU Occupancy | 0.000 | % | 0.000 | 5 |
Cluster Occupancy | 0.000 | % | 0.000 | 5 |
Block Limit SM | 32.000 | block | 0.000 | 5 |
Block Limit Registers | 5.000 | block | 0.000 | 5 |
Block Limit Shared Mem | 12.000 | block | 0.000 | 5 |
Block Limit Warps | 4.000 | block | 0.000 | 5 |
Theoretical Active Warps per SM | 64.000 | warp | 0.000 | 5 |
Theoretical Occupancy | 100.000 | % | 0.000 | 5 |
Achieved Occupancy | 24.612 | % | 0.000 | 5 |
Achieved Active Warps Per SM | 15.752 | warp | 0.000 | 5 |
Rule | Description |
---|---|
WRN HighPipeUtilization | All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details. |
WRN Occupancy | This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (24.6%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy. |
Operation / Metric | Value | Unit |
---|---|---|
aten::to | ||
CPU Time | 260279.09 | μs |
Device Time | 175.61 | μs |
Self CPU Time | 54.33 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_to_copy | ||
CPU Time | 260224.76 | μs |
Device Time | 175.61 | μs |
Self CPU Time | 109.66 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::empty_strided | ||
CPU Time | 259624.05 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 123.49 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaDeviceGetStreamPriorityRange | ||
CPU Time | 259052.88 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 259052.88 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaLaunchKernel | ||
CPU Time | 597703.42 | μs |
Device Time | 37084.56 | μs |
Self CPU Time | 597703.42 | μs |
Self Device Time | 37084.56 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
minimal_sync_gemm_kernel(float const*, float const*, float const*, float*, int, int, int) | ||
CPU Time | 0.00 | μs |
Device Time | 230863.30 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 230863.30 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::zero_ | ||
CPU Time | 78162.56 | μs |
Device Time | 636794.10 | μs |
Self CPU Time | 12772.89 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::fill_ | ||
CPU Time | 65392.21 | μs |
Device Time | 636794.10 | μs |
Self CPU Time | 16443.67 | μs |
Self Device Time | 636794.10 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>) | ||
CPU Time | 0.00 | μs |
Device Time | 636794.10 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 636794.10 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45294 warnings generated when compiling for host. Suppressed 45324 warnings (45277 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.