import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
"""
Performs matrix multiplication, adds bias, and applies ReLU activation.
Args:
x (torch.Tensor): Input tensor with shape (batch_size, in_features)
weight (torch.Tensor): Weight matrix with shape (out_features, in_features)
bias (torch.Tensor): Bias tensor with shape (out_features,)
Returns:
torch.Tensor: Output tensor with shape (batch_size, out_features)
"""
x = F.linear(x, weight)
x = x + bias
x = F.relu(x)
return x
class Model(nn.Module):
"""
Simple model that performs a matrix multiplication, adds a bias term, and applies ReLU.
"""
def __init__(self, in_features, out_features, bias_shape):
super(Model, self).__init__()
gemm = nn.Linear(in_features, out_features, bias=False)
self.weight = nn.Parameter(gemm.weight)
self.bias = nn.Parameter(torch.randn(bias_shape) * 0.02)
def forward(self, x, fn=module_fn):
return fn(x, self.weight, self.bias)
batch_size = 128
in_features = 1024
out_features = 512
bias_shape = (out_features,)
def get_inputs():
return [torch.randn(batch_size, in_features)]
def get_init_inputs():
return [in_features, out_features, bias_shape]
import torch
import torch.nn as nn
class Model(nn.Module):
"""
Simple model that performs a matrix multiplication, adds a bias term, and applies ReLU.
"""
def __init__(self, in_features, out_features, bias_shape):
super(Model, self).__init__()
self.gemm = nn.Linear(in_features, out_features, bias=False)
self.bias = nn.Parameter(torch.randn(bias_shape)*0.02)
def forward(self, x):
"""
Args:
x (torch.Tensor): Input tensor with shape (batch_size, in_features).
Returns:
torch.Tensor: Output tensor with shape (batch_size, out_features).
"""
x = self.gemm(x)
x = x + self.bias
x = torch.relu(x)
return x
batch_size = 128
in_features = 1024
out_features = 512
bias_shape = (out_features,)
def get_inputs():
return [torch.randn(batch_size, in_features)]
def get_init_inputs():
return [in_features, out_features, bias_shape]
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <c10/cuda/CUDAGuard.h>
#define WARP_SIZE 32
#define TILE_SIZE 4 // Each warp processes TILE_SIZE output features concurrently
// This kernel leverages shared memory to load the input vector for a given batch sample,
// reducing global memory accesses and latency. Threads in each block cooperatively load
// the current input row into shared memory. Then, using warp-level tiling and vectorized
// memory accesses (via float4), each warp computes dot products with multiple weight rows.
// Warp shuffles are used to perform intra-warp reductions. Bias is added and ReLU is applied
// before writing the output.
__global__ void shared_warp_tile_kernel(const float* __restrict__ x,
const float* __restrict__ weight,
const float* __restrict__ bias,
float* __restrict__ out,
int in_features,
int out_features) {
// Each block processes one batch sample (blockIdx.x) and a group of output features (blockIdx.y)
int batch_idx = blockIdx.x;
// Allocate shared memory for the input row
extern __shared__ float s_x[];
int tid = threadIdx.x;
int block_size = blockDim.x;
// Cooperatively load the entire input row into shared memory
for (int i = tid; i < in_features; i += block_size) {
s_x[i] = x[batch_idx * in_features + i];
}
__syncthreads();
// Determine warp and lane indices
int warps_per_block = block_size / WARP_SIZE;
int warp_id = tid / WARP_SIZE;
int lane_id = tid % WARP_SIZE;
// Each warp computes a tile of TILE_SIZE consecutive output features
int base_out_group = blockIdx.y * (warps_per_block * TILE_SIZE);
int out_base = base_out_group + warp_id * TILE_SIZE;
if (out_base >= out_features) return;
// Accumulators for the dot product results for each output in the tile
float sums[TILE_SIZE] = {0.0f, 0.0f, 0.0f, 0.0f};
// Set up vectorized loads parameters
int nvec = in_features / 4; // number of groups of 4 floats
int rem = in_features % 4; // any leftovers
// Reinterpret the shared memory as float4 pointer for vectorized loads
float4* s_x_vec = reinterpret_cast<float4*>(s_x);
// Loop over each output inside the tile
for (int t = 0; t < TILE_SIZE; t++) {
int current_out = out_base + t;
if (current_out < out_features) {
// Pointer to the corresponding weight row
const float* w_row = weight + current_out * in_features;
const float4* w_row_vec = reinterpret_cast<const float4*>(w_row);
float sum = 0.0f;
// Vectorized dot product: each thread processes a strided section
for (int k = lane_id; k < nvec; k += WARP_SIZE) {
float4 x_val = s_x_vec[k]; // read from shared memory (faster than global)
float4 w_val = __ldg(w_row_vec + k);
sum += x_val.x * w_val.x + x_val.y * w_val.y + x_val.z * w_val.z + x_val.w * w_val.w;
}
// Process remaining elements with scalar loads
int offset = nvec * 4;
for (int k = lane_id; k < rem; k += WARP_SIZE) {
sum += s_x[offset + k] * __ldg(w_row + offset + k);
}
sums[t] = sum;
}
}
// Perform warp-level reduction using shuffle instructions
for (int t = 0; t < TILE_SIZE; t++) {
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
sums[t] += __shfl_down_sync(0xffffffff, sums[t], offset);
}
}
// The first lane in the warp writes the final output with bias and ReLU activation
if (lane_id == 0) {
for (int t = 0; t < TILE_SIZE; t++) {
int current_out = out_base + t;
if (current_out < out_features) {
float result = sums[t] + __ldg(bias + current_out);
out[batch_idx * out_features + current_out] = (result > 0.0f) ? result : 0.0f;
}
}
}
}
// Host launcher function
// Grid dimensions: grid.x spans the batch dimension, grid.y spans groups of output features
// Block dimensions: each block contains multiple warps, with each warp computing TILE_SIZE outputs
// The dynamic shared memory size is set to (in_features * sizeof(float))
torch::Tensor shared_warp_tile_forward(torch::Tensor x,
torch::Tensor weight,
torch::Tensor bias) {
TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor");
TORCH_CHECK(weight.is_cuda(), "weight must be a CUDA tensor");
TORCH_CHECK(bias.is_cuda(), "bias must be a CUDA tensor");
int batch_size = x.size(0);
int in_features = x.size(1);
int out_features = weight.size(0);
auto out = torch::empty({batch_size, out_features}, x.options());
// Tunable parameter: number of warps per block
int warps_per_block = 8;
int threads_per_block = warps_per_block * WARP_SIZE;
// Compute grid dimension over the output features
int blocks_y = (out_features + (warps_per_block * TILE_SIZE) - 1) / (warps_per_block * TILE_SIZE);
dim3 grid(batch_size, blocks_y);
dim3 block(threads_per_block);
// Allocate shared memory to hold one entire input row
size_t shared_mem_size = in_features * sizeof(float);
cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
shared_warp_tile_kernel<<<grid, block, shared_mem_size, stream>>>(
x.data_ptr<float>(),
weight.data_ptr<float>(),
bias.data_ptr<float>(),
out.data_ptr<float>(),
in_features,
out_features
);
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &shared_warp_tile_forward, "GEMM with bias and ReLU using shared memory for input vector (CUDA)");
}
Metric | Value | Unit | Variance | Samples |
---|---|---|---|---|
Executed Ipc Active | 1.658 | inst/cycle | 0.000 | 5 |
Executed Ipc Elapsed | 1.480 | inst/cycle | 0.000 | 5 |
Issue Slots Busy | 41.464 | % | 0.006 | 5 |
Issued Ipc Active | 1.658 | inst/cycle | 0.000 | 5 |
SM Busy | 41.464 | % | 0.006 | 5 |
Memory Throughput | 85662353124.662 | byte/second | 225757483080286656.000 | 5 |
Mem Busy | 75.880 | % | 0.193 | 5 |
Max Bandwidth | 73.302 | % | 0.187 | 5 |
L1/TEX Hit Rate | 14.500 | % | 0.100 | 5 |
L2 Hit Rate | 93.244 | % | 2.711 | 5 |
Mem Pipes Busy | 25.978 | % | 0.024 | 5 |
Warp Cycles Per Issued Instruction | 25.240 | cycle | 0.005 | 5 |
Warp Cycles Per Executed Instruction | 25.262 | cycle | 0.005 | 5 |
Avg. Active Threads Per Warp | 30.490 | 0.000 | 5 | |
Avg. Not Predicated Off Threads Per Warp | 29.210 | 0.000 | 5 | |
Max Active Clusters | 0.000 | cluster | 0.000 | 5 |
Max Cluster Size | 8.000 | block | 0.000 | 5 |
Overall GPU Occupancy | 0.000 | % | 0.000 | 5 |
Cluster Occupancy | 0.000 | % | 0.000 | 5 |
Block Limit SM | 32.000 | block | 0.000 | 5 |
Block Limit Registers | 6.000 | block | 0.000 | 5 |
Block Limit Shared Mem | 12.000 | block | 0.000 | 5 |
Block Limit Warps | 8.000 | block | 0.000 | 5 |
Theoretical Active Warps per SM | 48.000 | warp | 0.000 | 5 |
Theoretical Occupancy | 75.000 | % | 0.000 | 5 |
Achieved Occupancy | 65.298 | % | 0.021 | 5 |
Achieved Active Warps Per SM | 41.790 | warp | 0.009 | 5 |
Rule | Description |
---|---|
INF HighPipeUtilization | ALU is the highest-utilized pipeline (22.0%) based on active cycles, taking into account the rates of its different instructions. It executes integer and logic operations. It is well-utilized, but should not be a bottleneck. |
INF CPIStall | Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason. |
WRN Occupancy | This kernel's theoretical occupancy (75.0%) is limited by the number of required registers. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy. |
Operation / Metric | Value | Unit |
---|---|---|
aten::to | ||
CPU Time | 890303.11 | μs |
Device Time | 205.85 | μs |
Self CPU Time | 65.01 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_to_copy | ||
CPU Time | 890238.10 | μs |
Device Time | 205.85 | μs |
Self CPU Time | 126.58 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::empty_strided | ||
CPU Time | 889560.34 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 149.20 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaDeviceGetStreamPriorityRange | ||
CPU Time | 883115.41 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 883115.41 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaLaunchKernel | ||
CPU Time | 258503.74 | μs |
Device Time | 391.17 | μs |
Self CPU Time | 258503.74 | μs |
Self Device Time | 391.17 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
shared_warp_tile_kernel(float const*, float const*, float const*, float*, int, int) | ||
CPU Time | 0.00 | μs |
Device Time | 92287.19 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 92287.19 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::zero_ | ||
CPU Time | 78907.38 | μs |
Device Time | 279221.34 | μs |
Self CPU Time | 8029.71 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::fill_ | ||
CPU Time | 70878.79 | μs |
Device Time | 279221.34 | μs |
Self CPU Time | 7930.67 | μs |
Self Device Time | 279221.34 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>) | ||
CPU Time | 0.00 | μs |
Device Time | 279221.34 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 279221.34 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45313 warnings generated when compiling for host. Suppressed 45347 warnings (45300 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.