import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(
x: torch.Tensor,
linear1_weight: torch.Tensor,
linear1_bias: torch.Tensor,
) -> torch.Tensor:
"""
Performs matrix multiplication, applies Sigmoid, sums result, and calculates LogSumExp.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, input_size)
linear1_weight (torch.Tensor): Weight matrix for first linear layer of shape (hidden_size, input_size)
linear1_bias (torch.Tensor): Bias vector for first linear layer of shape (hidden_size)
Returns:
torch.Tensor: Scalar output after applying linear layers, sigmoid, sum and logsumexp
"""
x = F.linear(x, linear1_weight, linear1_bias)
x = torch.sigmoid(x)
x = torch.sum(x, dim=1)
x = torch.logsumexp(x, dim=0)
return x
class Model(nn.Module):
"""
Model that performs a matrix multiplication (Gemm), applies Sigmoid, sums the result, and calculates the LogSumExp.
"""
def __init__(self, input_size, hidden_size, output_size):
super(Model, self).__init__()
lin1 = nn.Linear(input_size, hidden_size)
self.linear1_weight = nn.Parameter(lin1.weight)
self.linear1_bias = nn.Parameter(
lin1.bias
+ torch.randn(
lin1.bias.shape, device=lin1.bias.device, dtype=lin1.bias.dtype
)
* 0.02
)
def forward(self, x, fn=module_fn):
return fn(x, self.linear1_weight, self.linear1_bias)
batch_size = 128
input_size = 10
hidden_size = 20
output_size = 5
def get_inputs():
return [torch.randn(batch_size, input_size)]
def get_init_inputs():
return [input_size, hidden_size, output_size]
import torch
import torch.nn as nn
class Model(nn.Module):
"""
Model that performs a matrix multiplication (Gemm), applies Sigmoid, sums the result, and calculates the LogSumExp.
"""
def __init__(self, input_size, hidden_size, output_size):
super(Model, self).__init__()
self.linear1 = nn.Linear(input_size, hidden_size)
self.linear1.bias = nn.Parameter(self.linear1.bias + torch.randn(self.linear1.bias.shape, device=self.linear1.bias.device, dtype=self.linear1.bias.dtype) * 0.02)
def forward(self, x):
x = self.linear1(x)
x = torch.sigmoid(x)
x = torch.sum(x, dim=1)
x = torch.logsumexp(x, dim=0)
return x
batch_size = 128
input_size = 10
hidden_size = 20
output_size = 5
def get_inputs():
return [torch.randn(batch_size, input_size)]
def get_init_inputs():
return [input_size, hidden_size, output_size]
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include <math.h>
// Fused kernel that computes matrix multiplication with bias, applies sigmoid, and reduces each row,
// all within a single warp (32 threads per block).
__global__ void fused_linear_sigmoid_rowsum_kernel(
const float* __restrict__ input,
const float* __restrict__ weight,
const float* __restrict__ bias,
float* __restrict__ row_sums,
const int batch_size,
const int input_size,
const int hidden_size
) {
// Each block handles one row; we launch with 32 threads per block (one warp).
int row = blockIdx.x;
if (row < batch_size) {
float thread_sum = 0.0f;
// Each thread processes a subset of the hidden dimension using a stride of 32.
for (int col = threadIdx.x; col < hidden_size; col += 32) {
float dot = 0.0f;
// Compute the dot product: input[row, :] * weight[col, :]
for (int i = 0; i < input_size; i++) {
dot += input[row * input_size + i] * weight[col * input_size + i];
}
dot += bias[col];
// Apply sigmoid activation
float activated = 1.0f / (1.0f + expf(-dot));
thread_sum += activated;
}
// Perform warp-level reduction using shfl_down_sync
// The warp size is assumed to be 32.
for (int offset = 16; offset > 0; offset /= 2) {
thread_sum += __shfl_down_sync(0xffffffff, thread_sum, offset);
}
// Thread 0 writes the final reduced sum to global memory.
if (threadIdx.x == 0) {
row_sums[row] = thread_sum;
}
}
}
// LogSumExp kernel that reduces the row sums using a single warp and warp-level primitives.
__global__ void logsumexp_warp_kernel(
const float* __restrict__ row_sums,
float* __restrict__ final_output,
const int batch_size
) {
// Launch with a single warp (32 threads).
float thread_max = -INFINITY;
// Each thread processes elements in a grid-stride loop over the row_sums vector.
for (int i = threadIdx.x; i < batch_size; i += 32) {
thread_max = fmaxf(thread_max, row_sums[i]);
}
// Warp-level max reduction
for (int offset = 16; offset > 0; offset /= 2) {
thread_max = fmaxf(thread_max, __shfl_down_sync(0xffffffff, thread_max, offset));
}
float global_max = thread_max; // Now every thread has access to the maximum via broadcast within the warp.
// Compute the sum of exponentials in a grid-stride loop.
float thread_sum = 0.0f;
for (int i = threadIdx.x; i < batch_size; i += 32) {
thread_sum += expf(row_sums[i] - global_max);
}
// Warp-level sum reduction
for (int offset = 16; offset > 0; offset /= 2) {
thread_sum += __shfl_down_sync(0xffffffff, thread_sum, offset);
}
// Thread 0 writes the final logsumexp result.
if (threadIdx.x == 0) {
final_output[0] = logf(thread_sum) + global_max;
}
}
// The forward function binds the CUDA kernels to PyTorch.
// It launches one block per row for the fused kernel (with 32 threads per block), and a single block
// of 32 threads for the logsumexp reduction.
torch::Tensor forward(
torch::Tensor input,
torch::Tensor weight,
torch::Tensor bias
) {
const int batch_size = input.size(0);
const int input_size = input.size(1);
const int hidden_size = weight.size(0);
auto options = torch::TensorOptions().dtype(input.dtype()).device(input.device());
auto row_sums = torch::empty({batch_size}, options);
auto final_output = torch::empty({1}, options);
// Launch one warp (32 threads) per row.
dim3 grid(batch_size);
dim3 block(32);
fused_linear_sigmoid_rowsum_kernel<<<grid, block>>>(
input.data_ptr<float>(),
weight.data_ptr<float>(),
bias.data_ptr<float>(),
row_sums.data_ptr<float>(),
batch_size,
input_size,
hidden_size
);
// Launch the logsumexp kernel with a single warp (32 threads) in one block.
logsumexp_warp_kernel<<<1, 32>>>(
row_sums.data_ptr<float>(),
final_output.data_ptr<float>(),
batch_size
);
return final_output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Warp-level fused forward pass");
}
Metric | Value | Unit | Variance | Samples |
---|---|---|---|---|
Executed Ipc Active | 0.060 | inst/cycle | 0.000 | 5 |
Executed Ipc Elapsed | 0.000 | inst/cycle | 0.000 | 5 |
Issue Slots Busy | 1.548 | % | 0.004 | 5 |
Issued Ipc Active | 0.062 | inst/cycle | 0.000 | 5 |
SM Busy | 1.548 | % | 0.004 | 5 |
Memory Throughput | 2372922780.628 | byte/second | 5940212748235258.000 | 5 |
Mem Busy | 9.214 | % | 0.079 | 5 |
Max Bandwidth | 4.760 | % | 0.026 | 5 |
L1/TEX Hit Rate | 48.480 | % | 0.000 | 5 |
L2 Hit Rate | 103.682 | % | 0.069 | 5 |
Mem Pipes Busy | 0.000 | % | 0.000 | 5 |
Warp Cycles Per Issued Instruction | 16.410 | cycle | 2.394 | 5 |
Warp Cycles Per Executed Instruction | 16.680 | cycle | 2.473 | 5 |
Avg. Active Threads Per Warp | 30.730 | 0.000 | 5 | |
Avg. Not Predicated Off Threads Per Warp | 29.150 | 0.000 | 5 | |
Max Active Clusters | 0.000 | cluster | 0.000 | 5 |
Max Cluster Size | 8.000 | block | 0.000 | 5 |
Overall GPU Occupancy | 0.000 | % | 0.000 | 5 |
Cluster Occupancy | 0.000 | % | 0.000 | 5 |
Block Limit SM | 32.000 | block | 0.000 | 5 |
Block Limit Registers | 64.000 | block | 0.000 | 5 |
Block Limit Shared Mem | 32.000 | block | 0.000 | 5 |
Block Limit Warps | 64.000 | block | 0.000 | 5 |
Theoretical Active Warps per SM | 32.000 | warp | 0.000 | 5 |
Theoretical Occupancy | 50.000 | % | 0.000 | 5 |
Achieved Occupancy | 1.560 | % | 0.000 | 5 |
Achieved Active Warps Per SM | 1.000 | warp | 0.000 | 5 |
Rule | Description |
---|---|
WRN HighPipeUtilization | All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details. |
INF CPIStall | Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason. |
WRN Occupancy | This kernel's theoretical occupancy (50.0%) is limited by the number of blocks that can fit on the SM. This kernel's theoretical occupancy (50.0%) is limited by the required amount of shared memory. The difference between calculated theoretical (50.0%) and measured achieved occupancy (1.6%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy. |
Operation / Metric | Value | Unit |
---|---|---|
aten::to | ||
CPU Time | 408300.66 | μs |
Device Time | 5.28 | μs |
Self CPU Time | 55.60 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_to_copy | ||
CPU Time | 408245.06 | μs |
Device Time | 5.28 | μs |
Self CPU Time | 110.28 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::empty_strided | ||
CPU Time | 407980.55 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 109.43 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaDeviceGetStreamPriorityRange | ||
CPU Time | 407681.08 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 407681.08 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaLaunchKernel | ||
CPU Time | 470202.69 | μs |
Device Time | 33559.43 | μs |
Self CPU Time | 470202.69 | μs |
Self Device Time | 33559.43 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaEventRecord | ||
CPU Time | 16393.17 | μs |
Device Time | 33370.85 | μs |
Self CPU Time | 16393.17 | μs |
Self Device Time | 33370.85 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::zero_ | ||
CPU Time | 61203.59 | μs |
Device Time | 600741.12 | μs |
Self CPU Time | 13049.25 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::fill_ | ||
CPU Time | 48156.10 | μs |
Device Time | 600741.12 | μs |
Self CPU Time | 16396.30 | μs |
Self Device Time | 600741.12 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>) | ||
CPU Time | 0.00 | μs |
Device Time | 600741.12 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 600741.12 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45288 warnings generated when compiling for host. Suppressed 45323 warnings (45276 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.