46_NetVladWithGhostClusters
• netvlad_modular_device_funcs_base
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch as th
def module_fn(
x: torch.Tensor,
clusters: torch.Tensor,
clusters2: torch.Tensor,
bn_weight: torch.Tensor,
bn_bias: torch.Tensor,
bn_mean: torch.Tensor,
bn_var: torch.Tensor,
is_training: bool,
cluster_size: int,
feature_size: int,
) -> torch.Tensor:
"""
Functional version of the NetVLAD with ghost clusters
Args:
x: Input tensor of shape (batch_size, num_features, feature_size)
clusters: Weight tensor for cluster assignments
clusters2: Weight tensor for visual words
bn_weight: BatchNorm weight
bn_bias: BatchNorm bias
bn_mean: BatchNorm running mean
bn_var: BatchNorm running var
is_training: Whether in training mode
cluster_size: Number of clusters (K)
feature_size: Feature dimension (D)
Returns:
Output tensor of shape (batch_size, cluster_size * feature_size)
"""
max_sample = x.size()[1]
x = x.view(-1, feature_size) # B x N x D -> BN x D
if x.device != clusters.device:
msg = f"x.device {x.device} != cluster.device {clusters.device}"
raise ValueError(msg)
assignment = th.matmul(x, clusters) # (BN x D) x (D x (K+G)) -> BN x (K+G)
assignment = F.batch_norm(
assignment, bn_mean, bn_var, bn_weight, bn_bias, is_training
)
assignment = F.softmax(assignment, dim=1) # BN x (K+G) -> BN x (K+G)
# remove ghost assigments
assignment = assignment[:, :cluster_size]
assignment = assignment.view(-1, max_sample, cluster_size) # -> B x N x K
a_sum = th.sum(assignment, dim=1, keepdim=True) # B x N x K -> B x 1 x K
a = a_sum * clusters2
assignment = assignment.transpose(1, 2) # B x N x K -> B x K x N
x = x.view(-1, max_sample, feature_size) # BN x D -> B x N x D
vlad = th.matmul(assignment, x) # (B x K x N) x (B x N x D) -> B x K x D
vlad = vlad.transpose(1, 2) # -> B x D x K
vlad = vlad - a
# L2 intra norm
vlad = F.normalize(vlad)
# flattening + L2 norm
vlad = vlad.reshape(-1, cluster_size * feature_size) # -> B x DK
vlad = F.normalize(vlad)
return vlad # B x DK
class Model(nn.Module):
def __init__(self, cluster_size, feature_size, ghost_clusters):
super(Model, self).__init__()
self.feature_size = feature_size
self.cluster_size = cluster_size
self.ghost_clusters = ghost_clusters
init_sc = 1 / math.sqrt(feature_size)
clusters = cluster_size + ghost_clusters
# The `clusters` weights are the `(w,b)` in the paper
self.clusters = nn.Parameter(init_sc * th.randn(feature_size, clusters))
# Extract batchnorm parameters
bn = nn.BatchNorm1d(clusters)
self.bn_weight = nn.Parameter(bn.weight.data.clone())
self.bn_bias = nn.Parameter(bn.bias.data.clone())
self.bn_mean = nn.Parameter(bn.running_mean.data.clone())
self.bn_var = nn.Parameter(bn.running_var.data.clone())
# The `clusters2` weights are the visual words `c_k` in the paper
self.clusters2 = nn.Parameter(init_sc * th.randn(1, feature_size, cluster_size))
self.out_dim = self.cluster_size * feature_size
def forward(self, x, fn=module_fn):
return fn(
x,
self.clusters,
self.clusters2,
self.bn_weight,
self.bn_bias,
self.bn_mean,
self.bn_var,
self.training,
self.cluster_size,
self.feature_size,
)
batch_size = 32
num_features = 100
num_clusters = 32
feature_size = 512
ghost_clusters = 16
def get_inputs():
return [torch.randn(batch_size, num_features, feature_size)]
def get_init_inputs():
return [num_clusters, feature_size, ghost_clusters]
# Copyright 2018 Antoine Miech All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Code modified from here
https://github.com/albanie/collaborative-experts/blob/master/model/net_vlad.py
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch as th
class Model(nn.Module):
def __init__(self, cluster_size, feature_size, ghost_clusters):
super(Model, self).__init__()
self.feature_size = feature_size
self.cluster_size = cluster_size
self.ghost_clusters = ghost_clusters
init_sc = (1 / math.sqrt(feature_size))
clusters = cluster_size + ghost_clusters
# The `clusters` weights are the `(w,b)` in the paper
self.clusters = nn.Parameter(init_sc * th.randn(feature_size, clusters))
self.batch_norm = nn.BatchNorm1d(clusters)
# The `clusters2` weights are the visual words `c_k` in the paper
self.clusters2 = nn.Parameter(init_sc * th.randn(1, feature_size, cluster_size))
self.out_dim = self.cluster_size * feature_size
def forward(self, x, mask=None):
"""Aggregates feature maps into a fixed size representation. In the following
notation, B = batch_size, N = num_features, K = num_clusters, D = feature_size.
Args:
x (th.Tensor): B x N x D
Returns:
(th.Tensor): B x DK
"""
max_sample = x.size()[1]
x = x.view(-1, self.feature_size) # B x N x D -> BN x D
if x.device != self.clusters.device:
msg = f"x.device {x.device} != cluster.device {self.clusters.device}"
raise ValueError(msg)
assignment = th.matmul(x, self.clusters) # (BN x D) x (D x (K+G)) -> BN x (K+G)
assignment = self.batch_norm(assignment)
assignment = F.softmax(assignment, dim=1) # BN x (K+G) -> BN x (K+G)
# remove ghost assigments
assignment = assignment[:, :self.cluster_size]
assignment = assignment.view(-1, max_sample, self.cluster_size) # -> B x N x K
a_sum = th.sum(assignment, dim=1, keepdim=True) # B x N x K -> B x 1 x K
a = a_sum * self.clusters2
assignment = assignment.transpose(1, 2) # B x N x K -> B x K x N
x = x.view(-1, max_sample, self.feature_size) # BN x D -> B x N x D
vlad = th.matmul(assignment, x) # (B x K x N) x (B x N x D) -> B x K x D
vlad = vlad.transpose(1, 2) # -> B x D x K
vlad = vlad - a
# L2 intra norm
vlad = F.normalize(vlad)
# flattening + L2 norm
vlad = vlad.reshape(-1, self.cluster_size * self.feature_size) # -> B x DK
vlad = F.normalize(vlad)
return vlad # B x DK
batch_size = 32
num_features = 100
num_clusters = 32
feature_size = 512
ghost_clusters = 16
def get_inputs():
return [torch.randn(batch_size, num_features, feature_size)]
def get_init_inputs():
return [num_clusters, feature_size, ghost_clusters]
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <cuda_runtime.h>
// Maximum number of floats that can be stored in constant memory (16384 floats ~ 64KB)
#define MAX_CLUSTERS_SIZE 16384
// Declare constant memory for the clusters tensor (expected shape [D, K+G])
__constant__ float d_clusters[MAX_CLUSTERS_SIZE];
// Modular device function for warp-level reduction
__device__ float warpReduceSum(float val) {
// Use full warp mask
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
val += __shfl_down_sync(0xFFFFFFFF, val, offset);
}
return val;
}
// Modular device function to compute partial dot product for one row and one column
// Each thread in a warp computes a partial sum over elements of D using a stride of warpSize
__device__ float dotProduct(const float* __restrict__ x, int row, int D, int N_out, int col, int lane) {
float sum = 0.0f;
for (int k = lane; k < D; k += warpSize) {
// Access clusters from constant memory; assumed layout: [D, N_out] (column-major for clusters)
sum += x[row * D + k] * d_clusters[k * N_out + col];
}
return sum;
}
// Modular warp-level kernel for matrix multiplication using the above device functions
// Computes out = x * clusters, where x is [M, D] and clusters (in constant memory) is [D, N_out].
// Each warp computes one output element (dot product) by having its lanes cooperate.
__global__ void matmul_warp_kernel(const float* __restrict__ x, float* __restrict__ out, int M, int D, int N_out) {
// Determine the warp id and lane id within the warp
int warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / warpSize;
int lane = threadIdx.x & (warpSize - 1);
// Each warp is responsible for one element of the output matrix
if (warp_id < M * N_out) {
int row = warp_id / N_out;
int col = warp_id % N_out;
// Compute partial dot product using the modular dotProduct function
float partial = dotProduct(x, row, D, N_out, col, lane);
// Reduce the partial sums across the warp
float sum = warpReduceSum(partial);
// Lane 0 writes the final result
if (lane == 0) {
out[row * N_out + col] = sum;
}
}
}
// Forward function for NetVLAD with ghost clusters using modular CUDA device functions
// This function first computes the assignment via custom matrix multiplication (with warp-level reduction),
// then applies batch normalization, softmax, VLAD aggregation, and normalization as in the original implementation.
torch::Tensor forward(
torch::Tensor x,
torch::Tensor clusters,
torch::Tensor clusters2,
torch::Tensor bn_weight,
torch::Tensor bn_bias,
torch::Tensor bn_mean,
torch::Tensor bn_var,
bool is_training,
int64_t cluster_size,
int64_t feature_size
) {
// x is expected to be [B, N, D]
auto B = x.size(0);
auto N = x.size(1);
auto D = x.size(2);
TORCH_CHECK(D == feature_size, "feature_size mismatch.");
// Flatten x to [M, D] where M = B * N
int64_t M = B * N;
auto x_flat = x.reshape({M, D}).contiguous();
// clusters is [D, (K+G)] where (K+G) >= cluster_size
int N_out = clusters.size(1);
auto assignment = torch::empty({M, N_out}, x.options());
// Use the custom warp-level kernel if clusters fits in constant memory
if (clusters.numel() <= MAX_CLUSTERS_SIZE) {
auto clusters_contig = clusters.contiguous();
size_t clusters_bytes = clusters_contig.numel() * sizeof(float);
cudaError_t err = cudaMemcpyToSymbol(d_clusters, clusters_contig.data_ptr<float>(), clusters_bytes);
TORCH_CHECK(err == cudaSuccess, "cudaMemcpyToSymbol failed: ", cudaGetErrorString(err));
// Total output elements = M * N_out, with each warp computing one element
int total_warps = M * N_out;
int threadsPerBlock = 256; // e.g., 8 warps per block
int warpsPerBlock = threadsPerBlock / 32;
int blocks = (total_warps + warpsPerBlock - 1) / warpsPerBlock;
matmul_warp_kernel<<<blocks, threadsPerBlock>>>(
x_flat.data_ptr<float>(),
assignment.data_ptr<float>(),
M, D, N_out
);
err = cudaGetLastError();
TORCH_CHECK(err == cudaSuccess, "Kernel launch failed: ", cudaGetErrorString(err));
err = cudaDeviceSynchronize();
TORCH_CHECK(err == cudaSuccess, "Kernel sync failed: ", cudaGetErrorString(err));
} else {
// Fallback to regular matmul if clusters does not fit in constant memory
assignment = at::matmul(x_flat, clusters);
}
// Batch normalization
assignment = at::batch_norm(
assignment, bn_weight, bn_bias, bn_mean, bn_var,
is_training, 0.1, 1e-5, true
);
// Softmax along dimension 1 and removing ghost clusters to yield [M, cluster_size]
assignment = at::softmax(assignment, 1);
assignment = assignment.narrow(1, 0, cluster_size);
// Reshape to [B, N, cluster_size]
assignment = assignment.reshape({B, N, cluster_size});
// Compute a_sum = sum across the N dimension -> [B, 1, cluster_size]
auto a_sum = assignment.sum(1, /*keepdim=*/true);
// Multiply by clusters2 (expected shape: [1, D, cluster_size])
auto a = a_sum * clusters2;
// Transpose assignment to [B, cluster_size, N]
assignment = assignment.transpose(1, 2);
// Reshape x back to [B, N, D]
auto x_reshaped = x_flat.reshape({B, N, D});
// VLAD aggregation: compute [B, cluster_size, D] via bmm and then transpose to [B, D, cluster_size]
auto vlad = at::bmm(assignment, x_reshaped);
vlad = vlad.transpose(1, 2);
vlad = vlad - a;
// Intra-normalize across the D dimension
vlad = vlad / (vlad.norm(2, {1}, /*keepdim=*/true) + 1e-12);
// Flatten to [B, D * cluster_size] and perform final L2 normalization
vlad = vlad.reshape({B, D * cluster_size});
vlad = vlad / (vlad.norm(2, {1}, /*keepdim=*/true) + 1e-12);
return vlad;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "NetVLAD with ghost clusters (Modular CUDA device functions)");
}
Metric | Value | Unit | Variance | Samples |
---|
Rule | Description |
---|
Operation / Metric | Value | Unit |
---|---|---|
aten::zero_ | ||
CPU Time | 511981.23 | μs |
Device Time | 2248122.11 | μs |
Self CPU Time | 92047.42 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::fill_ | ||
CPU Time | 419958.85 | μs |
Device Time | 2248122.11 | μs |
Self CPU Time | 119983.46 | μs |
Self Device Time | 2248122.11 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::matmul | ||
CPU Time | 833267.38 | μs |
Device Time | 488343.45 | μs |
Self CPU Time | 33486.67 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::batch_norm | ||
CPU Time | 1710978.96 | μs |
Device Time | 488396.07 | μs |
Self CPU Time | 52138.31 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_batch_norm_impl_index | ||
CPU Time | 1658840.65 | μs |
Device Time | 488396.07 | μs |
Self CPU Time | 72498.31 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::native_batch_norm | ||
CPU Time | 1531467.61 | μs |
Device Time | 488396.07 | μs |
Self CPU Time | 441495.49 | μs |
Self Device Time | 420326.48 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaLaunchKernel | ||
CPU Time | 2385576.39 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 2385576.39 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>) | ||
CPU Time | 0.00 | μs |
Device Time | 2248122.11 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 2248122.11 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45296 warnings generated when compiling for host. Suppressed 45330 warnings (45283 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.