import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(x: torch.Tensor, eps: float) -> torch.Tensor:
"""
Applies RMS Normalization to the input tensor.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, num_features, *)
eps (float): Small value added to denominator for numerical stability
Returns:
torch.Tensor: Output tensor with RMS Normalization applied
"""
rms = torch.sqrt(torch.mean(x**2, dim=1, keepdim=True) + eps)
return x / rms
class Model(nn.Module):
"""
Simple model that performs RMS Normalization.
"""
def __init__(self, num_features: int, eps: float):
"""
Initializes the RMSNorm layer.
Args:
num_features (int): Number of features in the input tensor
eps (float): Small value added to denominator for numerical stability
"""
super(Model, self).__init__()
self.eps = eps
def forward(self, x: torch.Tensor, fn=module_fn) -> torch.Tensor:
"""
Forward pass that calls module_fn.
Args:
x (torch.Tensor): Input tensor
fn: Function to call, defaults to module_fn
Returns:
torch.Tensor: Output of module_fn
"""
return fn(x, self.eps)
batch_size = 16
features = 64
dim1 = 256
dim2 = 256
eps = 1e-5
def get_inputs():
x = torch.randn(batch_size, features, dim1, dim2)
return [x]
def get_init_inputs():
return [features, eps]
import torch
import torch.nn as nn
class Model(nn.Module):
"""
Simple model that performs RMS Normalization.
"""
def __init__(self, num_features: int, eps: float = 1e-5):
"""
Initializes the RMSNorm layer.
Args:
num_features (int): Number of features in the input tensor.
eps (float, optional): A small value added to the denominator to avoid division by zero. Defaults to 1e-5.
"""
super(Model, self).__init__()
self.num_features = num_features
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Applies RMS Normalization to the input tensor.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, num_features, *).
Returns:
torch.Tensor: Output tensor with RMS Normalization applied, same shape as input.
"""
# Calculate the RMS along the feature dimension
rms = torch.sqrt(torch.mean(x**2, dim=1, keepdim=True) + self.eps)
# Normalize the input by dividing by the RMS
return x / rms
batch_size = 16
features = 64
dim1 = 256
dim2 = 256
eps = 1e-5
def get_inputs():
x = torch.randn(batch_size, features, dim1, dim2)
return [x]
def get_init_inputs():
return [features, eps]
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
// Combined kernel: uses grid-stride loop over samples and manual unrolling for feature reduction
// This kernel processes each sample (a specific offset in the batch) by accumulating the sum of squares
// with 32-wide unrolling and then normalizes using the computed RMS value.
template <typename scalar_t>
__global__ void rms_norm_kernel_combined(
const scalar_t* __restrict__ input,
scalar_t* __restrict__ output,
const int batch_size,
const int num_features,
const int numel_per_batch,
const float eps
) {
// Each thread processes one sample across features; total_samples = batch_size * numel_per_batch
const int total_samples = batch_size * numel_per_batch;
const int stride = blockDim.x * gridDim.x;
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total_samples; idx += stride) {
int batch_id = idx / numel_per_batch;
int offset = idx % numel_per_batch;
int batch_offset = batch_id * num_features * numel_per_batch;
// Compute sum of squares with manual unrolling
scalar_t sumsq = 0;
const int unroll_factor = 32;
int n_unrolled = (num_features / unroll_factor) * unroll_factor;
#pragma unroll
for (int feat = 0; feat < n_unrolled; feat += unroll_factor) {
scalar_t v0 = input[batch_offset + (feat + 0) * numel_per_batch + offset];
scalar_t v1 = input[batch_offset + (feat + 1) * numel_per_batch + offset];
scalar_t v2 = input[batch_offset + (feat + 2) * numel_per_batch + offset];
scalar_t v3 = input[batch_offset + (feat + 3) * numel_per_batch + offset];
scalar_t v4 = input[batch_offset + (feat + 4) * numel_per_batch + offset];
scalar_t v5 = input[batch_offset + (feat + 5) * numel_per_batch + offset];
scalar_t v6 = input[batch_offset + (feat + 6) * numel_per_batch + offset];
scalar_t v7 = input[batch_offset + (feat + 7) * numel_per_batch + offset];
scalar_t v8 = input[batch_offset + (feat + 8) * numel_per_batch + offset];
scalar_t v9 = input[batch_offset + (feat + 9) * numel_per_batch + offset];
scalar_t v10 = input[batch_offset + (feat + 10) * numel_per_batch + offset];
scalar_t v11 = input[batch_offset + (feat + 11) * numel_per_batch + offset];
scalar_t v12 = input[batch_offset + (feat + 12) * numel_per_batch + offset];
scalar_t v13 = input[batch_offset + (feat + 13) * numel_per_batch + offset];
scalar_t v14 = input[batch_offset + (feat + 14) * numel_per_batch + offset];
scalar_t v15 = input[batch_offset + (feat + 15) * numel_per_batch + offset];
scalar_t v16 = input[batch_offset + (feat + 16) * numel_per_batch + offset];
scalar_t v17 = input[batch_offset + (feat + 17) * numel_per_batch + offset];
scalar_t v18 = input[batch_offset + (feat + 18) * numel_per_batch + offset];
scalar_t v19 = input[batch_offset + (feat + 19) * numel_per_batch + offset];
scalar_t v20 = input[batch_offset + (feat + 20) * numel_per_batch + offset];
scalar_t v21 = input[batch_offset + (feat + 21) * numel_per_batch + offset];
scalar_t v22 = input[batch_offset + (feat + 22) * numel_per_batch + offset];
scalar_t v23 = input[batch_offset + (feat + 23) * numel_per_batch + offset];
scalar_t v24 = input[batch_offset + (feat + 24) * numel_per_batch + offset];
scalar_t v25 = input[batch_offset + (feat + 25) * numel_per_batch + offset];
scalar_t v26 = input[batch_offset + (feat + 26) * numel_per_batch + offset];
scalar_t v27 = input[batch_offset + (feat + 27) * numel_per_batch + offset];
scalar_t v28 = input[batch_offset + (feat + 28) * numel_per_batch + offset];
scalar_t v29 = input[batch_offset + (feat + 29) * numel_per_batch + offset];
scalar_t v30 = input[batch_offset + (feat + 30) * numel_per_batch + offset];
scalar_t v31 = input[batch_offset + (feat + 31) * numel_per_batch + offset];
sumsq += v0*v0 + v1*v1 + v2*v2 + v3*v3 +
v4*v4 + v5*v5 + v6*v6 + v7*v7 +
v8*v8 + v9*v9 + v10*v10 + v11*v11 +
v12*v12 + v13*v13 + v14*v14 + v15*v15 +
v16*v16 + v17*v17 + v18*v18 + v19*v19 +
v20*v20 + v21*v21 + v22*v22 + v23*v23 +
v24*v24 + v25*v25 + v26*v26 + v27*v27 +
v28*v28 + v29*v29 + v30*v30 + v31*v31;
}
// Process remaining features
for (int feat = n_unrolled; feat < num_features; feat++) {
scalar_t val = input[batch_offset + feat * numel_per_batch + offset];
sumsq += val * val;
}
// Compute RMS
scalar_t rms = sqrt(sumsq / num_features + eps);
// Normalize input values with the computed RMS (using similar unrolling)
#pragma unroll
for (int feat = 0; feat < n_unrolled; feat += unroll_factor) {
int j0 = batch_offset + (feat + 0) * numel_per_batch + offset;
int j1 = batch_offset + (feat + 1) * numel_per_batch + offset;
int j2 = batch_offset + (feat + 2) * numel_per_batch + offset;
int j3 = batch_offset + (feat + 3) * numel_per_batch + offset;
int j4 = batch_offset + (feat + 4) * numel_per_batch + offset;
int j5 = batch_offset + (feat + 5) * numel_per_batch + offset;
int j6 = batch_offset + (feat + 6) * numel_per_batch + offset;
int j7 = batch_offset + (feat + 7) * numel_per_batch + offset;
int j8 = batch_offset + (feat + 8) * numel_per_batch + offset;
int j9 = batch_offset + (feat + 9) * numel_per_batch + offset;
int j10 = batch_offset + (feat + 10) * numel_per_batch + offset;
int j11 = batch_offset + (feat + 11) * numel_per_batch + offset;
int j12 = batch_offset + (feat + 12) * numel_per_batch + offset;
int j13 = batch_offset + (feat + 13) * numel_per_batch + offset;
int j14 = batch_offset + (feat + 14) * numel_per_batch + offset;
int j15 = batch_offset + (feat + 15) * numel_per_batch + offset;
int j16 = batch_offset + (feat + 16) * numel_per_batch + offset;
int j17 = batch_offset + (feat + 17) * numel_per_batch + offset;
int j18 = batch_offset + (feat + 18) * numel_per_batch + offset;
int j19 = batch_offset + (feat + 19) * numel_per_batch + offset;
int j20 = batch_offset + (feat + 20) * numel_per_batch + offset;
int j21 = batch_offset + (feat + 21) * numel_per_batch + offset;
int j22 = batch_offset + (feat + 22) * numel_per_batch + offset;
int j23 = batch_offset + (feat + 23) * numel_per_batch + offset;
int j24 = batch_offset + (feat + 24) * numel_per_batch + offset;
int j25 = batch_offset + (feat + 25) * numel_per_batch + offset;
int j26 = batch_offset + (feat + 26) * numel_per_batch + offset;
int j27 = batch_offset + (feat + 27) * numel_per_batch + offset;
int j28 = batch_offset + (feat + 28) * numel_per_batch + offset;
int j29 = batch_offset + (feat + 29) * numel_per_batch + offset;
int j30 = batch_offset + (feat + 30) * numel_per_batch + offset;
int j31 = batch_offset + (feat + 31) * numel_per_batch + offset;
output[j0] = input[j0] / rms;
output[j1] = input[j1] / rms;
output[j2] = input[j2] / rms;
output[j3] = input[j3] / rms;
output[j4] = input[j4] / rms;
output[j5] = input[j5] / rms;
output[j6] = input[j6] / rms;
output[j7] = input[j7] / rms;
output[j8] = input[j8] / rms;
output[j9] = input[j9] / rms;
output[j10] = input[j10] / rms;
output[j11] = input[j11] / rms;
output[j12] = input[j12] / rms;
output[j13] = input[j13] / rms;
output[j14] = input[j14] / rms;
output[j15] = input[j15] / rms;
output[j16] = input[j16] / rms;
output[j17] = input[j17] / rms;
output[j18] = input[j18] / rms;
output[j19] = input[j19] / rms;
output[j20] = input[j20] / rms;
output[j21] = input[j21] / rms;
output[j22] = input[j22] / rms;
output[j23] = input[j23] / rms;
output[j24] = input[j24] / rms;
output[j25] = input[j25] / rms;
output[j26] = input[j26] / rms;
output[j27] = input[j27] / rms;
output[j28] = input[j28] / rms;
output[j29] = input[j29] / rms;
output[j30] = input[j30] / rms;
output[j31] = input[j31] / rms;
}
// Process any trailing features
for (int feat = n_unrolled; feat < num_features; feat++) {
int j = batch_offset + feat * numel_per_batch + offset;
output[j] = input[j] / rms;
}
}
}
// Host function that launches the kernel
torch::Tensor rms_norm_cuda_forward(torch::Tensor input, float eps) {
auto output = torch::empty_like(input);
const int batch_size = input.size(0);
const int num_features = input.size(1);
int numel_per_batch = 1;
for (int i = 2; i < input.dim(); i++) {
numel_per_batch *= input.size(i);
}
const int total_samples = batch_size * numel_per_batch;
const int threads_per_block = 256;
const int max_blocks = 65535;
int blocks = (total_samples + threads_per_block - 1) / threads_per_block;
if (blocks > max_blocks) blocks = max_blocks;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "rms_norm_cuda", ([&] {
rms_norm_kernel_combined<scalar_t><<<blocks, threads_per_block>>>(
input.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(),
batch_size,
num_features,
numel_per_batch,
eps
);
}));
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &rms_norm_cuda_forward, "Combined RMS normalization forward (CUDA)");
}
Metric | Value | Unit | Variance | Samples |
---|---|---|---|---|
Executed Ipc Active | 0.400 | inst/cycle | 0.000 | 5 |
Executed Ipc Elapsed | 0.390 | inst/cycle | 0.000 | 5 |
Issue Slots Busy | 9.944 | % | 0.001 | 5 |
Issued Ipc Active | 0.400 | inst/cycle | 0.000 | 5 |
SM Busy | 14.120 | % | 0.002 | 5 |
Memory Throughput | 2954192925999.685 | byte/second | 15330745663015645184.000 | 5 |
Mem Busy | 48.538 | % | 0.005 | 5 |
Max Bandwidth | 88.140 | % | 0.013 | 5 |
L1/TEX Hit Rate | 1.048 | % | 0.001 | 5 |
L2 Hit Rate | 34.392 | % | 0.002 | 5 |
Mem Pipes Busy | 11.498 | % | 0.000 | 5 |
Warp Cycles Per Issued Instruction | 112.574 | cycle | 0.087 | 5 |
Warp Cycles Per Executed Instruction | 112.740 | cycle | 0.088 | 5 |
Avg. Active Threads Per Warp | 32.000 | 0.000 | 5 | |
Avg. Not Predicated Off Threads Per Warp | 31.460 | 0.000 | 5 | |
Max Active Clusters | 0.000 | cluster | 0.000 | 5 |
Max Cluster Size | 8.000 | block | 0.000 | 5 |
Overall GPU Occupancy | 0.000 | % | 0.000 | 5 |
Cluster Occupancy | 0.000 | % | 0.000 | 5 |
Block Limit SM | 32.000 | block | 0.000 | 5 |
Block Limit Registers | 6.000 | block | 0.000 | 5 |
Block Limit Shared Mem | 32.000 | block | 0.000 | 5 |
Block Limit Warps | 8.000 | block | 0.000 | 5 |
Theoretical Active Warps per SM | 48.000 | warp | 0.000 | 5 |
Theoretical Occupancy | 75.000 | % | 0.000 | 5 |
Achieved Occupancy | 69.904 | % | 0.004 | 5 |
Achieved Active Warps Per SM | 44.740 | warp | 0.001 | 5 |
Rule | Description |
---|---|
WRN HighPipeUtilization | All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details. |
INF CPIStall | Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason. |
WRN Occupancy | This kernel's theoretical occupancy (75.0%) is limited by the number of required registers. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy. |
Operation / Metric | Value | Unit |
---|---|---|
aten::randn | ||
CPU Time | 300018.16 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 79.36 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::normal_ | ||
CPU Time | 299907.32 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 299907.32 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaLaunchKernel | ||
CPU Time | 1906525.85 | μs |
Device Time | 16746.30 | μs |
Self CPU Time | 1906525.85 | μs |
Self Device Time | 16746.30 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void rms_norm_kernel_combined<float>(float const*, float*, int, int, int, float) | ||
CPU Time | 0.00 | μs |
Device Time | 1649266.52 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 1649266.52 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaEventRecord | ||
CPU Time | 17497.82 | μs |
Device Time | 32227.63 | μs |
Self CPU Time | 17497.82 | μs |
Self Device Time | 32227.63 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::zero_ | ||
CPU Time | 1543082.83 | μs |
Device Time | 480157.99 | μs |
Self CPU Time | 12086.71 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::fill_ | ||
CPU Time | 1530997.87 | μs |
Device Time | 480157.99 | μs |
Self CPU Time | 14367.17 | μs |
Self Device Time | 480157.99 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>) | ||
CPU Time | 0.00 | μs |
Device Time | 480157.99 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 480157.99 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45286 warnings generated when compiling for host. Suppressed 45322 warnings (45275 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.