← Back to Leaderboard

The AI CUDA Engineer 👷

51_Argmax_over_a_dimensionwarp_argmax_nosm_edit_1

Level 1 • Task 51
import torch
import torch.nn as nn
import torch.functional as F


def module_fn(x: torch.Tensor, dim: int) -> torch.Tensor:
    """
    Applies argmax over the specified dimension to the input tensor.

    Args:
        x (torch.Tensor): Input tensor
        dim (int): Dimension to perform argmax over

    Returns:
        torch.Tensor: Output tensor with argmax applied over specified dimension
    """
    return torch.argmax(x, dim)


class Model(nn.Module):
    """
    Simple model that performs Argmax over a specified dimension.
    """

    def __init__(self, dim: int):
        """
        Initializes the model with the dimension to perform argmax.

        Args:
            dim (int): The dimension to perform argmax over.
        """
        super(Model, self).__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor, fn=module_fn) -> torch.Tensor:
        """
        Applies argmax over the specified dimension to the input tensor.

        Args:
            x (torch.Tensor): Input tensor
            fn: Function to apply (defaults to module_fn)

        Returns:
            torch.Tensor: Output tensor with argmax applied, with the specified dimension removed.
        """
        return fn(x, self.dim)


batch_size = 16
dim1 = 256
dim2 = 256


def get_inputs():
    x = torch.randn(batch_size, dim1, dim2)
    return [x]


def get_init_inputs():
    return [1]
import torch
import torch.nn as nn


class Model(nn.Module):
    """
    Simple model that performs Argmax over a specified dimension.
    """

    def __init__(self, dim: int):
        """
        Initializes the model with the dimension to perform argmax.

        Args:
            dim (int): The dimension to perform argmax over.
        """
        super(Model, self).__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Applies argmax over the specified dimension to the input tensor.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor with argmax applied, with the specified dimension removed.
        """
        return torch.argmax(x, dim=self.dim)


batch_size = 16
dim1 = 256
dim2 = 256


def get_inputs():
    x = torch.randn(batch_size, dim1, dim2)
    return [x]


def get_init_inputs():
    return [1]

Kernel Information

Related Kernels (Level 1, Task 51 • 51_Argmax_over_a_dimension)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 warp_argmax_nosm_edit_1 0.01 1.85 2.54
🥈 warp_level_argmax_base 0.01 1.72 2.36
🥈 warp_level_argmax_edit_1 0.01 1.72 2.36
🥈 efficient_argmax_base 0.01 1.72 2.36
🥈 stride_loop_argmax_stride_base 0.01 1.72 2.36
🥈 stride_loop_argmax_final_edit_1 0.01 1.72 2.36
🥈 argmax_coop_red_tuned_base 0.01 1.72 2.36
🥈 argmax_coop_red_tuned_edit_1 0.01 1.72 2.36
🥈 divergence_free_argmax_base 0.01 1.72 2.36
🥈 optimized_argmax_combination_base 0.01 1.72 2.36
🥈 argmax_ldg_128_opt_base 0.01 1.72 2.36
🥈 argmax_ldg_128_opt_edit_1 0.01 1.72 2.36
🥈 argmax_coop_red_sync_opt_base 0.01 1.72 2.36
🥈 argmax_aligned_mem_base_edit_1 0.01 1.72 2.36
🥈 stride_loop_argmax_final_base 0.01 1.72 2.36
🥈 warp_argmax_nosm_base 0.01 1.72 2.36
17 stride_loop_argmax_base 0.01 1.61 2.20
17 loop_unrolled_argmax_edit_1 0.01 1.61 2.20
17 stride_loop_argmax_edit_1 0.01 1.61 2.20
17 optimized_argmax_kernel_base 0.01 1.61 2.20
#include <torch/extension.h>
#include <cuda_runtime.h>
#include <cfloat>
#include <vector>

// This kernel computes argmax over a specified dimension using only warp-level primitives.
// Each block is assigned one (outer, inner) pair and is launched with exactly 32 threads (one warp).
// Each thread processes several elements along the reduction dimension in a stride loop, utilizing shared memory for improved performance.
// Then, warp-level intrinsic __shfl_down_sync() is used to reduce and determine the maximum value and its index,
// completely avoiding shared memory operations for the reduction phase.

__global__ void warp_argmax_nosm_kernel(
    const float* __restrict__ x,
    int64_t* __restrict__ indices,
    const int outerSize,
    const int dimSize,
    const int innerSize) {

    // Each block handles one (outer, inner) pair
    int idx = blockIdx.x;
    if (idx >= outerSize * innerSize) return;

    int outer_idx = idx / innerSize;
    int inner_idx = idx % innerSize;
    int start_offset = outer_idx * (dimSize * innerSize) + inner_idx;

    // Each thread in the warp computes a partial maximum over the reduction dimension.
    // Using a stride loop with a step equal to the warp size.
    float thread_max = -FLT_MAX;
    int thread_arg = 0;
    const int warpSize = 32;

    for (int d = threadIdx.x; d < dimSize; d += warpSize) {
        // Use __ldg to enable read-only cache and improved performance
        float val = __ldg(&x[start_offset + d * innerSize]);
        if (val > thread_max) {
            thread_max = val;
            thread_arg = d;
        } else if (val == thread_max && d < thread_arg) {
            // Tie-breaker: choose the smaller index
            thread_arg = d;
        }
    }

    // Perform warp-level reduction using shuffle intrinsics
    unsigned int mask = 0xffffffff; // Full mask for 32 threads
    for (int offset = warpSize / 2; offset > 0; offset /= 2) {
        float other_max = __shfl_down_sync(mask, thread_max, offset);
        int other_arg = __shfl_down_sync(mask, thread_arg, offset);
        if (other_max > thread_max) {
            thread_max = other_max;
            thread_arg = other_arg;
        } else if (other_max == thread_max && other_arg < thread_arg) {
            thread_arg = other_arg;
        }
    }

    // The first thread in the warp writes the final argmax result
    if (threadIdx.x == 0) {
        indices[idx] = thread_arg;
    }
}

// Host function to launch the CUDA kernel for argmax
// This function computes outerSize, dimSize, and innerSize based on the input tensor dimensions
// and then launches one warp (32 threads) per (outer, inner) pair.

torch::Tensor argmax_forward_cuda(const torch::Tensor& x, const int64_t dim) {
    TORCH_CHECK(x.scalar_type() == at::kFloat, "Only float32 is supported.");
    
    auto x_contig = x.contiguous();
    auto sizes = x_contig.sizes();
    const int ndim = x_contig.dim();
    TORCH_CHECK(dim >= 0 && dim < ndim, "Invalid dimension for argmax.");

    int outerSize = 1;
    for (int i = 0; i < dim; i++) {
        outerSize *= sizes[i];
    }
    int dimSize = sizes[dim];
    int innerSize = 1;
    for (int i = dim + 1; i < ndim; i++) {
        innerSize *= sizes[i];
    }

    // Build the output shape by removing the reduction dimension
    std::vector<int64_t> out_sizes;
    for (int i = 0; i < ndim; i++) {
        if (i != dim) {
            out_sizes.push_back(sizes[i]);
        }
    }
    
    auto options = torch::TensorOptions().device(x.device()).dtype(torch::kLong);
    auto indices = torch::empty(out_sizes, options);

    // Each output element corresponds to one outer*inner pair
    int total = outerSize * innerSize;
    // Launch one warp (32 threads) per output element
    const int threads = 32;
    const int blocks = total;

    warp_argmax_nosm_kernel<<<blocks, threads>>>(
        x_contig.data_ptr<float>(),
        indices.data_ptr<int64_t>(),
        outerSize,
        dimSize,
        innerSize);

    return indices;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &argmax_forward_cuda, "ArgMax CUDA forward (warp-level reduction, no shared memory)");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 0.478 inst/cycle 0.000 5
Executed Ipc Elapsed 0.346 inst/cycle 0.000 5
Issue Slots Busy 11.944 % 0.002 5
Issued Ipc Active 0.478 inst/cycle 0.000 5
SM Busy 13.582 % 0.003 5
Memory Throughput 381847501582.306 byte/second 3917388363601052672.000 5
Mem Busy 62.256 % 9.271 5
Max Bandwidth 30.214 % 2.087 5
L1/TEX Hit Rate 0.000 % 0.000 5
L2 Hit Rate 85.130 % 0.984 5
Mem Pipes Busy 3.492 % 0.028 5
Warp Cycles Per Issued Instruction 55.194 cycle 0.212 5
Warp Cycles Per Executed Instruction 55.408 cycle 0.221 5
Avg. Active Threads Per Warp 30.430 0.000 5
Avg. Not Predicated Off Threads Per Warp 28.810 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 64.000 block 0.000 5
Block Limit Shared Mem 32.000 block 0.000 5
Block Limit Warps 64.000 block 0.000 5
Theoretical Active Warps per SM 32.000 warp 0.000 5
Theoretical Occupancy 50.000 % 0.000 5
Achieved Occupancy 40.904 % 0.003 5
Achieved Active Warps Per SM 26.178 warp 0.001 5
Analysis Rules
Rule Description
WRN HighPipeUtilization All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details.
INF CPIStall Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason.
WRN Occupancy This kernel's theoretical occupancy (50.0%) is limited by the number of blocks that can fit on the SM. This kernel's theoretical occupancy (50.0%) is limited by the required amount of shared memory. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy.
Operation / Metric Value Unit
aten::to
CPU Time 869396.25 μs
Device Time 367.10 μs
Self CPU Time 39.50 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_to_copy
CPU Time 869356.75 μs
Device Time 367.10 μs
Self CPU Time 114.93 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::empty_strided
CPU Time 868633.48 μs
Device Time 0.00 μs
Self CPU Time 90.57 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaDeviceGetStreamPriorityRange
CPU Time 841738.99 μs
Device Time 0.00 μs
Self CPU Time 841738.99 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 540045.50 μs
Device Time 20766.24 μs
Self CPU Time 540045.50 μs
Self Device Time 20766.24 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
warp_argmax_nosm_kernel(float const*, long*, int, int, int)
CPU Time 0.00 μs
Device Time 85198.67 μs
Self CPU Time 0.00 μs
Self Device Time 85198.67 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaEventRecord
CPU Time 23237.61 μs
Device Time 41363.93 μs
Self CPU Time 23237.61 μs
Self Device Time 41363.93 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 67449.46 μs
Device Time 617664.19 μs
Self CPU Time 12938.39 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::fill_
CPU Time 54515.04 μs
Device Time 617664.19 μs
Self CPU Time 15571.23 μs
Self Device Time 617664.19 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>)
CPU Time 0.00 μs
Device Time 617664.19 μs
Self CPU Time 0.00 μs
Self Device Time 617664.19 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45283 warnings generated when compiling for host.
Suppressed 45322 warnings (45275 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_1/task_51/b5_s2_warp_argmax_nosm/edit_1/edit_1.cu:15:5 bugprone-easily-swappable-parameters
15 | const int outerSize,
| ^~~~~~~~~~~~~~~~~~~~
16 | const int dimSize,
| ~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_1/task_51/b5_s2_warp_argmax_nosm/edit_1/edit_1.cu:15:15: note: the first parameter in the range is 'outerSize'
15 | const int outerSize,
| ^~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_1/task_51/b5_s2_warp_argmax_nosm/edit_1/edit_1.cu:16:15: note: the last parameter in the range is 'dimSize'
16 | const int dimSize,
| ^~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_1/task_51/b5_s2_warp_argmax_nosm/edit_1/edit_1.cu:20:15: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
20 | int idx = blockIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_1/task_51/b5_s2_warp_argmax_nosm/edit_1/edit_1.cu:33:18: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
33 | for (int d = threadIdx.x; d < dimSize; d += warpSize) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_1/task_51/b5_s2_warp_argmax_nosm/edit_1/edit_1.cu:73:22: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
73 | const int ndim = x_contig.dim();
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_1/task_51/b5_s2_warp_argmax_nosm/edit_1/edit_1.cu:78:22: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
78 | outerSize *= sizes[i];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_1/task_51/b5_s2_warp_argmax_nosm/edit_1/edit_1.cu:80:19: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
80 | int dimSize = sizes[dim];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_1/task_51/b5_s2_warp_argmax_nosm/edit_1/edit_1.cu:82:18: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
82 | for (int i = dim + 1; i < ndim; i++) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250212_optimize_b5_s4_e1_v2/level_1/task_51/b5_s2_warp_argmax_nosm/edit_1/edit_1.cu:83:22: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
83 | innerSize *= sizes[i];
| ^