import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
"""
Performs matrix multiplication, adds bias, and applies ReLU activation.
Args:
x (torch.Tensor): Input tensor with shape (batch_size, in_features)
weight (torch.Tensor): Weight matrix with shape (out_features, in_features)
bias (torch.Tensor): Bias tensor with shape (out_features,)
Returns:
torch.Tensor: Output tensor with shape (batch_size, out_features)
"""
x = F.linear(x, weight)
x = x + bias
x = F.relu(x)
return x
class Model(nn.Module):
"""
Simple model that performs a matrix multiplication, adds a bias term, and applies ReLU.
"""
def __init__(self, in_features, out_features, bias_shape):
super(Model, self).__init__()
gemm = nn.Linear(in_features, out_features, bias=False)
self.weight = nn.Parameter(gemm.weight)
self.bias = nn.Parameter(torch.randn(bias_shape) * 0.02)
def forward(self, x, fn=module_fn):
return fn(x, self.weight, self.bias)
batch_size = 128
in_features = 1024
out_features = 512
bias_shape = (out_features,)
def get_inputs():
return [torch.randn(batch_size, in_features)]
def get_init_inputs():
return [in_features, out_features, bias_shape]
import torch
import torch.nn as nn
class Model(nn.Module):
"""
Simple model that performs a matrix multiplication, adds a bias term, and applies ReLU.
"""
def __init__(self, in_features, out_features, bias_shape):
super(Model, self).__init__()
self.gemm = nn.Linear(in_features, out_features, bias=False)
self.bias = nn.Parameter(torch.randn(bias_shape)*0.02)
def forward(self, x):
"""
Args:
x (torch.Tensor): Input tensor with shape (batch_size, in_features).
Returns:
torch.Tensor: Output tensor with shape (batch_size, out_features).
"""
x = self.gemm(x)
x = x + self.bias
x = torch.relu(x)
return x
batch_size = 128
in_features = 1024
out_features = 512
bias_shape = (out_features,)
def get_inputs():
return [torch.randn(batch_size, in_features)]
def get_init_inputs():
return [in_features, out_features, bias_shape]
Operation Name | 76_Gemm_Add_ReLU |
Level ID | 2 |
Task ID | 76 |
Kernel Name | unrolled_warp_gemm_edit_1 |
CUDA Speedup (Native) | 0.671x |
CUDA Speedup (Compile) | 1.118x |
CUDA Runtime | 0.040 ms |
PyTorch Runtime (Native) | 0.027 ms |
PyTorch Runtime (Compile) | 0.045 ms |
Correct | True |
Max Diff (vs. Reference) | 0.000000 |
Model | bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0 |
Temperature | 0.00 |
Rank | Kernel Name | Runtime (ms) | Speedup Native | Speedup Compile |
---|---|---|---|---|
🥇 | shared_warp_tile_kernel_base | 0.03 | 0.93 | 1.54 |
🥇 | combined_warp_tile_base | 0.03 | 0.93 | 1.54 |
🥉 | optimized_block_size_kernel_base | 0.03 | 0.89 | 1.49 |
4 | warp_tile_ldg_base | 0.03 | 0.87 | 1.44 |
4 | even_workload_dist_base_base | 0.03 | 0.87 | 1.44 |
4 | hybrid_warp_tile_kernel_base | 0.03 | 0.87 | 1.44 |
4 | warp_tile_hybrid_base | 0.03 | 0.87 | 1.44 |
8 | warp_tile_ldg_opt_base | 0.03 | 0.81 | 1.36 |
8 | warp_reduction_optimized_base_base | 0.03 | 0.81 | 1.36 |
10 | optimized_shared_memory_base_base | 0.03 | 0.79 | 1.32 |
10 | warp_tile_base_base | 0.03 | 0.79 | 1.32 |
12 | hybrid_optimized_kernel_base | 0.04 | 0.77 | 1.28 |
13 | warp_reduction_gemm_base | 0.04 | 0.71 | 1.18 |
13 | warp_tile_aligned_base_base | 0.04 | 0.71 | 1.18 |
15 | vectorized_warp_unroll_base_base | 0.04 | 0.69 | 1.15 |
15 | vectorized_warp_unroll_base_edit_1 | 0.04 | 0.69 | 1.15 |
15 | warp_reduction_unrolled_gemm_edit_1 | 0.04 | 0.69 | 1.15 |
18 | unrolled_warp_gemm_edit_1 | 0.04 | 0.67 | 1.12 |
18 | unrolled_warp_gemm_base | 0.04 | 0.67 | 1.12 |
18 | vectorized_warp_reduction_base | 0.04 | 0.67 | 1.12 |
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <c10/cuda/CUDAGuard.h>
// Each warp computes one output element using warp-level reduction
// This design avoids atomic operations since each warp exclusively works on one output element.
__global__ void linear_relu_unrolled_warp_kernel(const float* x, const float* weight, const float* bias, float* out,
int batch_size, int in_features, int out_features) {
// Compute a global warp id from the thread index
int warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
int lane = threadIdx.x % 32; // Lane index within the warp
int total_outputs = batch_size * out_features;
if (warp_id >= total_outputs) return;
// Map warp id to output matrix coordinates (i, j)
int i = warp_id / out_features;
int j = warp_id % out_features;
float sum = 0.0f;
// Cache the base indices to avoid repeated calculations
const int x_base = i * in_features;
const int w_base = j * in_features;
// Each thread in the warp processes a strided portion of the in_features dimension
// Unroll by 4 to reduce loop overhead and enable better instruction-level parallelism
#pragma unroll 4
for (int k = lane; k < in_features; k += 32) {
float x_val = x[x_base + k];
float w_val = weight[w_base + k];
sum += x_val * w_val;
}
// Perform warp-level reduction using shuffle operations
for (int offset = 16; offset > 0; offset /= 2) {
sum += __shfl_down_sync(0xffffffff, sum, offset);
}
// The first lane of each warp writes the result
if (lane == 0) {
sum += bias[j];
// Apply ReLU activation
out[i * out_features + j] = sum > 0.0f ? sum : 0.0f;
}
}
torch::Tensor linear_relu_forward(torch::Tensor x, torch::Tensor weight, torch::Tensor bias) {
TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor");
TORCH_CHECK(weight.is_cuda(), "weight must be a CUDA tensor");
TORCH_CHECK(bias.is_cuda(), "bias must be a CUDA tensor");
const int batch_size = x.size(0);
const int in_features = x.size(1);
const int out_features = weight.size(0);
// Allocate output tensor
auto out = torch::empty({batch_size, out_features}, x.options());
// Each warp computes one output element. Total number of warps required is batch_size * out_features.
int total_warps = batch_size * out_features;
// Each warp consists of 32 threads. Determine total threads required.
int total_threads = total_warps * 32;
// Choose block size as a multiple of 32, e.g., 256 threads per block
int threads_per_block = 256;
int blocks = (total_threads + threads_per_block - 1) / threads_per_block;
cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
linear_relu_unrolled_warp_kernel<<<blocks, threads_per_block, 0, stream>>>(
x.data_ptr<float>(),
weight.data_ptr<float>(),
bias.data_ptr<float>(),
out.data_ptr<float>(),
batch_size,
in_features,
out_features
);
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &linear_relu_forward, "Unrolled warp-level reduction GEMM+Bias+ReLU (CUDA)");
}
Metric | Value | Unit | Variance | Samples |
---|---|---|---|---|
Executed Ipc Active | 2.320 | inst/cycle | 0.000 | 5 |
Executed Ipc Elapsed | 2.198 | inst/cycle | 0.000 | 5 |
Issue Slots Busy | 58.162 | % | 0.058 | 5 |
Issued Ipc Active | 2.328 | inst/cycle | 0.000 | 5 |
SM Busy | 58.162 | % | 0.058 | 5 |
Memory Throughput | 67329691765.920 | byte/second | 41758051448048168.000 | 5 |
Mem Busy | 59.174 | % | 0.051 | 5 |
Max Bandwidth | 71.198 | % | 0.018 | 5 |
L1/TEX Hit Rate | 48.936 | % | 0.014 | 5 |
L2 Hit Rate | 96.420 | % | 0.552 | 5 |
Mem Pipes Busy | 58.566 | % | 0.050 | 5 |
Warp Cycles Per Issued Instruction | 25.282 | cycle | 0.016 | 5 |
Warp Cycles Per Executed Instruction | 25.312 | cycle | 0.016 | 5 |
Avg. Active Threads Per Warp | 30.740 | 0.000 | 5 | |
Avg. Not Predicated Off Threads Per Warp | 29.670 | 0.000 | 5 | |
Max Active Clusters | 0.000 | cluster | 0.000 | 5 |
Max Cluster Size | 8.000 | block | 0.000 | 5 |
Overall GPU Occupancy | 0.000 | % | 0.000 | 5 |
Cluster Occupancy | 0.000 | % | 0.000 | 5 |
Block Limit SM | 32.000 | block | 0.000 | 5 |
Block Limit Registers | 10.000 | block | 0.000 | 5 |
Block Limit Shared Mem | 32.000 | block | 0.000 | 5 |
Block Limit Warps | 8.000 | block | 0.000 | 5 |
Theoretical Active Warps per SM | 64.000 | warp | 0.000 | 5 |
Theoretical Occupancy | 100.000 | % | 0.000 | 5 |
Achieved Occupancy | 92.328 | % | 0.035 | 5 |
Achieved Active Warps Per SM | 59.090 | warp | 0.014 | 5 |
Rule | Description |
---|---|
INF HighPipeUtilization | ALU is the highest-utilized pipeline (36.4%) based on active cycles, taking into account the rates of its different instructions. It executes integer and logic operations. It is well-utilized, but should not be a bottleneck. |
INF CPIStall | Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason. |
INF Occupancy | This kernel's theoretical occupancy is not impacted by any block limit. |
Operation / Metric | Value | Unit |
---|---|---|
aten::to | ||
CPU Time | 198075.69 | μs |
Device Time | 164.96 | μs |
Self CPU Time | 62.21 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_to_copy | ||
CPU Time | 198013.48 | μs |
Device Time | 164.96 | μs |
Self CPU Time | 113.05 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaLaunchKernel | ||
CPU Time | 600132.46 | μs |
Device Time | 14244.51 | μs |
Self CPU Time | 600132.46 | μs |
Self Device Time | 14244.51 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
linear_relu_unrolled_warp_kernel(float const*, float const*, float const*, float*, int, int, int) | ||
CPU Time | 0.00 | μs |
Device Time | 252210.23 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 252210.23 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaEventRecord | ||
CPU Time | 19205.20 | μs |
Device Time | 27535.78 | μs |
Self CPU Time | 19205.20 | μs |
Self Device Time | 27535.78 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::zero_ | ||
CPU Time | 214172.35 | μs |
Device Time | 532307.20 | μs |
Self CPU Time | 15245.79 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::fill_ | ||
CPU Time | 198928.06 | μs |
Device Time | 532307.20 | μs |
Self CPU Time | 15501.84 | μs |
Self Device Time | 532307.20 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>) | ||
CPU Time | 0.00 | μs |
Device Time | 532307.20 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 532307.20 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45310 warnings generated when compiling for host. Suppressed 45347 warnings (45300 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.