49_ConvTranspose3d_Softmax_Sigmoid
• stride_loop_fused_softmax_sigmoid_base
import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(
x: torch.Tensor,
stride: int,
padding: int,
output_padding: int,
bias_flag: bool,
conv_transpose: torch.Tensor,
conv_transpose_bias: torch.Tensor,
) -> torch.Tensor:
"""
Applies a 3D transposed convolution operation followed by softmax and sigmoid.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_channels, D, H, W)
stride (int): Stride of the transposed convolution
padding (int): Padding of the transposed convolution
output_padding (int): Additional size added to output shape
bias_flag (bool): Whether to use bias in conv_transpose
conv_transpose (torch.Tensor): Transposed convolution weight tensor
conv_transpose_bias (torch.Tensor): Bias tensor for transposed convolution
Returns:
torch.Tensor: Output tensor after applying transposed convolution, softmax and sigmoid
"""
bias = conv_transpose_bias if bias_flag else None
x = F.conv_transpose3d(
x,
conv_transpose,
bias=bias,
stride=stride,
padding=padding,
output_padding=output_padding,
)
x = F.softmax(x, dim=1)
x = torch.sigmoid(x)
return x
class Model(nn.Module):
"""
Model that performs a 3D transposed convolution, applies Softmax and Sigmoid.
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
output_padding,
bias,
):
super(Model, self).__init__()
conv_transpose = nn.ConvTranspose3d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
bias=bias,
)
self.conv_transpose_parameter = nn.Parameter(conv_transpose.weight)
self.conv_transpose_bias = (
nn.Parameter(
conv_transpose.bias
+ torch.randn(
conv_transpose.bias.shape,
device=conv_transpose.bias.device,
dtype=conv_transpose.bias.dtype,
)
* 0.02
)
if bias
else None
)
def forward(self, x, stride, padding, output_padding, bias, fn=module_fn):
return fn(
x,
stride,
padding,
output_padding,
bias,
self.conv_transpose_parameter,
self.conv_transpose_bias,
)
batch_size = 16
in_channels = 32
out_channels = 64
D, H, W = 16, 32, 32
kernel_size = 3
stride = 2
padding = 1
output_padding = 1
bias = True
def get_inputs():
return [
torch.randn(batch_size, in_channels, D, H, W),
stride,
padding,
output_padding,
bias,
]
def get_init_inputs():
return [
in_channels,
out_channels,
kernel_size,
stride,
padding,
output_padding,
bias,
]
import torch
import torch.nn as nn
class Model(nn.Module):
"""
Model that performs a 3D transposed convolution, applies Softmax and Sigmoid.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, output_padding, bias=True):
super(Model, self).__init__()
self.conv_transpose = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=bias)
self.conv_transpose.bias = nn.Parameter(self.conv_transpose.bias + torch.randn(self.conv_transpose.bias.shape, device=self.conv_transpose.bias.device, dtype=self.conv_transpose.bias.dtype) * 0.02) if bias else None
self.softmax = nn.Softmax(dim=1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
"""
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_channels, D, H, W).
Returns:
torch.Tensor: Output tensor of shape (batch_size, out_channels, D, H, W).
"""
x = self.conv_transpose(x)
x = self.softmax(x)
x = self.sigmoid(x)
return x
batch_size = 16
in_channels = 32
out_channels = 64
D, H, W = 16, 32, 32
kernel_size = 3
stride = 2
padding = 1
output_padding = 1
def get_inputs():
return [torch.randn(batch_size, in_channels, D, H, W)]
def get_init_inputs():
return [in_channels, out_channels, kernel_size, stride, padding, output_padding]
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include <math.h>
// Define block size
#define BLOCK_SIZE 256
// Fused CUDA kernel that performs softmax across channels followed by a sigmoid activation.
// It uses an online reduction algorithm to compute the maximum and the sum of exponentials in a single pass,
// thus reducing the number of global memory accesses and iterations over the channel dimension.
// The online algorithm works as follows:
// Initialize: m = x[0], t = 1 (equivalent to exp(x[0]-m) = 1).
// For each channel c from 1 to channels-1:
// if (x[c] > m):
// t = t * exp(m - x[c]) + 1;
// m = x[c];
// else:
// t += exp(x[c] - m);
// After the loop, the softmax for a given channel is computed as:
// softmax(x[c]) = exp(x[c]-m) / t,
// and then sigmoid is applied: sigmoid(softmax) = 1 / (1 + exp(-softmax)).
template <typename scalar_t>
__global__ void stride_loop_fused_softmax_sigmoid_kernel(
const scalar_t* __restrict__ input,
scalar_t* __restrict__ output,
const int batch,
const int channels,
const int depth,
const int height,
const int width) {
// Each thread processes one spatial location (pixel) over all channels
int spatial = depth * height * width;
int total_pixels = batch * spatial;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
for (int pixel_idx = idx; pixel_idx < total_pixels; pixel_idx += blockDim.x * gridDim.x) {
// Compute batch and spatial coordinates
int b = pixel_idx / spatial;
int pixel_offset = pixel_idx % spatial;
int d = pixel_offset / (height * width);
int rem = pixel_offset % (height * width);
int h = rem / width;
int w = rem % width;
// Compute base offset for this pixel across all channels
// Data layout: [batch, channels, depth, height, width]
int base = (b * channels * spatial) + (d * height * width + h * width + w);
int stride = spatial; // jumping one channel
// Use online reduction to compute softmax normalization factors in one pass
scalar_t max_val = input[base]; // assume channels >= 1
scalar_t sum_exp = 1.0f; // corresponds to exp(input[base] - max_val) where max_val == input[base]
for (int c = 1; c < channels; c++) {
int pos = base + c * stride;
scalar_t val = input[pos];
if (val > max_val) {
// Adjust previous sum to new max
sum_exp = sum_exp * exp(max_val - val) + 1.0f;
max_val = val;
} else {
sum_exp += exp(val - max_val);
}
}
// Second pass: compute softmax and apply sigmoid activation
for (int c = 0; c < channels; c++) {
int pos = base + c * stride;
scalar_t softmax_val = exp(input[pos] - max_val) / sum_exp;
// Apply sigmoid activation to the softmax value
output[pos] = 1.0f / (1.0f + exp(-softmax_val));
}
}
}
// Forward function: applies conv_transpose3d using PyTorch built-in, then the fused CUDA kernel
torch::Tensor forward(
torch::Tensor input,
int stride,
int padding,
int output_padding,
bool bias_flag,
torch::Tensor conv_transpose,
torch::Tensor conv_transpose_bias) {
// Perform the 3D transpose convolution using PyTorch's conv_transpose3d
auto x = torch::conv_transpose3d(
input,
conv_transpose,
bias_flag ? conv_transpose_bias : torch::Tensor(),
stride,
padding,
output_padding
);
const int batch = x.size(0);
const int channels = x.size(1);
const int depth = x.size(2);
const int height = x.size(3);
const int width = x.size(4);
auto output = torch::empty_like(x);
int spatial = depth * height * width;
int total_pixels = batch * spatial;
int threads = BLOCK_SIZE;
int blocks = (total_pixels + threads - 1) / threads;
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "stride_loop_fused_softmax_sigmoid_kernel", ([&] {
stride_loop_fused_softmax_sigmoid_kernel<scalar_t><<<blocks, threads>>>(
x.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(),
batch,
channels,
depth,
height,
width
);
}));
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused ConvTranspose3d with Softmax and Sigmoid using stride loops");
}
Metric | Value | Unit | Variance | Samples |
---|---|---|---|---|
Executed Ipc Active | 1.060 | inst/cycle | 0.000 | 5 |
Executed Ipc Elapsed | 1.040 | inst/cycle | 0.000 | 5 |
Issue Slots Busy | 26.480 | % | 0.000 | 5 |
Issued Ipc Active | 1.060 | inst/cycle | 0.000 | 5 |
SM Busy | 38.180 | % | 0.001 | 5 |
Memory Throughput | 2801958849023.250 | byte/second | 8938389564859716608.000 | 5 |
Mem Busy | 45.658 | % | 0.003 | 5 |
Max Bandwidth | 83.588 | % | 0.008 | 5 |
L1/TEX Hit Rate | 0.268 | % | 0.000 | 5 |
L2 Hit Rate | 33.782 | % | 0.000 | 5 |
Mem Pipes Busy | 10.662 | % | 0.000 | 5 |
Warp Cycles Per Issued Instruction | 56.622 | cycle | 0.002 | 5 |
Warp Cycles Per Executed Instruction | 56.632 | cycle | 0.002 | 5 |
Avg. Active Threads Per Warp | 32.000 | 0.000 | 5 | |
Avg. Not Predicated Off Threads Per Warp | 26.400 | 0.000 | 5 | |
Max Active Clusters | 0.000 | cluster | 0.000 | 5 |
Max Cluster Size | 8.000 | block | 0.000 | 5 |
Overall GPU Occupancy | 0.000 | % | 0.000 | 5 |
Cluster Occupancy | 0.000 | % | 0.000 | 5 |
Block Limit SM | 32.000 | block | 0.000 | 5 |
Block Limit Registers | 8.000 | block | 0.000 | 5 |
Block Limit Shared Mem | 32.000 | block | 0.000 | 5 |
Block Limit Warps | 8.000 | block | 0.000 | 5 |
Theoretical Active Warps per SM | 64.000 | warp | 0.000 | 5 |
Theoretical Occupancy | 100.000 | % | 0.000 | 5 |
Achieved Occupancy | 93.520 | % | 0.001 | 5 |
Achieved Active Warps Per SM | 59.850 | warp | 0.000 | 5 |
Rule | Description |
---|---|
WRN HighPipeUtilization | All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details. |
INF CPIStall | Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason. |
INF Occupancy | This kernel's theoretical occupancy is not impacted by any block limit. |
Operation / Metric | Value | Unit |
---|---|---|
aten::conv_transpose3d | ||
CPU Time | 1593546.43 | μs |
Device Time | 4195882.75 | μs |
Self CPU Time | 6500.80 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::convolution | ||
CPU Time | 1587045.63 | μs |
Device Time | 4195882.75 | μs |
Self CPU Time | 9148.71 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_convolution | ||
CPU Time | 1577896.92 | μs |
Device Time | 4195882.75 | μs |
Self CPU Time | 18415.62 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::cudnn_convolution_transpose | ||
CPU Time | 505178.78 | μs |
Device Time | 2580109.70 | μs |
Self CPU Time | 160072.86 | μs |
Self Device Time | 2580109.70 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaLaunchKernel | ||
CPU Time | 4407167.84 | μs |
Device Time | 67411.46 | μs |
Self CPU Time | 4407167.84 | μs |
Self Device Time | 67411.46 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void stride_loop_fused_softmax_sigmoid_kernel<float>(float const*, float*, int, int, int, int, int) | ||
CPU Time | 0.00 | μs |
Device Time | 1965448.26 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 1965448.26 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::zero_ | ||
CPU Time | 1919871.77 | μs |
Device Time | 269114.58 | μs |
Self CPU Time | 8962.93 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::fill_ | ||
CPU Time | 1910910.72 | μs |
Device Time | 269114.58 | μs |
Self CPU Time | 14023.22 | μs |
Self Device Time | 269114.58 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45294 warnings generated when compiling for host. Suppressed 45327 warnings (45280 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.