import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(
anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor, margin: float
) -> torch.Tensor:
Computes the Triplet Margin Loss for metric learning tasks.
anchor (torch.Tensor): Anchor values.
positive (torch.Tensor): Positive values.
negative (torch.Tensor): Negative values.
margin (float): Margin value.
torch.Tensor: Triplet Margin Loss.
return F.triplet_margin_loss(anchor, positive, negative, margin=margin)
class Model(nn.Module):
A model that computes Triplet Margin Loss for metric learning tasks.
def __init__(self, margin):
super(Model, self).__init__()
self.margin = margin
def forward(self, anchor, positive, negative, fn=module_fn):
return fn(anchor, positive, negative, self.margin)
batch_size = 128
input_shape = (4096,)
dim = 1
margin = 1.0
def get_inputs():
return [
torch.randn(batch_size, *input_shape),
torch.randn(batch_size, *input_shape),
torch.randn(batch_size, *input_shape),
def get_init_inputs():
return [margin]
import torch
import torch.nn as nn
class Model(nn.Module):
A model that computes Triplet Margin Loss for metric learning tasks.
margin (float): The margin between the positive and negative samples.
def __init__(self, margin=1.0):
super(Model, self).__init__()
self.loss_fn = torch.nn.TripletMarginLoss(margin=margin)
def forward(self, anchor, positive, negative):
return self.loss_fn(anchor, positive, negative)
batch_size = 128
input_shape = (4096, )
dim = 1
def get_inputs():
return [torch.randn(batch_size, *input_shape), torch.randn(batch_size, *input_shape), torch.randn(batch_size, *input_shape)]
def get_init_inputs():
return [1.0] # Default margin
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>
__global__ void triplet_margin_loss_combined(
const float* __restrict__ anchor,
const float* __restrict__ positive,
const float* __restrict__ negative,
float* __restrict__ output,
const float margin,
const int batch_size,
const int feat_size) {
int batch_idx = blockIdx.x;
if (batch_idx >= batch_size) return;
int offset = batch_idx * feat_size;
int tid = threadIdx.x;
float sum_pos = 0.f;
float sum_neg = 0.f;
// Vectorized processing with 128-bit loads
int vectorized_end = (feat_size / 4) * 4;
const float4* anchor_vec = reinterpret_cast<const float4*>(anchor + offset);
const float4* positive_vec = reinterpret_cast<const float4*>(positive + offset);
const float4* negative_vec = reinterpret_cast<const float4*>(negative + offset);
int num_vec = vectorized_end / 4;
for (int i = tid; i < num_vec; i += blockDim.x) {
float4 a = __ldg(&anchor_vec[i]);
float4 p = __ldg(&positive_vec[i]);
float4 n = __ldg(&negative_vec[i]);
// Positive distances
float d = a.x - p.x; sum_pos += d * d;
d = a.y - p.y; sum_pos += d * d;
d = a.z - p.z; sum_pos += d * d;
d = a.w - p.w; sum_pos += d * d;
// Negative distances
d = a.x - n.x; sum_neg += d * d;
d = a.y - n.y; sum_neg += d * d;
d = a.z - n.z; sum_neg += d * d;
d = a.w - n.w; sum_neg += d * d;
// Process remaining elements
for (int i = vectorized_end + tid; i < feat_size; i += blockDim.x) {
float a = __ldg(anchor + offset + i);
float p = __ldg(positive + offset + i);
float n = __ldg(negative + offset + i);
float d = a - p;
sum_pos += d * d;
d = a - n;
sum_neg += d * d;
// Warp-level reduction
unsigned int warp_mask = 0xffffffff;
for (int offset = 16; offset > 0; offset >>= 1) {
sum_pos += __shfl_down_sync(warp_mask, sum_pos, offset);
sum_neg += __shfl_down_sync(warp_mask, sum_neg, offset);
// Cross-warp reduction
__shared__ float shared_pos[32];
__shared__ float shared_neg[32];
int lane = tid % warpSize;
int warp_id = tid / warpSize;
if (lane == 0) {
shared_pos[warp_id] = sum_pos;
shared_neg[warp_id] = sum_neg;
if (tid < warpSize) {
sum_pos = tid < blockDim.x / warpSize ? shared_pos[tid] : 0;
sum_neg = tid < blockDim.x / warpSize ? shared_neg[tid] : 0;
for (int offset = 16; offset > 0; offset >>= 1) {
sum_pos += __shfl_down_sync(warp_mask, sum_pos, offset);
sum_neg += __shfl_down_sync(warp_mask, sum_neg, offset);
if (tid == 0) {
float loss = sqrtf(sum_pos) - sqrtf(sum_neg) + margin;
output[batch_idx] = fmaxf(loss, 0.0f);
torch::Tensor triplet_margin_loss_cuda_combined(
torch::Tensor anchor,
torch::Tensor positive,
torch::Tensor negative,
float margin) {
TORCH_CHECK(anchor.device().is_cuda(), "anchor must be a CUDA tensor");
TORCH_CHECK(positive.device().is_cuda(), "positive must be a CUDA tensor");
TORCH_CHECK(negative.device().is_cuda(), "negative must be a CUDA tensor");
const int batch_size = anchor.size(0);
const int feat_size = anchor.size(1);
auto output = torch::empty({batch_size}, anchor.options());
int threads = 256;
triplet_margin_loss_combined<<<batch_size, threads>>>(
return output.mean();
m.def("forward", &triplet_margin_loss_cuda_combined, "Triplet margin loss combined optimized (CUDA)");
Metric | Value | Unit | Variance | Samples |
Executed Ipc Active | 0.260 | inst/cycle | 0.000 | 5 |
Executed Ipc Elapsed | 0.156 | inst/cycle | 0.000 | 5 |
Issue Slots Busy | 6.624 | % | 0.000 | 5 |
Issued Ipc Active | 0.264 | inst/cycle | 0.000 | 5 |
SM Busy | 6.624 | % | 0.000 | 5 |
Memory Throughput | 1009202239055.476 | byte/second | 128523310498447523840.000 | 5 |
Mem Busy | 17.504 | % | 0.036 | 5 |
Max Bandwidth | 30.226 | % | 0.101 | 5 |
L1/TEX Hit Rate | 0.000 | % | 0.000 | 5 |
L2 Hit Rate | 13.104 | % | 0.001 | 5 |
Mem Pipes Busy | 2.098 | % | 0.001 | 5 |
Warp Cycles Per Issued Instruction | 28.984 | cycle | 0.355 | 5 |
Warp Cycles Per Executed Instruction | 29.278 | cycle | 0.364 | 5 |
Avg. Active Threads Per Warp | 29.400 | 0.000 | 5 | |
Avg. Not Predicated Off Threads Per Warp | 28.200 | 0.000 | 5 | |
Max Active Clusters | 0.000 | cluster | 0.000 | 5 |
Max Cluster Size | 8.000 | block | 0.000 | 5 |
Overall GPU Occupancy | 0.000 | % | 0.000 | 5 |
Cluster Occupancy | 0.000 | % | 0.000 | 5 |
Block Limit SM | 32.000 | block | 0.000 | 5 |
Block Limit Registers | 8.000 | block | 0.000 | 5 |
Block Limit Shared Mem | 25.000 | block | 0.000 | 5 |
Block Limit Warps | 8.000 | block | 0.000 | 5 |
Theoretical Active Warps per SM | 64.000 | warp | 0.000 | 5 |
Theoretical Occupancy | 100.000 | % | 0.000 | 5 |
Achieved Occupancy | 11.810 | % | 0.000 | 5 |
Achieved Active Warps Per SM | 7.560 | warp | 0.000 | 5 |
Rule | Description |
WRN HighPipeUtilization | All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details. |
INF CPIStall | Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide ( provides more details on each stall reason. |
WRN Occupancy | This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (11.8%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide ( for more details on optimizing occupancy. |
Operation / Metric | Value | Unit |
aten::to | ||
CPU Time | 578277.17 | μs |
Device Time | 542.68 | μs |
Self CPU Time | 36.90 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_to_copy | ||
CPU Time | 578240.27 | μs |
Device Time | 542.68 | μs |
Self CPU Time | 101.26 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::empty_strided | ||
CPU Time | 577081.39 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 104.53 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaDeviceGetStreamPriorityRange | ||
CPU Time | 576434.58 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 576434.58 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaLaunchKernel | ||
CPU Time | 763386.01 | μs |
Device Time | 24687.82 | μs |
Self CPU Time | 763386.01 | μs |
Self Device Time | 24687.82 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
triplet_margin_loss_combined(float const*, float const*, float const*, float*, float, int, int) | ||
CPU Time | 0.00 | μs |
Device Time | 78718.38 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 78718.38 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaEventRecord | ||
CPU Time | 31316.46 | μs |
Device Time | 49216.40 | μs |
Self CPU Time | 31316.46 | μs |
Self Device Time | 49216.40 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::zero_ | ||
CPU Time | 94905.14 | μs |
Device Time | 975174.64 | μs |
Self CPU Time | 20033.92 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::fill_ | ||
CPU Time | 74873.28 | μs |
Device Time | 975174.64 | μs |
Self CPU Time | 27039.90 | μs |
Self Device Time | 975174.64 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>) | ||
CPU Time | 0.00 | μs |
Device Time | 975174.64 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 975174.64 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45288 warnings generated when compiling for host. Suppressed 45324 warnings (45277 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.