← Back to Leaderboard

The AI CUDA Engineer 👷

47_Sum_reduction_over_a_dimensionfully_unrolled_warp_sum_reduction_base

Level 1 • Task 47
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(x: torch.Tensor, dim: int) -> torch.Tensor:
    """
    Applies sum reduction over the specified dimension.

    Args:
        x (torch.Tensor): Input tensor of shape (..., dim, ...).
        dim (int): Dimension to reduce over.

    Returns:
        torch.Tensor: Output tensor after sum reduction, shape (..., 1, ...).
    """
    return torch.sum(x, dim=dim, keepdim=True)


class Model(nn.Module):
    """
    Simple model that performs sum reduction over a specified dimension.
    """

    def __init__(self, dim: int):
        """
        Initializes the model with the dimension to reduce over.

        Args:
            dim (int): Dimension to reduce over.
        """
        super(Model, self).__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor, fn=module_fn) -> torch.Tensor:
        """
        Applies sum reduction over the specified dimension.

        Args:
            x (torch.Tensor): Input tensor of shape (..., dim, ...).

        Returns:
            torch.Tensor: Output tensor after sum reduction, shape (..., 1, ...).
        """
        return fn(x, self.dim)


batch_size = 16
dim1 = 256
dim2 = 256
reduce_dim = 1


def get_inputs():
    x = torch.randn(batch_size, dim1, dim2)
    return [x]


def get_init_inputs():
    return [reduce_dim]
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Simple model that performs sum reduction over a specified dimension.
    """
    def __init__(self, dim: int):
        """
        Initializes the model with the dimension to reduce over.

        Args:
            dim (int): Dimension to reduce over.
        """
        super(Model, self).__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Applies sum reduction over the specified dimension.

        Args:
            x (torch.Tensor): Input tensor of shape (..., dim, ...).

        Returns:
            torch.Tensor: Output tensor after sum reduction, shape (..., 1, ...).
        """
        return torch.sum(x, dim=self.dim, keepdim=True)

batch_size = 16
dim1 = 256
dim2 = 256
reduce_dim = 1

def get_inputs():
    x = torch.randn(batch_size, dim1, dim2)
    return [x]

def get_init_inputs():
    return [reduce_dim]

Kernel Information

Related Kernels (Level 1, Task 47 • 47_Sum_reduction_over_a_dimension)

#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

// This kernel performs sum reduction over a given dimension using warp-level primitives with full unrolling.
// Each warp is responsible for computing one output element by summing over a segment of the reduction dimension.
// The reduction within each warp is fully unrolled using __shfl_down_sync, eliminating the use of shared memory and reducing overhead.

template <typename scalar_t>
__global__ void fully_unrolled_warp_reduce_sum_kernel(
    const scalar_t* __restrict__ input,
    scalar_t* __restrict__ output,
    int64_t reduce_size,
    int64_t inner_size,
    int64_t total_outputs) {

    const int warpSize = 32;
    int global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;
    int warp_id = global_thread_id / warpSize;    // Which warp this thread belongs to
    int lane = global_thread_id % warpSize;         // Lane index within the warp
    int total_warps = (gridDim.x * blockDim.x) / warpSize;

    // Each warp computes one output element, with a grid-stride loop over outputs
    for (int out_idx = warp_id; out_idx < total_outputs; out_idx += total_warps) {
        // Map the 1D output index to (outer, inner) indices
        int outer_idx = out_idx / inner_size;
        int inner_idx = out_idx % inner_size;

        // Base index for the given outer and inner indices
        int64_t base = outer_idx * reduce_size * inner_size + inner_idx;
        scalar_t sum = 0;

        // Each thread accumulates a partial sum from its portion of the reduction dimension
        for (int i = lane; i < reduce_size; i += warpSize) {
            sum += input[base + i * inner_size];
        }

        // Fully unroll warp-level reduction using shuffle down
        unsigned int mask = 0xffffffff;
        sum += __shfl_down_sync(mask, sum, 16);
        sum += __shfl_down_sync(mask, sum, 8);
        sum += __shfl_down_sync(mask, sum, 4);
        sum += __shfl_down_sync(mask, sum, 2);
        sum += __shfl_down_sync(mask, sum, 1);

        // The first lane writes the final result
        if (lane == 0) {
            output[out_idx] = sum;
        }
    }
}

// CUDA wrapper function
torch::Tensor sum_reduce_cuda(torch::Tensor input, int64_t dim) {
    // Adjust negative dimensions
    if (dim < 0) dim += input.dim();

    auto sizes = input.sizes().vec();
    int64_t reduce_size = sizes[dim];

    // Compute outer and inner sizes
    int64_t outer_size = 1;
    for (int i = 0; i < dim; i++) {
        outer_size *= sizes[i];
    }
    int64_t inner_size = 1;
    for (int i = dim + 1; i < sizes.size(); i++) {
        inner_size *= sizes[i];
    }

    // Set the reduced dimension to 1
    sizes[dim] = 1;
    auto output = torch::empty(sizes, input.options());

    // Total number of outputs is outer_size * inner_size
    int64_t total_outputs = outer_size * inner_size;

    // Each output element is computed by one warp (32 threads)
    const int warpSize = 32;
    int total_threads = total_outputs * warpSize;
    int threads = 256;  // Must be a multiple of 32
    int blocks = (total_threads + threads - 1) / threads;

    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "sum_reduce_cuda", ([&] {
        fully_unrolled_warp_reduce_sum_kernel<scalar_t><<<blocks, threads>>>(
            input.data_ptr<scalar_t>(),
            output.data_ptr<scalar_t>(),
            reduce_size,
            inner_size,
            total_outputs
        );
    }));

    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &sum_reduce_cuda, "Sum reduction forward (CUDA) using fully unrolled warp-level primitives");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 0.626 inst/cycle 0.000 5
Executed Ipc Elapsed 0.462 inst/cycle 0.000 5
Issue Slots Busy 16.022 % 0.009 5
Issued Ipc Active 0.644 inst/cycle 0.000 5
SM Busy 16.022 % 0.009 5
Memory Throughput 440144928057.034 byte/second 19189658725823696896.000 5
Mem Busy 54.014 % 0.268 5
Max Bandwidth 13.176 % 0.014 5
L1/TEX Hit Rate 87.482 % 0.000 5
L2 Hit Rate 47.544 % 0.002 5
Mem Pipes Busy 3.786 % 0.001 5
Warp Cycles Per Issued Instruction 41.524 cycle 0.246 5
Warp Cycles Per Executed Instruction 42.450 cycle 0.259 5
Avg. Active Threads Per Warp 32.000 0.000 5
Avg. Not Predicated Off Threads Per Warp 30.160 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 6.000 block 0.000 5
Block Limit Shared Mem 32.000 block 0.000 5
Block Limit Warps 8.000 block 0.000 5
Theoretical Active Warps per SM 48.000 warp 0.000 5
Theoretical Occupancy 75.000 % 0.000 5
Achieved Occupancy 41.290 % 0.019 5
Achieved Active Warps Per SM 26.424 warp 0.008 5
Analysis Rules
Rule Description
WRN HighPipeUtilization All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details.
INF CPIStall Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason.
WRN Occupancy This kernel's theoretical occupancy (75.0%) is limited by the number of required registers. The difference between calculated theoretical (75.0%) and measured achieved occupancy (41.5%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy.
Operation / Metric Value Unit
aten::to
CPU Time 463973.28 μs
Device Time 344.89 μs
Self CPU Time 37.67 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_to_copy
CPU Time 463935.61 μs
Device Time 344.89 μs
Self CPU Time 90.34 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::empty_strided
CPU Time 463281.89 μs
Device Time 0.00 μs
Self CPU Time 70.53 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaDeviceGetStreamPriorityRange
CPU Time 463001.66 μs
Device Time 0.00 μs
Self CPU Time 463001.66 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 563014.21 μs
Device Time 22219.64 μs
Self CPU Time 563014.21 μs
Self Device Time 22219.64 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void fully_unrolled_warp_reduce_sum_kernel<float>(float const*, float*, long, long, long)
CPU Time 0.00 μs
Device Time 67675.48 μs
Self CPU Time 0.00 μs
Self Device Time 67675.48 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaEventRecord
CPU Time 18553.05 μs
Device Time 44019.60 μs
Self CPU Time 18553.05 μs
Self Device Time 44019.60 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 67525.17 μs
Device Time 657990.82 μs
Self CPU Time 14991.00 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::fill_
CPU Time 52535.05 μs
Device Time 657990.82 μs
Self CPU Time 15991.54 μs
Self Device Time 657990.82 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>)
CPU Time 0.00 μs
Device Time 658069.47 μs
Self CPU Time 0.00 μs
Self Device Time 658069.47 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45285 warnings generated when compiling for host.
Suppressed 45322 warnings (45275 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250201_optimize_b10_s4_e0_sweep/level_1/task_47/b7_s3_fully_unrolled_warp_sum_reduction/base/base.cu:14:5 bugprone-easily-swappable-parameters
14 | int64_t inner_size,
| ^~~~~~~~~~~~~~~~~~~
15 | int64_t total_outputs) {
| ~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250201_optimize_b10_s4_e0_sweep/level_1/task_47/b7_s3_fully_unrolled_warp_sum_reduction/base/base.cu:14:13: note: the first parameter in the range is 'inner_size'
14 | int64_t inner_size,
| ^~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250201_optimize_b10_s4_e0_sweep/level_1/task_47/b7_s3_fully_unrolled_warp_sum_reduction/base/base.cu:15:13: note: the last parameter in the range is 'total_outputs'
15 | int64_t total_outputs) {
| ^~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250201_optimize_b10_s4_e0_sweep/level_1/task_47/b7_s3_fully_unrolled_warp_sum_reduction/base/base.cu:18:28: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
18 | int global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250201_optimize_b10_s4_e0_sweep/level_1/task_47/b7_s3_fully_unrolled_warp_sum_reduction/base/base.cu:21:23: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
21 | int total_warps = (gridDim.x * blockDim.x) / warpSize;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250201_optimize_b10_s4_e0_sweep/level_1/task_47/b7_s3_fully_unrolled_warp_sum_reduction/base/base.cu:26:25: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
26 | int outer_idx = out_idx / inner_size;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250201_optimize_b10_s4_e0_sweep/level_1/task_47/b7_s3_fully_unrolled_warp_sum_reduction/base/base.cu:27:25: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
27 | int inner_idx = out_idx % inner_size;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250201_optimize_b10_s4_e0_sweep/level_1/task_47/b7_s3_fully_unrolled_warp_sum_reduction/base/base.cu:67:18: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
67 | for (int i = dim + 1; i < sizes.size(); i++) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250201_optimize_b10_s4_e0_sweep/level_1/task_47/b7_s3_fully_unrolled_warp_sum_reduction/base/base.cu:80:25: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
80 | int total_threads = total_outputs * warpSize;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250201_optimize_b10_s4_e0_sweep/level_1/task_47/b7_s3_fully_unrolled_warp_sum_reduction/base/base.cu:84:5: warning: inside a lambda, '__func__' expands to the name of the function call operator; consider capturing the name of the enclosing function explicitly [bugprone-lambda-function-name]
84 | AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "sum_reduce_cuda", ([&] {
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:237:34: note: expanded from macro 'AT_DISPATCH_FLOATING_TYPES'
237 | AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:233:3: note: expanded from macro 'AT_DISPATCH_CASE_FLOATING_TYPES'
233 | AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:74:3: note: expanded from macro 'AT_DISPATCH_CASE'
74 | AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
| ^
note: (skipping 1 expansions in backtrace; use -fmacro-backtrace-limit=0 to see all)
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:58:7: note: expanded from macro 'AT_PRIVATE_CHECK_SELECTIVE_BUILD'
58 | AT_ERROR( \
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/c10/util/Exception.h:711:32: note: expanded from macro 'AT_ERROR'
711 | C10_EXPAND_MSVC_WORKAROUND(TORCH_CHECK(false, ::c10::str(__VA_ARGS__))); \
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/c10/util/Exception.h:536:9: note: expanded from macro 'TORCH_CHECK'
536 | __func__, \
| ^