89_ConvTranspose3d_MaxPool_Softmax_Subtract_Swish_Max
• fused_unroll_edit_1
import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(
x: torch.Tensor,
stride: int,
padding: int,
output_padding: int,
pool_kernel_size: int,
pool_stride: int,
pool_padding: int,
conv_transpose: torch.Tensor,
conv_transpose_bias: torch.Tensor,
subtract: torch.Tensor,
) -> torch.Tensor:
"""
Applies sequence of operations:
- ConvTranspose3d
- MaxPool3d
- Softmax
- Subtract
- Swish
- Max
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_channels, depth, height, width)
stride (int): Stride for conv transpose
padding (int): Padding for conv transpose
output_padding (int): Output padding for conv transpose
pool_kernel_size (int): Kernel size for max pooling
pool_stride (int): Stride for max pooling
pool_padding (int): Padding for max pooling
conv_transpose (torch.Tensor): Weight tensor for transposed convolution
conv_transpose_bias (torch.Tensor): Bias tensor for transposed convolution
subtract (torch.Tensor): Subtraction parameter tensor
"""
x = F.conv_transpose3d(
x,
conv_transpose,
bias=conv_transpose_bias,
stride=stride,
padding=padding,
output_padding=output_padding,
)
x = F.max_pool3d(
x, kernel_size=pool_kernel_size, stride=pool_stride, padding=pool_padding
)
x = F.softmax(x, dim=1)
x = x - subtract.view(1, -1, 1, 1, 1)
x = torch.sigmoid(x) * x # Swish
x = torch.max(x, dim=1)[0]
return x
class Model(nn.Module):
"""
A model that performs a sequence of operations:
- ConvTranspose3d
- MaxPool3d
- Softmax
- Subtract
- Swish
- Max
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
output_padding,
pool_kernel_size,
pool_stride,
pool_padding,
):
super(Model, self).__init__()
conv_transpose = nn.ConvTranspose3d(in_channels, out_channels, kernel_size)
self.conv_transpose_parameter = conv_transpose.weight
self.conv_transpose_bias = conv_transpose.bias
self.subtract_parameter = nn.Parameter(torch.randn(out_channels) * 0.02)
def forward(
self,
x,
stride,
padding,
output_padding,
pool_kernel_size,
pool_stride,
pool_padding,
fn=module_fn,
):
return fn(
x,
stride,
padding,
output_padding,
pool_kernel_size,
pool_stride,
pool_padding,
self.conv_transpose_parameter,
self.conv_transpose_bias,
self.subtract_parameter,
)
batch_size = 128
in_channels = 3
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = 3
stride = 2
padding = 1
output_padding = 1
pool_kernel_size = 2
pool_stride = 2
pool_padding = 0
def get_inputs():
return [
torch.randn(batch_size, in_channels, depth, height, width),
stride,
padding,
output_padding,
pool_kernel_size,
pool_stride,
pool_padding,
]
def get_init_inputs():
return [
in_channels,
out_channels,
kernel_size,
stride,
padding,
output_padding,
pool_kernel_size,
pool_stride,
pool_padding,
]
import torch
import torch.nn as nn
class Model(nn.Module):
"""
A model that performs a sequence of operations:
- ConvTranspose3d
- MaxPool3d
- Softmax
- Subtract
- Swish
- Max
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, output_padding, pool_kernel_size, pool_stride, pool_padding):
super(Model, self).__init__()
self.conv_transpose = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding)
self.max_pool = nn.MaxPool3d(kernel_size=pool_kernel_size, stride=pool_stride, padding=pool_padding)
self.subtract = nn.Parameter(torch.randn(out_channels)*0.02) # Assuming subtraction is element-wise across channels
def forward(self, x):
x = self.conv_transpose(x)
x = self.max_pool(x)
x = torch.softmax(x, dim=1) # Apply softmax across channels (dim=1)
x = x - self.subtract.view(1, -1, 1, 1, 1) # Subtract across channels
x = torch.sigmoid(x) * x # Swish activation
x = torch.max(x, dim=1)[0] # Max pooling across channels
return x
batch_size = 128
in_channels = 3
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = 3
stride = 2
padding = 1
output_padding = 1
pool_kernel_size = 2
pool_stride = 2
pool_padding = 0
def get_inputs():
return [torch.randn(batch_size, in_channels, depth, height, width)]
def get_init_inputs():
return [in_channels, out_channels, kernel_size, stride, padding, output_padding, pool_kernel_size, pool_stride, pool_padding]
#include <torch/extension.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cfloat>
#include <cmath>
#include <pybind11/pybind11.h>
namespace py = pybind11;
// This CUDA kernel fuses the softmax (along channel dimension), subtraction, swish activation,
// and channel-wise max reduction into a single kernel. The loops over the channel dimension
// are unrolled with #pragma unroll to reduce loop overhead. It operates on the tensor produced
// by the ConvTranspose3d and MaxPool3d operations.
// Assumptions:
// 1. Input tensor is in NCDHW layout.
// 2. 'subtract' tensor has size [C] and is broadcast along spatial dimensions.
// 3. The output tensor is of shape [N, D, H, W] representing the maximum over channels after the
// fused operations.
__global__ void fused_softmax_subtract_swish_max_kernel(
const float* __restrict__ input, // [N, C, D, H, W]
const float* __restrict__ subtract, // [C]
float* __restrict__ output, // [N, D, H, W]
int N, int C, int D, int H, int W
) {
// Compute a linear index over the output spatial locations
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = N * D * H * W;
if (idx >= total) return;
// Decode the index into (n, d, h, w) assuming contiguous layout for D, H, W
int spatialSize = D * H * W;
int n = idx / spatialSize;
int rem = idx % spatialSize;
int d = rem / (H * W);
rem = rem % (H * W);
int h = rem / W;
int w = rem % W;
// In a tensor with shape [N, C, D, H, W] in NCDHW format, the stride for the channel dimension
// is (D*H*W). For a given (n, d, h, w), the element at channel c is at:
// index = n*(C*D*H*W) + c*(D*H*W) + (d*H*W + h*W + w)
int spatialOffset = d * (H * W) + h * W + w;
int base = n * C * spatialSize + spatialOffset;
// Step 1: Compute the maximum value over the channel dimension for numerical stability.
float max_val = -FLT_MAX;
#pragma unroll
for (int c = 0; c < 64; c++) { // unrolling hint; actual loop will run for c < C
if (c < C) {
int in_index = base + c * spatialSize;
float val = input[in_index];
if (val > max_val)
max_val = val;
}
}
// Step 2: Compute the sum of exponentials for softmax.
float sum_exp = 0.0f;
#pragma unroll
for (int c = 0; c < 64; c++) {
if (c < C) {
int in_index = base + c * spatialSize;
float tmp = input[in_index] - max_val;
float exp_val = expf(tmp);
sum_exp += exp_val;
}
}
// Step 3: For each channel, compute the softmax value, subtract the channel-specific bias,
// apply the swish activation, and reduce to find the maximum activation across channels.
float max_swish = -FLT_MAX;
#pragma unroll
for (int c = 0; c < 64; c++) {
if (c < C) {
int in_index = base + c * spatialSize;
float exp_val = expf(input[in_index] - max_val);
float softmax_val = exp_val / sum_exp; // Softmax output
float subtracted = softmax_val - subtract[c];
float sigmoid = 1.0f / (1.0f + expf(-subtracted));
float swish = subtracted * sigmoid;
if (swish > max_swish)
max_swish = swish;
}
}
output[idx] = max_swish;
}
// Wrapper to launch the fused kernel. This kernel fuses four operations: softmax along the channel dimension,
// subtract a broadcasted tensor, swish activation (x * sigmoid(x)), and then a max reduction over channels.
// The kernel uses #pragma unroll to attempt to unroll the channel loops (assuming C does not exceed 64).
// If C > 64, only the first 64 iterations will be unrolled in each pragma unroll; the "if (c < C)"
// covers cases where C is less. For best performance, C should be known and small at compile time.
torch::Tensor fused_forward(
torch::Tensor input, // [N, C, D, H, W]
torch::Tensor subtract_tensor // [C]
) {
// Ensure the tensors are contiguous
input = input.contiguous();
subtract_tensor = subtract_tensor.contiguous();
auto sizes = input.sizes();
int N = sizes[0];
int C = sizes[1];
int D = sizes[2];
int H = sizes[3];
int W = sizes[4];
auto output = torch::empty({N, D, H, W}, input.options());
int total_spatial = N * D * H * W;
int threads = 256;
int blocks = (total_spatial + threads - 1) / threads;
fused_softmax_subtract_swish_max_kernel<<<blocks, threads>>>(
input.data_ptr<float>(),
subtract_tensor.data_ptr<float>(),
output.data_ptr<float>(),
N, C, D, H, W
);
return output;
}
// The complete forward function performs the following operations in sequence:
// 1. ConvTranspose3d
// 2. MaxPool3d
// 3. Fused operations (softmax along channel, subtract, swish, max over channel) using a custom CUDA kernel
torch::Tensor forward(
torch::Tensor x,
int64_t stride,
int64_t padding,
int64_t output_padding,
int64_t pool_kernel_size,
int64_t pool_stride,
int64_t pool_padding,
torch::Tensor conv_transpose_weight,
torch::Tensor conv_transpose_bias,
torch::Tensor subtract_tensor
) {
// 1. Transposed convolution
auto out = at::conv_transpose3d(
x,
conv_transpose_weight,
conv_transpose_bias,
{stride, stride, stride},
{padding, padding, padding},
{output_padding, output_padding, output_padding},
1, // groups
{1, 1, 1} // dilation
);
// 2. MaxPool3d
out = at::max_pool3d(
out,
{pool_kernel_size, pool_kernel_size, pool_kernel_size},
{pool_stride, pool_stride, pool_stride},
{pool_padding, pool_padding, pool_padding}
);
// 3. Fused kernel: softmax (along channels) -> subtract -> swish -> max over channels
// The subtract_tensor is reshaped to a 1D tensor of size [C]
auto result = fused_forward(out, subtract_tensor.view({-1}));
return result;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused CUDA forward pass for 89_ConvTranspose3d_MaxPool_Softmax_Subtract_Swish_Max with loop unrolling");
}
Metric | Value | Unit | Variance | Samples |
---|---|---|---|---|
Executed Ipc Active | 0.890 | inst/cycle | 0.000 | 5 |
Executed Ipc Elapsed | 0.870 | inst/cycle | 0.000 | 5 |
Issue Slots Busy | 22.342 | % | 0.000 | 5 |
Issued Ipc Active | 0.890 | inst/cycle | 0.000 | 5 |
SM Busy | 30.422 | % | 0.000 | 5 |
Memory Throughput | 207956348308.540 | byte/second | 23633605407968464.000 | 5 |
Mem Busy | 11.762 | % | 0.000 | 5 |
Max Bandwidth | 11.762 | % | 0.000 | 5 |
L1/TEX Hit Rate | 67.920 | % | 0.000 | 5 |
L2 Hit Rate | 6.890 | % | 0.001 | 5 |
Mem Pipes Busy | 11.762 | % | 0.000 | 5 |
Warp Cycles Per Issued Instruction | 8.790 | cycle | 0.000 | 5 |
Warp Cycles Per Executed Instruction | 8.790 | cycle | 0.000 | 5 |
Avg. Active Threads Per Warp | 32.000 | 0.000 | 5 | |
Avg. Not Predicated Off Threads Per Warp | 14.190 | 0.000 | 5 | |
Max Active Clusters | 0.000 | cluster | 0.000 | 5 |
Max Cluster Size | 8.000 | block | 0.000 | 5 |
Overall GPU Occupancy | 0.000 | % | 0.000 | 5 |
Cluster Occupancy | 0.000 | % | 0.000 | 5 |
Block Limit SM | 32.000 | block | 0.000 | 5 |
Block Limit Registers | 1.000 | block | 0.000 | 5 |
Block Limit Shared Mem | 8.000 | block | 0.000 | 5 |
Block Limit Warps | 8.000 | block | 0.000 | 5 |
Theoretical Active Warps per SM | 8.000 | warp | 0.000 | 5 |
Theoretical Occupancy | 12.500 | % | 0.000 | 5 |
Achieved Occupancy | 12.290 | % | 0.000 | 5 |
Achieved Active Warps Per SM | 7.860 | warp | 0.000 | 5 |
Rule | Description |
---|---|
WRN HighPipeUtilization | All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details. |
INF CPIStall | Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason. |
WRN ThreadDivergence | Instructions are executed in warps, which are groups of 32 threads. Optimal instruction throughput is achieved if all 32 threads of a warp execute the same instruction. The chosen launch configuration, early thread completion, and divergent flow control can significantly lower the number of active threads in a warp per cycle. This kernel achieves an average of 32.0 threads being active per cycle. This is further reduced to 14.2 threads per warp due to predication. The compiler may use predication to avoid an actual branch. Instead, all instructions are scheduled, but a per-thread condition code or predicate controls which threads execute the instructions. Try to avoid different execution paths within a warp when possible. In addition, ensure your kernel makes use of Independent Thread Scheduling, which allows a warp to reconverge after a data-dependent conditional block by explicitly calling __syncwarp(). |
WRN Occupancy | This kernel's theoretical occupancy (12.5%) is limited by the number of required registers. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy. |
Operation / Metric | Value | Unit |
---|---|---|
aten::conv_transpose3d | ||
CPU Time | 8198072.10 | μs |
Device Time | 6719391.82 | μs |
Self CPU Time | 4144.28 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::convolution | ||
CPU Time | 8193927.83 | μs |
Device Time | 6719391.82 | μs |
Self CPU Time | 6032.21 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_convolution | ||
CPU Time | 8187895.62 | μs |
Device Time | 6719391.82 | μs |
Self CPU Time | 14193.21 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::cudnn_convolution_transpose | ||
CPU Time | 8139328.92 | μs |
Device Time | 5321011.62 | μs |
Self CPU Time | 114369.33 | μs |
Self Device Time | 5321011.62 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaMemsetAsync | ||
CPU Time | 5648566.07 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 5648566.07 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
sm90_xmma_dgrad_implicit_gemm_indexed_f32f32_tf32f32_f32_nhwckrsc_nhwc_tilesize256x64x32_warpgroupsize1x1x1_g1_strided_execute_kernel__5x_cudnn | ||
CPU Time | 0.00 | μs |
Device Time | 3755539.27 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 3755539.27 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45355 warnings generated when compiling for host. Suppressed 45389 warnings (45342 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.