← Back to Leaderboard

The AI CUDA Engineer 👷

81_Gemm_Swish_Divide_Clamp_Tanh_Clampstride_loop_opt_base

Level 2 • Task 81

Kernel Information

Related Kernels (Level 2, Task 81 • 81_Gemm_Swish_Divide_Clamp_Tanh_Clamp)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 gemm_2d_map_base 0.02 2.19 1.90
🥇 gemm_2d_map_improved_base 0.02 2.19 1.90
🥇 hybrid_aligned_2d_kernel_base 0.02 2.19 1.90
🥇 aligned_memory_access_base_base 0.02 2.19 1.90
🥇 optimized_stride_loop_base 0.02 2.19 1.90
🥇 stride_loop_optimization_base_base 0.02 2.19 1.90
🥇 unrolled_2d_map_base_base 0.02 2.19 1.90
🥇 stride_loop_opt_base 0.02 2.19 1.90
🥇 fused_activation_base 0.02 2.19 1.90
🥇 optimized_stride_loop_with_prefetch_base 0.02 2.19 1.90
🥇 optimized_fused_activation_base 0.02 2.19 1.90
🥇 optimized_thread_block_indexing_base 0.02 2.19 1.90
🥇 dim2_grid_activation_base 0.02 2.19 1.90
14 dynamic_block_size_base_base 0.02 2.10 1.82
14 strided_vectorized_base_base 0.02 2.10 1.82
14 even_load_base 0.02 2.10 1.82
14 81_gemm_swish_divide_clamp_tanh_clamp_optimized_blocks_base 0.02 2.10 1.82
14 stride_loop_correct_base 0.02 2.10 1.82
14 adaptive_vectorized_kernel_base 0.02 2.10 1.82
14 uniform_control_flow_base_base 0.02 2.10 1.82
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cmath>

// Macros for input checking
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
    CHECK_CUDA(x);     \
    CHECK_CONTIGUOUS(x)

// CUDA kernel using grid-stride loop to handle workloads larger than available threads
// and ensuring correct boundary handling
template <typename scalar_t>
__global__ void module_kernel_stride_improved(
    const scalar_t* __restrict__ x_in,
    scalar_t* __restrict__ x_out,
    const size_t size) {

    const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
    const size_t stride = blockDim.x * gridDim.x;
    const scalar_t one = static_cast<scalar_t>(1);
    const scalar_t half = static_cast<scalar_t>(0.5);

    // Grid-stride loop for processing multiple elements per thread
    for (size_t i = tid; i < size; i += stride) {
        scalar_t x = x_in[i];

        // Swish activation: x * sigmoid(x)
        scalar_t sigmoid_x = one / (one + exp(-x));
        x = x * sigmoid_x;

        // Divide by 2
        x = x * half;

        // Clamp between -1 and 1
        x = max(min(x, one), -one);

        // Tanh activation
        x = tanh(x);

        // Final clamp between -1 and 1
        x = max(min(x, one), -one);

        x_out[i] = x;
    }
}

// CUDA forward function: performs linear transformation using torch::addmm
// then applies the activation via the grid-stride kernel
torch::Tensor module_forward_cuda(
    torch::Tensor x,
    torch::Tensor weight,
    torch::Tensor bias) {

    // Perform GEMM: x_linear = bias + x * weight^T
    auto x_linear = torch::addmm(bias, x, weight.t());
    auto x_out = torch::empty_like(x_linear);

    size_t numel = x_linear.numel();
    const int threads = 256;
    const int blocks = (numel + threads - 1) / threads;

    AT_DISPATCH_FLOATING_TYPES_AND_HALF(x_linear.scalar_type(), "module_forward_cuda", ([&] {
        module_kernel_stride_improved<scalar_t><<<blocks, threads>>>(
            x_linear.data_ptr<scalar_t>(),
            x_out.data_ptr<scalar_t>(),
            numel);
    }));

    return x_out;
}

// C++ interface
torch::Tensor module_forward(
    torch::Tensor x,
    torch::Tensor weight,
    torch::Tensor bias) {
    CHECK_INPUT(x);
    CHECK_INPUT(weight);
    CHECK_INPUT(bias);
    return module_forward_cuda(x, weight, bias);
}

// PyBind11 module definition
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &module_forward, "Custom module forward function (CUDA) with grid-stride loop");
}