89_ConvTranspose3d_MaxPool_Softmax_Subtract_Swish_Max
• constant_memory_optimization_base
import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(
x: torch.Tensor,
stride: int,
padding: int,
output_padding: int,
pool_kernel_size: int,
pool_stride: int,
pool_padding: int,
conv_transpose: torch.Tensor,
conv_transpose_bias: torch.Tensor,
subtract: torch.Tensor,
) -> torch.Tensor:
"""
Applies sequence of operations:
- ConvTranspose3d
- MaxPool3d
- Softmax
- Subtract
- Swish
- Max
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_channels, depth, height, width)
stride (int): Stride for conv transpose
padding (int): Padding for conv transpose
output_padding (int): Output padding for conv transpose
pool_kernel_size (int): Kernel size for max pooling
pool_stride (int): Stride for max pooling
pool_padding (int): Padding for max pooling
conv_transpose (torch.Tensor): Weight tensor for transposed convolution
conv_transpose_bias (torch.Tensor): Bias tensor for transposed convolution
subtract (torch.Tensor): Subtraction parameter tensor
"""
x = F.conv_transpose3d(
x,
conv_transpose,
bias=conv_transpose_bias,
stride=stride,
padding=padding,
output_padding=output_padding,
)
x = F.max_pool3d(
x, kernel_size=pool_kernel_size, stride=pool_stride, padding=pool_padding
)
x = F.softmax(x, dim=1)
x = x - subtract.view(1, -1, 1, 1, 1)
x = torch.sigmoid(x) * x # Swish
x = torch.max(x, dim=1)[0]
return x
class Model(nn.Module):
"""
A model that performs a sequence of operations:
- ConvTranspose3d
- MaxPool3d
- Softmax
- Subtract
- Swish
- Max
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
output_padding,
pool_kernel_size,
pool_stride,
pool_padding,
):
super(Model, self).__init__()
conv_transpose = nn.ConvTranspose3d(in_channels, out_channels, kernel_size)
self.conv_transpose_parameter = conv_transpose.weight
self.conv_transpose_bias = conv_transpose.bias
self.subtract_parameter = nn.Parameter(torch.randn(out_channels) * 0.02)
def forward(
self,
x,
stride,
padding,
output_padding,
pool_kernel_size,
pool_stride,
pool_padding,
fn=module_fn,
):
return fn(
x,
stride,
padding,
output_padding,
pool_kernel_size,
pool_stride,
pool_padding,
self.conv_transpose_parameter,
self.conv_transpose_bias,
self.subtract_parameter,
)
batch_size = 128
in_channels = 3
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = 3
stride = 2
padding = 1
output_padding = 1
pool_kernel_size = 2
pool_stride = 2
pool_padding = 0
def get_inputs():
return [
torch.randn(batch_size, in_channels, depth, height, width),
stride,
padding,
output_padding,
pool_kernel_size,
pool_stride,
pool_padding,
]
def get_init_inputs():
return [
in_channels,
out_channels,
kernel_size,
stride,
padding,
output_padding,
pool_kernel_size,
pool_stride,
pool_padding,
]
import torch
import torch.nn as nn
class Model(nn.Module):
"""
A model that performs a sequence of operations:
- ConvTranspose3d
- MaxPool3d
- Softmax
- Subtract
- Swish
- Max
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, output_padding, pool_kernel_size, pool_stride, pool_padding):
super(Model, self).__init__()
self.conv_transpose = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding)
self.max_pool = nn.MaxPool3d(kernel_size=pool_kernel_size, stride=pool_stride, padding=pool_padding)
self.subtract = nn.Parameter(torch.randn(out_channels)*0.02) # Assuming subtraction is element-wise across channels
def forward(self, x):
x = self.conv_transpose(x)
x = self.max_pool(x)
x = torch.softmax(x, dim=1) # Apply softmax across channels (dim=1)
x = x - self.subtract.view(1, -1, 1, 1, 1) # Subtract across channels
x = torch.sigmoid(x) * x # Swish activation
x = torch.max(x, dim=1)[0] # Max pooling across channels
return x
batch_size = 128
in_channels = 3
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = 3
stride = 2
padding = 1
output_padding = 1
pool_kernel_size = 2
pool_stride = 2
pool_padding = 0
def get_inputs():
return [torch.randn(batch_size, in_channels, depth, height, width)]
def get_init_inputs():
return [in_channels, out_channels, kernel_size, stride, padding, output_padding, pool_kernel_size, pool_stride, pool_padding]
#include <torch/extension.h>
#include <pybind11/pybind11.h>
namespace py = pybind11;
// Declare constant memory for kernel parameters
__constant__ int STRIDE[3];
__constant__ int PADDING[3];
__constant__ int OUTPUT_PADDING[3];
__constant__ int POOL_PARAMS[9]; // kernel_size, stride, padding for each dimension
// Declare constant memory for subtract tensor
__constant__ float SUBTRACT_TENSOR[1024]; // Assuming max channel size is 1024
__device__ __forceinline__ float swish(float x) {
return x * (1.0f / (1.0f + expf(-x)));
}
torch::Tensor forward(
torch::Tensor x,
int64_t stride,
int64_t padding,
int64_t output_padding,
int64_t pool_kernel_size,
int64_t pool_stride,
int64_t pool_padding,
torch::Tensor conv_transpose_weight,
torch::Tensor conv_transpose_bias,
torch::Tensor subtract_tensor
) {
// Copy constant parameters to device memory
int h_stride[3] = {stride, stride, stride};
int h_padding[3] = {padding, padding, padding};
int h_output_padding[3] = {output_padding, output_padding, output_padding};
int h_pool_params[9] = {
pool_kernel_size, pool_kernel_size, pool_kernel_size,
pool_stride, pool_stride, pool_stride,
pool_padding, pool_padding, pool_padding
};
cudaMemcpyToSymbol(STRIDE, h_stride, sizeof(int) * 3);
cudaMemcpyToSymbol(PADDING, h_padding, sizeof(int) * 3);
cudaMemcpyToSymbol(OUTPUT_PADDING, h_output_padding, sizeof(int) * 3);
cudaMemcpyToSymbol(POOL_PARAMS, h_pool_params, sizeof(int) * 9);
// Copy subtract tensor to constant memory
cudaMemcpyToSymbol(SUBTRACT_TENSOR, subtract_tensor.data_ptr<float>(), subtract_tensor.numel() * sizeof(float));
// Transposed convolution
auto out = at::conv_transpose3d(
x,
conv_transpose_weight,
conv_transpose_bias,
{stride, stride, stride},
{padding, padding, padding},
{output_padding, output_padding, output_padding},
1,
{1, 1, 1}
);
// MaxPool
out = at::max_pool3d(
out,
{pool_kernel_size, pool_kernel_size, pool_kernel_size},
{pool_stride, pool_stride, pool_stride},
{pool_padding, pool_padding, pool_padding}
);
// Softmax along channel dimension
out = at::softmax(out, 1, /*dtype=*/c10::nullopt);
// Subtract using constant memory
auto sub_view = subtract_tensor.view({1, -1, 1, 1, 1});
out = out - sub_view;
// Swish
out = out * at::sigmoid(out);
// Max over channel dimension
out = std::get<0>(out.max(/*dim=*/1, /*keepdim=*/false));
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "CUDA forward pass for the module");
}
Operation / Metric | Value | Unit |
---|---|---|
aten::to | ||
CPU Time | 426750.80 | μs |
Device Time | 2688.79 | μs |
Self CPU Time | 49.34 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_to_copy | ||
CPU Time | 426701.46 | μs |
Device Time | 2688.79 | μs |
Self CPU Time | 107.29 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::empty_strided | ||
CPU Time | 423617.94 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 111.40 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaDeviceGetStreamPriorityRange | ||
CPU Time | 423421.41 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 423421.41 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaMemcpyToSymbol | ||
CPU Time | 8445037.19 | μs |
Device Time | 23161.06 | μs |
Self CPU Time | 8445037.19 | μs |
Self Device Time | 23161.06 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::conv_transpose3d | ||
CPU Time | 289720.23 | μs |
Device Time | 6544680.84 | μs |
Self CPU Time | 3375.15 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::convolution | ||
CPU Time | 286345.09 | μs |
Device Time | 6544680.84 | μs |
Self CPU Time | 5673.87 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_convolution | ||
CPU Time | 280671.22 | μs |
Device Time | 6544680.84 | μs |
Self CPU Time | 9184.26 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::cudnn_convolution_transpose | ||
CPU Time | 243876.06 | μs |
Device Time | 5185732.36 | μs |
Self CPU Time | 99125.91 | μs |
Self Device Time | 5185732.36 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
sm90_xmma_dgrad_implicit_gemm_indexed_f32f32_tf32f32_f32_nhwckrsc_nhwc_tilesize256x64x32_warpgroupsize1x1x1_g1_strided_execute_kernel__5x_cudnn | ||
CPU Time | 0.00 | μs |
Device Time | 3664337.67 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 3664337.67 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |