← Back to Leaderboard

The AI CUDA Engineer 👷

76_Gemm_Add_ReLUunrolled_warp_gemm_edit_1

Level 2 • Task 76

Kernel Information

Related Kernels (Level 2, Task 76 • 76_Gemm_Add_ReLU)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 shared_warp_tile_kernel_base 0.03 0.93 1.54
🥇 combined_warp_tile_base 0.03 0.93 1.54
🥉 optimized_block_size_kernel_base 0.03 0.89 1.49
4 warp_tile_ldg_base 0.03 0.87 1.44
4 even_workload_dist_base_base 0.03 0.87 1.44
4 hybrid_warp_tile_kernel_base 0.03 0.87 1.44
4 warp_tile_hybrid_base 0.03 0.87 1.44
8 warp_tile_ldg_opt_base 0.03 0.81 1.36
8 warp_reduction_optimized_base_base 0.03 0.81 1.36
10 optimized_shared_memory_base_base 0.03 0.79 1.32
10 warp_tile_base_base 0.03 0.79 1.32
12 hybrid_optimized_kernel_base 0.04 0.77 1.28
13 warp_reduction_gemm_base 0.04 0.71 1.18
13 warp_tile_aligned_base_base 0.04 0.71 1.18
15 vectorized_warp_unroll_base_base 0.04 0.69 1.15
15 vectorized_warp_unroll_base_edit_1 0.04 0.69 1.15
15 warp_reduction_unrolled_gemm_edit_1 0.04 0.69 1.15
18 unrolled_warp_gemm_edit_1 0.04 0.67 1.12
18 unrolled_warp_gemm_base 0.04 0.67 1.12
18 vectorized_warp_reduction_base 0.04 0.67 1.12
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <c10/cuda/CUDAGuard.h>

// Each warp computes one output element using warp-level reduction
// This design avoids atomic operations since each warp exclusively works on one output element.
__global__ void linear_relu_unrolled_warp_kernel(const float* x, const float* weight, const float* bias, float* out,
                                                 int batch_size, int in_features, int out_features) {
  // Compute a global warp id from the thread index
  int warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
  int lane = threadIdx.x % 32; // Lane index within the warp
  
  int total_outputs = batch_size * out_features;
  if (warp_id >= total_outputs) return;

  // Map warp id to output matrix coordinates (i, j)
  int i = warp_id / out_features;
  int j = warp_id % out_features;

  float sum = 0.0f;

  // Cache the base indices to avoid repeated calculations
  const int x_base = i * in_features;
  const int w_base = j * in_features;
  
  // Each thread in the warp processes a strided portion of the in_features dimension
  // Unroll by 4 to reduce loop overhead and enable better instruction-level parallelism
  #pragma unroll 4
  for (int k = lane; k < in_features; k += 32) {
    float x_val = x[x_base + k];
    float w_val = weight[w_base + k];
    sum += x_val * w_val;
  }

  // Perform warp-level reduction using shuffle operations
  for (int offset = 16; offset > 0; offset /= 2) {
    sum += __shfl_down_sync(0xffffffff, sum, offset);
  }

  // The first lane of each warp writes the result
  if (lane == 0) {
    sum += bias[j];
    // Apply ReLU activation
    out[i * out_features + j] = sum > 0.0f ? sum : 0.0f;
  }
}


torch::Tensor linear_relu_forward(torch::Tensor x, torch::Tensor weight, torch::Tensor bias) {
  TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor");
  TORCH_CHECK(weight.is_cuda(), "weight must be a CUDA tensor");
  TORCH_CHECK(bias.is_cuda(), "bias must be a CUDA tensor");

  const int batch_size = x.size(0);
  const int in_features = x.size(1);
  const int out_features = weight.size(0);

  // Allocate output tensor
  auto out = torch::empty({batch_size, out_features}, x.options());

  // Each warp computes one output element. Total number of warps required is batch_size * out_features.
  int total_warps = batch_size * out_features;

  // Each warp consists of 32 threads. Determine total threads required.
  int total_threads = total_warps * 32;

  // Choose block size as a multiple of 32, e.g., 256 threads per block
  int threads_per_block = 256;
  int blocks = (total_threads + threads_per_block - 1) / threads_per_block;

  cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
  linear_relu_unrolled_warp_kernel<<<blocks, threads_per_block, 0, stream>>>(
      x.data_ptr<float>(),
      weight.data_ptr<float>(),
      bias.data_ptr<float>(),
      out.data_ptr<float>(),
      batch_size,
      in_features,
      out_features
  );
  
  return out;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &linear_relu_forward, "Unrolled warp-level reduction GEMM+Bias+ReLU (CUDA)");
}