← Back to Leaderboard

The AI CUDA Engineer 👷

30_Gemm_GroupNorm_Hardtanhmin_warp_divergence_edit_1

Level 2 • Task 30

Kernel Information

Related Kernels (Level 2, Task 30 • 30_Gemm_GroupNorm_Hardtanh)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 warp_divergence_minimization_base 0.06 0.88 0.91
🥈 warp_divergence_minimization_edit_1 0.06 0.86 0.89
🥉 optimized_block_sizes_base_edit_1 0.06 0.85 0.88
🥉 optimized_gemm_groupnorm_hardtanh_edit_1 0.06 0.85 0.88
🥉 ldg_memory_alignment_optimization_base 0.06 0.85 0.88
🥉 optimized_kernel_unroll_loops_base 0.06 0.85 0.88
🥉 modular_device_functions_optimized_v2_base 0.06 0.85 0.88
🥉 modular_device_functions_refactor_base 0.06 0.85 0.88
🥉 optimized_kernel_unroll_loops_edit_1 0.06 0.85 0.88
10 shared_mem_reuse_v1_base 0.06 0.83 0.86
11 unroll_loops_optim_base 0.06 0.79 0.82
11 min_warp_divergence_edit_1 0.06 0.79 0.82
11 sync_reduction_optim_edit_1 0.06 0.79 0.82
14 sync_reduction_optim_base 0.06 0.78 0.81
14 min_warp_divergence_base 0.06 0.78 0.81
14 modular_kernel_edit_1 0.06 0.78 0.81
14 constant_memory_optimization_base_edit_1 0.06 0.78 0.81
18 optimized_kernel_constant_memory_base 0.07 0.74 0.77
18 const_memory_optimized_kernel_edit_1 0.07 0.74 0.77
20 const_memory_optimized_kernel_base 0.07 0.73 0.76
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <vector>
#include <cmath>

// Inline device clamp function for both float and double types to ensure branchless clamping.

template <typename scalar_t>
__device__ inline scalar_t device_clamp(scalar_t v, scalar_t min_val, scalar_t max_val);

template <>
__device__ inline float device_clamp<float>(float v, float min_val, float max_val) {
    return fminf(fmaxf(v, min_val), max_val);
}

template <>
__device__ inline double device_clamp<double>(double v, double min_val, double max_val) {
    return fmin(fmax(v, min_val), max_val);
}

// Compute mean and variance for a group in GroupNorm with loop unrolling
template <typename scalar_t>
__device__ inline void compute_group_mean_var(
    const scalar_t* __restrict__ x,
    int batch,
    int group,
    int channels_per_group,
    int num_channels,
    scalar_t &mean,
    scalar_t &var) {
  mean = 0;
  #pragma unroll
  for (int c = 0; c < channels_per_group; ++c) {
    int channel = group * channels_per_group + c;
    mean += x[batch * num_channels + channel];
  }
  mean /= static_cast<scalar_t>(channels_per_group);
  var = 0;
  #pragma unroll
  for (int c = 0; c < channels_per_group; ++c) {
    int channel = group * channels_per_group + c;
    scalar_t diff = x[batch * num_channels + channel] - mean;
    var += diff * diff;
  }
  var /= static_cast<scalar_t>(channels_per_group);
}

// Normalize and apply scale (gamma) and shift (beta)
template <typename scalar_t>
__device__ inline scalar_t group_norm_normalize(
    scalar_t val,
    scalar_t mean,
    scalar_t var,
    scalar_t eps,
    scalar_t gamma,
    scalar_t beta) {
  scalar_t inv_std = __frsqrt_rn(var + eps);
  return ((val - mean) * inv_std) * gamma + beta;
}

// Tiled linear kernel with uniform control flow
template <typename scalar_t>
__global__ void linear_forward_kernel(
    const scalar_t* __restrict__ x,
    const scalar_t* __restrict__ weight,
    const scalar_t* __restrict__ bias,
    scalar_t* __restrict__ output,
    size_t batch_size,
    size_t in_features,
    size_t out_features) {

  int row = blockIdx.y * blockDim.y + threadIdx.y;  // batch index
  int col = blockIdx.x * blockDim.x + threadIdx.x;      // output feature index
  const int TILE_DIM = 16;
  
  if (row < batch_size && col < out_features) {
    scalar_t sum = bias[col];
    __shared__ scalar_t tile_x[TILE_DIM][TILE_DIM];
    __shared__ scalar_t tile_w[TILE_DIM][TILE_DIM];
    int numTiles = (in_features + TILE_DIM - 1) / TILE_DIM;
    for (int t = 0; t < numTiles; t++) {
      int x_idx = t * TILE_DIM + threadIdx.x;
      int w_idx = t * TILE_DIM + threadIdx.y;
      tile_x[threadIdx.y][threadIdx.x] = (x_idx < in_features) ? x[row * in_features + x_idx] : static_cast<scalar_t>(0);
      tile_w[threadIdx.y][threadIdx.x] = (w_idx < in_features) ? weight[col * in_features + w_idx] : static_cast<scalar_t>(0);
      __syncthreads();
      #pragma unroll
      for (int k = 0; k < TILE_DIM; k++) {
        sum += tile_x[threadIdx.y][k] * tile_w[k][threadIdx.x];
      }
      __syncthreads();
    }
    output[row * out_features + col] = sum;
  }
}

// Group normalization kernel with loop unrolling; using one thread per block minimizes divergence.
template <typename scalar_t>
__global__ void group_norm_forward_kernel(
    const scalar_t* __restrict__ x,
    const scalar_t* __restrict__ gamma,
    const scalar_t* __restrict__ beta,
    scalar_t* __restrict__ output,
    int64_t batch_size,
    int64_t num_channels,
    int64_t num_groups,
    int64_t channels_per_group,
    float eps = 1e-5f) {

  int batch = blockIdx.x;
  int group = blockIdx.y;
  
  if (batch < batch_size && group < num_groups) {
    scalar_t mean, var;
    compute_group_mean_var(x, batch, group, channels_per_group, num_channels, mean, var);
    #pragma unroll
    for (int c = 0; c < channels_per_group; ++c) {
      int channel = group * channels_per_group + c;
      scalar_t val = x[batch * num_channels + channel];
      output[batch * num_channels + channel] = group_norm_normalize(val, mean, var, static_cast<scalar_t>(eps), gamma[channel], beta[channel]);
    }
  }
}

// Hardtanh activation kernel refactored to remove divergent branches using branchless clamping
template <typename scalar_t>
__global__ void hardtanh_forward_kernel(
    const scalar_t* __restrict__ x,
    scalar_t min_val,
    scalar_t max_val,
    scalar_t* __restrict__ output,
    size_t total_elements) {

  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < total_elements) {
    scalar_t val = x[idx];
    // Replace conditional branches with branchless clamp to minimize warp divergence
    output[idx] = device_clamp<scalar_t>(val, min_val, max_val);
  }
}

// C++ interface functions

void linear_forward_cuda(
    at::Tensor x, 
    at::Tensor weight, 
    at::Tensor bias, 
    at::Tensor output) {

  const auto batch_size = x.size(0);
  const auto in_features = x.size(1);
  const auto out_features = weight.size(0);

  const int threads = 16;
  const dim3 threadsPerBlock(threads, threads);
  const dim3 numBlocks((out_features + threads - 1) / threads, (batch_size + threads - 1) / threads);

  AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "linear_forward_cuda", ([&] {
    linear_forward_kernel<scalar_t><<<numBlocks, threadsPerBlock>>>(
        x.data_ptr<scalar_t>(),
        weight.data_ptr<scalar_t>(),
        bias.data_ptr<scalar_t>(),
        output.data_ptr<scalar_t>(),
        batch_size,
        in_features,
        out_features);
  }));
  cudaError_t err = cudaGetLastError();
  if (err != cudaSuccess) {
    printf("Error in linear_forward_cuda: %s\n", cudaGetErrorString(err));
  }
}

void group_norm_forward_cuda(
    at::Tensor x, 
    at::Tensor gamma, 
    at::Tensor beta, 
    int64_t num_groups,
    at::Tensor output) {

  const int64_t batch_size = x.size(0);
  const int64_t num_channels = x.size(1);
  const int64_t channels_per_group = num_channels / num_groups;

  const dim3 blocks(batch_size, num_groups);
  const int threads = 1; // One thread per block minimizes divergence in group norm
  AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "group_norm_forward_cuda", ([&] {
    group_norm_forward_kernel<scalar_t><<<blocks, threads>>>(
        x.data_ptr<scalar_t>(),
        gamma.data_ptr<scalar_t>(),
        beta.data_ptr<scalar_t>(),
        output.data_ptr<scalar_t>(),
        batch_size,
        num_channels,
        num_groups,
        channels_per_group);
  }));
  cudaError_t err = cudaGetLastError();
  if (err != cudaSuccess) {
    printf("Error in group_norm_forward_cuda: %s\n", cudaGetErrorString(err));
  }
}

void hardtanh_forward_cuda(
    at::Tensor x, 
    float min_val, 
    float max_val,
    at::Tensor output) {

  const size_t total_elements = x.numel();
  const int threads = 256;
  const int blocks = (total_elements + threads - 1) / threads;
  AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "hardtanh_forward_cuda", ([&] {
    hardtanh_forward_kernel<scalar_t><<<blocks, threads>>>(
        x.data_ptr<scalar_t>(),
        static_cast<scalar_t>(min_val),
        static_cast<scalar_t>(max_val),
        output.data_ptr<scalar_t>(),
        total_elements);
  }));
  cudaError_t err = cudaGetLastError();
  if (err != cudaSuccess) {
    printf("Error in hardtanh_forward_cuda: %s\n", cudaGetErrorString(err));
  }
}

// Combined module function integrating linear, group norm and hardtanh kernels
at::Tensor module_fn_cuda_forward(
    at::Tensor x,
    at::Tensor weight,
    at::Tensor bias,
    at::Tensor group_norm_weight,
    at::Tensor group_norm_bias,
    int64_t num_groups,
    float hardtanh_min,
    float hardtanh_max) {

  x = x.contiguous();
  weight = weight.contiguous();
  bias = bias.contiguous();
  group_norm_weight = group_norm_weight.contiguous();
  group_norm_bias = group_norm_bias.contiguous();

  int64_t batch_size = x.size(0);
  int64_t in_features = x.size(1);
  int64_t out_features = weight.size(0);
  auto options = x.options();
  
  at::Tensor linear_output = at::empty({batch_size, out_features}, options);
  at::Tensor group_norm_output = at::empty({batch_size, out_features}, options);
  at::Tensor output = at::empty({batch_size, out_features}, options);

  linear_forward_cuda(x, weight, bias, linear_output);
  group_norm_forward_cuda(linear_output, group_norm_weight, group_norm_bias, num_groups, group_norm_output);
  hardtanh_forward_cuda(group_norm_output, hardtanh_min, hardtanh_max, output);

  return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &module_fn_cuda_forward, "Module function forward (CUDA)");
}