← Back to Leaderboard

The AI CUDA Engineer 👷

51_Argmax_over_a_dimensionargmax_coop_red_tuned_base

Level 1 • Task 51
import torch
import torch.nn as nn
import torch.functional as F


def module_fn(x: torch.Tensor, dim: int) -> torch.Tensor:
    """
    Applies argmax over the specified dimension to the input tensor.

    Args:
        x (torch.Tensor): Input tensor
        dim (int): Dimension to perform argmax over

    Returns:
        torch.Tensor: Output tensor with argmax applied over specified dimension
    """
    return torch.argmax(x, dim)


class Model(nn.Module):
    """
    Simple model that performs Argmax over a specified dimension.
    """

    def __init__(self, dim: int):
        """
        Initializes the model with the dimension to perform argmax.

        Args:
            dim (int): The dimension to perform argmax over.
        """
        super(Model, self).__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor, fn=module_fn) -> torch.Tensor:
        """
        Applies argmax over the specified dimension to the input tensor.

        Args:
            x (torch.Tensor): Input tensor
            fn: Function to apply (defaults to module_fn)

        Returns:
            torch.Tensor: Output tensor with argmax applied, with the specified dimension removed.
        """
        return fn(x, self.dim)


batch_size = 16
dim1 = 256
dim2 = 256


def get_inputs():
    x = torch.randn(batch_size, dim1, dim2)
    return [x]


def get_init_inputs():
    return [1]
import torch
import torch.nn as nn


class Model(nn.Module):
    """
    Simple model that performs Argmax over a specified dimension.
    """

    def __init__(self, dim: int):
        """
        Initializes the model with the dimension to perform argmax.

        Args:
            dim (int): The dimension to perform argmax over.
        """
        super(Model, self).__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Applies argmax over the specified dimension to the input tensor.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor with argmax applied, with the specified dimension removed.
        """
        return torch.argmax(x, dim=self.dim)


batch_size = 16
dim1 = 256
dim2 = 256


def get_inputs():
    x = torch.randn(batch_size, dim1, dim2)
    return [x]


def get_init_inputs():
    return [1]

Kernel Information

Related Kernels (Level 1, Task 51 • 51_Argmax_over_a_dimension)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 warp_argmax_nosm_edit_1 0.01 1.85 2.54
🥈 warp_level_argmax_base 0.01 1.72 2.36
🥈 warp_level_argmax_edit_1 0.01 1.72 2.36
🥈 efficient_argmax_base 0.01 1.72 2.36
🥈 stride_loop_argmax_stride_base 0.01 1.72 2.36
🥈 stride_loop_argmax_final_edit_1 0.01 1.72 2.36
🥈 argmax_coop_red_tuned_base 0.01 1.72 2.36
🥈 argmax_coop_red_tuned_edit_1 0.01 1.72 2.36
🥈 divergence_free_argmax_base 0.01 1.72 2.36
🥈 optimized_argmax_combination_base 0.01 1.72 2.36
🥈 argmax_ldg_128_opt_base 0.01 1.72 2.36
🥈 argmax_ldg_128_opt_edit_1 0.01 1.72 2.36
🥈 argmax_coop_red_sync_opt_base 0.01 1.72 2.36
🥈 argmax_aligned_mem_base_edit_1 0.01 1.72 2.36
🥈 stride_loop_argmax_final_base 0.01 1.72 2.36
🥈 warp_argmax_nosm_base 0.01 1.72 2.36
17 stride_loop_argmax_base 0.01 1.61 2.20
17 loop_unrolled_argmax_edit_1 0.01 1.61 2.20
17 stride_loop_argmax_edit_1 0.01 1.61 2.20
17 optimized_argmax_kernel_base 0.01 1.61 2.20
#include <torch/extension.h>
#include <vector>
#include <float.h>

__global__ void argmax_kernel_coop_tuned(
    const float* __restrict__ x,
    int64_t* __restrict__ indices,
    const int outerSize,
    const int dimSize,
    const int innerSize) {

    int slice = blockIdx.x;
    if (slice >= outerSize * innerSize) return;

    int outer_idx = slice / innerSize;
    int inner_idx = slice % innerSize;
    int base_offset = outer_idx * (dimSize * innerSize) + inner_idx;

    // Use 128 threads per block for better occupancy on H100
    float local_max = -FLT_MAX;
    int local_argmax = 0;

    // Each thread handles multiple elements with stride
    for (int d = threadIdx.x; d < dimSize; d += blockDim.x) {
        float curr_val = x[base_offset + d * innerSize];
        if (curr_val > local_max) {
            local_max = curr_val;
            local_argmax = d;
        }
    }

    // Shared memory for reduction
    extern __shared__ char shared_mem[];
    float* s_max = reinterpret_cast<float*>(shared_mem);
    int* s_idx = reinterpret_cast<int*>(s_max + blockDim.x);

    s_max[threadIdx.x] = local_max;
    s_idx[threadIdx.x] = local_argmax;
    __syncthreads();

    // Optimized reduction using warp-level operations for final phase
    if (threadIdx.x < 64) {
        if (s_max[threadIdx.x + 64] > s_max[threadIdx.x]) {
            s_max[threadIdx.x] = s_max[threadIdx.x + 64];
            s_idx[threadIdx.x] = s_idx[threadIdx.x + 64];
        }
    }
    __syncthreads();

    if (threadIdx.x < 32) {
        // Warp-level reduction (no sync needed within a warp)
        if (s_max[threadIdx.x + 32] > s_max[threadIdx.x]) {
            s_max[threadIdx.x] = s_max[threadIdx.x + 32];
            s_idx[threadIdx.x] = s_idx[threadIdx.x + 32];
        }
        if (s_max[threadIdx.x + 16] > s_max[threadIdx.x]) {
            s_max[threadIdx.x] = s_max[threadIdx.x + 16];
            s_idx[threadIdx.x] = s_idx[threadIdx.x + 16];
        }
        if (s_max[threadIdx.x + 8] > s_max[threadIdx.x]) {
            s_max[threadIdx.x] = s_max[threadIdx.x + 8];
            s_idx[threadIdx.x] = s_idx[threadIdx.x + 8];
        }
        if (s_max[threadIdx.x + 4] > s_max[threadIdx.x]) {
            s_max[threadIdx.x] = s_max[threadIdx.x + 4];
            s_idx[threadIdx.x] = s_idx[threadIdx.x + 4];
        }
        if (s_max[threadIdx.x + 2] > s_max[threadIdx.x]) {
            s_max[threadIdx.x] = s_max[threadIdx.x + 2];
            s_idx[threadIdx.x] = s_idx[threadIdx.x + 2];
        }
        if (s_max[threadIdx.x + 1] > s_max[threadIdx.x]) {
            s_max[threadIdx.x] = s_max[threadIdx.x + 1];
            s_idx[threadIdx.x] = s_idx[threadIdx.x + 1];
        }
    }

    if (threadIdx.x == 0) {
        indices[slice] = s_idx[0];
    }
}

torch::Tensor argmax_forward_cuda(const torch::Tensor& x, const int64_t dim) {
    TORCH_CHECK(x.scalar_type() == at::kFloat, "Only float32 is supported.");
    auto x_contig = x.contiguous();

    auto sizes = x_contig.sizes();
    int ndim = x_contig.dim();
    TORCH_CHECK(dim >= 0 && dim < ndim, "Invalid dim for argmax.");

    int outerSize = 1;
    for (int d = 0; d < dim; d++) {
        outerSize *= sizes[d];
    }
    int dimSize = sizes[dim];
    int innerSize = 1;
    for (int d = dim + 1; d < ndim; d++) {
        innerSize *= sizes[d];
    }

    std::vector<int64_t> out_sizes;
    for (int d = 0; d < ndim; d++) {
        if (d == dim) continue;
        out_sizes.push_back(sizes[d]);
    }
    auto options = torch::TensorOptions().device(x.device()).dtype(torch::kLong);
    auto indices = torch::empty(out_sizes, options);

    int slices = outerSize * innerSize;
    
    // Use 128 threads per block for better occupancy
    const int threads = 128;
    int blocks = slices;
    int sharedMemSize = threads * (sizeof(float) + sizeof(int));

    argmax_kernel_coop_tuned<<<blocks, threads, sharedMemSize>>>(
        x_contig.data_ptr<float>(),
        indices.data_ptr<int64_t>(),
        outerSize,
        dimSize,
        innerSize
    );

    return indices;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &argmax_forward_cuda, "ArgMax CUDA forward with tuned cooperative reduction");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 0.910 inst/cycle 0.000 5
Executed Ipc Elapsed 0.718 inst/cycle 0.000 5
Issue Slots Busy 22.796 % 0.136 5
Issued Ipc Active 0.910 inst/cycle 0.000 5
SM Busy 22.796 % 0.136 5
Memory Throughput 371106283553.588 byte/second 10954306873758126080.000 5
Mem Busy 62.708 % 0.631 5
Max Bandwidth 28.740 % 0.092 5
L1/TEX Hit Rate 1.452 % 0.044 5
L2 Hit Rate 84.886 % 0.253 5
Mem Pipes Busy 13.422 % 0.017 5
Warp Cycles Per Issued Instruction 56.444 cycle 0.930 5
Warp Cycles Per Executed Instruction 56.594 cycle 0.930 5
Avg. Active Threads Per Warp 31.550 0.000 5
Avg. Not Predicated Off Threads Per Warp 26.970 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 21.000 block 0.000 5
Block Limit Shared Mem 32.000 block 0.000 5
Block Limit Warps 16.000 block 0.000 5
Theoretical Active Warps per SM 64.000 warp 0.000 5
Theoretical Occupancy 100.000 % 0.000 5
Achieved Occupancy 81.184 % 0.223 5
Achieved Active Warps Per SM 51.960 warp 0.092 5
Analysis Rules
Rule Description
WRN HighPipeUtilization All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details.
INF CPIStall Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason.
WRN Occupancy This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (80.9%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy.
Operation / Metric Value Unit
aten::to
CPU Time 513688.09 μs
Device Time 384.19 μs
Self CPU Time 41.88 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_to_copy
CPU Time 513646.21 μs
Device Time 384.19 μs
Self CPU Time 110.25 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::empty_strided
CPU Time 512890.21 μs
Device Time 0.00 μs
Self CPU Time 108.27 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaDeviceGetStreamPriorityRange
CPU Time 503631.42 μs
Device Time 0.00 μs
Self CPU Time 503631.42 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 544596.43 μs
Device Time 20950.53 μs
Self CPU Time 544596.43 μs
Self Device Time 20950.53 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
argmax_kernel_coop_tuned(float const*, long*, int, int, int)
CPU Time 0.00 μs
Device Time 81817.38 μs
Self CPU Time 0.00 μs
Self Device Time 81817.38 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaEventRecord
CPU Time 19299.16 μs
Device Time 41709.10 μs
Self CPU Time 19299.16 μs
Self Device Time 41709.10 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 66976.44 μs
Device Time 622694.70 μs
Self CPU Time 13456.74 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::fill_
CPU Time 53521.13 μs
Device Time 622694.70 μs
Self CPU Time 16040.01 μs
Self Device Time 622694.70 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>)
CPU Time 0.00 μs
Device Time 622694.70 μs
Self CPU Time 0.00 μs
Self Device Time 622694.70 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45284 warnings generated when compiling for host.
Suppressed 45322 warnings (45275 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_1/task_51/b3_s2_argmax_coop_red_tuned/base/base.cu:8:5 bugprone-easily-swappable-parameters
8 | const int outerSize,
| ^~~~~~~~~~~~~~~~~~~~
9 | const int dimSize,
| ~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_1/task_51/b3_s2_argmax_coop_red_tuned/base/base.cu:8:15: note: the first parameter in the range is 'outerSize'
8 | const int outerSize,
| ^~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_1/task_51/b3_s2_argmax_coop_red_tuned/base/base.cu:9:15: note: the last parameter in the range is 'dimSize'
9 | const int dimSize,
| ^~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_1/task_51/b3_s2_argmax_coop_red_tuned/base/base.cu:12:17: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
12 | int slice = blockIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_1/task_51/b3_s2_argmax_coop_red_tuned/base/base.cu:24:18: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
24 | for (int d = threadIdx.x; d < dimSize; d += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_1/task_51/b3_s2_argmax_coop_red_tuned/base/base.cu:24:49: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
24 | for (int d = threadIdx.x; d < dimSize; d += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_1/task_51/b3_s2_argmax_coop_red_tuned/base/base.cu:88:16: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
88 | int ndim = x_contig.dim();
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_1/task_51/b3_s2_argmax_coop_red_tuned/base/base.cu:93:22: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
93 | outerSize *= sizes[d];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_1/task_51/b3_s2_argmax_coop_red_tuned/base/base.cu:95:19: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
95 | int dimSize = sizes[dim];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_1/task_51/b3_s2_argmax_coop_red_tuned/base/base.cu:97:18: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
97 | for (int d = dim + 1; d < ndim; d++) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_1/task_51/b3_s2_argmax_coop_red_tuned/base/base.cu:98:22: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
98 | innerSize *= sizes[d];
| ^