← Back to Leaderboard

The AI CUDA Engineer 👷

53_Min_reduction_over_a_dimensionmin_reduce_optimized_base

Level 1 • Task 53
import torch
import torch.nn as nn
import torch.functional as F


def module_fn(x: torch.Tensor, dim: int) -> torch.Tensor:
    """
    Applies min reduction over the specified dimension to the input tensor.

    Args:
        x (torch.Tensor): Input tensor
        dim (int): The dimension to reduce over

    Returns:
        torch.Tensor: Output tensor after min reduction over the specified dimension
    """
    return torch.min(x, dim)[0]


class Model(nn.Module):
    """
    Simple model that performs min reduction over a specific dimension.
    """

    def __init__(self, dim: int):
        """
        Initializes the model with the dimension to reduce over.

        Args:
            dim (int): The dimension to reduce over.
        """
        super(Model, self).__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor, fn=module_fn) -> torch.Tensor:
        """
        Applies min reduction over the specified dimension to the input tensor.

        Args:
            x (torch.Tensor): Input tensor
            fn: Function to apply (defaults to module_fn)

        Returns:
            torch.Tensor: Output tensor after min reduction over the specified dimension
        """
        return fn(x, self.dim)


batch_size = 16
dim1 = 256
dim2 = 256


def get_inputs():
    x = torch.randn(batch_size, dim1, dim2)
    return [x]


def get_init_inputs():
    return [1]  # Example, change to desired dimension
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Simple model that performs min reduction over a specific dimension.
    """
    def __init__(self, dim: int):
        """
        Initializes the model with the dimension to reduce over.

        Args:
            dim (int): The dimension to reduce over.
        """
        super(Model, self).__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Applies min reduction over the specified dimension to the input tensor.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor after min reduction over the specified dimension.
        """
        return torch.min(x, dim=self.dim)[0]

batch_size = 16
dim1 = 256
dim2 = 256

def get_inputs():
    x = torch.randn(batch_size, dim1, dim2)
    return [x]

def get_init_inputs():
    return [1] # Example, change to desired dimension

Kernel Information

Related Kernels (Level 1, Task 53 • 53_Min_reduction_over_a_dimension)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 min_reduction_warp_base_base 0.01 2.19 3.09
🥇 efficient_min_reduce_kernel_base 0.01 2.19 3.09
🥇 min_reduce_block_size_tuning_base 0.01 2.19 3.09
🥇 min_reduce_optimized_base 0.01 2.19 3.09
🥇 min_reduction_optimized_memory_base 0.01 2.19 3.09
🥇 min_reduce_tunable_blocksize_base 0.01 2.19 3.09
🥇 min_reduce_dynamic_block_base_base 0.01 2.19 3.09
🥇 modular_min_reduce_kernel_base_base 0.01 2.19 3.09
🥇 min_reduce_adaptive_blocks_base_base 0.01 2.19 3.09
🥇 min_reduce_fused_warp_base 0.01 2.19 3.09
🥇 min_reduce_combined_base 0.01 2.19 3.09
🥇 min_reduce_warp_unroll_base 0.01 2.19 3.09
🥇 min_reduce_combined_kernel_base 0.01 2.19 3.09
14 balanced_min_reduction_base 0.01 1.85 2.62
15 min_reduction_warp_shared_hybrid_base 0.01 1.72 2.43
15 optimized_block_size_experiment_base_base 0.01 1.72 2.43
15 min_reduction_shared_base 0.01 1.72 2.43
15 modular_min_reduction_base 0.01 1.72 2.43
15 fast_min_reduction_edit_1 0.01 1.72 2.43
15 vector_load_min_reduction_edit_1 0.01 1.72 2.43
/*
Optimized CUDA kernel for min reduction over a specified dimension.
The input tensor is logically reshaped as [outer, r, inner] and the reduction is performed along the r dimension.
Each warp (32 threads) computes one output element via warp-level reduction. 
This design combines aspects of warp-level reduction (as in kernel1) with streamlined indexing (from kernel2),
while applying loop unrolling for potentially improved performance.
*/

#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include <limits>
#include <c10/cuda/CUDAStream.h>

// Optimized kernel: each warp computes one output element's min reduction
template <typename scalar_t>
__global__ void min_reduce_optimized_kernel(
    const scalar_t* __restrict__ input,
    scalar_t* __restrict__ output,
    int outer,
    int r,
    int inner) {

  const int warpSize = 32;
  // Each warp is assigned one output element
  int warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / warpSize;
  if (warp_id >= outer * inner) return;

  // Determine lane within the warp
  int lane = threadIdx.x % warpSize;

  // Map warp_id to the corresponding (outer, inner) coordinates
  int outer_idx = warp_id / inner;
  int inner_idx = warp_id % inner;
  // Compute the base pointer offset for the reduction dimension
  int base = outer_idx * (r * inner) + inner_idx;

  // Initialize with maximum possible value
  scalar_t local_min = std::numeric_limits<scalar_t>::max();

  // Each thread in the warp iterates over the reduction dimension in stride of warpSize
  #pragma unroll
  for (int j = lane; j < r; j += warpSize) {
    int idx = base + j * inner;
    scalar_t val = input[idx];
    local_min = (val < local_min) ? val : local_min;
  }

  // Warp-level reduction using shuffle operations
  for (int offset = warpSize / 2; offset > 0; offset /= 2) {
    scalar_t other = __shfl_down_sync(0xffffffff, local_min, offset);
    local_min = (other < local_min) ? other : local_min;
  }

  // The first lane writes the result
  if (lane == 0) {
    output[warp_id] = local_min;
  }
}

// Forward function: sets up tensor dimensions, output shape and kernel launch parameters
torch::Tensor forward(torch::Tensor input, int64_t dim) {
  TORCH_CHECK(input.is_cuda(), "Input must be a CUDA tensor");
  if (!input.is_contiguous()) {
    input = input.contiguous();
  }

  int ndim = input.dim();
  TORCH_CHECK(dim >= 0 && dim < ndim, "dim out of range");

  // Compute sizes: outer dimensions, reduction dimension (r), and inner dimensions
  int outer = 1;
  for (int i = 0; i < dim; i++) {
    outer *= input.size(i);
  }
  int r = input.size(dim);
  int inner = 1;
  for (int i = dim + 1; i < ndim; i++) {
    inner *= input.size(i);
  }

  // Build output shape by removing the reduced dimension
  std::vector<int64_t> output_shape;
  for (int i = 0; i < ndim; i++) {
    if (i != dim) {
      output_shape.push_back(input.size(i));
    }
  }
  auto output = torch::empty(output_shape, input.options());

  // Each warp (32 threads) computes one output element
  int total_output = outer * inner;
  const int threads_per_block = 128;  // 128 threads/block = 4 warps per block
  int num_blocks = (total_output * 32 + threads_per_block - 1) / threads_per_block;

  AT_DISPATCH_ALL_TYPES(input.scalar_type(), "min_reduce_optimized_cuda", ([&] {
    min_reduce_optimized_kernel<scalar_t><<<num_blocks, threads_per_block, 0,
      c10::cuda::getCurrentCUDAStream().stream()>>>(
        input.data_ptr<scalar_t>(),
        output.data_ptr<scalar_t>(),
        outer,
        r,
        inner);
  }));

  return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &forward, "Optimized min reduction using warp-level primitives (CUDA)");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 0.360 inst/cycle 0.000 5
Executed Ipc Elapsed 0.260 inst/cycle 0.000 5
Issue Slots Busy 9.252 % 0.001 5
Issued Ipc Active 0.370 inst/cycle 0.000 5
SM Busy 9.252 % 0.001 5
Memory Throughput 467918742118.394 byte/second 7384462541123192832.000 5
Mem Busy 57.624 % 0.105 5
Max Bandwidth 16.564 % 0.005 5
L1/TEX Hit Rate 74.894 % 0.000 5
L2 Hit Rate 62.576 % 0.349 5
Mem Pipes Busy 3.286 % 0.000 5
Warp Cycles Per Issued Instruction 70.066 cycle 0.296 5
Warp Cycles Per Executed Instruction 72.400 cycle 0.316 5
Avg. Active Threads Per Warp 30.450 0.000 5
Avg. Not Predicated Off Threads Per Warp 28.040 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 16.000 block 0.000 5
Block Limit Shared Mem 32.000 block 0.000 5
Block Limit Warps 16.000 block 0.000 5
Theoretical Active Warps per SM 64.000 warp 0.000 5
Theoretical Occupancy 100.000 % 0.000 5
Achieved Occupancy 40.108 % 0.009 5
Achieved Active Warps Per SM 25.668 warp 0.003 5
Analysis Rules
Rule Description
WRN HighPipeUtilization All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details.
INF CPIStall Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason.
WRN Occupancy This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (40.3%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy.
Operation / Metric Value Unit
aten::to
CPU Time 274025.90 μs
Device Time 366.05 μs
Self CPU Time 34.66 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_to_copy
CPU Time 273991.24 μs
Device Time 366.05 μs
Self CPU Time 87.48 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::empty_strided
CPU Time 273318.05 μs
Device Time 0.00 μs
Self CPU Time 66.36 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaDeviceGetStreamPriorityRange
CPU Time 272999.34 μs
Device Time 0.00 μs
Self CPU Time 272999.34 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 424417.51 μs
Device Time 628.19 μs
Self CPU Time 424417.51 μs
Self Device Time 628.19 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void min_reduce_optimized_kernel<float>(float const*, float*, int, int, int)
CPU Time 0.00 μs
Device Time 48325.13 μs
Self CPU Time 0.00 μs
Self Device Time 48325.13 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 65901.38 μs
Device Time 520286.80 μs
Self CPU Time 10714.98 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::fill_
CPU Time 55191.18 μs
Device Time 520286.80 μs
Self CPU Time 14958.71 μs
Self Device Time 520286.80 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>)
CPU Time 0.00 μs
Device Time 520286.80 μs
Self CPU Time 0.00 μs
Self Device Time 520286.80 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45296 warnings generated when compiling for host.
Suppressed 45327 warnings (45280 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_53/b4_s1_min_reduce_optimized/base/base.cu:21:5 bugprone-easily-swappable-parameters
21 | int outer,
| ^~~~~~~~~~
22 | int r,
| ~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_53/b4_s1_min_reduce_optimized/base/base.cu:21:9: note: the first parameter in the range is 'outer'
21 | int outer,
| ^~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_53/b4_s1_min_reduce_optimized/base/base.cu:22:9: note: the last parameter in the range is 'r'
22 | int r,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_53/b4_s1_min_reduce_optimized/base/base.cu:27:17: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
27 | int warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / warpSize;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_53/b4_s1_min_reduce_optimized/base/base.cu:31:14: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
31 | int lane = threadIdx.x % warpSize;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_53/b4_s1_min_reduce_optimized/base/base.cu:69:14: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
69 | int ndim = input.dim();
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_53/b4_s1_min_reduce_optimized/base/base.cu:75:14: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
75 | outer *= input.size(i);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_53/b4_s1_min_reduce_optimized/base/base.cu:77:11: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
77 | int r = input.size(dim);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_53/b4_s1_min_reduce_optimized/base/base.cu:79:16: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
79 | for (int i = dim + 1; i < ndim; i++) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_53/b4_s1_min_reduce_optimized/base/base.cu:80:14: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
80 | inner *= input.size(i);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_53/b4_s1_min_reduce_optimized/base/base.cu:97:3: warning: inside a lambda, '__func__' expands to the name of the function call operator; consider capturing the name of the enclosing function explicitly [bugprone-lambda-function-name]
97 | AT_DISPATCH_ALL_TYPES(input.scalar_type(), "min_reduce_optimized_cuda", ([&] {
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:482:34: note: expanded from macro 'AT_DISPATCH_ALL_TYPES'
482 | AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__))
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:478:3: note: expanded from macro 'AT_DISPATCH_CASE_ALL_TYPES'
478 | AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:458:3: note: expanded from macro 'AT_DISPATCH_CASE_INTEGRAL_TYPES'
458 | AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
| ^
note: (skipping 2 expansions in backtrace; use -fmacro-backtrace-limit=0 to see all)
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:58:7: note: expanded from macro 'AT_PRIVATE_CHECK_SELECTIVE_BUILD'
58 | AT_ERROR( \
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/c10/util/Exception.h:711:32: note: expanded from macro 'AT_ERROR'
711 | C10_EXPAND_MSVC_WORKAROUND(TORCH_CHECK(false, ::c10::str(__VA_ARGS__))); \
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/c10/util/Exception.h:536:9: note: expanded from macro 'TORCH_CHECK'
536 | __func__, \
| ^