78_ConvTranspose3d_Max_Max_Sum
• unroll_conv3d_max_sum_edit_1
import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(
x: torch.Tensor,
stride: int,
padding: int,
conv_transpose: torch.Tensor,
conv_transpose_bias: torch.Tensor,
) -> torch.Tensor:
"""
Applies a 3D transposed convolution operation followed by two max pooling layers and a sum operation.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_channels, depth, height, width)
stride (int): Stride of the transposed convolution
padding (int): Padding of the transposed convolution
conv_transpose (torch.Tensor): Transposed convolution weight tensor
conv_transpose_bias (torch.Tensor): Bias tensor for transposed convolution
Returns:
torch.Tensor: Output tensor after applying transposed convolution, max pooling and sum reduction
"""
x = F.conv_transpose3d(
x, conv_transpose, bias=conv_transpose_bias, stride=stride, padding=padding
)
x = F.max_pool3d(x, kernel_size=2)
x = F.max_pool3d(x, kernel_size=3)
x = torch.sum(x, dim=1, keepdim=True)
return x
class Model(nn.Module):
"""
Model that performs a 3D transposed convolution, followed by two max pooling layers and a sum operation.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
super(Model, self).__init__()
conv = nn.ConvTranspose3d(in_channels, out_channels, kernel_size)
self.conv_transpose_parameter = nn.Parameter(conv.weight)
self.conv_transpose_bias = nn.Parameter(conv.bias)
def forward(self, x, stride, padding, fn=module_fn):
return fn(
x, stride, padding, self.conv_transpose_parameter, self.conv_transpose_bias
)
batch_size = 16
in_channels = 8
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = 3
stride = 2
padding = 1
def get_inputs():
return [torch.randn(batch_size, in_channels, depth, height, width), stride, padding]
def get_init_inputs():
return [in_channels, out_channels, kernel_size, stride, padding]
import torch
import torch.nn as nn
class Model(nn.Module):
"""
Model that performs a 3D transposed convolution, followed by two max pooling layers and a sum operation.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
super(Model, self).__init__()
self.conv_transpose = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)
self.max_pool1 = nn.MaxPool3d(kernel_size=2)
self.max_pool2 = nn.MaxPool3d(kernel_size=3)
def forward(self, x):
x = self.conv_transpose(x)
x = self.max_pool1(x)
x = self.max_pool2(x)
x = torch.sum(x, dim=1, keepdim=True)
return x
batch_size = 16
in_channels = 8
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = 3
stride = 2
padding = 1
def get_inputs():
return [torch.randn(batch_size, in_channels, depth, height, width)]
def get_init_inputs():
return [in_channels, out_channels, kernel_size, stride, padding]
#include <torch/extension.h>
#include <ATen/ATen.h>
// Device function for max pooling operation with loop unrolling
__device__ float max_pool3d_window(const float* input, const int x, const int y, const int z,
const int pool_size, const int D, const int H, const int W) {
float max_val = -FLT_MAX;
const int max_pz = (pool_size < (D - z)) ? pool_size : (D - z);
const int max_py = min(pool_size, H - y);
const int max_px = min(pool_size, W - x);
// Unroll loops for performance
#pragma unroll
for(int pz = 0; pz < max_pz; pz++) {
const int z_offset = (z + pz) * H * W;
#pragma unroll
for(int py = 0; py < max_py; py++) {
const int y_offset = (y + py) * W;
#pragma unroll
for(int px = 0; px < max_px; px++) {
max_val = fmaxf(max_val, input[z_offset + y_offset + (x + px)]);
}
}
}
return max_val;
}
// Device function for channel-wise sum reduction
__device__ void channel_sum(float* output, const float* input,
const int C, const int D, const int H, const int W) {
float sum = 0.0f;
for(int c = 0; c < C; c++) {
sum += input[c * D * H * W];
}
*output = sum;
}
torch::Tensor forward(
torch::Tensor x,
int64_t stride,
int64_t padding,
torch::Tensor conv_transpose,
torch::Tensor conv_transpose_bias) {
// Ensure inputs are contiguous
x = x.contiguous();
conv_transpose = conv_transpose.contiguous();
conv_transpose_bias = conv_transpose_bias.contiguous();
// Check that inputs are on CUDA
TORCH_CHECK(x.is_cuda(), "Input x must be a CUDA tensor");
TORCH_CHECK(conv_transpose.is_cuda(), "conv_transpose must be a CUDA tensor");
TORCH_CHECK(conv_transpose_bias.is_cuda(), "conv_transpose_bias must be a CUDA tensor");
// Apply transposed convolution using ATen
auto conv_out = at::conv_transpose3d(
x,
conv_transpose,
conv_transpose_bias,
{stride, stride, stride},
{padding, padding, padding}
);
// First max pooling with kernel size 2
auto pool1_out = at::max_pool3d(conv_out, {2, 2, 2});
// Second max pooling with kernel size 3
auto pool2_out = at::max_pool3d(pool1_out, {3, 3, 3});
// Sum reduction over channels
auto output = pool2_out.sum(1, /*keepdim=*/true);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Module function forward");
}
Operation / Metric | Value | Unit |
---|---|---|
aten::conv_transpose3d | ||
CPU Time | 3001795.92 | μs |
Device Time | 3461117.19 | μs |
Self CPU Time | 12490.99 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::convolution | ||
CPU Time | 2989304.94 | μs |
Device Time | 3461117.19 | μs |
Self CPU Time | 17886.92 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_convolution | ||
CPU Time | 2971418.01 | μs |
Device Time | 3461117.19 | μs |
Self CPU Time | 36760.70 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::cudnn_convolution_transpose | ||
CPU Time | 1195294.21 | μs |
Device Time | 2710688.55 | μs |
Self CPU Time | 204524.02 | μs |
Self Device Time | 2710688.55 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaLaunchKernel | ||
CPU Time | 3922386.40 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 3922386.40 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
sm90_xmma_dgrad_implicit_gemm_indexed_f32f32_tf32f32_f32_nhwckrsc_nhwc_tilesize256x64x32_warpgroupsize1x1x1_g1_strided_execute_kernel__5x_cudnn | ||
CPU Time | 0.00 | μs |
Device Time | 1821644.72 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 1821644.72 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::add_ | ||
CPU Time | 1729854.64 | μs |
Device Time | 750428.64 | μs |
Self CPU Time | 30234.05 | μs |
Self Device Time | 750428.64 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |