47_Sum_reduction_over_a_dimension
• warp_sum_reduce_unroll_base
import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(x: torch.Tensor, dim: int) -> torch.Tensor:
"""
Applies sum reduction over the specified dimension.
Args:
x (torch.Tensor): Input tensor of shape (..., dim, ...).
dim (int): Dimension to reduce over.
Returns:
torch.Tensor: Output tensor after sum reduction, shape (..., 1, ...).
"""
return torch.sum(x, dim=dim, keepdim=True)
class Model(nn.Module):
"""
Simple model that performs sum reduction over a specified dimension.
"""
def __init__(self, dim: int):
"""
Initializes the model with the dimension to reduce over.
Args:
dim (int): Dimension to reduce over.
"""
super(Model, self).__init__()
self.dim = dim
def forward(self, x: torch.Tensor, fn=module_fn) -> torch.Tensor:
"""
Applies sum reduction over the specified dimension.
Args:
x (torch.Tensor): Input tensor of shape (..., dim, ...).
Returns:
torch.Tensor: Output tensor after sum reduction, shape (..., 1, ...).
"""
return fn(x, self.dim)
batch_size = 16
dim1 = 256
dim2 = 256
reduce_dim = 1
def get_inputs():
x = torch.randn(batch_size, dim1, dim2)
return [x]
def get_init_inputs():
return [reduce_dim]
import torch
import torch.nn as nn
class Model(nn.Module):
"""
Simple model that performs sum reduction over a specified dimension.
"""
def __init__(self, dim: int):
"""
Initializes the model with the dimension to reduce over.
Args:
dim (int): Dimension to reduce over.
"""
super(Model, self).__init__()
self.dim = dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Applies sum reduction over the specified dimension.
Args:
x (torch.Tensor): Input tensor of shape (..., dim, ...).
Returns:
torch.Tensor: Output tensor after sum reduction, shape (..., 1, ...).
"""
return torch.sum(x, dim=self.dim, keepdim=True)
batch_size = 16
dim1 = 256
dim2 = 256
reduce_dim = 1
def get_inputs():
x = torch.randn(batch_size, dim1, dim2)
return [x]
def get_init_inputs():
return [reduce_dim]
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
// Kernel that uses warp-level primitives to perform reduction over the specified dimension.
template <typename scalar_t>
__global__ void warp_sum_reduce_kernel(
const scalar_t* __restrict__ input,
scalar_t* __restrict__ output,
int64_t reduce_size,
int64_t inner_size,
int64_t total_output) {
// Each block handles one output element (one (outer, inner) pair).
int idx = blockIdx.x; // index for output element
if (idx >= total_output) return;
// Determine corresponding outer and inner indices
int outer_idx = idx / inner_size;
int inner_idx = idx % inner_size;
scalar_t sum = 0;
// Use lane id of the warp; assume blockDim.x == warpSize (32 threads)
int lane = threadIdx.x;
// Pre-calculate base offset to avoid redundant computations
const int64_t base_offset = outer_idx * reduce_size * inner_size + inner_idx;
// Each thread in the warp sums elements from the reduction dim in a strided manner
#pragma unroll
for (int i = lane; i < reduce_size; i += warpSize) {
int64_t offset = base_offset + i * inner_size;
sum += input[offset];
}
// Use warp-level shuffle to reduce the partial sums within the warp
// Cache the current partial sum in register
scalar_t partial = sum;
const unsigned int mask = 0xffffffff;
#pragma unroll
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
const scalar_t n = __shfl_down_sync(mask, partial, offset);
partial += n;
}
sum = partial;
// The first lane writes the result
if (lane == 0) {
output[idx] = sum;
}
}
// Host function wrapping the kernel launch
torch::Tensor sum_reduce_cuda(torch::Tensor input, int64_t dim) {
// Adjust negative dimensions
if (dim < 0) dim += input.dim();
auto sizes = input.sizes().vec();
int64_t reduce_size = sizes[dim];
// Compute outer_size: product of dimensions before the reduction dim
int64_t outer_size = 1;
for (int i = 0; i < dim; i++) {
outer_size *= sizes[i];
}
// Compute inner_size: product of dimensions after the reduction dim
int64_t inner_size = 1;
for (int i = dim + 1; i < sizes.size(); i++) {
inner_size *= sizes[i];
}
// Prepare output tensor with reduce dimension set to 1
sizes[dim] = 1;
auto output = torch::empty(sizes, input.options());
// Total number of output elements
int64_t total_output = outer_size * inner_size;
// Launch one warp (32 threads) per output element
const int threads = 32; // warp size
const int blocks = total_output;
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "sum_reduce_cuda", ([&] {
warp_sum_reduce_kernel<scalar_t><<<blocks, threads>>>(
input.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(),
reduce_size,
inner_size,
total_output
);
}));
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &sum_reduce_cuda, "Sum reduction forward (CUDA) using warp-level primitives");
}
Metric | Value | Unit | Variance | Samples |
---|---|---|---|---|
Executed Ipc Active | 0.620 | inst/cycle | 0.000 | 5 |
Executed Ipc Elapsed | 0.434 | inst/cycle | 0.000 | 5 |
Issue Slots Busy | 15.642 | % | 0.001 | 5 |
Issued Ipc Active | 0.628 | inst/cycle | 0.000 | 5 |
SM Busy | 15.642 | % | 0.001 | 5 |
Memory Throughput | 367094058761.408 | byte/second | 41443318936938332160.000 | 5 |
Mem Busy | 55.826 | % | 1.205 | 5 |
Max Bandwidth | 32.672 | % | 0.316 | 5 |
L1/TEX Hit Rate | 0.000 | % | 0.000 | 5 |
L2 Hit Rate | 84.150 | % | 1.307 | 5 |
Mem Pipes Busy | 2.578 | % | 0.002 | 5 |
Warp Cycles Per Issued Instruction | 45.720 | cycle | 0.799 | 5 |
Warp Cycles Per Executed Instruction | 46.002 | cycle | 0.819 | 5 |
Avg. Active Threads Per Warp | 31.150 | 0.000 | 5 | |
Avg. Not Predicated Off Threads Per Warp | 30.010 | 0.000 | 5 | |
Max Active Clusters | 0.000 | cluster | 0.000 | 5 |
Max Cluster Size | 8.000 | block | 0.000 | 5 |
Overall GPU Occupancy | 0.000 | % | 0.000 | 5 |
Cluster Occupancy | 0.000 | % | 0.000 | 5 |
Block Limit SM | 32.000 | block | 0.000 | 5 |
Block Limit Registers | 84.000 | block | 0.000 | 5 |
Block Limit Shared Mem | 32.000 | block | 0.000 | 5 |
Block Limit Warps | 64.000 | block | 0.000 | 5 |
Theoretical Active Warps per SM | 32.000 | warp | 0.000 | 5 |
Theoretical Occupancy | 50.000 | % | 0.000 | 5 |
Achieved Occupancy | 44.202 | % | 0.012 | 5 |
Achieved Active Warps Per SM | 28.288 | warp | 0.005 | 5 |
Rule | Description |
---|---|
WRN HighPipeUtilization | All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details. |
INF CPIStall | Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason. |
WRN Occupancy | This kernel's theoretical occupancy (50.0%) is limited by the number of blocks that can fit on the SM. This kernel's theoretical occupancy (50.0%) is limited by the required amount of shared memory. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy. |
Operation / Metric | Value | Unit |
---|---|---|
aten::to | ||
CPU Time | 349735.31 | μs |
Device Time | 353.44 | μs |
Self CPU Time | 36.06 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_to_copy | ||
CPU Time | 349699.25 | μs |
Device Time | 353.44 | μs |
Self CPU Time | 74.37 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::empty_strided | ||
CPU Time | 349058.94 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 66.80 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaDeviceGetStreamPriorityRange | ||
CPU Time | 348783.58 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 348783.58 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaLaunchKernel | ||
CPU Time | 557084.11 | μs |
Device Time | 21343.42 | μs |
Self CPU Time | 557084.11 | μs |
Self Device Time | 21343.42 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void warp_sum_reduce_kernel<float>(float const*, float*, long, long, long) | ||
CPU Time | 0.00 | μs |
Device Time | 85400.71 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 85400.71 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaEventRecord | ||
CPU Time | 23081.68 | μs |
Device Time | 42480.56 | μs |
Self CPU Time | 23081.68 | μs |
Self Device Time | 42480.56 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::zero_ | ||
CPU Time | 69854.04 | μs |
Device Time | 634803.27 | μs |
Self CPU Time | 14169.52 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::fill_ | ||
CPU Time | 55686.24 | μs |
Device Time | 634803.27 | μs |
Self CPU Time | 15699.01 | μs |
Self Device Time | 634803.27 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>) | ||
CPU Time | 0.00 | μs |
Device Time | 634880.90 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 634880.90 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45285 warnings generated when compiling for host. Suppressed 45322 warnings (45275 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.