← Back to Leaderboard

The AI CUDA Engineer 👷

68_Matmul_Min_Subtractgrid_2d_mapping_edit_1

Level 2 • Task 68

Kernel Information

Related Kernels (Level 2, Task 68 • 68_Matmul_Min_Subtract)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 optimized_shared_memory_sync_base 0.01 2.95 1.92
🥇 optimized_thread_block_indexing_base_base 0.01 2.95 1.92
🥇 modularized_device_functions_base_base 0.01 2.95 1.92
🥇 aligned_memory_access_ldg_base_base 0.01 2.95 1.92
🥇 stride_loop_optimization_thread_base_base 0.01 2.95 1.92
🥇 tiled_shared_memory_matmul_base_base 0.01 2.95 1.92
🥇 unrolled_loop_optimization_base 0.01 2.95 1.92
🥇 optimized_warp_coalesced_base 0.01 2.95 1.92
🥇 grid_2d_mapping_base 0.01 2.95 1.92
🥇 aligned_memory_access_base_edit_1 0.01 2.95 1.92
🥇 modular_device_functions_edit_1 0.01 2.95 1.92
🥇 aligned_memory_access_base_base 0.01 2.95 1.92
🥇 grid_2d_mapping_edit_1 0.01 2.95 1.92
🥇 efficient_thread_block_mapping_base 0.01 2.95 1.92
🥇 stride_loop_optimization_base_base 0.01 2.95 1.92
🥇 tiled_gemm_thread_mapping_base 0.01 2.95 1.92
17 modular_device_functions_base 0.01 2.62 1.71
17 aligned_ldg_vectorized_edit_1 0.01 2.62 1.71
17 aligned_ldg_vectorized_base 0.01 2.62 1.71
17 branchless_min_dot_edit_1 0.01 2.62 1.71
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

// 2D indexed kernel: Each thread computes one element of the output matrix.
__global__ void kernel_2d(
    const float* __restrict__ x,
    const float* __restrict__ linear_weight,
    const float* __restrict__ linear_bias,
    const float* __restrict__ constant,
    float* __restrict__ y,
    int batch_size,
    int in_features,
    int out_features) {

    // Map thread indices to matrix coordinates
    int o = blockIdx.x * blockDim.x + threadIdx.x; // output feature index
    int b = blockIdx.y * blockDim.y + threadIdx.y; // batch index

    if (b < batch_size && o < out_features) {
        float dot = 0.0f;
        int j = 0;
        // Unroll loop for better performance when in_features is divisible by 4
        for (; j <= in_features - 4; j += 4) {
            dot += x[b * in_features + j] * linear_weight[o * in_features + j];
            dot += x[b * in_features + j + 1] * linear_weight[o * in_features + j + 1];
            dot += x[b * in_features + j + 2] * linear_weight[o * in_features + j + 2];
            dot += x[b * in_features + j + 3] * linear_weight[o * in_features + j + 3];
        }
        // Handle remaining elements
        for (; j < in_features; j++) {
            dot += x[b * in_features + j] * linear_weight[o * in_features + j];
        }
        
        float bias = linear_bias[o];
        const float c = constant[0];
        dot = fminf(dot + bias, c) - c;
        y[b * out_features + o] = dot;
    }
}

// Forward function to launch the kernel
torch::Tensor forward(
    torch::Tensor x,
    torch::Tensor linear_weight,
    torch::Tensor linear_bias,
    torch::Tensor constant) {
    TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor");
    TORCH_CHECK(linear_weight.is_cuda(), "linear_weight must be a CUDA tensor");
    TORCH_CHECK(linear_bias.is_cuda(), "linear_bias must be a CUDA tensor");
    TORCH_CHECK(constant.is_cuda(), "constant must be a CUDA tensor");

    int batch_size = x.size(0);
    int in_features = x.size(1);
    int out_features = linear_weight.size(0);

    auto y = torch::zeros({batch_size, out_features}, x.options());

    const float* x_ptr = x.data_ptr<float>();
    const float* weight_ptr = linear_weight.data_ptr<float>();
    const float* bias_ptr = linear_bias.data_ptr<float>();
    const float* constant_ptr = constant.data_ptr<float>();
    float* y_ptr = y.data_ptr<float>();

    // Use a 2D block mapping to naturally cover the output matrix dimensions
    dim3 block(16, 16);
    dim3 grid((out_features + block.x - 1) / block.x, (batch_size + block.y - 1) / block.y);

    kernel_2d<<<grid, block>>>(
        x_ptr,
        weight_ptr,
        bias_ptr,
        constant_ptr,
        y_ptr,
        batch_size,
        in_features,
        out_features);

    return y;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "CUDA 2D grid indexed forward function");
}