27_Conv3d_HardSwish_ReLU_Softmax_Mean
• ldg_aligned_fused_kernel_base
import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(
x: torch.Tensor,
conv_weight: torch.Tensor,
conv_bias: torch.Tensor,
) -> torch.Tensor:
"""
Applies 3D convolution, HardSwish, ReLU, Softmax and mean reduction.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_channels, depth, height, width)
conv_weight (torch.Tensor): 3D convolution weight tensor of shape
(out_channels, in_channels, kernel_size, kernel_size, kernel_size)
conv_bias (torch.Tensor): Bias tensor for 3D convolution of shape (out_channels)
Returns:
torch.Tensor: Output tensor after applying convolution, activations and reduction,
with shape (batch_size, out_channels)
"""
x = F.conv3d(x, conv_weight, bias=conv_bias)
x = F.hardswish(x)
x = F.relu(x)
x = F.softmax(x, dim=1)
x = torch.mean(x, dim=[2, 3, 4])
return x
class Model(nn.Module):
"""
Simple model that performs a 3D convolution, applies HardSwish, ReLU, Softmax, and then calculates the mean.
"""
def __init__(self, in_channels, out_channels, kernel_size):
super(Model, self).__init__()
conv = nn.Conv3d(in_channels, out_channels, kernel_size)
self.conv_weight = nn.Parameter(conv.weight)
self.conv_bias = nn.Parameter(conv.bias + torch.ones_like(conv.bias) * 0.02)
def forward(self, x, fn=module_fn):
return fn(x, self.conv_weight, self.conv_bias)
batch_size = 128
in_channels = 3
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = 3
def get_inputs():
return [torch.randn(batch_size, in_channels, depth, height, width)]
def get_init_inputs():
return [in_channels, out_channels, kernel_size]
import torch
import torch.nn as nn
class Model(nn.Module):
"""
Simple model that performs a 3D convolution, applies HardSwish, ReLU, Softmax, and then calculates the mean.
"""
def __init__(self, in_channels, out_channels, kernel_size, bias=True):
super(Model, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, bias=bias)
self.conv.bias = nn.Parameter(self.conv.bias + torch.ones_like(self.conv.bias) * 0.02)
def forward(self, x):
x = self.conv(x)
x = torch.nn.functional.hardswish(x)
x = torch.relu(x)
x = torch.softmax(x, dim=1)
x = torch.mean(x, dim=[2, 3, 4])
return x
batch_size = 128
in_channels = 3
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = 3
def get_inputs():
return [torch.randn(batch_size, in_channels, depth, height, width)]
def get_init_inputs():
return [in_channels, out_channels, kernel_size]
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include <cfloat>
// Block size for spatial dimension processing
#define BLOCK_SIZE 256
// Constant memory for HardSwish parameters
// d_hswish_constants[0] = offset (3.0f), d_hswish_constants[1] = cap (6.0f)
__constant__ float d_hswish_constants[2];
__constant__ float d_hswish_div; // = 1/6.0f
// Initialize constant memory values (to be called once)
void initialize_constants() {
float h_constants[2] = {3.0f, 6.0f};
cudaMemcpyToSymbol(d_hswish_constants, h_constants, 2 * sizeof(float));
float div = 1.0f / 6.0f;
cudaMemcpyToSymbol(d_hswish_div, &div, sizeof(float));
}
// Fused kernel: applies HardSwish, ReLU, and Softmax in three passes over the channel dimension.
// It uses __ldg() for read-only global memory loads to leverage texture cache and assumes that the
// input data is allocated with 128-bit alignment. Each thread processes one spatial index in one batch.
__global__ void ldg_aligned_fused_kernel(const float* __restrict__ input,
float* __restrict__ output,
int batch_size,
int channels,
int spatial_size) {
// Calculate spatial index and batch index
int spatial_idx = blockIdx.x * blockDim.x + threadIdx.x;
int batch_idx = blockIdx.y;
if (spatial_idx >= spatial_size || batch_idx >= batch_size) return;
float max_val = -FLT_MAX;
// Pass 1: Compute maximum activation value across channels for numerical stability
// Activation: act = fmax( x * min(max(x+3,0),6) / 6, 0 )
for (int c = 0; c < channels; ++c) {
int idx = (batch_idx * channels + c) * spatial_size + spatial_idx;
// Use __ldg() for read-only access; assumes input is 128-bit aligned when possible
float x = __ldg(&input[idx]);
float relu6 = fminf(fmaxf(x + d_hswish_constants[0], 0.0f), d_hswish_constants[1]);
float hswish = x * relu6 * d_hswish_div;
float act = fmaxf(hswish, 0.0f);
if (act > max_val) {
max_val = act;
}
}
float sum_exp = 0.0f;
// Pass 2: Compute exponentials and accumulate the sum, store exp values temporarily in output
for (int c = 0; c < channels; ++c) {
int idx = (batch_idx * channels + c) * spatial_size + spatial_idx;
float x = __ldg(&input[idx]);
float relu6 = fminf(fmaxf(x + d_hswish_constants[0], 0.0f), d_hswish_constants[1]);
float hswish = x * relu6 * d_hswish_div;
float act = fmaxf(hswish, 0.0f);
float exp_val = expf(act - max_val);
sum_exp += exp_val;
output[idx] = exp_val;
}
// Pass 3: Normalize the exponentials to obtain softmax probabilities
for (int c = 0; c < channels; ++c) {
int idx = (batch_idx * channels + c) * spatial_size + spatial_idx;
output[idx] = output[idx] / sum_exp;
}
}
// Module forward function: combines conv3d, the fused activation and softmax kernel, and mean reduction
// The softmax is applied over the channel dimension after reformatting the tensor.
torch::Tensor module_forward(
torch::Tensor x,
torch::Tensor conv_weight,
torch::Tensor conv_bias) {
// Initialize constant memory once
static bool constants_initialized = false;
if (!constants_initialized) {
initialize_constants();
constants_initialized = true;
}
// Ensure tensors are contiguous and on CUDA
x = x.contiguous().cuda();
conv_weight = conv_weight.contiguous().cuda();
conv_bias = conv_bias.contiguous().cuda();
// Perform 3D convolution via PyTorch's conv3d
x = torch::conv3d(x, conv_weight, conv_bias);
// Retrieve tensor dimensions
int64_t batch_size = x.size(0);
int64_t channels = x.size(1);
int64_t depth = x.size(2);
int64_t height = x.size(3);
int64_t width = x.size(4);
int64_t spatial_size = depth * height * width;
// Allocate intermediate tensor for softmax result
torch::Tensor x_softmax = torch::empty_like(x);
// Launch kernel: 2D grid (spatial index and batch index)
dim3 threads(BLOCK_SIZE);
dim3 blocks((spatial_size + BLOCK_SIZE - 1) / BLOCK_SIZE, batch_size);
ldg_aligned_fused_kernel<<<blocks, threads>>>(x.data_ptr<float>(),
x_softmax.data_ptr<float>(),
batch_size,
channels,
spatial_size);
// Reshape back to original dimensions and compute mean over spatial dims
torch::Tensor output = x_softmax.view({batch_size, channels, depth, height, width}).mean({2, 3, 4});
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &module_forward, "Fused CUDA module forward with __ldg() and aligned global memory accesses");
}
Metric | Value | Unit | Variance | Samples |
---|---|---|---|---|
Executed Ipc Active | 1.706 | inst/cycle | 0.000 | 5 |
Executed Ipc Elapsed | 1.624 | inst/cycle | 0.000 | 5 |
Issue Slots Busy | 42.752 | % | 0.041 | 5 |
Issued Ipc Active | 1.710 | inst/cycle | 0.000 | 5 |
SM Busy | 42.752 | % | 0.041 | 5 |
Memory Throughput | 2168777127310.022 | byte/second | 145973059539643105280.000 | 5 |
Mem Busy | 64.764 | % | 0.137 | 5 |
Max Bandwidth | 67.710 | % | 0.157 | 5 |
L1/TEX Hit Rate | 55.014 | % | 0.001 | 5 |
L2 Hit Rate | 71.478 | % | 0.002 | 5 |
Mem Pipes Busy | 22.782 | % | 0.016 | 5 |
Warp Cycles Per Issued Instruction | 30.634 | cycle | 0.051 | 5 |
Warp Cycles Per Executed Instruction | 30.678 | cycle | 0.053 | 5 |
Avg. Active Threads Per Warp | 31.980 | 0.000 | 5 | |
Avg. Not Predicated Off Threads Per Warp | 31.200 | 0.000 | 5 | |
Max Active Clusters | 0.000 | cluster | 0.000 | 5 |
Max Cluster Size | 8.000 | block | 0.000 | 5 |
Overall GPU Occupancy | 0.000 | % | 0.000 | 5 |
Cluster Occupancy | 0.000 | % | 0.000 | 5 |
Block Limit SM | 32.000 | block | 0.000 | 5 |
Block Limit Registers | 8.000 | block | 0.000 | 5 |
Block Limit Shared Mem | 32.000 | block | 0.000 | 5 |
Block Limit Warps | 8.000 | block | 0.000 | 5 |
Theoretical Active Warps per SM | 64.000 | warp | 0.000 | 5 |
Theoretical Occupancy | 100.000 | % | 0.000 | 5 |
Achieved Occupancy | 81.676 | % | 0.015 | 5 |
Achieved Active Warps Per SM | 52.270 | warp | 0.006 | 5 |
Rule | Description |
---|---|
INF HighPipeUtilization | ALU is the highest-utilized pipeline (27.3%) based on active cycles, taking into account the rates of its different instructions. It executes integer and logic operations. It is well-utilized, but should not be a bottleneck. |
INF CPIStall | Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason. |
WRN Occupancy | This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (81.5%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy. |
Operation / Metric | Value | Unit |
---|---|---|
aten::conv3d | ||
CPU Time | 5280443.12 | μs |
Device Time | 5388021.08 | μs |
Self CPU Time | 13625.59 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::convolution | ||
CPU Time | 5266817.52 | μs |
Device Time | 5388021.08 | μs |
Self CPU Time | 18359.41 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_convolution | ||
CPU Time | 5248458.12 | μs |
Device Time | 5388021.08 | μs |
Self CPU Time | 42457.39 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::cudnn_convolution | ||
CPU Time | 4510040.75 | μs |
Device Time | 4670261.97 | μs |
Self CPU Time | 183453.29 | μs |
Self Device Time | 4670261.97 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaLaunchKernelExC | ||
CPU Time | 4295562.02 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 4295562.02 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
sm80_xmma_fprop_implicit_gemm_indexed_f32f32_f32f32_f32_nchwkcrs_nchw_tilesize32x32x8_stage3_warpsize1x2x1_g1_ffma_aligna4_alignc4_execute_kernel__5x_cudnn | ||
CPU Time | 0.00 | μs |
Device Time | 4670259.28 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 4670259.28 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45282 warnings generated when compiling for host. Suppressed 45323 warnings (45276 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.