6_Conv3d_Softmax_MaxPool_MaxPool
• efficient_fused_pooling_base
import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(
x: torch.Tensor,
conv_weight: torch.Tensor,
conv_bias: torch.Tensor,
) -> torch.Tensor:
"""Applies 3D convolution, softmax activation, and two max pooling operations.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_channels, depth, height, width)
conv_weight (torch.Tensor): Convolution weight tensor of shape
(out_channels, in_channels, kernel_size, kernel_size, kernel_size)
conv_bias (torch.Tensor): Bias tensor for convolution of shape (out_channels)
Returns:
torch.Tensor: Output tensor after applying convolution, softmax and max pooling,
with shape (batch_size, out_channels, depth', height', width') where:
depth' = ((depth - kernel_size + 1) // 4)
height' = ((height - kernel_size + 1) // 4)
width' = ((width - kernel_size + 1) // 4)
The //4 comes from two max pooling operations with kernel_size=2
"""
x = F.conv3d(x, conv_weight, conv_bias, stride=1, padding=0)
x = F.softmax(x, dim=1)
x = F.max_pool3d(x, kernel_size=2)
x = F.max_pool3d(x, kernel_size=2)
return x
class Model(nn.Module):
"""
Model that performs a 3D convolution, applies Softmax, and performs two max pooling operations.
"""
def __init__(self, in_channels, out_channels, kernel_size, pool_kernel_size):
super(Model, self).__init__()
conv = nn.Conv3d(in_channels, out_channels, kernel_size, padding=1)
self.conv_weight = nn.Parameter(conv.weight)
self.conv_bias = nn.Parameter(conv.bias)
def forward(self, x, fn=module_fn):
return fn(x, self.conv_weight, self.conv_bias)
batch_size = 128
in_channels = 3
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = 3
pool_kernel_size = 2
def get_inputs():
return [torch.randn(batch_size, in_channels, depth, height, width)]
def get_init_inputs():
return [in_channels, out_channels, kernel_size, pool_kernel_size]
import torch
import torch.nn as nn
class Model(nn.Module):
"""
Model that performs a 3D convolution, applies Softmax, and performs two max pooling operations.
"""
def __init__(self, in_channels, out_channels, kernel_size, pool_kernel_size):
super(Model, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size)
self.pool1 = nn.MaxPool3d(pool_kernel_size)
self.pool2 = nn.MaxPool3d(pool_kernel_size)
def forward(self, x):
"""
Args:
x: Input tensor of shape (batch_size, in_channels, depth, height, width)
Returns:
Output tensor of shape (batch_size, out_channels, depth', height', width') where depth', height', width' are the dimensions after pooling.
"""
x = self.conv(x)
x = torch.softmax(x, dim=1)
x = self.pool1(x)
x = self.pool2(x)
return x
batch_size = 128
in_channels = 3
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = 3
pool_kernel_size = 2
def get_inputs():
return [torch.randn(batch_size, in_channels, depth, height, width)]
def get_init_inputs():
return [in_channels, out_channels, kernel_size, pool_kernel_size]
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cfloat>
// Improved fused kernel for sequential max pooling operations with optimized memory access and synchronization
__global__ void improved_fused_maxpool_kernel(
const float* __restrict__ input,
float* __restrict__ output,
int N, int C, int D, int H, int W,
int outD, int outH, int outW
) {
int out_w = blockIdx.x;
int out_h = blockIdx.y;
int index = blockIdx.z;
int n = index / (C * outD);
int rem = index % (C * outD);
int c = rem / outD;
int out_d = rem % outD;
int d_start = out_d * 4;
int h_start = out_h * 4;
int w_start = out_w * 4;
int tid = threadIdx.x;
int local_d = tid / 16;
int local_h = (tid % 16) / 4;
int local_w = tid % 4;
int d = d_start + local_d;
int h = h_start + local_h;
int w = w_start + local_w;
float val = -FLT_MAX;
if (d < D && h < H && w < W) {
int input_idx = n * (C * D * H * W) + c * (D * H * W) + d * (H * W) + h * W + w;
val = input[input_idx];
}
__shared__ float shared[64];
shared[tid] = val;
__syncthreads();
for (int offset = 32; offset > 0; offset /= 2) {
if (tid < offset) {
shared[tid] = fmaxf(shared[tid], shared[tid + offset]);
}
__syncthreads();
}
if (tid == 0) {
int out_idx = n * (C * outD * outH * outW) + c * (outD * outH * outW) + out_d * (outH * outW) + out_h * outW + out_w;
output[out_idx] = shared[0];
}
}
// Forward function with optimized kernel
torch::Tensor forward(
torch::Tensor x,
torch::Tensor conv_weight,
torch::Tensor conv_bias
) {
x = x.contiguous();
conv_weight = conv_weight.contiguous();
conv_bias = conv_bias.contiguous();
auto conv_output = at::conv3d(x, conv_weight, conv_bias, {1, 1, 1}, {0, 0, 0});
auto softmax_output = at::softmax(conv_output, /*dim=*/1);
int N = softmax_output.size(0);
int C = softmax_output.size(1);
int D = softmax_output.size(2);
int H = softmax_output.size(3);
int W = softmax_output.size(4);
int outD = D / 4;
int outH = H / 4;
int outW = W / 4;
auto options = softmax_output.options();
auto output = torch::empty({N, C, outD, outH, outW}, options);
dim3 grid(outW, outH, N * C * outD);
int threads = 64;
improved_fused_maxpool_kernel<<<grid, threads>>>(
softmax_output.data_ptr<float>(),
output.data_ptr<float>(),
N, C, D, H, W,
outD, outH, outW
);
cudaDeviceSynchronize();
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Optimized CUDA forward function with improved fused pooling");
}
Metric | Value | Unit | Variance | Samples |
---|---|---|---|---|
Executed Ipc Active | 2.210 | inst/cycle | 0.000 | 5 |
Executed Ipc Elapsed | 2.164 | inst/cycle | 0.000 | 5 |
Issue Slots Busy | 55.188 | % | 0.001 | 5 |
Issued Ipc Active | 2.210 | inst/cycle | 0.000 | 5 |
SM Busy | 55.188 | % | 0.001 | 5 |
Memory Throughput | 473839971532.406 | byte/second | 169223768719032256.000 | 5 |
Mem Busy | 50.242 | % | 0.001 | 5 |
Max Bandwidth | 34.274 | % | 0.000 | 5 |
L1/TEX Hit Rate | 0.350 | % | 0.000 | 5 |
L2 Hit Rate | 65.006 | % | 0.014 | 5 |
Mem Pipes Busy | 49.878 | % | 0.001 | 5 |
Warp Cycles Per Issued Instruction | 11.404 | cycle | 0.000 | 5 |
Warp Cycles Per Executed Instruction | 11.404 | cycle | 0.000 | 5 |
Avg. Active Threads Per Warp | 30.880 | 0.000 | 5 | |
Avg. Not Predicated Off Threads Per Warp | 24.110 | 0.000 | 5 | |
Max Active Clusters | 0.000 | cluster | 0.000 | 5 |
Max Cluster Size | 8.000 | block | 0.000 | 5 |
Overall GPU Occupancy | 0.000 | % | 0.000 | 5 |
Cluster Occupancy | 0.000 | % | 0.000 | 5 |
Block Limit SM | 32.000 | block | 0.000 | 5 |
Block Limit Registers | 42.000 | block | 0.000 | 5 |
Block Limit Shared Mem | 51.000 | block | 0.000 | 5 |
Block Limit Warps | 32.000 | block | 0.000 | 5 |
Theoretical Active Warps per SM | 64.000 | warp | 0.000 | 5 |
Theoretical Occupancy | 100.000 | % | 0.000 | 5 |
Achieved Occupancy | 40.000 | % | 0.000 | 5 |
Achieved Active Warps Per SM | 25.600 | warp | 0.000 | 5 |
Rule | Description |
---|---|
INF HighPipeUtilization | ALU is the highest-utilized pipeline (38.9%) based on active cycles, taking into account the rates of its different instructions. It executes integer and logic operations. It is well-utilized, but should not be a bottleneck. |
INF CPIStall | Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason. |
WRN Occupancy | This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (40.0%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy. |
Operation / Metric | Value | Unit |
---|---|---|
aten::to | ||
CPU Time | 403803.51 | μs |
Device Time | 2473.87 | μs |
Self CPU Time | 61.39 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_to_copy | ||
CPU Time | 403742.12 | μs |
Device Time | 2473.87 | μs |
Self CPU Time | 108.66 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::empty_strided | ||
CPU Time | 400890.14 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 99.93 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaDeviceGetStreamPriorityRange | ||
CPU Time | 401993.47 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 401993.47 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::conv3d | ||
CPU Time | 365636.89 | μs |
Device Time | 4395901.35 | μs |
Self CPU Time | 11322.90 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::convolution | ||
CPU Time | 354313.99 | μs |
Device Time | 4395901.35 | μs |
Self CPU Time | 17683.98 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_convolution | ||
CPU Time | 336630.01 | μs |
Device Time | 4395901.35 | μs |
Self CPU Time | 30830.67 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::cudnn_convolution | ||
CPU Time | 234457.84 | μs |
Device Time | 3810985.88 | μs |
Self CPU Time | 160106.68 | μs |
Self Device Time | 3810985.88 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
sm80_xmma_fprop_implicit_gemm_indexed_f32f32_f32f32_f32_nchwkcrs_nchw_tilesize32x32x8_stage3_warpsize1x2x1_g1_ffma_aligna4_alignc4_execute_kernel__5x_cudnn | ||
CPU Time | 0.00 | μs |
Device Time | 3810983.36 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 3810983.36 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaDeviceSynchronize | ||
CPU Time | 6639383.13 | μs |
Device Time | 259630.89 | μs |
Self CPU Time | 6639383.13 | μs |
Self Device Time | 259630.89 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45287 warnings generated when compiling for host. Suppressed 45323 warnings (45276 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.