← Back to Leaderboard

The AI CUDA Engineer 👷

9_Matmul_Subtract_Multiply_ReLUefficient_indexing_tile_kernel_base

Level 2 • Task 9

Kernel Information

Related Kernels (Level 2, Task 9 • 9_Matmul_Subtract_Multiply_ReLU)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 unrolled_loop_kernel_base 0.01 4.05 2.63
🥇 9_Matmul_Subtract_Multiply_ReLU 0.01 4.05 2.63
🥇 9_matmul_subtract_multiply_relu_unroll_base 0.01 4.05 2.63
🥇 9_matmul_subtract_multiply_relu_unroll_base 0.01 4.05 2.63
🥇 modular_matmul_subtract_multiply_relu_base 0.01 4.05 2.63
🥇 efficient_indexing_tile_kernel_base 0.01 4.05 2.63
🥇 efficient_thread_block_mapping_base 0.01 4.05 2.63
🥇 warp_divergence_optimized_base 0.01 4.05 2.63
🥇 warp_level_fused_kernel_base 0.01 4.05 2.63
🥇 shared_mem_tiled_base 0.01 4.05 2.63
🥇 tiled_sharedmem_optimized_base 0.01 4.05 2.63
🥇 warp_level_reduction_kernel_base 0.01 4.05 2.63
🥇 strided_thread_blocks_base_base 0.01 4.05 2.63
🥇 optimized_block_size_base 0.01 4.05 2.63
🥇 double_buffered_tiled_kernel_base 0.01 4.05 2.63
🥇 coalesced_memory_matmul_base_base 0.01 4.05 2.63
🥇 tiled_matmul_shared_mem_base 0.01 4.05 2.63
🥇 optimized_tiled_2d_base 0.01 4.05 2.63
🥇 matmul_1d_thread_mapping_base 0.01 4.05 2.63
🥇 modularized_matmul_ops_base 0.01 4.05 2.63
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

#define TILE_WIDTH 16

// This kernel uses 2D grid and 2D block indexing to map the output matrix (batch_size x out_features)
// into tiles of size TILE_WIDTH x TILE_WIDTH. Each block computes one tile of the output using shared memory
// to load tiles of the input and weight matrices. The thread indices are mapped so that threadIdx.x corresponds
// to the row within a tile (from the batch dimension) and threadIdx.y corresponds to the column within a tile
// (from the output feature dimension). This ensures efficient and correct mapping of threads to problem domains,
// minimizing divergence and promoting coalesced memory accesses.

template <typename scalar_t>
__global__ void efficient_indexing_tile_kernel(
    const scalar_t* __restrict__ input,   // [batch_size, in_features]
    const scalar_t* __restrict__ weight,  // [out_features, in_features]
    const scalar_t* __restrict__ bias,    // [out_features]
    scalar_t* __restrict__ output,
    int batch_size,
    int in_features,
    int out_features,
    float subtract_value,
    float multiply_value) {

    // Compute global row and column indices
    int row = blockIdx.x * TILE_WIDTH + threadIdx.x; // batch dimension
    int col = blockIdx.y * TILE_WIDTH + threadIdx.y; // output feature dimension

    scalar_t sum = 0;

    // Allocate shared memory for a tile of input and weight
    __shared__ scalar_t shared_input[TILE_WIDTH][TILE_WIDTH];
    __shared__ scalar_t shared_weight[TILE_WIDTH][TILE_WIDTH];

    // Number of tiles to loop over the K-dimension
    int numTiles = (in_features + TILE_WIDTH - 1) / TILE_WIDTH;

    for (int t = 0; t < numTiles; t++) {
        // Compute the column index for input tile and corresponding column index for weight
        int input_col = t * TILE_WIDTH + threadIdx.y;
        int weight_col = t * TILE_WIDTH + threadIdx.x;  

        // Load one element of the input tile, if within bounds
        if (row < batch_size && input_col < in_features) {
            shared_input[threadIdx.x][threadIdx.y] = input[row * in_features + input_col];
        } else {
            shared_input[threadIdx.x][threadIdx.y] = 0;
        }
        
        // Load one element of the weight tile, if within bounds
        if (col < out_features && weight_col < in_features) {
            shared_weight[threadIdx.x][threadIdx.y] = weight[col * in_features + weight_col];
        } else {
            shared_weight[threadIdx.x][threadIdx.y] = 0;
        }

        __syncthreads();

        // Compute partial sum for the tile
        #pragma unroll
        for (int k = 0; k < TILE_WIDTH; k++) {
            sum += shared_input[threadIdx.x][k] * shared_weight[k][threadIdx.y];
        }

        __syncthreads();
    }

    // Write back the result if within output bounds, applying bias, subtract, multiply, and ReLU activation.
    if (row < batch_size && col < out_features) {
        sum += bias[col];
        sum = (sum - subtract_value) * multiply_value;
        output[row * out_features + col] = (sum > 0) ? sum : static_cast<scalar_t>(0);
    }
}

// PyTorch forward interface

torch::Tensor forward(
    torch::Tensor input,
    torch::Tensor weight,
    torch::Tensor bias,
    float subtract_value,
    float multiply_value) {

    int batch_size = input.size(0);
    int in_features = input.size(1);
    int out_features = weight.size(0);

    auto output = torch::empty({batch_size, out_features}, input.options());

    // Configure a 2D grid where x-dimension covers the batch and y-dimension covers output features
    dim3 threads(TILE_WIDTH, TILE_WIDTH);
    dim3 blocks(
        (batch_size + TILE_WIDTH - 1) / TILE_WIDTH,
        (out_features + TILE_WIDTH - 1) / TILE_WIDTH
    );

    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "efficient_indexing_tile_kernel", ([&] {
        efficient_indexing_tile_kernel<scalar_t><<<blocks, threads>>>(
            input.data_ptr<scalar_t>(),
            weight.data_ptr<scalar_t>(),
            bias.data_ptr<scalar_t>(),
            output.data_ptr<scalar_t>(),
            batch_size,
            in_features,
            out_features,
            subtract_value,
            multiply_value
        );
    }));

    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Efficient 2D tile kernel with optimized thread-block mapping");
}