60_ConvTranspose3d_Swish_GroupNorm_HardSwish
• optimized_reduction_fused_actnorm_base
import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(
x: torch.Tensor,
stride: int,
padding: int,
groups: int,
eps: float,
conv_transpose: torch.Tensor,
conv_transpose_bias: torch.Tensor,
group_norm_weight: torch.Tensor,
group_norm_bias: torch.Tensor,
) -> torch.Tensor:
"""
Applies 3D transposed convolution, Swish activation, group normalization and HardSwish activation.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_channels, depth, height, width)
stride (int): Stride of the transposed convolution
padding (int): Padding of the transposed convolution
groups (int): Number of groups for group normalization
eps (float): Epsilon value for group normalization
conv_transpose (torch.Tensor): Transposed convolution weight tensor
conv_transpose_bias (torch.Tensor): Bias tensor for transposed convolution
group_norm_weight (torch.Tensor): Weight tensor for group normalization
group_norm_bias (torch.Tensor): Bias tensor for group normalization
Returns:
torch.Tensor: Output tensor after applying all operations
"""
x = F.conv_transpose3d(
x, conv_transpose, bias=conv_transpose_bias, stride=stride, padding=padding
)
x = torch.sigmoid(x) * x # Swish activation
x = F.group_norm(
x, num_groups=groups, weight=group_norm_weight, bias=group_norm_bias, eps=eps
)
x = F.hardswish(x)
return x
class Model(nn.Module):
"""
Model that performs a 3D transposed convolution, applies Swish activation,
group normalization, and then HardSwish activation.
"""
def __init__(
self, in_channels, out_channels, kernel_size, stride, padding, groups, eps
):
super(Model, self).__init__()
conv = nn.ConvTranspose3d(
in_channels, out_channels, kernel_size, stride=stride, padding=padding
)
self.conv_transpose_parameter = nn.Parameter(conv.weight)
self.conv_transpose_bias = nn.Parameter(conv.bias)
gn = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps)
self.group_norm_weight = nn.Parameter(gn.weight)
self.group_norm_bias = nn.Parameter(gn.bias + torch.randn(out_channels) * 0.02)
def forward(self, x, stride, padding, groups, eps, fn=module_fn):
return fn(
x,
stride,
padding,
groups,
eps,
self.conv_transpose_parameter,
self.conv_transpose_bias,
self.group_norm_weight,
self.group_norm_bias,
)
batch_size = 128
in_channels = 3
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = 3
stride = 2
padding = 1
groups = 4
eps = 1e-5
def get_inputs():
return [
torch.randn(batch_size, in_channels, depth, height, width),
stride,
padding,
groups,
eps,
]
def get_init_inputs():
return [in_channels, out_channels, kernel_size, stride, padding, groups, eps]
import torch
import torch.nn as nn
class Model(nn.Module):
"""
Model that performs a 3D transposed convolution, applies Swish activation,
group normalization, and then HardSwish activation.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, groups, eps, bias=True):
super(Model, self).__init__()
self.conv_transpose = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias)
self.group_norm = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps)
# Add noise to group norm bias to match functional implementation
self.group_norm.bias = nn.Parameter(self.group_norm.bias + torch.randn(out_channels) * 0.02)
def forward(self, x):
x = self.conv_transpose(x)
x = torch.sigmoid(x) * x # Swish activation
x = self.group_norm(x)
x = torch.nn.functional.hardswish(x) # HardSwish activation
return x
batch_size = 128
in_channels = 3
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = 3
stride = 2
padding = 1
groups = 4
eps = 1e-5
def get_inputs():
return [torch.randn(batch_size, in_channels, depth, height, width)]
def get_init_inputs():
return [in_channels, out_channels, kernel_size, stride, padding, groups, eps]
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>
#include <vector>
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
// This kernel fuses Swish activation, Group Normalization, and HardSwish activation
// It optimizes reductions by using shared memory for intra-block partial sums
// and warp-level primitives (__shfl_down_sync) for the final reduction.
__global__ void fused_opt_red_kernel(
const float* __restrict__ input,
float* __restrict__ output,
const float* __restrict__ gamma,
const float* __restrict__ beta,
int N, int C, int D, int H, int W,
int groups,
float eps
) {
// Each block processes one (sample, group) pair
int n = blockIdx.x; // sample index
int g = blockIdx.y; // group index
int channels_per_group = C / groups;
int group_elements = channels_per_group * D * H * W; // Total elements in the group
int base = n * (C * D * H * W) + g * group_elements;
int tid = threadIdx.x;
int blockSize = blockDim.x;
constexpr int VECTOR_SIZE = 4; // Vectorization factor
int aligned_size = (group_elements / VECTOR_SIZE) * VECTOR_SIZE;
// Phase 1: Each thread computes partial sums of the Swish activation
float local_sum = 0.0f;
float local_sum_sq = 0.0f;
// Use vectorized loads for aligned elements
for (int i = tid * VECTOR_SIZE; i < aligned_size; i += blockSize * VECTOR_SIZE) {
int idx = base + i;
float4 data = *reinterpret_cast<const float4*>(input + idx);
#pragma unroll
for (int j = 0; j < VECTOR_SIZE; j++) {
float x = ((float*)&data)[j];
float sw = x / (1.0f + expf(-x)); // Swish activation
local_sum += sw;
local_sum_sq += sw * sw;
}
}
// Process remaining tail elements
for (int i = aligned_size + tid; i < group_elements; i += blockSize) {
int idx = base + i;
float x = __ldg(input + idx);
float sw = x / (1.0f + expf(-x));
local_sum += sw;
local_sum_sq += sw * sw;
}
// Phase 2: Intra-warp reduction using warp-level shuffles
unsigned int mask = 0xffffffff;
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
local_sum += __shfl_down_sync(mask, local_sum, offset);
local_sum_sq += __shfl_down_sync(mask, local_sum_sq, offset);
}
// Allocate shared memory for warp-level partial results
extern __shared__ float shared[]; // first half: warp sums; second half: warp sum-sqs
int numWarps = blockSize / warpSize;
float* s_sum = shared; // length = numWarps
float* s_sum_sq = shared + numWarps; // length = numWarps
int warpId = tid / warpSize;
if ((tid & (warpSize - 1)) == 0) {
s_sum[warpId] = local_sum;
s_sum_sq[warpId] = local_sum_sq;
}
__syncthreads();
// Final reduction: let first few threads load warp-level sums
float block_sum = 0.0f;
float block_sum_sq = 0.0f;
if (tid < numWarps) {
block_sum = s_sum[tid];
block_sum_sq = s_sum_sq[tid];
}
__syncthreads();
if (tid == 0) {
for (int i = 1; i < numWarps; i++) {
block_sum += s_sum[i];
block_sum_sq += s_sum_sq[i];
}
// Compute mean and variance of the Swish-activated values
float mean = block_sum / group_elements;
float variance = block_sum_sq / group_elements - mean * mean;
float inv_std = rsqrtf(variance + eps);
// Store computed mean and inv_std in shared memory for broadcasting
s_sum[0] = mean;
s_sum_sq[0] = inv_std;
}
__syncthreads();
float mean = s_sum[0];
float inv_std = s_sum_sq[0];
// Phase 3: Compute final output by applying Group Norm and HardSwish
// Recompute the Swish activation and normalize
for (int i = tid * VECTOR_SIZE; i < aligned_size; i += blockSize * VECTOR_SIZE) {
int idx = base + i;
float4 in_vec = *reinterpret_cast<const float4*>(input + idx);
float4 out_vec;
#pragma unroll
for (int j = 0; j < VECTOR_SIZE; j++) {
int elem = i + j; // index within group
float x = ((float*)&in_vec)[j];
float sw = x / (1.0f + expf(-x));
int local_channel = elem / (D * H * W);
int global_channel = g * channels_per_group + local_channel;
float norm = (sw - mean) * inv_std;
float y = norm * __ldg(gamma + global_channel) + __ldg(beta + global_channel);
float hs = y * fminf(fmaxf(y + 3.0f, 0.0f), 6.0f) / 6.0f;
((float*)&out_vec)[j] = hs;
}
*reinterpret_cast<float4*>(output + idx) = out_vec;
}
// Process remaining tail elements
for (int i = aligned_size + tid; i < group_elements; i += blockSize) {
int idx = base + i;
float x = __ldg(input + idx);
float sw = x / (1.0f + expf(-x));
int local_channel = i / (D * H * W);
int global_channel = g * channels_per_group + local_channel;
float norm = (sw - mean) * inv_std;
float y = norm * __ldg(gamma + global_channel) + __ldg(beta + global_channel);
output[idx] = y * fminf(fmaxf(y + 3.0f, 0.0f), 6.0f) / 6.0f;
}
}
// Host function that applies conv_transpose3d and then launches the fused kernel
torch::Tensor forward(
torch::Tensor x,
int stride,
int padding,
int groups,
float eps,
torch::Tensor conv_transpose,
torch::Tensor conv_transpose_bias,
torch::Tensor group_norm_weight,
torch::Tensor group_norm_bias
) {
CHECK_INPUT(x);
CHECK_INPUT(conv_transpose);
CHECK_INPUT(conv_transpose_bias);
CHECK_INPUT(group_norm_weight);
CHECK_INPUT(group_norm_bias);
// Apply 3D transposed convolution
x = torch::conv_transpose3d(x, conv_transpose, conv_transpose_bias, stride, padding);
torch::Tensor output = torch::empty_like(x);
int N = x.size(0);
int C = x.size(1);
int D = x.size(2);
int H = x.size(3);
int W = x.size(4);
dim3 grid(N, groups); // one block for each (sample, group) pair
int blockSize = 256;
int numWarps = blockSize / 32;
// Shared memory allocation: two arrays of size = numWarps (for mean and inv_std broadcasting)
size_t sharedMem = (numWarps + numWarps) * sizeof(float);
fused_opt_red_kernel<<<grid, blockSize, sharedMem>>>(
x.data_ptr<float>(),
output.data_ptr<float>(),
group_norm_weight.data_ptr<float>(),
group_norm_bias.data_ptr<float>(),
N, C, D, H, W, groups, eps
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused kernel with optimized shared memory reduction");
}
Metric | Value | Unit | Variance | Samples |
---|---|---|---|---|
Executed Ipc Active | 1.448 | inst/cycle | 0.000 | 5 |
Executed Ipc Elapsed | 1.404 | inst/cycle | 0.000 | 5 |
Issue Slots Busy | 36.156 | % | 0.006 | 5 |
Issued Ipc Active | 1.448 | inst/cycle | 0.000 | 5 |
SM Busy | 36.156 | % | 0.006 | 5 |
Memory Throughput | 2466897960717.034 | byte/second | 24679674909802942464.000 | 5 |
Mem Busy | 40.246 | % | 0.007 | 5 |
Max Bandwidth | 73.592 | % | 0.022 | 5 |
L1/TEX Hit Rate | 16.754 | % | 0.000 | 5 |
L2 Hit Rate | 35.654 | % | 0.000 | 5 |
Mem Pipes Busy | 8.402 | % | 0.000 | 5 |
Warp Cycles Per Issued Instruction | 20.912 | cycle | 0.002 | 5 |
Warp Cycles Per Executed Instruction | 20.916 | cycle | 0.001 | 5 |
Avg. Active Threads Per Warp | 31.990 | 0.000 | 5 | |
Avg. Not Predicated Off Threads Per Warp | 29.210 | 0.000 | 5 | |
Max Active Clusters | 0.000 | cluster | 0.000 | 5 |
Max Cluster Size | 8.000 | block | 0.000 | 5 |
Overall GPU Occupancy | 0.000 | % | 0.000 | 5 |
Cluster Occupancy | 0.000 | % | 0.000 | 5 |
Block Limit SM | 32.000 | block | 0.000 | 5 |
Block Limit Registers | 8.000 | block | 0.000 | 5 |
Block Limit Shared Mem | 28.000 | block | 0.000 | 5 |
Block Limit Warps | 8.000 | block | 0.000 | 5 |
Theoretical Active Warps per SM | 64.000 | warp | 0.000 | 5 |
Theoretical Occupancy | 100.000 | % | 0.000 | 5 |
Achieved Occupancy | 47.146 | % | 0.000 | 5 |
Achieved Active Warps Per SM | 30.174 | warp | 0.000 | 5 |
Rule | Description |
---|---|
INF HighPipeUtilization | ALU is the highest-utilized pipeline (26.0%) based on active cycles, taking into account the rates of its different instructions. It executes integer and logic operations. It is well-utilized, but should not be a bottleneck. |
INF CPIStall | Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason. |
WRN Occupancy | This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (47.2%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy. |
Operation / Metric | Value | Unit |
---|---|---|
aten::conv_transpose3d | ||
CPU Time | 3856256.44 | μs |
Device Time | 6803290.82 | μs |
Self CPU Time | 3258.45 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::convolution | ||
CPU Time | 3852997.99 | μs |
Device Time | 6803290.82 | μs |
Self CPU Time | 4537.79 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_convolution | ||
CPU Time | 3848460.20 | μs |
Device Time | 6803290.82 | μs |
Self CPU Time | 9379.11 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::cudnn_convolution_transpose | ||
CPU Time | 458954.29 | μs |
Device Time | 5424460.27 | μs |
Self CPU Time | 137277.61 | μs |
Self Device Time | 5424460.27 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaLaunchKernel | ||
CPU Time | 6738482.06 | μs |
Device Time | 62350.54 | μs |
Self CPU Time | 6738482.06 | μs |
Self Device Time | 62350.54 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
sm90_xmma_dgrad_implicit_gemm_indexed_f32f32_tf32f32_f32_nhwckrsc_nhwc_tilesize256x64x32_warpgroupsize1x1x1_g1_strided_execute_kernel__5x_cudnn | ||
CPU Time | 0.00 | μs |
Device Time | 3586177.82 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 3586177.82 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::add_ | ||
CPU Time | 3377472.93 | μs |
Device Time | 1378830.54 | μs |
Self CPU Time | 8573.52 | μs |
Self Device Time | 1378830.54 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45316 warnings generated when compiling for host. Suppressed 45344 warnings (45297 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.