← Back to Leaderboard

The AI CUDA Engineer 👷

24_LogSoftmax24_logsoftmax_unroll_base

Level 1 • Task 24
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(x: torch.Tensor, dim: int) -> torch.Tensor:
    """
    Applies LogSoftmax activation to the input tensor.

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, dim)
        dim (int): Dimension along which to apply LogSoftmax

    Returns:
        torch.Tensor: Output tensor with LogSoftmax applied, same shape as input
    """
    return F.log_softmax(x, dim=dim)


class Model(nn.Module):
    """
    Simple model that performs a LogSoftmax activation.
    """

    def __init__(self, dim):
        super(Model, self).__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor, fn=module_fn) -> torch.Tensor:
        return fn(x, self.dim)


batch_size = 16
dim = 16384
sm_dim = 1


def get_inputs():
    x = torch.randn(batch_size, dim)
    return [x]


def get_init_inputs():
    return [sm_dim]
import torch
import torch.nn as nn


class Model(nn.Module):
    """
    Simple model that performs a LogSoftmax activation.
    """

    def __init__(self, dim: int = 1):
        super(Model, self).__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Applies LogSoftmax activation to the input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, dim).

        Returns:
            torch.Tensor: Output tensor with LogSoftmax applied, same shape as input.
        """
        return torch.log_softmax(x, dim=self.dim)


batch_size = 16
dim = 16384
sm_dim = 1


def get_inputs():
    x = torch.randn(batch_size, dim)
    return [x]


def get_init_inputs():
    return [sm_dim]

Kernel Information

Related Kernels (Level 1, Task 24 • 24_LogSoftmax)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 hybrid_logsoftmax_kernel_base 0.01 1.07 3.72
🥇 unroll_tuned_logsoftmax_base 0.01 1.07 3.72
🥇 shared_warp_logsoftmax_base_base 0.01 1.07 3.72
🥇 min_atomic_logsoftmax_base 0.01 1.07 3.72
🥇 combined_logsoftmax_base 0.01 1.07 3.72
🥇 efficient_logsoftmax_combined_kernel_base 0.01 1.07 3.72
🥇 optimized_128_ldg_logsoftmax_base 0.01 1.07 3.72
🥇 atomic_free_logsoftmax_base 0.01 1.07 3.72
🥇 strided_logsoftmax_base_base 0.01 1.07 3.72
🥇 optimized_reduction_logsoftmax_base 0.01 1.07 3.72
11 24_logsoftmax_vectorized_loads_edit_1 0.01 0.97 3.38
11 24_logsoftmax_unroll_edit_1 0.01 0.97 3.38
11 tuned_logsoftmax_base 0.01 0.97 3.38
11 grid2d_logsoftmax_base 0.01 0.97 3.38
11 24_logsoftmax_fast_edit_1 0.01 0.97 3.38
11 24_logsoftmax_fast_base 0.01 0.97 3.38
11 unroll_logsoftmax_base 0.01 0.97 3.38
11 24_logsoftmax_unroll_base 0.01 0.97 3.38
11 24_logsoftmax_with_stride_base 0.01 0.97 3.38
20 log_softmax_2d_blocking_base 0.01 0.89 3.10
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include <limits>
#include <cmath>

// This kernel uses manual loop unrolling with #pragma unroll for critical reduction loops.
// The unrolling decreases loop overhead and improves ILP, while maintaining full precision.

template <typename scalar_t>
__global__ void log_softmax_forward_kernel_unroll(
    const scalar_t* __restrict__ input,
    scalar_t* __restrict__ output,
    const int dim_size) {

    // Each block processes one row (batch element)
    int batch_idx = blockIdx.x;
    const scalar_t* input_row = input + batch_idx * dim_size;
    scalar_t* output_row = output + batch_idx * dim_size;

    const int tid = threadIdx.x;
    const int blockSize = blockDim.x;
    const int warpSize = 32;
    const unsigned int mask = 0xffffffff;

    // Step 1: Compute local maximum
    scalar_t local_max = -std::numeric_limits<scalar_t>::infinity();
    for (int i = tid; i < dim_size; i += blockSize) {
        scalar_t val = input_row[i];
        local_max = (val > local_max) ? val : local_max;
    }

    // Warp-level reduction for maximum using unrolled loop
    #pragma unroll
    for (int offset = warpSize/2; offset > 0; offset /= 2) {
        scalar_t temp = __shfl_down_sync(mask, local_max, offset);
        local_max = (temp > local_max) ? temp : local_max;
    }

    // Allocate shared memory for per-warp results
    extern __shared__ __align__(sizeof(scalar_t)) unsigned char smem[];
    scalar_t* shared_data = reinterpret_cast<scalar_t*>(smem);
    int warp_id = tid / warpSize;
    if ((tid % warpSize) == 0) {
        shared_data[warp_id] = local_max;
    }
    __syncthreads();

    // Final reduction over warps for maximum (performed by thread 0)
    if (tid == 0) {
        int num_warps = (blockSize + warpSize - 1) / warpSize;
        scalar_t global_max = shared_data[0];
        #pragma unroll
        for (int i = 1; i < num_warps; i++) {
            global_max = (shared_data[i] > global_max) ? shared_data[i] : global_max;
        }
        shared_data[0] = global_max;  // broadcast global maximum
    }
    __syncthreads();
    scalar_t global_max = shared_data[0];

    // Step 2: Compute the sum of exp(val - global_max)
    scalar_t local_sum = 0;
    for (int i = tid; i < dim_size; i += blockSize) {
        scalar_t exp_val = exp(input_row[i] - global_max);
        local_sum += exp_val;
        output_row[i] = exp_val;  // store intermediate result
    }

    // Warp-level reduction for sum, unrolling the loop
    #pragma unroll
    for (int offset = warpSize/2; offset > 0; offset /= 2) {
        local_sum += __shfl_down_sync(mask, local_sum, offset);
    }
    if ((tid % warpSize) == 0) {
        shared_data[warp_id] = local_sum;
    }
    __syncthreads();

    scalar_t global_sum;
    if (tid == 0) {
        int num_warps = (blockSize + warpSize - 1) / warpSize;
        global_sum = shared_data[0];
        #pragma unroll
        for (int i = 1; i < num_warps; i++) {
            global_sum += shared_data[i];
        }
        shared_data[0] = global_sum;  // broadcast global sum
    }
    __syncthreads();
    global_sum = shared_data[0];

    scalar_t log_sum = log(global_sum);

    // Step 3: Compute final log softmax output
    for (int i = tid; i < dim_size; i += blockSize) {
        output_row[i] = (input_row[i] - global_max) - log_sum;
    }
}

// Host function launching the kernel

torch::Tensor log_softmax_cuda_forward(torch::Tensor input, int64_t dim) {
    TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor");
    TORCH_CHECK(
        input.scalar_type() == torch::kFloat32 || input.scalar_type() == torch::kFloat64,
        "input must be float32 or float64");

    int64_t ndim = input.dim();
    TORCH_CHECK(dim >= -ndim && dim < ndim, "dim out of range");
    dim = (dim >= 0) ? dim : dim + ndim;

    // Permute input to bring 'dim' to the last dimension
    std::vector<int64_t> permute_dims;
    for (int64_t i = 0; i < ndim; ++i) {
        if (i != dim) {
            permute_dims.push_back(i);
        }
    }
    permute_dims.push_back(dim);
    input = input.permute(permute_dims).contiguous();

    int64_t batch_size = input.numel() / input.size(-1);
    int64_t dim_size = input.size(-1);
    auto output = torch::empty_like(input);

    // Choose number of threads: next power of two of dim_size, capped at 1024
    int threads = 1;
    while (threads < dim_size) threads <<= 1;
    if (threads > 1024) threads = 1024;

    // Compute required shared memory: one scalar per warp
    int warpSize = 32;
    int num_warps = (threads + warpSize - 1) / warpSize;
    size_t shared_mem_size = num_warps * sizeof(float); // temporary, overridden per type below

    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "log_softmax_forward_cuda_unroll", ([&] {
        shared_mem_size = num_warps * sizeof(scalar_t);
        log_softmax_forward_kernel_unroll<scalar_t><<<batch_size, threads, shared_mem_size>>>(
            input.data_ptr<scalar_t>(),
            output.data_ptr<scalar_t>(),
            dim_size);
    }));

    // Inverse permutation to restore original shape
    std::vector<int64_t> inverse_permute_dims(ndim);
    for (size_t i = 0; i < permute_dims.size(); ++i) {
        inverse_permute_dims[permute_dims[i]] = i;
    }
    output = output.permute(inverse_permute_dims);
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &log_softmax_cuda_forward, "LogSoftmax forward with loop unrolling (CUDA)");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 1.512 inst/cycle 0.000 5
Executed Ipc Elapsed 0.140 inst/cycle 0.000 5
Issue Slots Busy 37.936 % 0.291 5
Issued Ipc Active 1.520 inst/cycle 0.000 5
SM Busy 37.936 % 0.291 5
Memory Throughput 110006274950.140 byte/second 1550931883055260160.000 5
Mem Busy 6.134 % 0.004 5
Max Bandwidth 7.370 % 0.006 5
L1/TEX Hit Rate 60.000 % 0.000 5
L2 Hit Rate 76.170 % 0.021 5
Mem Pipes Busy 2.478 % 0.001 5
Warp Cycles Per Issued Instruction 20.590 cycle 0.009 5
Warp Cycles Per Executed Instruction 20.662 cycle 0.009 5
Avg. Active Threads Per Warp 31.500 0.000 5
Avg. Not Predicated Off Threads Per Warp 30.190 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 2.000 block 0.000 5
Block Limit Shared Mem 7.000 block 0.000 5
Block Limit Warps 2.000 block 0.000 5
Theoretical Active Warps per SM 64.000 warp 0.000 5
Theoretical Occupancy 100.000 % 0.000 5
Achieved Occupancy 49.020 % 0.000 5
Achieved Active Warps Per SM 31.372 warp 0.000 5
Analysis Rules
Rule Description
INF HighPipeUtilization ALU is the highest-utilized pipeline (26.8%) based on active cycles, taking into account the rates of its different instructions. It executes integer and logic operations. It is well-utilized, but should not be a bottleneck.
WRN Occupancy This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (49.0%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy.
Operation / Metric Value Unit
aten::to
CPU Time 378716.88 μs
Device Time 40.19 μs
Self CPU Time 50.38 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_to_copy
CPU Time 378666.50 μs
Device Time 40.19 μs
Self CPU Time 105.96 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::empty_strided
CPU Time 399628.48 μs
Device Time 0.00 μs
Self CPU Time 21421.04 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaDeviceGetStreamPriorityRange
CPU Time 378028.67 μs
Device Time 0.00 μs
Self CPU Time 378028.67 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 491345.44 μs
Device Time 22751.92 μs
Self CPU Time 491345.44 μs
Self Device Time 22751.92 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void log_softmax_forward_kernel_unroll<float>(float const*, float*, int)
CPU Time 0.00 μs
Device Time 73976.10 μs
Self CPU Time 0.00 μs
Self Device Time 73976.10 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaEventRecord
CPU Time 22910.56 μs
Device Time 44721.47 μs
Self CPU Time 22910.56 μs
Self Device Time 44721.47 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 66188.23 μs
Device Time 655404.68 μs
Self CPU Time 14024.99 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::fill_
CPU Time 52164.21 μs
Device Time 655404.68 μs
Self CPU Time 16922.14 μs
Self Device Time 655404.68 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>)
CPU Time 0.00 μs
Device Time 655404.68 μs
Self CPU Time 0.00 μs
Self Device Time 655404.68 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45282 warnings generated when compiling for host.
Suppressed 45322 warnings (45275 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_1/task_24/b5_s2_24_logsoftmax_unroll/base/base.cu:18:21 bugprone-narrowing-conversions
18 | int batch_idx = blockIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_1/task_24/b5_s2_24_logsoftmax_unroll/base/base.cu:22:21: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
22 | const int tid = threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_1/task_24/b5_s2_24_logsoftmax_unroll/base/base.cu:23:27: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
23 | const int blockSize = blockDim.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_1/task_24/b5_s2_24_logsoftmax_unroll/base/base.cu:138:5: warning: inside a lambda, '__func__' expands to the name of the function call operator; consider capturing the name of the enclosing function explicitly [bugprone-lambda-function-name]
138 | AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "log_softmax_forward_cuda_unroll", ([&] {
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:237:34: note: expanded from macro 'AT_DISPATCH_FLOATING_TYPES'
237 | AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:233:3: note: expanded from macro 'AT_DISPATCH_CASE_FLOATING_TYPES'
233 | AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:74:3: note: expanded from macro 'AT_DISPATCH_CASE'
74 | AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
| ^
note: (skipping 1 expansions in backtrace; use -fmacro-backtrace-limit=0 to see all)
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:58:7: note: expanded from macro 'AT_PRIVATE_CHECK_SELECTIVE_BUILD'
58 | AT_ERROR( \
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/c10/util/Exception.h:711:32: note: expanded from macro 'AT_ERROR'
711 | C10_EXPAND_MSVC_WORKAROUND(TORCH_CHECK(false, ::c10::str(__VA_ARGS__))); \
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/c10/util/Exception.h:536:9: note: expanded from macro 'TORCH_CHECK'
536 | __func__, \
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_1/task_24/b5_s2_24_logsoftmax_unroll/base/base.cu:149:49: warning: narrowing conversion from 'size_t' (aka 'unsigned long') to signed type 'value_type' (aka 'long') is implementation-defined [bugprone-narrowing-conversions]
149 | inverse_permute_dims[permute_dims[i]] = i;
| ^