19_ConvTranspose2d_GELU_GroupNorm
• opt_convtrans_gelu_gn_even_distribution_base
import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(
x: torch.Tensor,
stride: int,
conv_transpose: torch.Tensor,
conv_transpose_bias: torch.Tensor,
group_norm_weight: torch.Tensor,
group_norm_bias: torch.Tensor,
num_groups: int,
) -> torch.Tensor:
"""
Applies transposed convolution, GELU activation, and group normalization.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width)
stride (int): Stride of the transposed convolution
conv_transpose (torch.Tensor): Transposed convolution weight tensor
conv_transpose_bias (torch.Tensor): Bias tensor for transposed convolution
group_norm_weight (torch.Tensor): Weight tensor for group normalization
group_norm_bias (torch.Tensor): Bias tensor for group normalization
num_groups (int): Number of groups for group normalization
Returns:
torch.Tensor: Output tensor after applying transposed convolution, GELU and group norm
"""
x = F.conv_transpose2d(x, conv_transpose, bias=conv_transpose_bias, stride=stride)
x = F.gelu(x)
x = F.group_norm(
x, num_groups=num_groups, weight=group_norm_weight, bias=group_norm_bias
)
return x
class Model(nn.Module):
"""
Model that performs a transposed convolution, applies GELU, and normalizes with GroupNorm.
"""
def __init__(
self, in_channels, out_channels, kernel_size, stride, groups, num_groups
):
super(Model, self).__init__()
conv_transpose = nn.ConvTranspose2d(
in_channels, out_channels, kernel_size, stride=stride
)
group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels)
self.conv_transpose_parameter = conv_transpose.weight
self.conv_transpose_bias = nn.Parameter(
conv_transpose.bias + torch.ones_like(conv_transpose.bias) * 0.02
) # make sure its nonzero
self.group_norm_weight = group_norm.weight
self.group_norm_bias = nn.Parameter(
group_norm.bias + torch.ones_like(group_norm.bias) * 0.02
) # make sure its nonzero
def forward(self, x, stride, num_groups, fn=module_fn):
return fn(
x,
stride,
self.conv_transpose_parameter,
self.conv_transpose_bias,
self.group_norm_weight,
self.group_norm_bias,
num_groups,
)
batch_size = 128
in_channels = 32
out_channels = 64
height, width = 32, 32
kernel_size = 4
stride = 2
groups = 8
num_groups = 8
def get_inputs():
return [torch.randn(batch_size, in_channels, height, width), stride, num_groups]
def get_init_inputs():
return [in_channels, out_channels, kernel_size, stride, groups, num_groups]
import torch
import torch.nn as nn
class Model(nn.Module):
"""
Model that performs a transposed convolution, applies GELU, and normalizes with GroupNorm.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, groups, num_groups):
super(Model, self).__init__()
self.conv_transpose = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride)
self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels)
# Add the same noise as in the functional implementation
self.conv_transpose.bias = nn.Parameter(self.conv_transpose.bias + torch.ones_like(self.conv_transpose.bias) * 0.02)
self.group_norm.bias = nn.Parameter(self.group_norm.bias + torch.ones_like(self.group_norm.bias) * 0.02)
def forward(self, x):
x = self.conv_transpose(x)
x = torch.nn.functional.gelu(x)
x = self.group_norm(x)
return x
batch_size = 128
in_channels = 32
out_channels = 64
height, width = 32, 32
kernel_size = 4
stride = 2
groups = 8
num_groups = 8
def get_inputs():
return [torch.randn(batch_size, in_channels, height, width)]
def get_init_inputs():
return [in_channels, out_channels, kernel_size, stride, groups, num_groups]
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <pybind11/pybind11.h>
#include <cmath>
// Kernel: Each block processes one group from the fused convTranspose output.
// Workload is distributed evenly by dynamically choosing the number of threads per block
// based on the group size. Grid-stride loops and optional vectorized loads ensure balanced work.
__global__ void fused_gelu_group_norm_kernel(
const float* __restrict__ in,
float* __restrict__ out,
int group_size, // = channels_per_group * (H*W)
int hw, // H * W
int channels_per_group,
int C, // Total channels
int num_groups,
float eps,
const float* __restrict__ gn_weight,
const float* __restrict__ gn_bias) {
// Each block processes one group. Calculate group indices.
int group_global = blockIdx.x; // global group index
int n = group_global / num_groups; // batch index
int g = group_global % num_groups; // group index
int base = n * C * hw + g * channels_per_group * hw; // starting offset for this group
float local_sum = 0.0f;
float local_sum_sq = 0.0f;
int tid = threadIdx.x;
int block_stride = blockDim.x;
// Check if group_size is vectorizable: process 4 elements at a time if group_size is divisible by 4
bool use_vector = (group_size % 4 == 0);
if (use_vector) {
const float4* in_vec = reinterpret_cast<const float4*>(in + base);
float4* out_vec = reinterpret_cast<float4*>(out + base);
int vec_count = group_size / 4;
for (int idx = tid; idx < vec_count; idx += block_stride) {
float4 vals = in_vec[idx];
float4 gelu_vals;
#pragma unroll
for (int j = 0; j < 4; j++) {
float v = ((float*)&vals)[j];
float gelu = 0.5f * v * (1.0f + tanhf(0.7978845608f * (v + 0.044715f * v * v * v)));
((float*)&gelu_vals)[j] = gelu;
local_sum += gelu;
local_sum_sq += gelu * gelu;
}
out_vec[idx] = gelu_vals;
}
} else {
// Scalar processing if vector load is not applicable
for (int idx = tid; idx < group_size; idx += block_stride) {
float v = in[base + idx];
float gelu = 0.5f * v * (1.0f + tanhf(0.7978845608f * (v + 0.044715f * v * v * v)));
out[base + idx] = gelu;
local_sum += gelu;
local_sum_sq += gelu * gelu;
}
}
// Warp-level reduction using shuffle for sum and sum of squares
int lane = tid & 31;
for (int offset = 16; offset > 0; offset /= 2) {
local_sum += __shfl_down_sync(0xffffffff, local_sum, offset);
local_sum_sq += __shfl_down_sync(0xffffffff, local_sum_sq, offset);
}
// Shared memory to hold per-warp partial sums (reserve space for up to 32 warps)
__shared__ float smem_sum[32];
__shared__ float smem_sum_sq[32];
int warp_id = tid / 32;
if (lane == 0) {
smem_sum[warp_id] = local_sum;
smem_sum_sq[warp_id] = local_sum_sq;
}
__syncthreads();
// Final reduction from warp sums done by thread 0
float group_mean = 0.0f;
float group_inv_std = 0.0f;
if (tid == 0) {
int num_warps = (blockDim.x + 31) / 32;
float sum_tot = 0.0f;
float sum_sq_tot = 0.0f;
for (int i = 0; i < num_warps; i++) {
sum_tot += smem_sum[i];
sum_sq_tot += smem_sum_sq[i];
}
group_mean = sum_tot / group_size;
float variance = sum_sq_tot / group_size - group_mean * group_mean;
group_inv_std = rsqrtf(variance + eps);
smem_sum[0] = group_mean; // reuse shared memory to broadcast
smem_sum[1] = group_inv_std;
}
__syncthreads();
group_mean = smem_sum[0];
group_inv_std = smem_sum[1];
// Normalize and apply affine transformation with grid-stride loop
if (use_vector) {
float4* out_vec = reinterpret_cast<float4*>(out + base);
int vec_count = group_size / 4;
for (int idx = tid; idx < vec_count; idx += block_stride) {
float4 vals = out_vec[idx];
#pragma unroll
for (int j = 0; j < 4; j++) {
float gelu = ((float*)&vals)[j];
float norm = (gelu - group_mean) * group_inv_std;
// Compute channel index: each channel has 'hw' elements
int k = idx * 4 + j; // overall element index within the group
int ch = k / hw; // channel index within the group
int global_ch = g * channels_per_group + ch; // global channel index for group norm params
float alpha = gn_weight[global_ch];
float beta = gn_bias[global_ch];
((float*)&vals)[j] = norm * alpha + beta;
}
out_vec[idx] = vals;
}
} else {
for (int idx = tid; idx < group_size; idx += block_stride) {
float gelu = out[base + idx];
float norm = (gelu - group_mean) * group_inv_std;
int ch = idx / hw;
int global_ch = g * channels_per_group + ch;
out[base + idx] = norm * gn_weight[global_ch] + gn_bias[global_ch];
}
}
}
torch::Tensor forward(
torch::Tensor x,
int64_t stride,
torch::Tensor conv_transpose_weight,
torch::Tensor conv_transpose_bias,
torch::Tensor group_norm_weight,
torch::Tensor group_norm_bias,
int64_t num_groups) {
// Ensure tensors are contiguous and on CUDA
x = x.contiguous();
conv_transpose_weight = conv_transpose_weight.contiguous();
conv_transpose_bias = conv_transpose_bias.contiguous();
group_norm_weight = group_norm_weight.contiguous();
group_norm_bias = group_norm_bias.contiguous();
if (!x.is_cuda()) x = x.cuda();
if (!conv_transpose_weight.is_cuda()) conv_transpose_weight = conv_transpose_weight.cuda();
if (!conv_transpose_bias.is_cuda()) conv_transpose_bias = conv_transpose_bias.cuda();
if (!group_norm_weight.is_cuda()) group_norm_weight = group_norm_weight.cuda();
if (!group_norm_bias.is_cuda()) group_norm_bias = group_norm_bias.cuda();
// Perform transposed convolution
auto conv_out = at::conv_transpose2d(x, conv_transpose_weight, conv_transpose_bias, {stride});
auto output = at::empty_like(conv_out);
int N = conv_out.size(0);
int C = conv_out.size(1);
int H = conv_out.size(2);
int W = conv_out.size(3);
int hw = H * W;
int channels_per_group = C / num_groups;
int group_size = channels_per_group * hw;
// Dynamically determine block size to evenly distribute the workload for each group
int threads = (group_size < 256) ? ((group_size < 32) ? 32 : group_size) : 256;
int total_groups = N * num_groups;
int shared_mem_size = 64 * sizeof(float); // Allocate enough shared memory for warp reductions
// Launch one block per group
fused_gelu_group_norm_kernel<<<total_groups, threads, shared_mem_size>>>(
conv_out.data_ptr<float>(),
output.data_ptr<float>(),
group_size,
hw,
channels_per_group,
C,
num_groups,
1e-5f,
group_norm_weight.data_ptr<float>(),
group_norm_bias.data_ptr<float>()
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused ConvTranspose2d with GELU+GroupNorm with Even Workload Distribution (CUDA)");
}
Metric | Value | Unit | Variance | Samples |
---|---|---|---|---|
Executed Ipc Active | 1.374 | inst/cycle | 0.000 | 5 |
Executed Ipc Elapsed | 1.272 | inst/cycle | 0.000 | 5 |
Issue Slots Busy | 34.336 | % | 0.008 | 5 |
Issued Ipc Active | 1.374 | inst/cycle | 0.000 | 5 |
SM Busy | 34.336 | % | 0.008 | 5 |
Memory Throughput | 2658103412848.900 | byte/second | 372477460704455753728.000 | 5 |
Mem Busy | 42.836 | % | 0.094 | 5 |
Max Bandwidth | 79.304 | % | 0.326 | 5 |
L1/TEX Hit Rate | 30.132 | % | 0.000 | 5 |
L2 Hit Rate | 50.558 | % | 0.003 | 5 |
Mem Pipes Busy | 8.070 | % | 0.004 | 5 |
Warp Cycles Per Issued Instruction | 43.126 | cycle | 0.035 | 5 |
Warp Cycles Per Executed Instruction | 43.150 | cycle | 0.036 | 5 |
Avg. Active Threads Per Warp | 31.730 | 0.000 | 5 | |
Avg. Not Predicated Off Threads Per Warp | 29.060 | 0.000 | 5 | |
Max Active Clusters | 0.000 | cluster | 0.000 | 5 |
Max Cluster Size | 8.000 | block | 0.000 | 5 |
Overall GPU Occupancy | 0.000 | % | 0.000 | 5 |
Cluster Occupancy | 0.000 | % | 0.000 | 5 |
Block Limit SM | 32.000 | block | 0.000 | 5 |
Block Limit Registers | 8.000 | block | 0.000 | 5 |
Block Limit Shared Mem | 21.000 | block | 0.000 | 5 |
Block Limit Warps | 8.000 | block | 0.000 | 5 |
Theoretical Active Warps per SM | 64.000 | warp | 0.000 | 5 |
Theoretical Occupancy | 100.000 | % | 0.000 | 5 |
Achieved Occupancy | 92.710 | % | 0.013 | 5 |
Achieved Active Warps Per SM | 59.336 | warp | 0.006 | 5 |
Rule | Description |
---|---|
INF HighPipeUtilization | ALU is the highest-utilized pipeline (22.8%) based on active cycles, taking into account the rates of its different instructions. It executes integer and logic operations. It is well-utilized, but should not be a bottleneck. |
INF CPIStall | Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason. |
INF Occupancy | This kernel's theoretical occupancy is not impacted by any block limit. |
Operation / Metric | Value | Unit |
---|---|---|
aten::fill_ | ||
CPU Time | 1564465.90 | μs |
Device Time | 588051.33 | μs |
Self CPU Time | 26668.75 | μs |
Self Device Time | 588051.33 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::zero_ | ||
CPU Time | 1580993.09 | μs |
Device Time | 588051.33 | μs |
Self CPU Time | 16553.39 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::conv_transpose2d | ||
CPU Time | 1469355.53 | μs |
Device Time | 2522924.26 | μs |
Self CPU Time | 13515.91 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::convolution | ||
CPU Time | 1455839.63 | μs |
Device Time | 2522924.26 | μs |
Self CPU Time | 18250.09 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_convolution | ||
CPU Time | 1437589.53 | μs |
Device Time | 2522924.26 | μs |
Self CPU Time | 36478.23 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::cudnn_convolution_transpose | ||
CPU Time | 793976.38 | μs |
Device Time | 1586775.50 | μs |
Self CPU Time | 206000.26 | μs |
Self Device Time | 1586775.50 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaLaunchKernel | ||
CPU Time | 3240858.36 | μs |
Device Time | 40781.29 | μs |
Self CPU Time | 3240858.36 | μs |
Self Device Time | 40781.29 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
fused_gelu_group_norm_kernel(float const*, float*, int, int, int, int, int, float, float const*, float const*) | ||
CPU Time | 0.00 | μs |
Device Time | 1573543.82 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 1573543.82 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45295 warnings generated when compiling for host. Suppressed 45327 warnings (45280 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.