90_Conv3d_LeakyReLU_Sum_Clamp_GELU
• atomic_minimal_usage_kernel_opt_base
import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(
x: torch.Tensor,
conv_weight: torch.Tensor,
conv_bias: torch.Tensor,
sum_tensor: torch.Tensor,
) -> torch.Tensor:
"""
Applies 3D convolution, LeakyReLU, tensor addition, clamping and GELU activation.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_channels, depth, height, width)
conv_weight (torch.Tensor): 3D convolution weight tensor of shape
(out_channels, in_channels, kernel_size, kernel_size, kernel_size)
conv_bias (torch.Tensor): Bias tensor for 3D convolution of shape (out_channels)
sum_tensor (torch.Tensor): Tensor to add of shape (out_channels, 1, 1, 1)
Returns:
torch.Tensor: Output tensor after applying convolution, LeakyReLU, addition,
clamping and GELU activation
"""
x = F.conv3d(x, conv_weight, bias=conv_bias)
x = F.leaky_relu(x, negative_slope=0.2)
x = x + sum_tensor
x = torch.clamp(x, min=-1.0, max=1.0)
x = F.gelu(x)
return x
class Model(nn.Module):
"""
Model that performs a 3D convolution, applies LeakyReLU, sums with a tensor, clamps, and applies GELU activation.
"""
def __init__(self, in_channels, out_channels, kernel_size, sum_tensor_shape):
super(Model, self).__init__()
conv = nn.Conv3d(in_channels, out_channels, kernel_size)
self.conv_weight = conv.weight
self.conv_bias = conv.bias
self.sum_tensor = nn.Parameter(torch.randn(sum_tensor_shape) * 0.02)
def forward(self, x, fn=module_fn):
return fn(x, self.conv_weight, self.conv_bias, self.sum_tensor)
batch_size = 128
in_channels = 3
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = 3
sum_tensor_shape = (out_channels, 1, 1, 1)
def get_inputs():
return [torch.randn(batch_size, in_channels, depth, height, width)]
def get_init_inputs():
return [in_channels, out_channels, kernel_size, sum_tensor_shape]
import torch
import torch.nn as nn
class Model(nn.Module):
"""
Model that performs a 3D convolution, applies LeakyReLU, sums with a tensor, clamps, and applies GELU activation.
"""
def __init__(self, in_channels, out_channels, kernel_size, sum_tensor_shape):
super(Model, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size)
self.sum_tensor = nn.Parameter(torch.randn(sum_tensor_shape)*0.02)
def forward(self, x):
x = self.conv(x)
x = torch.nn.functional.leaky_relu(x, negative_slope=0.2)
x = x + self.sum_tensor
x = torch.clamp(x, min=-1.0, max=1.0)
x = torch.nn.functional.gelu(x)
return x
batch_size = 128
in_channels = 3
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = 3
sum_tensor_shape = (out_channels, 1, 1, 1)
def get_inputs():
return [torch.randn(batch_size, in_channels, depth, height, width)]
def get_init_inputs():
return [in_channels, out_channels, kernel_size, sum_tensor_shape]
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <math.h>
#include <cuda_runtime.h>
// Device function for GELU activation
__device__ __forceinline__ float gelu_activation(float x) {
float x_cubed = x * x * x;
float inner = 0.7978845608f * (x + 0.044715f * x_cubed);
return x * 0.5f * (1.0f + tanhf(inner));
}
// Kernel: Each thread processes one element.
// Note: Atomic operations are not required here because each thread writes to a unique element.
__global__ void kernel_atomic_minimal(
const float* __restrict__ input,
const float* __restrict__ sum_tensor,
float* __restrict__ output,
int64_t num_elements,
int64_t channels,
int64_t depth,
int64_t height,
int64_t width) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= num_elements) return;
// Calculate channel index assuming input tensor has shape [N, C, D, H, W] in contiguous layout
int64_t whd = width * height * depth;
int64_t c = (idx / whd) % channels;
// Load input using read-only cache
float x = __ldg(&input[idx]);
// Apply LeakyReLU
float y = (x > 0.0f) ? x : 0.2f * x;
// Add per-channel bias from sum_tensor (read-only). No atomic operation is needed because there's no write race.
y += __ldg(&sum_tensor[c]);
// Clamp the result between -1 and 1
y = fmaxf(fminf(y, 1.0f), -1.0f);
// Apply GELU activation
y = gelu_activation(y);
// Write final result. Since each thread writes to a unique location, no atomics are used.
output[idx] = y;
}
// Kernel launcher
void launch_kernel_atomic_minimal(torch::Tensor& input, torch::Tensor& sum_tensor) {
int64_t num_elements = input.numel();
const int threads = 256;
const int blocks = (num_elements + threads - 1) / threads;
kernel_atomic_minimal<<<blocks, threads>>>(
input.data_ptr<float>(),
sum_tensor.data_ptr<float>(),
input.data_ptr<float>(),
num_elements,
input.size(1), // channels
input.size(2), // depth
input.size(3), // height
input.size(4) // width
);
cudaDeviceSynchronize();
}
// Forward function: performs 3D convolution then applies the element-wise operations
torch::Tensor forward(
torch::Tensor x,
torch::Tensor conv_weight,
torch::Tensor conv_bias,
torch::Tensor sum_tensor) {
TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor");
TORCH_CHECK(conv_weight.is_cuda(), "conv_weight must be a CUDA tensor");
TORCH_CHECK(conv_bias.is_cuda(), "conv_bias must be a CUDA tensor");
TORCH_CHECK(sum_tensor.is_cuda(), "sum_tensor must be a CUDA tensor");
TORCH_CHECK(x.scalar_type() == at::kFloat, "x must be float32");
// Perform 3D convolution with cuDNN
auto x_conv = at::conv3d(x, conv_weight, conv_bias).contiguous();
// Launch the kernel to apply LeakyReLU, addition, clamp, and GELU elementwise
launch_kernel_atomic_minimal(x_conv, sum_tensor);
return x_conv;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Forward pass with minimal atomic operations (CUDA)");
}
Metric | Value | Unit | Variance | Samples |
---|---|---|---|---|
Executed Ipc Active | 3.248 | inst/cycle | 0.000 | 5 |
Executed Ipc Elapsed | 3.192 | inst/cycle | 0.000 | 5 |
Issue Slots Busy | 81.298 | % | 0.009 | 5 |
Issued Ipc Active | 3.254 | inst/cycle | 0.000 | 5 |
SM Busy | 81.298 | % | 0.009 | 5 |
Memory Throughput | 1276870245578.770 | byte/second | 2131045984608143360.000 | 5 |
Mem Busy | 21.618 | % | 0.001 | 5 |
Max Bandwidth | 38.100 | % | 0.002 | 5 |
L1/TEX Hit Rate | 55.420 | % | 0.000 | 5 |
L2 Hit Rate | 50.384 | % | 0.017 | 5 |
Mem Pipes Busy | 16.054 | % | 0.000 | 5 |
Warp Cycles Per Issued Instruction | 16.294 | cycle | 0.000 | 5 |
Warp Cycles Per Executed Instruction | 16.300 | cycle | 0.000 | 5 |
Avg. Active Threads Per Warp | 29.260 | 0.000 | 5 | |
Avg. Not Predicated Off Threads Per Warp | 26.980 | 0.000 | 5 | |
Max Active Clusters | 0.000 | cluster | 0.000 | 5 |
Max Cluster Size | 8.000 | block | 0.000 | 5 |
Overall GPU Occupancy | 0.000 | % | 0.000 | 5 |
Cluster Occupancy | 0.000 | % | 0.000 | 5 |
Block Limit SM | 32.000 | block | 0.000 | 5 |
Block Limit Registers | 10.000 | block | 0.000 | 5 |
Block Limit Shared Mem | 32.000 | block | 0.000 | 5 |
Block Limit Warps | 8.000 | block | 0.000 | 5 |
Theoretical Active Warps per SM | 64.000 | warp | 0.000 | 5 |
Theoretical Occupancy | 100.000 | % | 0.000 | 5 |
Achieved Occupancy | 83.640 | % | 0.001 | 5 |
Achieved Active Warps Per SM | 53.528 | warp | 0.001 | 5 |
Rule | Description |
---|---|
INF HighPipeUtilization | ALU is the highest-utilized pipeline (50.2%) based on active cycles, taking into account the rates of its different instructions. It executes integer and logic operations. It is well-utilized, but should not be a bottleneck. |
WRN Occupancy | This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (83.6%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy. |
Operation / Metric | Value | Unit |
---|---|---|
aten::to | ||
CPU Time | 579908.01 | μs |
Device Time | 2557.43 | μs |
Self CPU Time | 64.67 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_to_copy | ||
CPU Time | 579843.34 | μs |
Device Time | 2557.43 | μs |
Self CPU Time | 132.38 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::empty_strided | ||
CPU Time | 576866.41 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 135.55 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaDeviceGetStreamPriorityRange | ||
CPU Time | 577531.29 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 577531.29 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::conv3d | ||
CPU Time | 345075.51 | μs |
Device Time | 4209693.16 | μs |
Self CPU Time | 10502.52 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::convolution | ||
CPU Time | 334572.99 | μs |
Device Time | 4209693.16 | μs |
Self CPU Time | 14505.94 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_convolution | ||
CPU Time | 320067.05 | μs |
Device Time | 4209693.16 | μs |
Self CPU Time | 29067.17 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::cudnn_convolution | ||
CPU Time | 211755.67 | μs |
Device Time | 3653586.48 | μs |
Self CPU Time | 148596.90 | μs |
Self Device Time | 3653586.48 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
sm80_xmma_fprop_implicit_gemm_indexed_f32f32_f32f32_f32_nchwkcrs_nchw_tilesize32x32x8_stage3_warpsize1x2x1_g1_ffma_aligna4_alignc4_execute_kernel__5x_cudnn | ||
CPU Time | 0.00 | μs |
Device Time | 3653585.20 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 3653585.20 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaDeviceSynchronize | ||
CPU Time | 5206266.20 | μs |
Device Time | 77519.37 | μs |
Self CPU Time | 5206266.20 | μs |
Self Device Time | 77519.37 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45283 warnings generated when compiling for host. Suppressed 45324 warnings (45277 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.