3_ConvTranspose3d_Sum_LayerNorm_AvgPool_GELU
• warp_only_layernorm_base
import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(
x: torch.Tensor,
conv_transpose_weight: torch.Tensor,
conv_transpose_bias: torch.Tensor,
sum_weight: torch.Tensor,
norm_weight: torch.Tensor,
norm_bias: torch.Tensor,
stride: tuple,
padding: tuple,
output_padding: tuple,
pool_kernel_size: tuple,
norm_shape: tuple,
) -> torch.Tensor:
"""
Functional implementation of a sequence of operations:
1. 3D transposed convolution
2. Addition with a learnable weight
3. Layer normalization
4. 3D average pooling
5. GELU activation
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_channels, depth, height, width)
conv_transpose_weight (torch.Tensor): Weight tensor for transposed convolution
conv_transpose_bias (torch.Tensor): Bias tensor for transposed convolution
sum_weight (torch.Tensor): Learnable weight for addition
norm_weight (torch.Tensor): Weight tensor for layer normalization
norm_bias (torch.Tensor): Bias tensor for layer normalization
stride (tuple): Stride for transposed convolution, as (depth_stride, height_stride, width_stride)
padding (tuple): Padding for transposed convolution, as (depth_pad, height_pad, width_pad)
output_padding (tuple): Output padding for transposed convolution, as (depth_pad, height_pad, width_pad)
pool_kernel_size (tuple): Kernel size for average pooling, as (depth_kernel, height_kernel, width_kernel)
norm_shape (tuple): Shape for layer normalization
Returns:
torch.Tensor: Output tensor after applying all operations
"""
x = F.conv_transpose3d(
x,
conv_transpose_weight,
bias=conv_transpose_bias,
stride=stride,
padding=padding,
output_padding=output_padding,
)
x = x + sum_weight
x = F.layer_norm(x, norm_shape, norm_weight, norm_bias)
x = F.avg_pool3d(x, kernel_size=pool_kernel_size)
x = F.gelu(x)
return x
class Model(nn.Module):
"""
Model that performs a 3D transposed convolution, followed by a sum, layer normalization, average pooling, and GELU activation.
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
output_padding,
sum_weight,
norm_shape,
pool_kernel_size,
):
super(Model, self).__init__()
conv = nn.ConvTranspose3d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
)
self.conv_transpose_weight = nn.Parameter(conv.weight)
self.conv_transpose_bias = nn.Parameter(conv.bias)
self.sum_weight = nn.Parameter(torch.tensor(sum_weight))
norm = nn.LayerNorm(norm_shape)
self.norm_weight = nn.Parameter(norm.weight + torch.randn(norm_shape) * 0.02)
self.norm_bias = nn.Parameter(norm.bias + torch.randn(norm_shape) * 0.02)
def forward(
self,
x,
stride,
padding,
output_padding,
pool_kernel_size,
norm_shape,
fn=module_fn,
):
return fn(
x,
self.conv_transpose_weight,
self.conv_transpose_bias,
self.sum_weight,
self.norm_weight,
self.norm_bias,
stride,
padding,
output_padding,
pool_kernel_size,
norm_shape,
)
batch_size = 128
in_channels = 32
out_channels = 64
depth, height, width = 16, 32, 32
kernel_size = (3, 3, 3)
stride = (2, 2, 2)
padding = (1, 1, 1)
output_padding = (1, 1, 1)
sum_weight = 1.0
norm_shape = (out_channels,)
pool_kernel_size = (2, 2, 2)
def get_inputs():
return [
torch.randn(batch_size, in_channels, depth, height, width),
stride,
padding,
output_padding,
pool_kernel_size,
norm_shape,
]
def get_init_inputs():
return [
in_channels,
out_channels,
kernel_size,
stride,
padding,
output_padding,
sum_weight,
norm_shape,
pool_kernel_size,
]
import torch
import torch.nn as nn
class Model(nn.Module):
"""
Model that performs a 3D transposed convolution, followed by a sum, layer normalization, average pooling, and GELU activation.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, output_padding, sum_weight, norm_shape, pool_kernel_size):
super(Model, self).__init__()
self.conv_transpose = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding)
self.sum_weight = nn.Parameter(torch.tensor(sum_weight))
self.norm = nn.LayerNorm(norm_shape)
self.norm.weight = nn.Parameter(self.norm.weight + torch.randn(norm_shape)*0.02)
self.norm.bias = nn.Parameter(self.norm.bias + torch.randn(norm_shape)*0.02)
self.avg_pool = nn.AvgPool3d(kernel_size=pool_kernel_size)
self.gelu = nn.GELU()
def forward(self, x):
x = self.conv_transpose(x)
x = x + self.sum_weight
x = self.norm(x)
x = self.avg_pool(x)
x = self.gelu(x)
return x
batch_size = 128
in_channels = 32
out_channels = 64
depth, height, width = 16, 32, 32
kernel_size = (3, 3, 3)
stride = (2, 2, 2)
padding = (1, 1, 1)
output_padding = (1, 1, 1)
sum_weight = 1.0
norm_shape = (out_channels,)
pool_kernel_size = (2, 2, 2)
def get_inputs():
return [torch.randn(batch_size, in_channels, depth, height, width)]
def get_init_inputs():
return [in_channels, out_channels, kernel_size, stride, padding, output_padding, sum_weight, norm_shape, pool_kernel_size]
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <cuda_runtime.h>
#include <vector>
// Using a block size equal to the warp size eliminates the need for shared memory in reductions.
#define BLOCK_SIZE 32
#define WARP_SIZE 32
// This kernel performs layer normalization using warp-level primitives only.
// It is launched with one warp per normalization group so that all reduction
// operations can be done entirely with __shfl_down_sync and __shfl_sync without shared memory.
template <typename T>
__global__ void warp_only_layernorm_kernel(
const T* __restrict__ input,
const T* __restrict__ gamma,
const T* __restrict__ beta,
T* __restrict__ output,
int n1, // number of normalization groups
int n2 // elements per group
) {
// Each block is one warp, so threadIdx.x ranges from 0 to 31
int tid = threadIdx.x;
int bid = blockIdx.x; // each block processes one normalization group
int offset = bid * n2;
float local_sum = 0.0f;
float local_sq_sum = 0.0f;
// Each thread processes several elements in a strided manner
for (int i = tid; i < n2; i += BLOCK_SIZE) {
float val = __ldg(&input[offset + i]);
local_sum += val;
local_sq_sum += val * val;
}
// Perform warp-level reduction within the warp using __shfl_down_sync
for (int sh = WARP_SIZE/2; sh > 0; sh /= 2) {
local_sum += __shfl_down_sync(0xffffffff, local_sum, sh);
local_sq_sum += __shfl_down_sync(0xffffffff, local_sq_sum, sh);
}
// Thread 0 now holds the total sum and sum of squares for the group
float mean = local_sum / n2;
float variance = local_sq_sum / n2 - mean * mean;
float inv_std = rsqrtf(variance + 1e-5f);
// Broadcast mean and inv_std to all threads in the warp
mean = __shfl_sync(0xffffffff, mean, 0);
inv_std = __shfl_sync(0xffffffff, inv_std, 0);
// Normalize each element and apply scale (gamma) and bias (beta)
for (int i = tid; i < n2; i += BLOCK_SIZE) {
int idx = offset + i;
float norm_val = (__ldg(&input[idx]) - mean) * inv_std;
output[idx] = gamma[i] * norm_val + beta[i];
}
}
// The forward function applies a sequence of operations: 3D transposed convolution, addition,
// layer normalization using the warp-only kernel, average pooling, and GELU activation.
torch::Tensor forward(
torch::Tensor x,
torch::Tensor conv_transpose_weight,
torch::Tensor conv_transpose_bias,
torch::Tensor sum_weight,
torch::Tensor norm_weight,
torch::Tensor norm_bias,
std::vector<int64_t> stride,
std::vector<int64_t> padding,
std::vector<int64_t> output_padding,
std::vector<int64_t> pool_kernel_size,
std::vector<int64_t> norm_shape
) {
// Ensure tensors are contiguous
x = x.contiguous();
conv_transpose_weight = conv_transpose_weight.contiguous();
sum_weight = sum_weight.contiguous();
norm_weight = norm_weight.contiguous();
norm_bias = norm_bias.contiguous();
at::IntArrayRef strideRef(stride);
at::IntArrayRef paddingRef(padding);
at::IntArrayRef outputPaddingRef(output_padding);
at::IntArrayRef poolKernelRef(pool_kernel_size);
// 1. 3D transposed convolution
auto out = at::conv_transpose3d(
x,
conv_transpose_weight,
conv_transpose_bias,
strideRef,
paddingRef,
outputPaddingRef,
/*groups=*/1,
/*dilation=*/1
);
// 2. Elementwise addition with the sum_weight tensor
out.add_(sum_weight);
// 3. Custom layer normalization using the warp-only kernel
auto out_size = out.sizes();
int64_t n1 = 1;
for (int i = 0; i < out.dim() - norm_shape.size(); ++i) {
n1 *= out_size[i];
}
int64_t n2 = 1;
for (size_t i = 0; i < norm_shape.size(); ++i) {
n2 *= norm_shape[i];
}
auto output = torch::empty_like(out);
dim3 grid(n1);
dim3 block(BLOCK_SIZE); // Each block is one warp
warp_only_layernorm_kernel<float><<<grid, block>>>(
out.data_ptr<float>(),
norm_weight.data_ptr<float>(),
norm_bias.data_ptr<float>(),
output.data_ptr<float>(),
n1,
n2
);
cudaDeviceSynchronize();
// 4. 3D average pooling
output = at::avg_pool3d(
output,
poolKernelRef, // kernel_size
poolKernelRef, // stride (same as kernel_size)
{0, 0, 0},
false,
true
);
// 5. GELU activation
output = at::gelu(output);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Warp-level only fused layer norm forward (CUDA)");
}
Metric | Value | Unit | Variance | Samples |
---|---|---|---|---|
Executed Ipc Active | 1.150 | inst/cycle | 0.000 | 4 |
Executed Ipc Elapsed | 1.150 | inst/cycle | 0.000 | 4 |
Issue Slots Busy | 28.760 | % | 0.000 | 4 |
Issued Ipc Active | 1.150 | inst/cycle | 0.000 | 4 |
SM Busy | 28.760 | % | 0.000 | 4 |
Memory Throughput | 850523394340.100 | byte/second | 1080371527722667.750 | 4 |
Mem Busy | 18.680 | % | 0.000 | 4 |
Max Bandwidth | 25.370 | % | 0.000 | 4 |
L1/TEX Hit Rate | 59.817 | % | 0.000 | 4 |
L2 Hit Rate | 50.175 | % | 0.000 | 4 |
Mem Pipes Busy | 18.240 | % | 0.000 | 4 |
Warp Cycles Per Issued Instruction | 14.430 | cycle | 0.000 | 4 |
Warp Cycles Per Executed Instruction | 14.430 | cycle | 0.000 | 4 |
Avg. Active Threads Per Warp | 32.000 | 0.000 | 4 | |
Avg. Not Predicated Off Threads Per Warp | 30.680 | 0.000 | 4 | |
Max Active Clusters | 0.000 | cluster | 0.000 | 4 |
Max Cluster Size | 8.000 | block | 0.000 | 4 |
Overall GPU Occupancy | 0.000 | % | 0.000 | 4 |
Cluster Occupancy | 0.000 | % | 0.000 | 4 |
Block Limit SM | 32.000 | block | 0.000 | 4 |
Block Limit Registers | 64.000 | block | 0.000 | 4 |
Block Limit Shared Mem | 32.000 | block | 0.000 | 4 |
Block Limit Warps | 64.000 | block | 0.000 | 4 |
Theoretical Active Warps per SM | 32.000 | warp | 0.000 | 4 |
Theoretical Occupancy | 50.000 | % | 0.000 | 4 |
Achieved Occupancy | 26.155 | % | 0.000 | 4 |
Achieved Active Warps Per SM | 16.740 | warp | 0.000 | 4 |
Rule | Description |
---|---|
INF HighPipeUtilization | ALU is the highest-utilized pipeline (25.4%) based on active cycles, taking into account the rates of its different instructions. It executes integer and logic operations. It is well-utilized, but should not be a bottleneck. |
INF CPIStall | Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason. |
WRN Occupancy | This kernel's theoretical occupancy (50.0%) is limited by the number of blocks that can fit on the SM. This kernel's theoretical occupancy (50.0%) is limited by the required amount of shared memory. The difference between calculated theoretical (50.0%) and measured achieved occupancy (26.2%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy. |
Operation / Metric | Value | Unit |
---|---|---|
aten::randn | ||
CPU Time | 303495.74 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 100.31 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::normal_ | ||
CPU Time | 303356.17 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 303356.17 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::to | ||
CPU Time | 322688.10 | μs |
Device Time | 29082.74 | μs |
Self CPU Time | 75.77 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_to_copy | ||
CPU Time | 322612.32 | μs |
Device Time | 29082.74 | μs |
Self CPU Time | 133.50 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::conv_transpose3d | ||
CPU Time | 168466.84 | μs |
Device Time | 3534726.07 | μs |
Self CPU Time | 886.19 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::convolution | ||
CPU Time | 167580.64 | μs |
Device Time | 3534726.07 | μs |
Self CPU Time | 1140.15 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_convolution | ||
CPU Time | 166440.49 | μs |
Device Time | 3534726.07 | μs |
Self CPU Time | 2283.06 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::add_ | ||
CPU Time | 23693.57 | μs |
Device Time | 2667350.32 | μs |
Self CPU Time | 4334.57 | μs |
Self Device Time | 2667350.32 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaDeviceSynchronize | ||
CPU Time | 9763310.21 | μs |
Device Time | 55958.35 | μs |
Self CPU Time | 9763310.21 | μs |
Self Device Time | 55958.35 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void warp_only_layernorm_kernel<float>(float const*, float const*, float const*, float*, int, int) | ||
CPU Time | 0.00 | μs |
Device Time | 3763995.90 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 3763995.90 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45297 warnings generated when compiling for host. Suppressed 45331 warnings (45284 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.