← Back to Leaderboard

The AI CUDA Engineer 👷

51_Argmax_over_a_dimensionargmax_coop_red_sync_opt_base

Level 1 • Task 51
import torch
import torch.nn as nn
import torch.functional as F


def module_fn(x: torch.Tensor, dim: int) -> torch.Tensor:
    """
    Applies argmax over the specified dimension to the input tensor.

    Args:
        x (torch.Tensor): Input tensor
        dim (int): Dimension to perform argmax over

    Returns:
        torch.Tensor: Output tensor with argmax applied over specified dimension
    """
    return torch.argmax(x, dim)


class Model(nn.Module):
    """
    Simple model that performs Argmax over a specified dimension.
    """

    def __init__(self, dim: int):
        """
        Initializes the model with the dimension to perform argmax.

        Args:
            dim (int): The dimension to perform argmax over.
        """
        super(Model, self).__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor, fn=module_fn) -> torch.Tensor:
        """
        Applies argmax over the specified dimension to the input tensor.

        Args:
            x (torch.Tensor): Input tensor
            fn: Function to apply (defaults to module_fn)

        Returns:
            torch.Tensor: Output tensor with argmax applied, with the specified dimension removed.
        """
        return fn(x, self.dim)


batch_size = 16
dim1 = 256
dim2 = 256


def get_inputs():
    x = torch.randn(batch_size, dim1, dim2)
    return [x]


def get_init_inputs():
    return [1]
import torch
import torch.nn as nn


class Model(nn.Module):
    """
    Simple model that performs Argmax over a specified dimension.
    """

    def __init__(self, dim: int):
        """
        Initializes the model with the dimension to perform argmax.

        Args:
            dim (int): The dimension to perform argmax over.
        """
        super(Model, self).__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Applies argmax over the specified dimension to the input tensor.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor with argmax applied, with the specified dimension removed.
        """
        return torch.argmax(x, dim=self.dim)


batch_size = 16
dim1 = 256
dim2 = 256


def get_inputs():
    x = torch.randn(batch_size, dim1, dim2)
    return [x]


def get_init_inputs():
    return [1]

Kernel Information

Related Kernels (Level 1, Task 51 • 51_Argmax_over_a_dimension)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 warp_argmax_nosm_edit_1 0.01 1.85 2.54
🥈 warp_level_argmax_base 0.01 1.72 2.36
🥈 warp_level_argmax_edit_1 0.01 1.72 2.36
🥈 efficient_argmax_base 0.01 1.72 2.36
🥈 stride_loop_argmax_stride_base 0.01 1.72 2.36
🥈 stride_loop_argmax_final_edit_1 0.01 1.72 2.36
🥈 argmax_coop_red_tuned_base 0.01 1.72 2.36
🥈 argmax_coop_red_tuned_edit_1 0.01 1.72 2.36
🥈 divergence_free_argmax_base 0.01 1.72 2.36
🥈 optimized_argmax_combination_base 0.01 1.72 2.36
🥈 argmax_ldg_128_opt_base 0.01 1.72 2.36
🥈 argmax_ldg_128_opt_edit_1 0.01 1.72 2.36
🥈 argmax_coop_red_sync_opt_base 0.01 1.72 2.36
🥈 argmax_aligned_mem_base_edit_1 0.01 1.72 2.36
🥈 stride_loop_argmax_final_base 0.01 1.72 2.36
🥈 warp_argmax_nosm_base 0.01 1.72 2.36
17 stride_loop_argmax_base 0.01 1.61 2.20
17 loop_unrolled_argmax_edit_1 0.01 1.61 2.20
17 stride_loop_argmax_edit_1 0.01 1.61 2.20
17 optimized_argmax_kernel_base 0.01 1.61 2.20
#include <torch/extension.h>
#include <vector>
#include <float.h>

__global__ void argmax_kernel_coop_sync_opt(
    const float* __restrict__ x,
    int64_t* __restrict__ indices,
    const int outerSize,
    const int dimSize,
    const int innerSize) {

    int slice = blockIdx.x;
    if (slice >= outerSize * innerSize) return;

    int outer_idx = slice / innerSize;
    int inner_idx = slice % innerSize;
    int base_offset = outer_idx * (dimSize * innerSize) + inner_idx;

    float local_max = -FLT_MAX;
    int local_argmax = 0;

    for (int d = threadIdx.x; d < dimSize; d += blockDim.x) {
        float curr_val = x[base_offset + d * innerSize];
        if (curr_val > local_max) {
            local_max = curr_val;
            local_argmax = d;
        }
    }

    extern __shared__ char shared_mem[];
    float* s_max = reinterpret_cast<float*>(shared_mem);
    int* s_idx = reinterpret_cast<int*>(s_max + blockDim.x);

    s_max[threadIdx.x] = local_max;
    s_idx[threadIdx.x] = local_argmax;
    __syncthreads();

    for (unsigned int s = blockDim.x / 2; s > 32; s >>= 1) {
        if (threadIdx.x < s) {
            if (s_max[threadIdx.x + s] > s_max[threadIdx.x]) {
                s_max[threadIdx.x] = s_max[threadIdx.x + s];
                s_idx[threadIdx.x] = s_idx[threadIdx.x + s];
            }
        }
        __syncthreads();
    }

    if (threadIdx.x < 32) {
        if (s_max[threadIdx.x + 32] > s_max[threadIdx.x]) {
            s_max[threadIdx.x] = s_max[threadIdx.x + 32];
            s_idx[threadIdx.x] = s_idx[threadIdx.x + 32];
        }
        if (s_max[threadIdx.x + 16] > s_max[threadIdx.x]) {
            s_max[threadIdx.x] = s_max[threadIdx.x + 16];
            s_idx[threadIdx.x] = s_idx[threadIdx.x + 16];
        }
        if (s_max[threadIdx.x + 8] > s_max[threadIdx.x]) {
            s_max[threadIdx.x] = s_max[threadIdx.x + 8];
            s_idx[threadIdx.x] = s_idx[threadIdx.x + 8];
        }
        if (s_max[threadIdx.x + 4] > s_max[threadIdx.x]) {
            s_max[threadIdx.x] = s_max[threadIdx.x + 4];
            s_idx[threadIdx.x] = s_idx[threadIdx.x + 4];
        }
        if (s_max[threadIdx.x + 2] > s_max[threadIdx.x]) {
            s_max[threadIdx.x] = s_max[threadIdx.x + 2];
            s_idx[threadIdx.x] = s_idx[threadIdx.x + 2];
        }
        if (s_max[threadIdx.x + 1] > s_max[threadIdx.x]) {
            s_max[threadIdx.x] = s_max[threadIdx.x + 1];
            s_idx[threadIdx.x] = s_idx[threadIdx.x + 1];
        }
    }

    if (threadIdx.x == 0) {
        indices[slice] = s_idx[0];
    }
}

torch::Tensor argmax_forward_cuda(const torch::Tensor& x, const int64_t dim) {
    TORCH_CHECK(x.scalar_type() == at::kFloat, "Only float32 is supported.");
    auto x_contig = x.contiguous();

    auto sizes = x_contig.sizes();
    int ndim = x_contig.dim();
    TORCH_CHECK(dim >= 0 && dim < ndim, "Invalid dim for argmax.");

    int outerSize = 1;
    for (int d = 0; d < dim; d++) {
        outerSize *= sizes[d];
    }
    int dimSize = sizes[dim];
    int innerSize = 1;
    for (int d = dim + 1; d < ndim; d++) {
        innerSize *= sizes[d];
    }

    std::vector<int64_t> out_sizes;
    for (int d = 0; d < ndim; d++) {
        if (d == dim) continue;
        out_sizes.push_back(sizes[d]);
    }
    auto options = torch::TensorOptions().device(x.device()).dtype(torch::kLong);
    auto indices = torch::empty(out_sizes, options);

    int slices = outerSize * innerSize;
    const int threads = 128;
    int blocks = slices;
    int sharedMemSize = threads * (sizeof(float) + sizeof(int));

    argmax_kernel_coop_sync_opt<<<blocks, threads, sharedMemSize>>>(
        x_contig.data_ptr<float>(),
        indices.data_ptr<int64_t>(),
        outerSize,
        dimSize,
        innerSize
    );

    return indices;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &argmax_forward_cuda, "ArgMax CUDA forward with optimized synchronization");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 0.964 inst/cycle 0.001 5
Executed Ipc Elapsed 0.762 inst/cycle 0.000 5
Issue Slots Busy 24.200 % 0.373 5
Issued Ipc Active 0.970 inst/cycle 0.001 5
SM Busy 24.200 % 0.373 5
Memory Throughput 369015932520.718 byte/second 6300160817878350848.000 5
Mem Busy 62.996 % 0.300 5
Max Bandwidth 28.418 % 0.081 5
L1/TEX Hit Rate 1.620 % 0.118 5
L2 Hit Rate 84.254 % 0.849 5
Mem Pipes Busy 13.324 % 0.008 5
Warp Cycles Per Issued Instruction 53.014 cycle 1.040 5
Warp Cycles Per Executed Instruction 53.146 cycle 1.046 5
Avg. Active Threads Per Warp 31.580 0.000 5
Avg. Not Predicated Off Threads Per Warp 26.650 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 21.000 block 0.000 5
Block Limit Shared Mem 32.000 block 0.000 5
Block Limit Warps 16.000 block 0.000 5
Theoretical Active Warps per SM 64.000 warp 0.000 5
Theoretical Occupancy 100.000 % 0.000 5
Achieved Occupancy 81.412 % 0.615 5
Achieved Active Warps Per SM 52.104 warp 0.253 5
Analysis Rules
Rule Description
WRN HighPipeUtilization All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details.
INF CPIStall Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason.
WRN Occupancy This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (81.4%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy.
Operation / Metric Value Unit
aten::to
CPU Time 230124.62 μs
Device Time 386.43 μs
Self CPU Time 39.57 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_to_copy
CPU Time 230085.05 μs
Device Time 386.43 μs
Self CPU Time 106.16 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::empty_strided
CPU Time 229354.66 μs
Device Time 0.00 μs
Self CPU Time 98.27 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaDeviceGetStreamPriorityRange
CPU Time 228957.79 μs
Device Time 0.00 μs
Self CPU Time 228957.79 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 526059.42 μs
Device Time 20396.21 μs
Self CPU Time 526059.42 μs
Self Device Time 20396.21 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
argmax_kernel_coop_sync_opt(float const*, long*, int, int, int)
CPU Time 0.00 μs
Device Time 86967.78 μs
Self CPU Time 0.00 μs
Self Device Time 86967.78 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaEventRecord
CPU Time 18979.15 μs
Device Time 40645.47 μs
Self CPU Time 18979.15 μs
Self Device Time 40645.47 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 67172.05 μs
Device Time 607133.34 μs
Self CPU Time 13399.54 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::fill_
CPU Time 53774.28 μs
Device Time 607133.34 μs
Self CPU Time 15789.32 μs
Self Device Time 607133.34 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>)
CPU Time 0.00 μs
Device Time 607133.34 μs
Self CPU Time 0.00 μs
Self Device Time 607133.34 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45284 warnings generated when compiling for host.
Suppressed 45322 warnings (45275 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_1/task_51/b5_s3_argmax_coop_red_sync_opt/base/base.cu:8:5 bugprone-easily-swappable-parameters
8 | const int outerSize,
| ^~~~~~~~~~~~~~~~~~~~
9 | const int dimSize,
| ~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_1/task_51/b5_s3_argmax_coop_red_sync_opt/base/base.cu:8:15: note: the first parameter in the range is 'outerSize'
8 | const int outerSize,
| ^~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_1/task_51/b5_s3_argmax_coop_red_sync_opt/base/base.cu:9:15: note: the last parameter in the range is 'dimSize'
9 | const int dimSize,
| ^~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_1/task_51/b5_s3_argmax_coop_red_sync_opt/base/base.cu:12:17: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
12 | int slice = blockIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_1/task_51/b5_s3_argmax_coop_red_sync_opt/base/base.cu:22:18: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
22 | for (int d = threadIdx.x; d < dimSize; d += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_1/task_51/b5_s3_argmax_coop_red_sync_opt/base/base.cu:22:49: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
22 | for (int d = threadIdx.x; d < dimSize; d += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_1/task_51/b5_s3_argmax_coop_red_sync_opt/base/base.cu:85:16: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
85 | int ndim = x_contig.dim();
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_1/task_51/b5_s3_argmax_coop_red_sync_opt/base/base.cu:90:22: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
90 | outerSize *= sizes[d];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_1/task_51/b5_s3_argmax_coop_red_sync_opt/base/base.cu:92:19: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
92 | int dimSize = sizes[dim];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_1/task_51/b5_s3_argmax_coop_red_sync_opt/base/base.cu:94:18: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
94 | for (int d = dim + 1; d < ndim; d++) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_1/task_51/b5_s3_argmax_coop_red_sync_opt/base/base.cu:95:22: warning: narrowing conversion from 'long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
95 | innerSize *= sizes[d];
| ^