import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(
anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor, margin: float
) -> torch.Tensor:
"""
Computes the Triplet Margin Loss for metric learning tasks.
Args:
anchor (torch.Tensor): Anchor values.
positive (torch.Tensor): Positive values.
negative (torch.Tensor): Negative values.
margin (float): Margin value.
Returns:
torch.Tensor: Triplet Margin Loss.
"""
return F.triplet_margin_loss(anchor, positive, negative, margin=margin)
class Model(nn.Module):
"""
A model that computes Triplet Margin Loss for metric learning tasks.
"""
def __init__(self, margin):
super(Model, self).__init__()
self.margin = margin
def forward(self, anchor, positive, negative, fn=module_fn):
return fn(anchor, positive, negative, self.margin)
batch_size = 128
input_shape = (4096,)
dim = 1
margin = 1.0
def get_inputs():
return [
torch.randn(batch_size, *input_shape),
torch.randn(batch_size, *input_shape),
torch.randn(batch_size, *input_shape),
]
def get_init_inputs():
return [margin]
import torch
import torch.nn as nn
class Model(nn.Module):
"""
A model that computes Triplet Margin Loss for metric learning tasks.
Parameters:
margin (float): The margin between the positive and negative samples.
"""
def __init__(self, margin=1.0):
super(Model, self).__init__()
self.loss_fn = torch.nn.TripletMarginLoss(margin=margin)
def forward(self, anchor, positive, negative):
return self.loss_fn(anchor, positive, negative)
batch_size = 128
input_shape = (4096, )
dim = 1
def get_inputs():
return [torch.randn(batch_size, *input_shape), torch.randn(batch_size, *input_shape), torch.randn(batch_size, *input_shape)]
def get_init_inputs():
return [1.0] # Default margin
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>
// Kernel using __ldg() for read-only global memory accesses and vectorized loads aligned to 128-bit boundaries
template <typename scalar_t>
__global__ void ldg_aligned_triplet_kernel(
const scalar_t* __restrict__ anchor,
const scalar_t* __restrict__ positive,
const scalar_t* __restrict__ negative,
scalar_t* __restrict__ output,
const float margin,
const int batch_size,
const int feat_size) {
// Each block processes one batch element
int batch_idx = blockIdx.x;
if (batch_idx >= batch_size) return;
int tid = threadIdx.x;
int base_idx = batch_idx * feat_size;
scalar_t local_dist_pos = 0;
scalar_t local_dist_neg = 0;
// Use vectorized loads for read-only global memory accesses with __ldg()
if constexpr (std::is_same<scalar_t, float>::value) {
// Use float4 for 128-bit (4x32-bit) aligned loads
using vec_t = float4;
constexpr int vec_size = 4;
int vectorized_length = feat_size / vec_size;
int remainder = feat_size % vec_size;
const vec_t* anchor_vec = reinterpret_cast<const vec_t*>(anchor + base_idx);
const vec_t* positive_vec = reinterpret_cast<const vec_t*>(positive + base_idx);
const vec_t* negative_vec = reinterpret_cast<const vec_t*>(negative + base_idx);
for (int i = tid; i < vectorized_length; i += blockDim.x) {
vec_t a_vec = __ldg(&anchor_vec[i]);
vec_t p_vec = __ldg(&positive_vec[i]);
vec_t n_vec = __ldg(&negative_vec[i]);
float diff0 = a_vec.x - p_vec.x;
float diff1 = a_vec.y - p_vec.y;
float diff2 = a_vec.z - p_vec.z;
float diff3 = a_vec.w - p_vec.w;
local_dist_pos += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
float diff0n = a_vec.x - n_vec.x;
float diff1n = a_vec.y - n_vec.y;
float diff2n = a_vec.z - n_vec.z;
float diff3n = a_vec.w - n_vec.w;
local_dist_neg += diff0n * diff0n + diff1n * diff1n + diff2n * diff2n + diff3n * diff3n;
}
int offset = vectorized_length * vec_size;
for (int i = tid; i < remainder; i += blockDim.x) {
int idx = base_idx + offset + i;
float a = __ldg(&anchor[idx]);
float p = __ldg(&positive[idx]);
float n = __ldg(&negative[idx]);
float diff = a - p;
local_dist_pos += diff * diff;
float diffn = a - n;
local_dist_neg += diffn * diffn;
}
} else if constexpr (std::is_same<scalar_t, double>::value) {
// Use double2 for 128-bit (2x64-bit) aligned loads
using vec_t = double2;
constexpr int vec_size = 2;
int vectorized_length = feat_size / vec_size;
int remainder = feat_size % vec_size;
const vec_t* anchor_vec = reinterpret_cast<const vec_t*>(anchor + base_idx);
const vec_t* positive_vec = reinterpret_cast<const vec_t*>(positive + base_idx);
const vec_t* negative_vec = reinterpret_cast<const vec_t*>(negative + base_idx);
for (int i = tid; i < vectorized_length; i += blockDim.x) {
vec_t a_vec = __ldg(&anchor_vec[i]);
vec_t p_vec = __ldg(&positive_vec[i]);
vec_t n_vec = __ldg(&negative_vec[i]);
double diff0 = a_vec.x - p_vec.x;
double diff1 = a_vec.y - p_vec.y;
local_dist_pos += diff0 * diff0 + diff1 * diff1;
double diff0n = a_vec.x - n_vec.x;
double diff1n = a_vec.y - n_vec.y;
local_dist_neg += diff0n * diff0n + diff1n * diff1n;
}
int offset = vectorized_length * vec_size;
for (int i = tid; i < remainder; i += blockDim.x) {
int idx = base_idx + offset + i;
double a = __ldg(&anchor[idx]);
double p = __ldg(&positive[idx]);
double n = __ldg(&negative[idx]);
double diff = a - p;
local_dist_pos += diff * diff;
double diffn = a - n;
local_dist_neg += diffn * diffn;
}
} else {
// Fallback for other types: scalar reads using __ldg()
for (int i = tid; i < feat_size; i += blockDim.x) {
int idx = base_idx + i;
scalar_t a = __ldg(&anchor[idx]);
scalar_t p = __ldg(&positive[idx]);
scalar_t n = __ldg(&negative[idx]);
scalar_t diff = a - p;
local_dist_pos += diff * diff;
scalar_t diffn = a - n;
local_dist_neg += diffn * diffn;
}
}
// Warp-level reduction within each block
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
local_dist_pos += __shfl_down_sync(0xffffffff, local_dist_pos, offset);
local_dist_neg += __shfl_down_sync(0xffffffff, local_dist_neg, offset);
}
__shared__ scalar_t shared_sum_pos[32];
__shared__ scalar_t shared_sum_neg[32];
int lane = tid % 32;
int warp_id = tid / 32;
if (lane == 0) {
shared_sum_pos[warp_id] = local_dist_pos;
shared_sum_neg[warp_id] = local_dist_neg;
}
__syncthreads();
scalar_t block_sum_pos = 0;
scalar_t block_sum_neg = 0;
if (tid < (blockDim.x / 32)) {
block_sum_pos = shared_sum_pos[lane];
block_sum_neg = shared_sum_neg[lane];
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
block_sum_pos += __shfl_down_sync(0xffffffff, block_sum_pos, offset);
block_sum_neg += __shfl_down_sync(0xffffffff, block_sum_neg, offset);
}
if (lane == 0) {
scalar_t loss = sqrt(block_sum_pos) - sqrt(block_sum_neg) + margin;
output[batch_idx] = loss < scalar_t(0) ? scalar_t(0) : loss;
}
}
}
// Host function to launch the kernel
torch::Tensor triplet_margin_loss_cuda(
torch::Tensor anchor,
torch::Tensor positive,
torch::Tensor negative,
float margin) {
TORCH_CHECK(anchor.device().is_cuda(), "anchor must be a CUDA tensor");
TORCH_CHECK(positive.device().is_cuda(), "positive must be a CUDA tensor");
TORCH_CHECK(negative.device().is_cuda(), "negative must be a CUDA tensor");
const int batch_size = anchor.size(0);
const int feat_size = anchor.size(1);
auto output = torch::zeros({batch_size}, anchor.options());
// Launch one block per batch sample; use 256 threads per block
const int threads_per_block = 256;
const int num_blocks = batch_size;
AT_DISPATCH_FLOATING_TYPES(anchor.scalar_type(), "ldg_aligned_triplet_kernel", ([&] {
ldg_aligned_triplet_kernel<scalar_t><<<num_blocks, threads_per_block>>>(
anchor.data_ptr<scalar_t>(),
positive.data_ptr<scalar_t>(),
negative.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(),
margin,
batch_size,
feat_size);
}));
return output.mean();
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &triplet_margin_loss_cuda, "Triplet margin loss forward (CUDA)");
}
Metric | Value | Unit | Variance | Samples |
---|---|---|---|---|
Executed Ipc Active | 0.290 | inst/cycle | 0.000 | 5 |
Executed Ipc Elapsed | 0.170 | inst/cycle | 0.000 | 5 |
Issue Slots Busy | 7.298 | % | 0.001 | 5 |
Issued Ipc Active | 0.290 | inst/cycle | 0.000 | 5 |
SM Busy | 7.298 | % | 0.001 | 5 |
Memory Throughput | 989833350235.322 | byte/second | 72983699404739510272.000 | 5 |
Mem Busy | 17.128 | % | 0.023 | 5 |
Max Bandwidth | 29.646 | % | 0.055 | 5 |
L1/TEX Hit Rate | 0.000 | % | 0.000 | 5 |
L2 Hit Rate | 13.132 | % | 0.020 | 5 |
Mem Pipes Busy | 2.054 | % | 0.000 | 5 |
Warp Cycles Per Issued Instruction | 26.576 | cycle | 1.286 | 5 |
Warp Cycles Per Executed Instruction | 26.828 | cycle | 1.311 | 5 |
Avg. Active Threads Per Warp | 29.530 | 0.000 | 5 | |
Avg. Not Predicated Off Threads Per Warp | 28.940 | 0.000 | 5 | |
Max Active Clusters | 0.000 | cluster | 0.000 | 5 |
Max Cluster Size | 8.000 | block | 0.000 | 5 |
Overall GPU Occupancy | 0.000 | % | 0.000 | 5 |
Cluster Occupancy | 0.000 | % | 0.000 | 5 |
Block Limit SM | 32.000 | block | 0.000 | 5 |
Block Limit Registers | 8.000 | block | 0.000 | 5 |
Block Limit Shared Mem | 25.000 | block | 0.000 | 5 |
Block Limit Warps | 8.000 | block | 0.000 | 5 |
Theoretical Active Warps per SM | 64.000 | warp | 0.000 | 5 |
Theoretical Occupancy | 100.000 | % | 0.000 | 5 |
Achieved Occupancy | 11.778 | % | 0.000 | 5 |
Achieved Active Warps Per SM | 7.538 | warp | 0.000 | 5 |
Rule | Description |
---|---|
WRN HighPipeUtilization | All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details. |
INF CPIStall | Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason. |
WRN Occupancy | This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (11.8%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy. |
Operation / Metric | Value | Unit |
---|---|---|
aten::zeros | ||
CPU Time | 5040542.28 | μs |
Device Time | 212278.80 | μs |
Self CPU Time | 145491.06 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::zero_ | ||
CPU Time | 5351407.76 | μs |
Device Time | 7288951.88 | μs |
Self CPU Time | 306486.68 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::fill_ | ||
CPU Time | 5044922.20 | μs |
Device Time | 7288951.88 | μs |
Self CPU Time | 375383.86 | μs |
Self Device Time | 7288951.88 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaLaunchKernel | ||
CPU Time | 5341610.87 | μs |
Device Time | 464809.89 | μs |
Self CPU Time | 5341610.87 | μs |
Self Device Time | 464809.89 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void ldg_aligned_triplet_kernel<float>(float const*, float const*, float const*, float*, float, int, int) | ||
CPU Time | 0.00 | μs |
Device Time | 584998.81 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 584998.81 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::mean | ||
CPU Time | 1130569.78 | μs |
Device Time | 377295.99 | μs |
Self CPU Time | 717364.71 | μs |
Self Device Time | 377295.99 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>) | ||
CPU Time | 0.00 | μs |
Device Time | 7076673.08 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 7076673.08 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45291 warnings generated when compiling for host. Suppressed 45324 warnings (45277 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.