← Back to Leaderboard

The AI CUDA Engineer 👷

16_DenseNet201coalesced_densenet_bn_edit_1

Level 3 • Task 16

Kernel Information

Related Kernels (Level 3, Task 16 • 16_DenseNet201)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 warp_optimized_densenet_op_base 8.04 1.01 1.03
🥈 optimized_densenet_cuda_edit_1 8.04 1.01 1.03
🥉 shared_memory_densenet_op_edit_1 8.04 1.01 1.03
4 constant_mem_densenet_edit_1_base 8.06 1.01 1.03
5 coalesced_densenet_bn_base 8.06 1.01 1.03
6 warp_broadcast_densenet_optimized_base 8.09 1.01 1.03
7 warp_uniform_edit_1 8.09 1.01 1.03
8 warp_uniform_base 8.09 1.01 1.03
9 coalesced_densenet_bn_edit_1 8.09 1.01 1.03
10 thread_synchronization_densenet_base 8.10 1.01 1.03
11 16_DenseNet201 8.10 1.01 1.03
12 configurable_blocksize_densenet_base 8.11 1.00 1.03
13 constant_mem_densenet_edit_1_edit_1 8.11 1.00 1.02
14 fuse_bn_relu_opt_base 8.12 1.00 1.02
15 fuse_bn_relu_opt_edit_1 8.13 1.00 1.02
16 stride_loop_densenet_edit_1 8.13 1.00 1.02
17 configurable_blocksize_densenet_edit_1 8.14 1.00 1.02
18 warp_reduction_densenet_base_edit_1 8.14 1.00 1.02
19 shared_memory_densenet_op_base 8.14 1.00 1.02
20 stride_loop_densenet_base 8.15 1.00 1.02
#include <torch/extension.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <vector>
#include <cuda_runtime.h>

#define EPS 1e-5f

// Kernel that uses vectorized loads/stores (float4) to ensure memory coalescing.
// Assumes that the total number of elements (N*C*H*W) is divisible by 4.
__global__ void coalesced_bn_kernel(
    float* __restrict__ output,
    const float* __restrict__ input,
    const float* __restrict__ weight,
    const float* __restrict__ bias,
    const float* __restrict__ mean,
    const float* __restrict__ var,
    int N, int C, int H, int W) {

  int total = N * C * H * W;
  int total_vec = total / 4;  // Each float4 covers 4 consecutive elements
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = blockDim.x * gridDim.x;

  const float4* input_vec = reinterpret_cast<const float4*>(input);
  float4* output_vec = reinterpret_cast<float4*>(output);

  for (int i = idx; i < total_vec; i += stride) {
    int base_index = i * 4;  // starting index in the original array
    float4 in_val = input_vec[i];
    float4 out_val;
    
    // Process each of the 4 elements
    #pragma unroll
    for (int j = 0; j < 4; j++) {
      int global_index = base_index + j;
      // Compute channel index: in NCHW layout, channels change every H*W elements
      int channel = (global_index / (H * W)) % C;
      float inv_std = rsqrtf(var[channel] + EPS);
      float elem = (j == 0) ? in_val.x : (j == 1) ? in_val.y : (j == 2) ? in_val.z : in_val.w;
      float normalized = (elem - mean[channel]) * inv_std;
      float result = weight[channel] * normalized + bias[channel];
      if (j == 0) out_val.x = result;
      else if (j == 1) out_val.y = result;
      else if (j == 2) out_val.z = result;
      else out_val.w = result;
    }
    output_vec[i] = out_val;
  }
}

// Dense layer function: applies batch normalization (using the optimized, coalesced kernel in inference), relu, conv2d, then dropout.
torch::Tensor dense_layer_fn(
    torch::Tensor x,
    torch::Tensor bn_weight,  // gamma
    torch::Tensor bn_bias,    // beta
    torch::Tensor bn_mean,
    torch::Tensor bn_var,
    torch::Tensor conv_weight,
    bool is_training) {

  auto sizes = x.sizes();
  int N = sizes[0], C = sizes[1], H = sizes[2], W = sizes[3];
  auto output = torch::empty_like(x);

  int total = N * C * H * W;
  if (!is_training && (total % 4 == 0)) {
    int total_vec = total / 4;
    const int threads = 256;
    const int max_blocks = 65535;  // Maximum number of blocks per grid
    int desired_blocks = (total_vec + threads - 1) / threads;
    int blocks = min(desired_blocks, max_blocks);
    coalesced_bn_kernel<<<blocks, threads>>>(
        output.data_ptr<float>(),
        x.data_ptr<float>(),
        bn_weight.data_ptr<float>(),
        bn_bias.data_ptr<float>(),
        bn_mean.data_ptr<float>(),
        bn_var.data_ptr<float>(),
        N, C, H, W);
  } else {
    output = at::batch_norm(x, bn_weight, bn_bias, bn_mean, bn_var,
                            is_training, 0.1, EPS, true);
  }
  
  output = at::relu(output);
  output = at::conv2d(output,
                      conv_weight,
                      c10::nullopt,
                      at::IntArrayRef({1, 1}),
                      at::IntArrayRef({1, 1}));
  output = at::dropout(output, 0.0, is_training);
  return output;
}

// Dense block: iteratively applies dense layers and concatenates the features.
torch::Tensor dense_block_fn(torch::Tensor x, pybind11::list layer_params, bool is_training) {
  std::vector<torch::Tensor> features;
  features.push_back(x);
  for (ssize_t i = 0; i < layer_params.size(); i++) {
    auto params_tuple = layer_params[i].cast<pybind11::tuple>();
    torch::Tensor bn_weight   = params_tuple[0].cast<torch::Tensor>();
    torch::Tensor bn_bias     = params_tuple[1].cast<torch::Tensor>();
    torch::Tensor bn_mean     = params_tuple[2].cast<torch::Tensor>();
    torch::Tensor bn_var      = params_tuple[3].cast<torch::Tensor>();
    torch::Tensor conv_weight = params_tuple[4].cast<torch::Tensor>();

    torch::Tensor new_feature = dense_layer_fn(x, bn_weight, bn_bias, bn_mean, bn_var, conv_weight, is_training);
    features.push_back(new_feature);
    x = at::cat(features, 1);
  }
  return x;
}

// Transition layer: applies batch normalization (optimized in inference), relu, conv2d and average pooling.
torch::Tensor transition_layer_fn(
    torch::Tensor x,
    torch::Tensor bn_weight,
    torch::Tensor bn_bias,
    torch::Tensor bn_mean,
    torch::Tensor bn_var,
    torch::Tensor conv_weight,
    bool is_training) {

  auto sizes = x.sizes();
  int N = sizes[0], C = sizes[1], H = sizes[2], W = sizes[3];
  auto output = torch::empty_like(x);

  int total = N * C * H * W;
  if (!is_training && (total % 4 == 0)) {
    int total_vec = total / 4;
    const int threads = 256;
    const int max_blocks = 65535;  // Maximum number of blocks per grid
    int desired_blocks = (total_vec + threads - 1) / threads;
    int blocks = min(desired_blocks, max_blocks);
    coalesced_bn_kernel<<<blocks, threads>>>(
        output.data_ptr<float>(),
        x.data_ptr<float>(),
        bn_weight.data_ptr<float>(),
        bn_bias.data_ptr<float>(),
        bn_mean.data_ptr<float>(),
        bn_var.data_ptr<float>(),
        N, C, H, W);
  } else {
    output = at::batch_norm(x, bn_weight, bn_bias, bn_mean, bn_var,
                            is_training, 0.1, EPS, true);
  }

  output = at::relu(output);
  output = at::conv2d(output,
                      conv_weight,
                      c10::nullopt,
                      at::IntArrayRef({1, 1}),
                      at::IntArrayRef({0, 0}));
  output = at::avg_pool2d(output,
                          at::IntArrayRef({2, 2}),
                          at::IntArrayRef({2, 2}));
  return output;
}

// Forward pass: processes initial conv, dense blocks with transition layers, final BN, pooling and linear classifier.
torch::Tensor forward(torch::Tensor x, pybind11::object params_obj, bool is_training) {
  pybind11::dict params = params_obj.cast<pybind11::dict>();

  torch::Tensor features_conv_weight = params["features_conv_weight"].cast<torch::Tensor>();
  torch::Tensor features_bn_mean     = params["features_bn_mean"].cast<torch::Tensor>();
  torch::Tensor features_bn_var      = params["features_bn_var"].cast<torch::Tensor>();
  torch::Tensor features_bn_weight   = params["features_bn_weight"].cast<torch::Tensor>();
  torch::Tensor features_bn_bias     = params["features_bn_bias"].cast<torch::Tensor>();

  x = at::conv2d(x,
                 features_conv_weight,
                 c10::nullopt,
                 at::IntArrayRef({2, 2}),
                 at::IntArrayRef({3, 3}));
  
  auto sizes = x.sizes();
  int N = sizes[0], C = sizes[1], H = sizes[2], W = sizes[3];
  auto output = torch::empty_like(x);
  int total = N * C * H * W;
  if (!is_training && (total % 4 == 0)) {
    int total_vec = total / 4;
    const int threads = 256;
    const int max_blocks = 65535;  // Maximum number of blocks per grid
    int desired_blocks = (total_vec + threads - 1) / threads;
    int blocks = min(desired_blocks, max_blocks);
    coalesced_bn_kernel<<<blocks, threads>>>(
        output.data_ptr<float>(),
        x.data_ptr<float>(),
        features_bn_weight.data_ptr<float>(),
        features_bn_bias.data_ptr<float>(),
        features_bn_mean.data_ptr<float>(),
        features_bn_var.data_ptr<float>(),
        N, C, H, W);
    x = output;
  } else {
    x = at::batch_norm(x,
                       features_bn_weight,
                       features_bn_bias,
                       features_bn_mean,
                       features_bn_var,
                       is_training, 0.1, EPS, true);
  }
  
  x = at::relu(x);
  x = at::max_pool2d(x,
                     at::IntArrayRef({3, 3}),
                     at::IntArrayRef({2, 2}),
                     at::IntArrayRef({1, 1}));

  pybind11::list dense_blocks = params["dense_blocks"].cast<pybind11::list>();
  pybind11::list transition_layers = params["transition_layers"].cast<pybind11::list>();

  int num_dense_blocks = dense_blocks.size();
  for (int i = 0; i < num_dense_blocks; i++) {
    pybind11::list block_params = dense_blocks[i].cast<pybind11::list>();
    x = dense_block_fn(x, block_params, is_training);

    if (i != num_dense_blocks - 1) {
      auto trans_tuple = transition_layers[i].cast<pybind11::tuple>();
      torch::Tensor t_bn_weight = trans_tuple[0].cast<torch::Tensor>();
      torch::Tensor t_bn_bias   = trans_tuple[1].cast<torch::Tensor>();
      torch::Tensor t_bn_mean   = trans_tuple[2].cast<torch::Tensor>();
      torch::Tensor t_bn_var    = trans_tuple[3].cast<torch::Tensor>();
      torch::Tensor t_conv_weight = trans_tuple[4].cast<torch::Tensor>();

      x = transition_layer_fn(x, t_bn_weight, t_bn_bias, t_bn_mean, t_bn_var, t_conv_weight, is_training);
    }
  }

  torch::Tensor final_bn_mean   = params["final_bn_mean"].cast<torch::Tensor>();
  torch::Tensor final_bn_var    = params["final_bn_var"].cast<torch::Tensor>();
  torch::Tensor final_bn_weight = params["final_bn_weight"].cast<torch::Tensor>();
  torch::Tensor final_bn_bias   = params["final_bn_bias"].cast<torch::Tensor>();

  sizes = x.sizes();
  N = sizes[0]; C = sizes[1]; H = sizes[2]; W = sizes[3];
  output = torch::empty_like(x);
  total = N * C * H * W;
  if (!is_training && (total % 4 == 0)) {
    int total_vec = total / 4;
    const int threads = 256;
    const int max_blocks = 65535;  // Maximum number of blocks per grid
    int desired_blocks = (total_vec + threads - 1) / threads;
    int blocks = min(desired_blocks, max_blocks);
    coalesced_bn_kernel<<<blocks, threads>>>(
        output.data_ptr<float>(),
        x.data_ptr<float>(),
        final_bn_weight.data_ptr<float>(),
        final_bn_bias.data_ptr<float>(),
        final_bn_mean.data_ptr<float>(),
        final_bn_var.data_ptr<float>(),
        N, C, H, W);
    x = output;
  } else {
    x = at::batch_norm(x,
                       final_bn_weight,
                       final_bn_bias,
                       final_bn_mean,
                       final_bn_var,
                       is_training, 0.1, EPS, true);
  }
  
  x = at::relu(x);
  x = at::adaptive_avg_pool2d(x, at::IntArrayRef({1, 1}));
  x = x.view({x.size(0), -1});

  torch::Tensor classifier_weight = params["classifier_weight"].cast<torch::Tensor>();
  torch::Tensor classifier_bias   = params["classifier_bias"].cast<torch::Tensor>();
  x = at::linear(x, classifier_weight, classifier_bias);

  return x;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &forward, "Custom CUDA forward function with coalesced memory accesses using vectorized loads and stores");
}