← Back to Leaderboard

The AI CUDA Engineer 👷

81_Gemm_Swish_Divide_Clamp_Tanh_Clampadaptive_vectorized_kernel_base

Level 2 • Task 81

Kernel Information

Related Kernels (Level 2, Task 81 • 81_Gemm_Swish_Divide_Clamp_Tanh_Clamp)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 gemm_2d_map_base 0.02 2.19 1.90
🥇 gemm_2d_map_improved_base 0.02 2.19 1.90
🥇 hybrid_aligned_2d_kernel_base 0.02 2.19 1.90
🥇 aligned_memory_access_base_base 0.02 2.19 1.90
🥇 optimized_stride_loop_base 0.02 2.19 1.90
🥇 stride_loop_optimization_base_base 0.02 2.19 1.90
🥇 unrolled_2d_map_base_base 0.02 2.19 1.90
🥇 stride_loop_opt_base 0.02 2.19 1.90
🥇 fused_activation_base 0.02 2.19 1.90
🥇 optimized_stride_loop_with_prefetch_base 0.02 2.19 1.90
🥇 optimized_fused_activation_base 0.02 2.19 1.90
🥇 optimized_thread_block_indexing_base 0.02 2.19 1.90
🥇 dim2_grid_activation_base 0.02 2.19 1.90
14 dynamic_block_size_base_base 0.02 2.10 1.82
14 strided_vectorized_base_base 0.02 2.10 1.82
14 even_load_base 0.02 2.10 1.82
14 81_gemm_swish_divide_clamp_tanh_clamp_optimized_blocks_base 0.02 2.10 1.82
14 stride_loop_correct_base 0.02 2.10 1.82
14 adaptive_vectorized_kernel_base 0.02 2.10 1.82
14 uniform_control_flow_base_base 0.02 2.10 1.82
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>

#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
    CHECK_CUDA(x);     \
    CHECK_CONTIGUOUS(x)

template <typename scalar_t>
__device__ __forceinline__ scalar_t activation_pipeline(scalar_t x) {
    // Swish + divide by 2 combined
    const scalar_t sigmoid_x = static_cast<scalar_t>(1) / (static_cast<scalar_t>(1) + exp(-x));
    x = x * sigmoid_x * static_cast<scalar_t>(0.5);
    
    // Combined clamp and tanh
    x = max(min(x, static_cast<scalar_t>(1)), static_cast<scalar_t>(-1));
    x = tanh(x);
    return max(min(x, static_cast<scalar_t>(1)), static_cast<scalar_t>(-1));
}

template <typename scalar_t>
__global__ void module_kernel_adaptive(
    const scalar_t* __restrict__ x_in,
    scalar_t* __restrict__ x_out,
    const int height,
    const int width,
    const bool use_vectorized) {
    
    if (use_vectorized) {
        const int tid = blockIdx.x * blockDim.x + threadIdx.x;
        const int stride = blockDim.x * gridDim.x;
        const int total_size = height * width;
        const int vec_size = 4;
        const size_t aligned_size = total_size & ~(vec_size - 1);
        
        for (size_t i = tid * vec_size; i < aligned_size; i += stride * vec_size) {
            float4 in_vec;
            scalar_t* in_ptr = (scalar_t*)&in_vec;
            
            #pragma unroll
            for (int j = 0; j < vec_size; j++) {
                in_ptr[j] = __ldg(&x_in[i + j]);
            }
            
            #pragma unroll
            for (int j = 0; j < vec_size; j++) {
                in_ptr[j] = activation_pipeline(in_ptr[j]);
            }
            
            *reinterpret_cast<float4*>(&x_out[i]) = in_vec;
        }
        
        for (size_t i = tid + aligned_size; i < total_size; i += stride) {
            x_out[i] = activation_pipeline(__ldg(&x_in[i]));
        }
    } else {
        const int row = blockIdx.y * blockDim.y + threadIdx.y;
        const int col = blockIdx.x * blockDim.x + threadIdx.x;
        
        if (row < height && col < width) {
            const int index = row * width + col;
            x_out[index] = activation_pipeline(__ldg(&x_in[index]));
        }
    }
}

torch::Tensor module_forward_cuda(
    torch::Tensor x,
    torch::Tensor weight,
    torch::Tensor bias) {
    
    auto x_linear = torch::addmm(bias, x, weight.t());
    auto x_out = torch::empty_like(x_linear);
    
    const int height = x_linear.size(0);
    const int width = x_linear.size(1);
    const int total_elements = height * width;
    
    const bool use_vectorized = (total_elements >= 16384) && (width % 4 == 0);
    
    if (use_vectorized) {
        const int threads = 256;
        const int blocks = (total_elements + threads * 4 - 1) / (threads * 4);
        dim3 grid(blocks);
        dim3 block(threads);
        
        AT_DISPATCH_FLOATING_TYPES_AND_HALF(x_linear.scalar_type(), "module_forward_cuda", ([&] {
            module_kernel_adaptive<scalar_t><<<grid, block>>>(
                x_linear.data_ptr<scalar_t>(),
                x_out.data_ptr<scalar_t>(),
                height,
                width,
                true);
        }));
    } else {
        dim3 block(16, 16);
        dim3 grid((width + block.x - 1) / block.x, (height + block.y - 1) / block.y);
        
        AT_DISPATCH_FLOATING_TYPES_AND_HALF(x_linear.scalar_type(), "module_forward_cuda", ([&] {
            module_kernel_adaptive<scalar_t><<<grid, block>>>(
                x_linear.data_ptr<scalar_t>(),
                x_out.data_ptr<scalar_t>(),
                height,
                width,
                false);
        }));
    }
    
    return x_out;
}

torch::Tensor module_forward(
    torch::Tensor x,
    torch::Tensor weight,
    torch::Tensor bias) {
    CHECK_INPUT(x);
    CHECK_INPUT(weight);
    CHECK_INPUT(bias);
    return module_forward_cuda(x, weight, bias);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &module_forward, "Custom module forward function (CUDA)");
}