import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""
Computes the Smooth L1 (Huber) Loss for regression tasks.
Args:
predictions (torch.Tensor): Predicted values.
targets (torch.Tensor): Target values.
Returns:
torch.Tensor: Smooth L1 (Huber) Loss.
"""
return F.smooth_l1_loss(predictions, targets)
class Model(nn.Module):
"""
A model that computes Smooth L1 (Huber) Loss for regression tasks.
Parameters:
None
"""
def __init__(self):
super(Model, self).__init__()
def forward(self, predictions, targets, fn=module_fn):
return fn(predictions, targets)
batch_size = 128
input_shape = (4096,)
dim = 1
def get_inputs():
return [
torch.randn(batch_size, *input_shape),
torch.randn(batch_size, *input_shape),
]
def get_init_inputs():
return []
import torch
import torch.nn as nn
class Model(nn.Module):
"""
A model that computes Smooth L1 (Huber) Loss for regression tasks.
Parameters:
None
"""
def __init__(self):
super(Model, self).__init__()
def forward(self, predictions, targets):
return torch.nn.functional.smooth_l1_loss(predictions, targets)
batch_size = 128
input_shape = (4096, )
dim = 1
def get_inputs():
return [torch.randn(batch_size, *input_shape), torch.randn(batch_size, *input_shape)]
def get_init_inputs():
return []
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>
// Warp-level reduction using shuffle
__device__ __forceinline__ float warp_reduce(float val) {
#pragma unroll
for (int offset = warpSize/2; offset > 0; offset >>= 1) {
val += __shfl_down_sync(0xffffffff, val, offset);
}
return val;
}
__global__ void smooth_l1_loss_optimized_sync_kernel(
const float* __restrict__ predictions,
const float* __restrict__ targets,
float* output,
const int n_elements
) {
const int tid = threadIdx.x;
const int gid = blockIdx.x * blockDim.x + tid;
const int stride = gridDim.x * blockDim.x;
const int lane = tid % warpSize;
const int wid = tid / warpSize;
float thread_sum = 0.0f;
// Vectorized processing using float4
const int vec_elements = n_elements / 4;
const float4* pred4 = reinterpret_cast<const float4*>(predictions);
const float4* targ4 = reinterpret_cast<const float4*>(targets);
for (int i = gid; i < vec_elements; i += stride) {
float4 p = __ldg(pred4 + i);
float4 t = __ldg(targ4 + i);
float diff = p.x - t.x;
float abs_diff = fabsf(diff);
thread_sum += (abs_diff < 1.0f) ? 0.5f * diff * diff : abs_diff - 0.5f;
diff = p.y - t.y;
abs_diff = fabsf(diff);
thread_sum += (abs_diff < 1.0f) ? 0.5f * diff * diff : abs_diff - 0.5f;
diff = p.z - t.z;
abs_diff = fabsf(diff);
thread_sum += (abs_diff < 1.0f) ? 0.5f * diff * diff : abs_diff - 0.5f;
diff = p.w - t.w;
abs_diff = fabsf(diff);
thread_sum += (abs_diff < 1.0f) ? 0.5f * diff * diff : abs_diff - 0.5f;
}
// Handle remaining elements
const int remainder_start = vec_elements * 4;
for (int i = remainder_start + gid; i < n_elements; i += stride) {
float diff = __ldg(predictions + i) - __ldg(targets + i);
float abs_diff = fabsf(diff);
thread_sum += (abs_diff < 1.0f) ? 0.5f * diff * diff : abs_diff - 0.5f;
}
// First level reduction using warp shuffle
thread_sum = warp_reduce(thread_sum);
// Shared memory for inter-warp reduction
__shared__ float warp_sums[32]; // Assuming max 32 warps per block
// First thread in each warp writes the warp's sum
if (lane == 0) {
warp_sums[wid] = thread_sum;
}
// Only synchronize if necessary for shared memory consistency
if (wid == 0) {
__syncthreads();
// Final reduction in the first warp
float val = (tid < (blockDim.x + warpSize - 1) / warpSize) ? warp_sums[lane] : 0.0f;
val = warp_reduce(val);
if (lane == 0) {
atomicAdd(output, val / n_elements);
}
}
}
torch::Tensor smooth_l1_loss_optimized_sync(
torch::Tensor predictions,
torch::Tensor targets
) {
TORCH_CHECK(predictions.sizes() == targets.sizes(), "Input tensors must have the same shape");
TORCH_CHECK(predictions.is_contiguous() && targets.is_contiguous(), "Input tensors must be contiguous");
TORCH_CHECK(predictions.device().is_cuda() && targets.device().is_cuda(), "Inputs must be CUDA tensors");
const int n_elements = predictions.numel();
auto output = torch::zeros({1}, predictions.options());
const int block_size = 256;
const int vec_elements = n_elements / 4;
const int grid_size = std::min(65535, (vec_elements + block_size - 1) / block_size);
smooth_l1_loss_optimized_sync_kernel<<<grid_size, block_size>>>(
predictions.data_ptr<float>(),
targets.data_ptr<float>(),
output.data_ptr<float>(),
n_elements
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &smooth_l1_loss_optimized_sync, "Optimized Smooth L1 Loss with selective synchronization (CUDA)");
}
Metric | Value | Unit | Variance | Samples |
---|---|---|---|---|
Executed Ipc Active | 0.726 | inst/cycle | 0.000 | 5 |
Executed Ipc Elapsed | 0.390 | inst/cycle | 0.000 | 5 |
Issue Slots Busy | 18.660 | % | 0.260 | 5 |
Issued Ipc Active | 0.748 | inst/cycle | 0.000 | 5 |
SM Busy | 18.660 | % | 0.260 | 5 |
Memory Throughput | 871379597661.710 | byte/second | 60697990818340388864.000 | 5 |
Mem Busy | 15.134 | % | 0.006 | 5 |
Max Bandwidth | 26.130 | % | 0.025 | 5 |
L1/TEX Hit Rate | 0.000 | % | 0.000 | 5 |
L2 Hit Rate | 18.608 | % | 0.002 | 5 |
Mem Pipes Busy | 7.112 | % | 0.001 | 5 |
Warp Cycles Per Issued Instruction | 31.704 | cycle | 0.931 | 5 |
Warp Cycles Per Executed Instruction | 32.646 | cycle | 0.987 | 5 |
Avg. Active Threads Per Warp | 31.710 | 0.000 | 5 | |
Avg. Not Predicated Off Threads Per Warp | 27.060 | 0.000 | 5 | |
Max Active Clusters | 0.000 | cluster | 0.000 | 5 |
Max Cluster Size | 8.000 | block | 0.000 | 5 |
Overall GPU Occupancy | 0.000 | % | 0.000 | 5 |
Cluster Occupancy | 0.000 | % | 0.000 | 5 |
Block Limit SM | 32.000 | block | 0.000 | 5 |
Block Limit Registers | 8.000 | block | 0.000 | 5 |
Block Limit Shared Mem | 28.000 | block | 0.000 | 5 |
Block Limit Warps | 8.000 | block | 0.000 | 5 |
Theoretical Active Warps per SM | 64.000 | warp | 0.000 | 5 |
Theoretical Occupancy | 100.000 | % | 0.000 | 5 |
Achieved Occupancy | 37.264 | % | 0.048 | 5 |
Achieved Active Warps Per SM | 23.850 | warp | 0.020 | 5 |
Rule | Description |
---|---|
WRN HighPipeUtilization | All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details. |
INF CPIStall | Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason. |
WRN Occupancy | This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (37.3%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy. |
Operation / Metric | Value | Unit |
---|---|---|
aten::zeros | ||
CPU Time | 5779227.69 | μs |
Device Time | 225555.16 | μs |
Self CPU Time | 157975.35 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::zero_ | ||
CPU Time | 6108894.26 | μs |
Device Time | 7662261.09 | μs |
Self CPU Time | 305788.01 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::fill_ | ||
CPU Time | 5803110.78 | μs |
Device Time | 7662261.09 | μs |
Self CPU Time | 391199.41 | μs |
Self Device Time | 7662261.09 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaLaunchKernel | ||
CPU Time | 5782598.49 | μs |
Device Time | 2919.90 | μs |
Self CPU Time | 5782598.49 | μs |
Self Device Time | 2919.90 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
smooth_l1_loss_optimized_sync_kernel(float const*, float const*, float*, int) | ||
CPU Time | 0.00 | μs |
Device Time | 443009.54 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 443009.54 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaEventRecord | ||
CPU Time | 265386.67 | μs |
Device Time | 1233420.59 | μs |
Self CPU Time | 265386.67 | μs |
Self Device Time | 1233420.59 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>) | ||
CPU Time | 0.00 | μs |
Device Time | 7437568.24 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 7437568.24 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaEventElapsedTime | ||
CPU Time | 321733.24 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 321733.24 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45284 warnings generated when compiling for host. Suppressed 45324 warnings (45277 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.