← Back to Leaderboard

The AI CUDA Engineer 👷

51_Argmax_over_a_dimensionwarp_argmax_nosm_base

Level 1 • Task 51
import torch
import torch.nn as nn
import torch.functional as F


def module_fn(x: torch.Tensor, dim: int) -> torch.Tensor:
    """
    Applies argmax over the specified dimension to the input tensor.

    Args:
        x (torch.Tensor): Input tensor
        dim (int): Dimension to perform argmax over

    Returns:
        torch.Tensor: Output tensor with argmax applied over specified dimension
    """
    return torch.argmax(x, dim)


class Model(nn.Module):
    """
    Simple model that performs Argmax over a specified dimension.
    """

    def __init__(self, dim: int):
        """
        Initializes the model with the dimension to perform argmax.

        Args:
            dim (int): The dimension to perform argmax over.
        """
        super(Model, self).__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor, fn=module_fn) -> torch.Tensor:
        """
        Applies argmax over the specified dimension to the input tensor.

        Args:
            x (torch.Tensor): Input tensor
            fn: Function to apply (defaults to module_fn)

        Returns:
            torch.Tensor: Output tensor with argmax applied, with the specified dimension removed.
        """
        return fn(x, self.dim)


batch_size = 16
dim1 = 256
dim2 = 256


def get_inputs():
    x = torch.randn(batch_size, dim1, dim2)
    return [x]


def get_init_inputs():
    return [1]
import torch
import torch.nn as nn


class Model(nn.Module):
    """
    Simple model that performs Argmax over a specified dimension.
    """

    def __init__(self, dim: int):
        """
        Initializes the model with the dimension to perform argmax.

        Args:
            dim (int): The dimension to perform argmax over.
        """
        super(Model, self).__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Applies argmax over the specified dimension to the input tensor.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor with argmax applied, with the specified dimension removed.
        """
        return torch.argmax(x, dim=self.dim)


batch_size = 16
dim1 = 256
dim2 = 256


def get_inputs():
    x = torch.randn(batch_size, dim1, dim2)
    return [x]


def get_init_inputs():
    return [1]

Kernel Information

Related Kernels (Level 1, Task 51 • 51_Argmax_over_a_dimension)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 warp_argmax_nosm_edit_1 0.01 1.85 2.54
🥈 warp_level_argmax_base 0.01 1.72 2.36
🥈 warp_level_argmax_edit_1 0.01 1.72 2.36
🥈 efficient_argmax_base 0.01 1.72 2.36
🥈 stride_loop_argmax_stride_base 0.01 1.72 2.36
🥈 stride_loop_argmax_final_edit_1 0.01 1.72 2.36
🥈 argmax_coop_red_tuned_base 0.01 1.72 2.36
🥈 argmax_coop_red_tuned_edit_1 0.01 1.72 2.36
🥈 divergence_free_argmax_base 0.01 1.72 2.36
🥈 optimized_argmax_combination_base 0.01 1.72 2.36
🥈 argmax_ldg_128_opt_base 0.01 1.72 2.36
🥈 argmax_ldg_128_opt_edit_1 0.01 1.72 2.36
🥈 argmax_coop_red_sync_opt_base 0.01 1.72 2.36
🥈 argmax_aligned_mem_base_edit_1 0.01 1.72 2.36
🥈 stride_loop_argmax_final_base 0.01 1.72 2.36
🥈 warp_argmax_nosm_base 0.01 1.72 2.36
17 stride_loop_argmax_base 0.01 1.61 2.20
17 loop_unrolled_argmax_edit_1 0.01 1.61 2.20
17 stride_loop_argmax_edit_1 0.01 1.61 2.20
17 optimized_argmax_kernel_base 0.01 1.61 2.20
#include <torch/extension.h>
#include <cuda_runtime.h>
#include <cfloat>
#include <vector>

// This kernel computes argmax over a specified dimension using only warp-level primitives.
// Each block is assigned one (outer, inner) pair and is launched with exactly 32 threads (one warp).
// Each thread processes several elements along the reduction dimension in a stride loop.
// Then, warp-level intrinsic __shfl_down_sync() is used to reduce and determine the maximum value and its index,
// completely avoiding shared memory operations for the reduction phase.

__global__ void warp_argmax_nosm_kernel(
    const float* __restrict__ x,
    int64_t* __restrict__ indices,
    const int outerSize,
    const int dimSize,
    const int innerSize) {

    // Each block handles one (outer, inner) pair
    int idx = blockIdx.x;
    if (idx >= outerSize * innerSize) return;

    int outer_idx = idx / innerSize;
    int inner_idx = idx % innerSize;
    int start_offset = outer_idx * (dimSize * innerSize) + inner_idx;

    // Each thread in the warp computes a partial maximum over the reduction dimension.
    // Using a stride loop with a step equal to the warp size.
    float thread_max = -FLT_MAX;
    int thread_arg = 0;
    const int warpSize = 32;

    for (int d = threadIdx.x; d < dimSize; d += warpSize) {
        // Use __ldg to enable read-only cache and improved performance
        float val = __ldg(&x[start_offset + d * innerSize]);
        if (val > thread_max) {
            thread_max = val;
            thread_arg = d;
        } else if (val == thread_max && d < thread_arg) {
            // Tie-breaker: choose the smaller index
            thread_arg = d;
        }
    }

    // Perform warp-level reduction using shuffle intrinsics
    unsigned int mask = 0xffffffff; // Full mask for 32 threads
    for (int offset = warpSize / 2; offset > 0; offset /= 2) {
        float other_max = __shfl_down_sync(mask, thread_max, offset);
        int other_arg = __shfl_down_sync(mask, thread_arg, offset);
        if (other_max > thread_max) {
            thread_max = other_max;
            thread_arg = other_arg;
        } else if (other_max == thread_max && other_arg < thread_arg) {
            thread_arg = other_arg;
        }
    }

    // The first thread in the warp writes the final argmax result
    if (threadIdx.x == 0) {
        indices[idx] = thread_arg;
    }
}

// Host function to launch the CUDA kernel for argmax
// This function computes outerSize, dimSize, and innerSize based on the input tensor dimensions
// and then launches one warp (32 threads) per (outer, inner) pair.

torch::Tensor argmax_forward_cuda(const torch::Tensor& x, const int64_t dim) {
    TORCH_CHECK(x.scalar_type() == at::kFloat, "Only float32 is supported.");
    
    auto x_contig = x.contiguous();
    auto sizes = x_contig.sizes();
    const int ndim = x_contig.dim();
    TORCH_CHECK(dim >= 0 && dim < ndim, "Invalid dimension for argmax.");

    int outerSize = 1;
    for (int i = 0; i < dim; i++) {
        outerSize *= sizes[i];
    }
    int dimSize = sizes[dim];
    int innerSize = 1;
    for (int i = dim + 1; i < ndim; i++) {
        innerSize *= sizes[i];
    }

    // Build the output shape by removing the reduction dimension
    std::vector<int64_t> out_sizes;
    for (int i = 0; i < ndim; i++) {
        if (i != dim) {
            out_sizes.push_back(sizes[i]);
        }
    }
    
    auto options = torch::TensorOptions().device(x.device()).dtype(torch::kLong);
    auto indices = torch::empty(out_sizes, options);

    // Each output element corresponds to one outer*inner pair
    int total = outerSize * innerSize;
    // Launch one warp (32 threads) per output element
    const int threads = 32;
    const int blocks = total;

    warp_argmax_nosm_kernel<<<blocks, threads>>>(
        x_contig.data_ptr<float>(),
        indices.data_ptr<int64_t>(),
        outerSize,
        dimSize,
        innerSize);

    return indices;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &argmax_forward_cuda, "ArgMax CUDA forward (warp-level reduction, no shared memory)");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 0.478 inst/cycle 0.000 5
Executed Ipc Elapsed 0.358 inst/cycle 0.000 5
Issue Slots Busy 11.936 % 0.017 5
Issued Ipc Active 0.478 inst/cycle 0.000 5
SM Busy 13.566 % 0.021 5
Memory Throughput 383209954448.874 byte/second 11889782905367425024.000 5
Mem Busy 63.644 % 1.140 5
Max Bandwidth 31.102 % 0.105 5
L1/TEX Hit Rate 0.000 % 0.000 5
L2 Hit Rate 85.292 % 1.129 5
Mem Pipes Busy 3.594 % 0.001 5
Warp Cycles Per Issued Instruction 54.812 cycle 0.107 5
Warp Cycles Per Executed Instruction 55.036 cycle 0.121 5
Avg. Active Threads Per Warp 30.430 0.000 5
Avg. Not Predicated Off Threads Per Warp 28.810 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 64.000 block 0.000 5
Block Limit Shared Mem 32.000 block 0.000 5
Block Limit Warps 64.000 block 0.000 5
Theoretical Active Warps per SM 32.000 warp 0.000 5
Theoretical Occupancy 50.000 % 0.000 5
Achieved Occupancy 40.924 % 0.007 5
Achieved Active Warps Per SM 26.190 warp 0.003 5
Analysis Rules
Rule Description
WRN HighPipeUtilization All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details.
INF CPIStall Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason.
WRN Occupancy This kernel's theoretical occupancy (50.0%) is limited by the number of blocks that can fit on the SM. This kernel's theoretical occupancy (50.0%) is limited by the required amount of shared memory. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy.
Operation / Metric Value Unit
aten::to
CPU Time 631076.75 μs
Device Time 387.55 μs
Self CPU Time 38.91 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_to_copy
CPU Time 631037.84 μs
Device Time 387.55 μs
Self CPU Time 93.55 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::empty_strided
CPU Time 630321.57 μs
Device Time 0.00 μs
Self CPU Time 73.78 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaDeviceGetStreamPriorityRange
CPU Time 625424.83 μs
Device Time 0.00 μs
Self CPU Time 625424.83 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 502821.69 μs
Device Time 21311.11 μs
Self CPU Time 502821.69 μs
Self Device Time 21311.11 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
warp_argmax_nosm_kernel(float const*, long*, int, int, int)
CPU Time 0.00 μs
Device Time 75021.14 μs
Self CPU Time 0.00 μs
Self Device Time 75021.14 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaEventRecord
CPU Time 22836.76 μs
Device Time 39570.47 μs
Self CPU Time 22836.76 μs
Self Device Time 39570.47 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 65450.48 μs
Device Time 591722.59 μs
Self CPU Time 14105.33 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::fill_
CPU Time 51349.62 μs
Device Time 591722.59 μs
Self CPU Time 15223.92 μs
Self Device Time 591722.59 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>)
CPU Time 0.00 μs
Device Time 591722.59 μs
Self CPU Time 0.00 μs
Self Device Time 591722.59 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45283 warnings generated when compiling for host.
Suppressed 45322 warnings (45275 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_1/task_51/b5_s2_warp_argmax_nosm/base/base.cu:15:5 bugprone-easily-swappable-parameters
15 | const int outerSize,
| ^~~~~~~~~~~~~~~~~~~~
16 | const int dimSize,
| ~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_1/task_51/b5_s2_warp_argmax_nosm/base/base.cu:15:15: note: the first parameter in the range is 'outerSize'
15 | const int outerSize,
| ^~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_1/task_51/b5_s2_warp_argmax_nosm/base/base.cu:16:15: note: the last parameter in the range is 'dimSize'
16 | const int dimSize,
| ^~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_1/task_51/b5_s2_warp_argmax_nosm/base/base.cu:20:15: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
20 | int idx = blockIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_1/task_51/b5_s2_warp_argmax_nosm/base/base.cu:33:18: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
33 | for (int d = threadIdx.x; d < dimSize; d += warpSize) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_1/task_51/b5_s2_warp_argmax_nosm/base/base.cu:73:22: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
73 | const int ndim = x_contig.dim();
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_1/task_51/b5_s2_warp_argmax_nosm/base/base.cu:78:22: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
78 | outerSize *= sizes[i];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_1/task_51/b5_s2_warp_argmax_nosm/base/base.cu:80:19: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
80 | int dimSize = sizes[dim];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_1/task_51/b5_s2_warp_argmax_nosm/base/base.cu:82:18: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
82 | for (int i = dim + 1; i < ndim; i++) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_1/task_51/b5_s2_warp_argmax_nosm/base/base.cu:83:22: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
83 | innerSize *= sizes[i];
| ^