12_Gemm_Multiply_LeakyReLU
• 12_gemm_tiled_coalesced_edit_1
import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(
x: torch.Tensor,
multiplier: float,
negative_slope: float,
weight: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
"""
Applies linear transformation, multiplies by scalar, and applies LeakyReLU.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features)
multiplier (float): Scalar multiplier
negative_slope (float): Negative slope for LeakyReLU
weight (torch.Tensor): Weight matrix of shape (out_features, in_features)
bias (torch.Tensor): Bias vector of shape (out_features)
Returns:
torch.Tensor: Output tensor of shape (batch_size, out_features)
"""
x = F.linear(x, weight, bias)
x = x * multiplier
x = F.leaky_relu(x, negative_slope=negative_slope)
return x
class Model(nn.Module):
"""
Simple model that performs a Gemm, multiplies the result, and applies LeakyReLU.
"""
def __init__(self, in_features, out_features, multiplier, negative_slope):
super(Model, self).__init__()
gemm = nn.Linear(in_features, out_features)
self.weight = gemm.weight
self.bias = gemm.bias
def forward(self, x, fn=module_fn):
return fn(x, multiplier, negative_slope, self.weight, self.bias)
batch_size = 128
in_features = 1024
out_features = 512
multiplier = 2.0
negative_slope = 0.1
def get_inputs():
return [torch.randn(batch_size, in_features)]
def get_init_inputs():
return [in_features, out_features, multiplier, negative_slope]
import torch
import torch.nn as nn
class Model(nn.Module):
"""
Simple model that performs a Gemm, multiplies the result, and applies LeakyReLU.
"""
def __init__(self, in_features, out_features, multiplier, negative_slope):
super(Model, self).__init__()
self.gemm = nn.Linear(in_features, out_features)
self.multiplier = multiplier
self.leaky_relu = nn.LeakyReLU(negative_slope)
def forward(self, x):
x = self.gemm(x)
x = x * self.multiplier
x = self.leaky_relu(x)
return x
batch_size = 128
in_features = 1024
out_features = 512
multiplier = 2.0
negative_slope = 0.1
def get_inputs():
return [torch.randn(batch_size, in_features)]
def get_init_inputs():
return [in_features, out_features, multiplier, negative_slope]
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#define TILE 16
__global__ void module_fn_kernel(
const float* __restrict__ x, // [batch_size, in_features]
const float* __restrict__ weight, // [out_features, in_features]
const float* __restrict__ bias, // [out_features]
float* __restrict__ output, // [batch_size, out_features]
const int batch_size,
const int in_features,
const int out_features,
const float multiplier,
const float negative_slope
) {
// Compute global row and col indices for output
int row = blockIdx.y * TILE + threadIdx.y; // batch index
int col = blockIdx.x * TILE + threadIdx.x; // out_feature index
float acc = 0.0f;
// Loop over tiles of in_features dimension
int numTiles = (in_features + TILE - 1) / TILE;
// Declare shared memory tiles
__shared__ float s_x[TILE][TILE];
// For weight, we load transposed to ensure that later multiplication accesses are coalesced
__shared__ float s_weight[TILE][TILE];
for (int t = 0; t < numTiles; t++) {
// Global index for x tile load
int x_col = t * TILE + threadIdx.x;
if (row < batch_size && x_col < in_features) {
// x is stored row-major
s_x[threadIdx.y][threadIdx.x] = x[row * in_features + x_col];
} else {
s_x[threadIdx.y][threadIdx.x] = 0.0f;
}
// Global index for weight tile load (we load weight transposed)
int w_index = t * TILE + threadIdx.y; // this is the k index
if (col < out_features && w_index < in_features) {
// weight is stored as [out_features, in_features] in row-major order,
// so weight[col][w_index] is at index (col * in_features + w_index).
// We load it transposed into shared memory: s_weight[threadIdx.x][threadIdx.y]
s_weight[threadIdx.x][threadIdx.y] = weight[col * in_features + w_index];
} else {
s_weight[threadIdx.x][threadIdx.y] = 0.0f;
}
__syncthreads();
// Multiply the two tiles together
for (int k = 0; k < TILE; k++) {
acc += s_x[threadIdx.y][k] * s_weight[threadIdx.x][k];
}
__syncthreads();
}
// Write the result back to global memory, applying bias, multiplier and LeakyReLU
if (row < batch_size && col < out_features) {
acc = (acc + bias[col]) * multiplier;
output[row * out_features + col] = (acc > 0.0f) ? acc : acc * negative_slope;
}
}
torch::Tensor module_fn_forward(
torch::Tensor x,
float multiplier,
float negative_slope,
torch::Tensor weight,
torch::Tensor bias
) {
TORCH_CHECK(x.device().is_cuda(), "x must be a CUDA tensor");
TORCH_CHECK(weight.device().is_cuda(), "weight must be a CUDA tensor");
TORCH_CHECK(bias.device().is_cuda(), "bias must be a CUDA tensor");
const int batch_size = x.size(0);
const int in_features = x.size(1);
const int out_features = weight.size(0);
TORCH_CHECK(weight.size(1) == in_features, "Weight in_features must match x in_features");
TORCH_CHECK(bias.size(0) == out_features, "Bias size must match weight out_features");
auto output = torch::zeros({batch_size, out_features}, x.options());
dim3 block(TILE, TILE);
dim3 grid(
(out_features + TILE - 1) / TILE,
(batch_size + TILE - 1) / TILE
);
module_fn_kernel<<<grid, block>>>(
x.data_ptr<float>(),
weight.data_ptr<float>(),
bias.data_ptr<float>(),
output.data_ptr<float>(),
batch_size,
in_features,
out_features,
multiplier,
negative_slope
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &module_fn_forward, "Module function forward CUDA with tiled memory coalescing");
}
Metric | Value | Unit | Variance | Samples |
---|---|---|---|---|
Executed Ipc Active | 0.416 | inst/cycle | 0.000 | 5 |
Executed Ipc Elapsed | 0.384 | inst/cycle | 0.000 | 5 |
Issue Slots Busy | 10.414 | % | 0.001 | 5 |
Issued Ipc Active | 0.420 | inst/cycle | 0.000 | 5 |
SM Busy | 10.414 | % | 0.001 | 5 |
Memory Throughput | 33034861881.752 | byte/second | 37804124352725408.000 | 5 |
Mem Busy | 78.312 | % | 0.162 | 5 |
Max Bandwidth | 22.946 | % | 0.015 | 5 |
L1/TEX Hit Rate | 61.020 | % | 0.000 | 5 |
L2 Hit Rate | 86.078 | % | 0.209 | 5 |
Mem Pipes Busy | 9.474 | % | 0.003 | 5 |
Warp Cycles Per Issued Instruction | 37.156 | cycle | 0.012 | 5 |
Warp Cycles Per Executed Instruction | 37.226 | cycle | 0.012 | 5 |
Avg. Active Threads Per Warp | 32.000 | 0.000 | 5 | |
Avg. Not Predicated Off Threads Per Warp | 31.960 | 0.000 | 5 | |
Max Active Clusters | 0.000 | cluster | 0.000 | 5 |
Max Cluster Size | 8.000 | block | 0.000 | 5 |
Overall GPU Occupancy | 0.000 | % | 0.000 | 5 |
Cluster Occupancy | 0.000 | % | 0.000 | 5 |
Block Limit SM | 32.000 | block | 0.000 | 5 |
Block Limit Registers | 8.000 | block | 0.000 | 5 |
Block Limit Shared Mem | 21.000 | block | 0.000 | 5 |
Block Limit Warps | 8.000 | block | 0.000 | 5 |
Theoretical Active Warps per SM | 64.000 | warp | 0.000 | 5 |
Theoretical Occupancy | 100.000 | % | 0.000 | 5 |
Achieved Occupancy | 24.178 | % | 0.001 | 5 |
Achieved Active Warps Per SM | 15.476 | warp | 0.000 | 5 |
Rule | Description |
---|---|
WRN HighPipeUtilization | All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details. |
INF CPIStall | Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason. |
WRN Occupancy | This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (24.2%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy. |
Operation / Metric | Value | Unit |
---|---|---|
aten::zeros | ||
CPU Time | 4116376.11 | μs |
Device Time | 86389.99 | μs |
Self CPU Time | 103416.91 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::zero_ | ||
CPU Time | 7102702.33 | μs |
Device Time | 4824119.79 | μs |
Self CPU Time | 218737.39 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::fill_ | ||
CPU Time | 6883970.23 | μs |
Device Time | 4824119.79 | μs |
Self CPU Time | 263306.82 | μs |
Self Device Time | 4824119.79 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaLaunchKernel | ||
CPU Time | 6853442.79 | μs |
Device Time | 1678.97 | μs |
Self CPU Time | 6853442.79 | μs |
Self Device Time | 1678.97 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
module_fn_kernel(float const*, float const*, float const*, float*, int, int, int, float, float) | ||
CPU Time | 0.00 | μs |
Device Time | 3955583.84 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 3955583.84 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaEventRecord | ||
CPU Time | 175475.84 | μs |
Device Time | 208522.30 | μs |
Self CPU Time | 175475.84 | μs |
Self Device Time | 208522.30 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>) | ||
CPU Time | 0.00 | μs |
Device Time | 4738120.81 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 4738120.81 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaEventElapsedTime | ||
CPU Time | 210771.11 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 210771.11 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45290 warnings generated when compiling for host. Suppressed 45325 warnings (45278 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.