43_Conv3d_Max_LogSumExp_ReLU
• optimized_fused_3d_kernel_base
import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(
x: torch.Tensor,
stride: int,
padding: int,
conv_weight: torch.Tensor,
conv_bias: torch.Tensor,
) -> torch.Tensor:
"""
Applies 3D convolution, max pooling, log sum exp, and ReLU activation.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_channels, depth, height, width)
stride (int): Stride of the convolution
padding (int): Padding of the convolution
conv_weight (torch.Tensor): Convolution weight tensor
conv_bias (torch.Tensor): Convolution bias tensor
Returns:
torch.Tensor: Output tensor after applying convolution, max pooling, logsumexp and ReLU
"""
x = F.conv3d(x, conv_weight, bias=conv_bias, stride=stride, padding=padding)
x = F.max_pool3d(x, kernel_size=2, stride=2)
x = torch.logsumexp(x, dim=1, keepdim=True)
x = F.relu(x)
return x
class Model(nn.Module):
"""
Model that performs a 3D convolution, max pooling, log sum exp, and ReLU activation.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
super(Model, self).__init__()
conv = nn.Conv3d(
in_channels, out_channels, kernel_size, stride=stride, padding=padding
)
self.conv_weight = nn.Parameter(conv.weight)
self.conv_bias = nn.Parameter(
conv.bias
+ torch.randn(
conv.bias.shape, device=conv.bias.device, dtype=conv.bias.dtype
)
* 0.02
)
def forward(self, x, stride, padding, fn=module_fn):
return fn(x, stride, padding, self.conv_weight, self.conv_bias)
batch_size = 128
in_channels = 3
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = 3
stride = 1
padding = 1
def get_inputs():
return [torch.randn(batch_size, in_channels, depth, height, width), stride, padding]
def get_init_inputs():
return [in_channels, out_channels, kernel_size, stride, padding]
import torch
import torch.nn as nn
class Model(nn.Module):
"""
Model that performs a 3D convolution, max pooling, log sum exp, and ReLU activation.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
super(Model, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)
self.conv.bias = nn.Parameter(self.conv.bias + torch.randn(self.conv.bias.shape, device=self.conv.bias.device, dtype=self.conv.bias.dtype) * 0.02)
self.max_pool = nn.MaxPool3d(kernel_size=2, stride=2)
def forward(self, x):
"""
Args:
x: Input tensor of shape (batch_size, in_channels, depth, height, width)
Returns:
Output tensor of shape (batch_size, out_channels, depth', height', width')
"""
x = self.conv(x)
x = self.max_pool(x)
x = torch.logsumexp(x, dim=1, keepdim=True)
x = torch.relu(x)
return x
batch_size = 128
in_channels = 3
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = 3
stride = 1
padding = 1
def get_inputs():
return [torch.randn(batch_size, in_channels, depth, height, width)]
def get_init_inputs():
return [in_channels, out_channels, kernel_size, stride, padding]
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include <cfloat>
// Shared memory kernel for better performance
__global__ void optimized_fused_kernel(
const float* __restrict__ input,
float* __restrict__ output,
const int N, const int C, const int D, const int H, const int W) {
extern __shared__ float shared_data[];
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int num_blocks = gridDim.x;
const int stride = D * H * W;
const int total_elements = N * D * H * W;
// Process multiple elements per thread using grid-stride loop
for (int idx = bid * blockDim.x + tid; idx < total_elements; idx += blockDim.x * num_blocks) {
// Decode indices
const int w = idx % W;
int temp = idx / W;
const int h = temp % H;
temp /= H;
const int d = temp % D;
const int n = temp / D;
// Use shared memory for intermediate calculations
float max_val = -FLT_MAX;
float local_sum = 0.0f;
// First pass: find maximum (coalesced memory access)
#pragma unroll 4
for (int c = 0; c < C; ++c) {
const int input_idx = n * (C * stride) + c * stride + d * (H * W) + h * W + w;
max_val = fmaxf(max_val, input[input_idx]);
}
// Store max_val in shared memory
shared_data[tid] = max_val;
__syncthreads();
// Second pass: compute sum of exponentials
#pragma unroll 4
for (int c = 0; c < C; ++c) {
const int input_idx = n * (C * stride) + c * stride + d * (H * W) + h * W + w;
local_sum += __expf(input[input_idx] - shared_data[tid]);
}
// Compute final result with ReLU using intrinsics for better performance
float result = shared_data[tid] + __logf(local_sum);
result = fmaxf(0.0f, result);
// Write to output (coalesced write)
output[idx] = result;
}
}
torch::Tensor forward(
torch::Tensor x,
int64_t stride,
int64_t padding,
torch::Tensor conv_weight,
torch::Tensor conv_bias) {
// Ensure input tensors are contiguous
x = x.contiguous();
conv_weight = conv_weight.contiguous();
conv_bias = conv_bias.contiguous();
// Perform 3D convolution using PyTorch
auto conv_result = torch::conv3d(x, conv_weight, conv_bias,
{stride, stride, stride},
{padding, padding, padding});
// Perform max pooling using PyTorch
auto pool_result = torch::max_pool3d(conv_result, {2, 2, 2}, {2, 2, 2});
const int N = pool_result.size(0);
const int C = pool_result.size(1);
const int D = pool_result.size(2);
const int H = pool_result.size(3);
const int W = pool_result.size(4);
auto output = torch::empty({N, 1, D, H, W}, pool_result.options());
// Optimize kernel launch configuration
const int block_size = 256;
const int num_blocks = std::min(65535, (N * D * H * W + block_size - 1) / block_size);
optimized_fused_kernel<<<num_blocks, block_size, block_size * sizeof(float)>>>(
pool_result.data_ptr<float>(),
output.data_ptr<float>(),
N, C, D, H, W
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Optimized fused 3D operations");
}
Metric | Value | Unit | Variance | Samples |
---|---|---|---|---|
Executed Ipc Active | 1.714 | inst/cycle | 0.000 | 5 |
Executed Ipc Elapsed | 1.228 | inst/cycle | 0.002 | 5 |
Issue Slots Busy | 43.190 | % | 0.008 | 5 |
Issued Ipc Active | 1.728 | inst/cycle | 0.000 | 5 |
SM Busy | 43.190 | % | 0.008 | 5 |
Memory Throughput | 1516485276016.050 | byte/second | 2080089818128557015040.000 | 5 |
Mem Busy | 28.898 | % | 0.683 | 5 |
Max Bandwidth | 45.372 | % | 1.745 | 5 |
L1/TEX Hit Rate | 41.868 | % | 0.008 | 5 |
L2 Hit Rate | 22.278 | % | 0.013 | 5 |
Mem Pipes Busy | 13.178 | % | 0.165 | 5 |
Warp Cycles Per Issued Instruction | 30.874 | cycle | 0.015 | 5 |
Warp Cycles Per Executed Instruction | 31.136 | cycle | 0.016 | 5 |
Avg. Active Threads Per Warp | 32.000 | 0.000 | 5 | |
Avg. Not Predicated Off Threads Per Warp | 29.730 | 0.000 | 5 | |
Max Active Clusters | 0.000 | cluster | 0.000 | 5 |
Max Cluster Size | 8.000 | block | 0.000 | 5 |
Overall GPU Occupancy | 0.000 | % | 0.000 | 5 |
Cluster Occupancy | 0.000 | % | 0.000 | 5 |
Block Limit SM | 32.000 | block | 0.000 | 5 |
Block Limit Registers | 8.000 | block | 0.000 | 5 |
Block Limit Shared Mem | 16.000 | block | 0.000 | 5 |
Block Limit Warps | 8.000 | block | 0.000 | 5 |
Theoretical Active Warps per SM | 64.000 | warp | 0.000 | 5 |
Theoretical Occupancy | 100.000 | % | 0.000 | 5 |
Achieved Occupancy | 83.612 | % | 0.028 | 5 |
Achieved Active Warps Per SM | 53.512 | warp | 0.012 | 5 |
Rule | Description |
---|---|
INF HighPipeUtilization | ALU is the highest-utilized pipeline (24.9%) based on active cycles, taking into account the rates of its different instructions. It executes integer and logic operations. It is well-utilized, but should not be a bottleneck. |
INF CPIStall | Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason. |
WRN Occupancy | This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (83.9%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy. |
Operation / Metric | Value | Unit |
---|---|---|
aten::conv3d | ||
CPU Time | 4905423.80 | μs |
Device Time | 4984817.89 | μs |
Self CPU Time | 14802.09 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::convolution | ||
CPU Time | 4890621.71 | μs |
Device Time | 4984817.89 | μs |
Self CPU Time | 20827.47 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_convolution | ||
CPU Time | 4869794.24 | μs |
Device Time | 4984817.89 | μs |
Self CPU Time | 44395.51 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::cudnn_convolution | ||
CPU Time | 3921804.18 | μs |
Device Time | 4050890.28 | μs |
Self CPU Time | 192740.84 | μs |
Self Device Time | 4050890.28 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaLaunchKernelExC | ||
CPU Time | 3695732.16 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 3695732.16 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
sm80_xmma_fprop_implicit_gemm_f32f32_f32f32_f32_nchwkcrs_nchw_tilesize128x32x8_stage3_warpsize2x2x1_g1_ffma_aligna4_alignc4_execute_kernel__5x_cudnn | ||
CPU Time | 0.00 | μs |
Device Time | 4050887.56 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 4050887.56 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45291 warnings generated when compiling for host. Suppressed 45327 warnings (45280 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.