← Back to Leaderboard

The AI CUDA Engineer 👷

10_ResNet101warp_pool_optimized_bottleneck_base

Level 3 • Task 10

Kernel Information

Related Kernels (Level 3, Task 10 • 10_ResNet101)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 resnet101_modular_functions_base_base 23.20 1.33 1.33
🥈 resnet101_coalesced_memory_access_base 24.25 1.27 1.28
🥉 resnet101_balanced_workload_base 24.36 1.27 1.27
4 resnet101_balanced_workload_base 24.60 1.26 1.26
5 10_ResNet101_mem_opt_base_base 24.62 1.26 1.26
6 resnet101_uniform_flow_base_base 24.84 1.24 1.25
7 resnet101_shared_mem_sync_optimized_base 24.93 1.24 1.24
8 efficient_resnet_base 25.08 1.23 1.23
9 resnet101_optimized_memory_access_base 25.45 1.21 1.22
10 resnet101_unrolled_loops_base_base 25.58 1.21 1.21
11 resnet101_min_sync_relu_base 25.64 1.21 1.21
12 warp_pool_optimized_bottleneck_base 25.79 1.20 1.20
13 unified_resnet_base 26.38 1.17 1.17
14 10_ResNet101_warp_avg_pool_base 26.55 1.16 1.17
15 resnet101_minimal_sync_base_base 26.90 1.15 1.15
16 10_ResNet101 28.04 1.10 1.10
17 resnet101_fused_distr_base 28.10 1.10 1.10
18 10_resnet101_opt_aligned_mem_edit_1 29.45 1.05 1.05
19 10_resnet101_opt_min_sync_edit_1 29.58 1.04 1.05
20 10_resnet101_opt_base 29.83 1.04 1.04
#include <torch/extension.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <cuda_runtime.h>

namespace py = pybind11;

__global__ void warp_avg_pool_kernel(const float* __restrict__ input, float* __restrict__ output, int spatial, int C, int H, int W) {
    int index = blockIdx.x;
    int n = index / C;
    int c = index % C;
    int total = H * W;
    const float* in_ptr = input + (n * C + c) * total;
    float sum = 0.0f;
    
    for (int i = threadIdx.x; i < total; i += blockDim.x) {
        sum += in_ptr[i];
    }
    
    for (int offset = warpSize / 2; offset > 0; offset /= 2) {
        sum += __shfl_down_sync(0xffffffff, sum, offset);
    }

    if (threadIdx.x == 0) {
        output[index] = sum / total;
    }
}

torch::Tensor unified_bottleneck(
    torch::Tensor x,
    torch::Tensor conv1_w,
    torch::Tensor conv2_w,
    torch::Tensor conv3_w,
    torch::Tensor bn1_w,
    torch::Tensor bn1_b,
    torch::Tensor bn1_m,
    torch::Tensor bn1_v,
    torch::Tensor bn2_w,
    torch::Tensor bn2_b,
    torch::Tensor bn2_m,
    torch::Tensor bn2_v,
    torch::Tensor bn3_w,
    torch::Tensor bn3_b,
    torch::Tensor bn3_m,
    torch::Tensor bn3_v,
    bool has_downsample,
    torch::Tensor downsample_conv_w,
    torch::Tensor downsample_bn_w,
    torch::Tensor downsample_bn_b,
    torch::Tensor downsample_bn_m,
    torch::Tensor downsample_bn_v,
    int64_t stride,
    bool is_training
) {
    auto out = torch::conv2d(x, conv1_w, /*bias=*/torch::Tensor());
    out = torch::batch_norm(out, bn1_w, bn1_b, bn1_m, bn1_v, is_training, 0.1, 1e-5, true);
    out = torch::relu(out);
    
    out = torch::conv2d(out, conv2_w, /*bias=*/torch::Tensor(), stride, 1);
    out = torch::batch_norm(out, bn2_w, bn2_b, bn2_m, bn2_v, is_training, 0.1, 1e-5, true);
    out = torch::relu(out);
    
    out = torch::conv2d(out, conv3_w, /*bias=*/torch::Tensor());
    out = torch::batch_norm(out, bn3_w, bn3_b, bn3_m, bn3_v, is_training, 0.1, 1e-5, true);
    
    torch::Tensor identity;
    if (has_downsample) {
        identity = torch::conv2d(x, downsample_conv_w, /*bias=*/torch::Tensor(), stride);
        identity = torch::batch_norm(identity, downsample_bn_w, downsample_bn_b, downsample_bn_m, downsample_bn_v, is_training, 0.1, 1e-5, true);
    } else {
        identity = x.to(out.dtype());
    }
    
    return torch::relu(out + identity);
}

torch::Tensor forward(torch::Tensor x, py::object params, bool is_training) {
    auto device = x.device();
    
    auto conv1_w = params.attr("get")("conv1_w").cast<torch::Tensor>().contiguous().to(device, true);
    auto bn1_w = params.attr("get")("bn1_w").cast<torch::Tensor>().contiguous().to(device, true);
    auto bn1_b = params.attr("get")("bn1_b").cast<torch::Tensor>().contiguous().to(device, true);
    auto bn1_m = params.attr("get")("bn1_m").cast<torch::Tensor>().contiguous().to(device, true);
    auto bn1_v = params.attr("get")("bn1_v").cast<torch::Tensor>().contiguous().to(device, true);
    
    x = torch::conv2d(x, conv1_w, /*bias=*/torch::Tensor(), 2, 3);
    x = torch::batch_norm(x, bn1_w, bn1_b, bn1_m, bn1_v, is_training, 0.1, 1e-5, true);
    x = torch::relu(x);
    x = torch::max_pool2d(x, 3, 2, 1);
    
    for (int layer_idx = 1; layer_idx <= 4; ++layer_idx) {
        std::string layer_key = "layer" + std::to_string(layer_idx) + "_blocks";
        py::list blocks = params.attr("get")(py::str(layer_key)).cast<py::list>();
        
        for (size_t block_idx = 0; block_idx < blocks.size(); ++block_idx) {
            py::object block = blocks[block_idx];
            
            auto conv1_w_blk = block.attr("get")("conv1_w").cast<torch::Tensor>().contiguous().to(device, true);
            auto conv2_w_blk = block.attr("get")("conv2_w").cast<torch::Tensor>().contiguous().to(device, true);
            auto conv3_w_blk = block.attr("get")("conv3_w").cast<torch::Tensor>().contiguous().to(device, true);
            
            auto bn1_w_blk = block.attr("get")("bn1_w").cast<torch::Tensor>().contiguous().to(device, true);
            auto bn1_b_blk = block.attr("get")("bn1_b").cast<torch::Tensor>().contiguous().to(device, true);
            auto bn1_m_blk = block.attr("get")("bn1_m").cast<torch::Tensor>().contiguous().to(device, true);
            auto bn1_v_blk = block.attr("get")("bn1_v").cast<torch::Tensor>().contiguous().to(device, true);
            
            auto bn2_w_blk = block.attr("get")("bn2_w").cast<torch::Tensor>().contiguous().to(device, true);
            auto bn2_b_blk = block.attr("get")("bn2_b").cast<torch::Tensor>().contiguous().to(device, true);
            auto bn2_m_blk = block.attr("get")("bn2_m").cast<torch::Tensor>().contiguous().to(device, true);
            auto bn2_v_blk = block.attr("get")("bn2_v").cast<torch::Tensor>().contiguous().to(device, true);
            
            auto bn3_w_blk = block.attr("get")("bn3_w").cast<torch::Tensor>().contiguous().to(device, true);
            auto bn3_b_blk = block.attr("get")("bn3_b").cast<torch::Tensor>().contiguous().to(device, true);
            auto bn3_m_blk = block.attr("get")("bn3_m").cast<torch::Tensor>().contiguous().to(device, true);
            auto bn3_v_blk = block.attr("get")("bn3_v").cast<torch::Tensor>().contiguous().to(device, true);
            
            bool has_downsample = py::bool_(block.attr("__contains__")(py::str("downsample_conv_w")));
            torch::Tensor downsample_conv_w, downsample_bn_w, downsample_bn_b, downsample_bn_m, downsample_bn_v;
            if (has_downsample) {
                downsample_conv_w = block.attr("get")("downsample_conv_w").cast<torch::Tensor>().contiguous().to(device, true);
                downsample_bn_w = block.attr("get")("downsample_bn_w").cast<torch::Tensor>().contiguous().to(device, true);
                downsample_bn_b = block.attr("get")("downsample_bn_b").cast<torch::Tensor>().contiguous().to(device, true);
                downsample_bn_m = block.attr("get")("downsample_bn_m").cast<torch::Tensor>().contiguous().to(device, true);
                downsample_bn_v = block.attr("get")("downsample_bn_v").cast<torch::Tensor>().contiguous().to(device, true);
            }
            
            int64_t stride = (block_idx == 0 && layer_idx > 1) ? 2 : 1;
            
            x = unified_bottleneck(
                x,
                conv1_w_blk, conv2_w_blk, conv3_w_blk,
                bn1_w_blk, bn1_b_blk, bn1_m_blk, bn1_v_blk,
                bn2_w_blk, bn2_b_blk, bn2_m_blk, bn2_v_blk,
                bn3_w_blk, bn3_b_blk, bn3_m_blk, bn3_v_blk,
                has_downsample,
                downsample_conv_w, downsample_bn_w, downsample_bn_b, downsample_bn_m, downsample_bn_v,
                stride, is_training
            );
        }
    }
    
    auto sizes = x.sizes();
    int N = sizes[0];
    int C = sizes[1];
    int H = sizes[2];
    int W = sizes[3];
    int spatial = H * W;
    auto pooled = torch::empty({N, C, 1, 1}, x.options());
    
    int grid = N * C;
    int threads = 32; 
    warp_avg_pool_kernel<<<grid, threads>>>(x.data_ptr<float>(), pooled.data_ptr<float>(), spatial, C, H, W);
    cudaDeviceSynchronize();
    
    x = pooled.view({N, C});
    
    auto fc_w = params.attr("get")("fc_w").cast<torch::Tensor>().contiguous().to(device, true);
    auto fc_b = params.attr("get")("fc_b").cast<torch::Tensor>().contiguous().to(device, true);
    return torch::linear(x, fc_w, fc_b);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Optimized ResNet101 with warp pooling and streamlined params");
}