← Back to Leaderboard

The AI CUDA Engineer 👷

81_Gemm_Swish_Divide_Clamp_Tanh_Clampfused_activation_base

Level 2 • Task 81

Kernel Information

Related Kernels (Level 2, Task 81 • 81_Gemm_Swish_Divide_Clamp_Tanh_Clamp)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 gemm_2d_map_base 0.02 2.19 1.90
🥇 gemm_2d_map_improved_base 0.02 2.19 1.90
🥇 hybrid_aligned_2d_kernel_base 0.02 2.19 1.90
🥇 aligned_memory_access_base_base 0.02 2.19 1.90
🥇 optimized_stride_loop_base 0.02 2.19 1.90
🥇 stride_loop_optimization_base_base 0.02 2.19 1.90
🥇 unrolled_2d_map_base_base 0.02 2.19 1.90
🥇 stride_loop_opt_base 0.02 2.19 1.90
🥇 fused_activation_base 0.02 2.19 1.90
🥇 optimized_stride_loop_with_prefetch_base 0.02 2.19 1.90
🥇 optimized_fused_activation_base 0.02 2.19 1.90
🥇 optimized_thread_block_indexing_base 0.02 2.19 1.90
🥇 dim2_grid_activation_base 0.02 2.19 1.90
14 dynamic_block_size_base_base 0.02 2.10 1.82
14 strided_vectorized_base_base 0.02 2.10 1.82
14 even_load_base 0.02 2.10 1.82
14 81_gemm_swish_divide_clamp_tanh_clamp_optimized_blocks_base 0.02 2.10 1.82
14 stride_loop_correct_base 0.02 2.10 1.82
14 adaptive_vectorized_kernel_base 0.02 2.10 1.82
14 uniform_control_flow_base_base 0.02 2.10 1.82
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cmath>

// Macros for input checking
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
    CHECK_CUDA(x);     \
    CHECK_CONTIGUOUS(x)


// Fused activation kernel using a grid-stride loop for flexibility and efficient load balancing
template <typename scalar_t>
__global__ void fused_activation_kernel(
    const scalar_t* __restrict__ x_in,
    scalar_t* __restrict__ x_out,
    const size_t size) {

    size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
    const scalar_t one = static_cast<scalar_t>(1);
    const scalar_t two = static_cast<scalar_t>(2);
    
    // Grid-stride loop handles arbitrary tensor sizes
    for (; idx < size; idx += blockDim.x * gridDim.x) {
        scalar_t x = x_in[idx];

        // Swish activation: x * sigmoid(x)
        scalar_t sigmoid_x = one / (one + exp(-x));
        x = x * sigmoid_x;
        
        // Scale down value by 2
        x = x / two;
        
        // Clamp between -1 and 1 (first clamp for numerical stability)
        x = max(min(x, one), -one);

        // Tanh activation
        x = tanh(x);

        // Second clamping for extra safety (tanh already returns in [-1, 1] but to enforce strict bounds)
        x = max(min(x, one), -one);
        
        x_out[idx] = x;
    }
}

// CUDA forward function fusing linear transformation with activation
// Note: We still use torch::addmm for the linear part since cuBLAS GEMM is highly optimized
// and then fuse the activation in one efficient kernel.

torch::Tensor module_forward_cuda(
    torch::Tensor x,
    torch::Tensor weight,
    torch::Tensor bias) {

    // Execute linear transformation: x_linear = bias + x * weight^T
    auto x_linear = torch::addmm(bias, x, weight.t());
    auto x_out = torch::empty_like(x_linear);

    // Compute total number of elements
    size_t numel = x_linear.numel();

    // Use a 1D grid-stride loop kernel for the activation part
    // Choosing a block size of 256 threads to balance occupancy
    const int threads = 256;
    const int blocks = (numel + threads - 1) / threads;

    AT_DISPATCH_FLOATING_TYPES_AND_HALF(x_linear.scalar_type(), "module_forward_cuda", ([&] {
        fused_activation_kernel<scalar_t><<<blocks, threads>>>(
            x_linear.data_ptr<scalar_t>(),
            x_out.data_ptr<scalar_t>(),
            numel);
    }));

    return x_out;
}

// C++ interface that wraps the CUDA implementation
torch::Tensor module_forward(
    torch::Tensor x,
    torch::Tensor weight,
    torch::Tensor bias) {
    CHECK_INPUT(x);
    CHECK_INPUT(weight);
    CHECK_INPUT(bias);
    return module_forward_cuda(x, weight, bias);
}

// Binding the forward function to Python via PyBind11
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &module_forward, "Fused linear and activation forward function (CUDA)");
}