← Back to Leaderboard

The AI CUDA Engineer 👷

46_NetVladWithGhostClusters46_netvlad_reduced_sync_base

Level 3 • Task 46

Kernel Information

Related Kernels (Level 3, Task 46 • 46_NetVladWithGhostClusters)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 modular_netvlad_ghost_base 0.10 1.99 0.78
🥇 warp_reduction_netvlad_base 0.10 1.99 0.78
🥇 netvlad_warp_shfl_sync_optimized_base 0.10 1.99 0.78
4 sync_optimized_netvlad_base_base 0.10 1.97 0.77
4 warp_reduction_netvlad_optimized_base 0.10 1.97 0.77
4 shared_memory_netvlad_v2_base 0.10 1.97 0.77
4 netvlad_modular_device_funcs_base 0.10 1.97 0.77
4 netvlad_warp_shfl_optimized_edit_1 0.10 1.97 0.77
4 netvlad_warp_atomic_optimized_edit_1 0.10 1.97 0.77
10 netvlad_warp_shfl_optimized_base 0.10 1.95 0.76
10 netvlad_block_size_optimized_base 0.10 1.95 0.76
10 46_NetVladWithGhostClusters 0.10 1.95 0.76
13 46_netvlad_reduced_sync_base 0.10 1.93 0.75
13 shared_memory_netvlad_optimized_base 0.10 1.93 0.75
13 optimized_netvlad_cuda_edit_1 0.10 1.93 0.75
13 46_netvlad_reduced_sync_edit_1 0.10 1.93 0.75
13 shared_memory_netvlad_optimized_base 0.10 1.93 0.75
13 netvlad_block_size_optimized_edit_1 0.10 1.93 0.75
13 netvlad_warp_atomic_optimized_base 0.10 1.93 0.75
13 netvlad_modular_device_funcs_edit_1 0.10 1.93 0.75
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <vector>

torch::Tensor forward(
    torch::Tensor x,
    torch::Tensor clusters,
    torch::Tensor clusters2,
    torch::Tensor bn_weight,
    torch::Tensor bn_bias,
    torch::Tensor bn_mean,
    torch::Tensor bn_var,
    bool is_training,
    int64_t cluster_size,
    int64_t feature_size
) {
    auto B = x.size(0);
    auto N = x.size(1);
    auto D = x.size(2);
    TORCH_CHECK(D == feature_size, "feature_size mismatch.");

    // Flatten x to [B*N, D] - no sync needed
    x = x.reshape({B * N, D});

    // Fuse operations that don't require sync
    auto assignment = at::matmul(x, clusters);
    assignment = at::batch_norm(
        assignment, bn_weight, bn_bias, bn_mean, bn_var,
        is_training, 0.1, 1e-5, true
    );
    
    // Single sync point after softmax
    assignment = at::softmax(assignment, 1);
    assignment = assignment.narrow(1, 0, cluster_size);
    assignment = assignment.reshape({B, N, cluster_size});

    // Fuse sum and multiplication operations
    auto a_sum = assignment.sum(1, true);
    auto a = a_sum * clusters2;

    // Optimize transpose and reshape operations
    assignment = assignment.transpose(1, 2).contiguous();
    x = x.reshape({B, N, D});

    // Fuse matrix multiplication and transpose
    auto vlad = at::bmm(assignment, x).transpose(1, 2);
    vlad = vlad - a;

    // Combine normalization operations
    vlad = vlad / (vlad.norm(2, {1}, true) + 1e-12);
    vlad = vlad.reshape({B, D * cluster_size});
    vlad = vlad / (vlad.norm(2, {1}, true) + 1e-12);

    return vlad;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "NetVLAD with ghost clusters (CUDA)");
}