← Back to Leaderboard

The AI CUDA Engineer 👷

51_Gemm_Subtract_GlobalAvgPool_LogSumExp_GELU_ResidualAddaligned_memory_access_base_base

Level 2 • Task 51

Kernel Information

Related Kernels (Level 2, Task 51 • 51_Gemm_Subtract_GlobalAvgPool_LogSumExp_GELU_ResidualAdd)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 fused_forward_base 0.05 1.62 0.92
🥇 fused_forward_edit_1 0.05 1.62 0.92
🥉 fused_forward_coalesced_base 0.05 1.58 0.90
4 fused_forward_coalesced_edit_1 0.05 1.55 0.89
5 optimized_fused_kernel_base 0.06 1.32 0.76
6 fused_pipeline_base 0.06 1.28 0.73
6 threadblock_mapping_opt_base 0.06 1.28 0.73
8 atomic_optimized_pipeline_base 0.06 1.26 0.72
8 efficient_thread_block_mapping_base 0.06 1.26 0.72
8 aligned_memory_access_base_base 0.06 1.26 0.72
8 fused_pool_gelu_atomic_minimal_base 0.06 1.26 0.72
8 fused_pool_gelu_warp_edit_base 0.06 1.26 0.72
8 aligned_memory_access_base_base 0.06 1.26 0.72
14 constant_memory_optimization_base 0.07 1.24 0.71
14 51_gemm_subtract_unroll_avgpool_logsumexp_gelu_residualadd_edit_1 0.07 1.24 0.71
14 uniform_control_flow_base_base_base 0.07 1.24 0.71
17 modular_device_functions_optimized_base 0.07 1.22 0.70
17 modular_device_functions_base_base 0.07 1.22 0.70
19 experiment_block_sizes_base 0.07 1.19 0.68
19 tiled_gemm_shared_edit_2_base 0.07 1.19 0.68
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cmath>

#define MAX_OUT_FEATURES 4096
#define TILE_DIM 16
#define BLOCK_ROWS 16

// Constant memory for bias and subtract vectors
__constant__ float c_bias[MAX_OUT_FEATURES];
__constant__ float c_subtract[MAX_OUT_FEATURES];

//---------------------------------------------------------------------------
// Optimized GEMM kernel with fully coalesced memory access and tiling
// Computes: out[r, c] = dot(x[r, :], weight[c, :]) + c_bias[c] - c_subtract[c]
// x: [batch_size x in_features]
// weight: [out_features x in_features]
// out: [batch_size x out_features]
//---------------------------------------------------------------------------
__global__ void coalesced_gemm_subtract_kernel(
    const float* __restrict__ x,
    const float* __restrict__ weight,
    float* __restrict__ out,
    int batch_size,
    int in_features,
    int out_features
) {
    __shared__ float tile_x[TILE_DIM][TILE_DIM+1];
    __shared__ float tile_w[TILE_DIM][TILE_DIM+1];

    int bx = blockIdx.x;
    int by = blockIdx.y;
    int tx = threadIdx.x;
    int ty = threadIdx.y;

    int row = by * TILE_DIM + ty;
    int col = bx * TILE_DIM + tx;

    float sum = 0.0f;

    // Loop over tiles of in_features dimension
    for (int t = 0; t < (in_features + TILE_DIM - 1) / TILE_DIM; t++) {
        // Ensure coalesced access for x and weights using shared memory
        if (row < batch_size && (t * TILE_DIM + tx) < in_features)
            tile_x[ty][tx] = x[row * in_features + t * TILE_DIM + tx];
        else
            tile_x[ty][tx] = 0.0f;

        if (col < out_features && (t * TILE_DIM + ty) < in_features)
            tile_w[ty][tx] = weight[col * in_features + t * TILE_DIM + ty];
        else
            tile_w[ty][tx] = 0.0f;

        __syncthreads();

        #pragma unroll
        for (int k = 0; k < TILE_DIM; k++) {
            sum += tile_x[ty][k] * tile_w[k][tx];
        }

        __syncthreads();
    }

    if (row < batch_size && col < out_features) {
        out[row * out_features + col] = sum + c_bias[col] - c_subtract[col];
    }
}

//---------------------------------------------------------------------------
// Fused kernel for average pooling, GELU activation, and residual addition
// Combines multiple operations into a single kernel for reduced memory traffic
// gemm_out: [batch_size x out_features]
// original_x: [batch_size x in_features]
// out: [batch_size x in_features]
//---------------------------------------------------------------------------
__global__ void fused_pool_gelu_residual_kernel(
    const float* __restrict__ gemm_out,
    const float* __restrict__ original_x,
    float* __restrict__ out,
    int batch_size,
    int out_features,
    int in_features
) {
    int row = blockIdx.x;

    extern __shared__ float sdata[];
    float local_sum = 0.0f;

    for (int i = threadIdx.x; i < out_features; i += blockDim.x) {
        local_sum += gemm_out[row * out_features + i];
    }
    sdata[threadIdx.x] = local_sum;
    __syncthreads();

    for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
        if (threadIdx.x < s)
            sdata[threadIdx.x] += sdata[threadIdx.x + s];
        __syncthreads();
    }

    float avg = sdata[0] / static_cast<float>(out_features);
    float gelu = avg * 0.5f * (1.0f + tanhf(0.7978845608f * (avg + 0.044715f * avg * avg * avg)));
    __syncthreads();

    for (int j = threadIdx.x; j < in_features; j += blockDim.x) {
        int idx = row * in_features + j;
        out[idx] = original_x[idx] + gelu;
    }
}

//---------------------------------------------------------------------------
// Forward function launching the fused pipeline:
// 1. GEMM with bias and subtract using coalesced memory accesses
// 2. Fused kernel for pooling, GELU, and residual addition
//---------------------------------------------------------------------------
torch::Tensor forward_cuda(
    const torch::Tensor& x,
    const torch::Tensor& weight,
    const torch::Tensor& bias,
    const torch::Tensor& subtract
) {
    TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor");
    TORCH_CHECK(weight.is_cuda(), "weight must be a CUDA tensor");
    TORCH_CHECK(bias.is_cuda(), "bias must be a CUDA tensor");
    TORCH_CHECK(subtract.is_cuda(), "subtract must be a CUDA tensor");
    TORCH_CHECK(x.dim() == 2, "x must be 2D (batch_size x in_features)");
    TORCH_CHECK(weight.dim() == 2, "weight must be 2D (out_features x in_features)");
    TORCH_CHECK(bias.dim() == 1, "bias must be 1D (out_features)");
    TORCH_CHECK(subtract.dim() == 1, "subtract must be 1D (out_features)");

    int batch_size = x.size(0);
    int in_features = x.size(1);
    int out_features = weight.size(0);

    TORCH_CHECK(weight.size(1) == in_features, "Mismatch between weight and x dimensions");
    TORCH_CHECK(bias.size(0) == out_features, "bias dimension must match weight output features");
    TORCH_CHECK(subtract.size(0) == out_features, "subtract dimension must match weight output features");
    TORCH_CHECK(out_features <= MAX_OUT_FEATURES, "out_features exceeds maximum allowed for constant memory");

    auto x_contig = x.contiguous();
    auto weight_contig = weight.contiguous();
    auto bias_contig = bias.contiguous();
    auto subtract_contig = subtract.contiguous();

    cudaMemcpyToSymbol(c_bias, bias_contig.data_ptr<float>(), out_features * sizeof(float));
    cudaMemcpyToSymbol(c_subtract, subtract_contig.data_ptr<float>(), out_features * sizeof(float));

    auto original_x = x_contig.clone();

    auto gemm_out = torch::empty({batch_size, out_features}, x.options());
    auto out_tensor = torch::empty({batch_size, in_features}, x.options());

    dim3 threadsGemm(TILE_DIM, TILE_DIM);
    dim3 blocksGemm((out_features + TILE_DIM - 1) / TILE_DIM, (batch_size + TILE_DIM - 1) / TILE_DIM);
    coalesced_gemm_subtract_kernel<<<blocksGemm, threadsGemm>>>(
        x_contig.data_ptr<float>(),
        weight_contig.data_ptr<float>(),
        gemm_out.data_ptr<float>(),
        batch_size,
        in_features,
        out_features
    );

    int threadsFused = 256;
    fused_pool_gelu_residual_kernel<<<batch_size, threadsFused, threadsFused * sizeof(float)>>>(
        gemm_out.data_ptr<float>(),
        original_x.data_ptr<float>(),
        out_tensor.data_ptr<float>(),
        batch_size,
        out_features,
        in_features
    );

    return out_tensor;
}

// PyBind11 interface
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward_cuda, "Fused GEMM, pooling, GELU, and residual add CUDA kernel");
}