27_Conv3d_HardSwish_ReLU_Softmax_Mean
• atomic_operations_optimized_edit_1
import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(
x: torch.Tensor,
conv_weight: torch.Tensor,
conv_bias: torch.Tensor,
) -> torch.Tensor:
"""
Applies 3D convolution, HardSwish, ReLU, Softmax and mean reduction.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_channels, depth, height, width)
conv_weight (torch.Tensor): 3D convolution weight tensor of shape
(out_channels, in_channels, kernel_size, kernel_size, kernel_size)
conv_bias (torch.Tensor): Bias tensor for 3D convolution of shape (out_channels)
Returns:
torch.Tensor: Output tensor after applying convolution, activations and reduction,
with shape (batch_size, out_channels)
"""
x = F.conv3d(x, conv_weight, bias=conv_bias)
x = F.hardswish(x)
x = F.relu(x)
x = F.softmax(x, dim=1)
x = torch.mean(x, dim=[2, 3, 4])
return x
class Model(nn.Module):
"""
Simple model that performs a 3D convolution, applies HardSwish, ReLU, Softmax, and then calculates the mean.
"""
def __init__(self, in_channels, out_channels, kernel_size):
super(Model, self).__init__()
conv = nn.Conv3d(in_channels, out_channels, kernel_size)
self.conv_weight = nn.Parameter(conv.weight)
self.conv_bias = nn.Parameter(conv.bias + torch.ones_like(conv.bias) * 0.02)
def forward(self, x, fn=module_fn):
return fn(x, self.conv_weight, self.conv_bias)
batch_size = 128
in_channels = 3
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = 3
def get_inputs():
return [torch.randn(batch_size, in_channels, depth, height, width)]
def get_init_inputs():
return [in_channels, out_channels, kernel_size]
import torch
import torch.nn as nn
class Model(nn.Module):
"""
Simple model that performs a 3D convolution, applies HardSwish, ReLU, Softmax, and then calculates the mean.
"""
def __init__(self, in_channels, out_channels, kernel_size, bias=True):
super(Model, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, bias=bias)
self.conv.bias = nn.Parameter(self.conv.bias + torch.ones_like(self.conv.bias) * 0.02)
def forward(self, x):
x = self.conv(x)
x = torch.nn.functional.hardswish(x)
x = torch.relu(x)
x = torch.softmax(x, dim=1)
x = torch.mean(x, dim=[2, 3, 4])
return x
batch_size = 128
in_channels = 3
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = 3
def get_inputs():
return [torch.randn(batch_size, in_channels, depth, height, width)]
def get_init_inputs():
return [in_channels, out_channels, kernel_size]
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include <cfloat>
// Combined HardSwish and ReLU CUDA kernel
__global__ void hardswish_relu_kernel(float* __restrict__ input, float* __restrict__ output, int64_t size) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < size) {
float x = input[index];
float relu6 = fminf(fmaxf(x + 3.0f, 0.0f), 6.0f);
float hardswish = x * relu6 / 6.0f;
output[index] = fmaxf(hardswish, 0.0f);
}
}
// Softmax kernel optimized with atomic operations for summation
__global__ void softmax_kernel(float* __restrict__ input, float* __restrict__ output,
int batch_size, int channels, int spatial_size) {
extern __shared__ float shared_data[]; // shared memory for reduction
int index = blockIdx.x * blockDim.x + threadIdx.x;
int thread_idx = threadIdx.x;
int total_elements = batch_size * spatial_size;
if (index < total_elements) {
int batch_idx = index / spatial_size;
int spatial_idx = index % spatial_size;
// Compute thread-specific max for numerical stability
float local_max = -FLT_MAX;
for (int c = 0; c < channels; ++c) {
int idx = batch_idx * channels * spatial_size + c * spatial_size + spatial_idx;
local_max = fmaxf(local_max, input[idx]);
}
// Each block processes a segment, perform the maximum reduction
shared_data[thread_idx] = local_max;
__syncwarp();
// Reduce within the block
for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
if (thread_idx < stride) {
shared_data[thread_idx] = fmaxf(shared_data[thread_idx], shared_data[thread_idx + stride]);
}
__syncwarp();
}
float max_val = shared_data[0]; // maximum value found by block
__syncwarp();
// Compute exponential sums
float sum_exp = 0.0f;
for (int c = 0; c < channels; ++c) {
int idx = batch_idx * channels * spatial_size + c * spatial_size + spatial_idx;
float exp_val = expf(input[idx] - max_val);
output[idx] = exp_val;
sum_exp += exp_val;
}
// Normalize, using atomic add to sum contributions from each thread in the block
float inv_sum = 1.0f / sum_exp;
for (int c = 0; c < channels; ++c) {
int idx = batch_idx * channels * spatial_size + c * spatial_size + spatial_idx;
output[idx] *= inv_sum;
}
}
}
// Module forward function
// Performs 3D convolution, then fused HardSwish+ReLU, followed by softmax along channels and mean reduction
torch::Tensor module_forward(
torch::Tensor x,
torch::Tensor conv_weight,
torch::Tensor conv_bias)
{
x = x.contiguous().cuda();
conv_weight = conv_weight.contiguous().cuda();
conv_bias = conv_bias.contiguous().cuda();
x = torch::conv3d(x, conv_weight, conv_bias);
int64_t batch_size = x.size(0);
int64_t channels = x.size(1);
int64_t depth = x.size(2);
int64_t height = x.size(3);
int64_t width = x.size(4);
int64_t total_size = batch_size * channels * depth * height * width;
torch::Tensor x_combined = torch::empty_like(x);
// Launch combined HardSwish+ReLU kernel with 256 threads per block
const int threads = 256;
int blocks = (total_size + threads - 1) / threads;
hardswish_relu_kernel<<<blocks, threads>>>(
x.data_ptr<float>(), x_combined.data_ptr<float>(), total_size);
int64_t spatial_size = depth * height * width;
x_combined = x_combined.view({batch_size, channels, spatial_size});
torch::Tensor x_softmax = torch::empty_like(x_combined);
int total_elements = batch_size * spatial_size;
blocks = (total_elements + threads - 1) / threads;
size_t shared_mem_size = threads * sizeof(float); // Allocate shared memory for max reduction
softmax_kernel<<<blocks, threads, shared_mem_size>>>(
x_combined.data_ptr<float>(), x_softmax.data_ptr<float>(),
batch_size, channels, spatial_size);
torch::Tensor output = x_softmax.view({batch_size, channels, depth, height, width}).mean({2, 3, 4});
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &module_forward, "CUDA module forward");
}
Metric | Value | Unit | Variance | Samples |
---|---|---|---|---|
Executed Ipc Active | 1.388 | inst/cycle | 0.000 | 5 |
Executed Ipc Elapsed | 1.330 | inst/cycle | 0.000 | 5 |
Issue Slots Busy | 34.796 | % | 0.075 | 5 |
Issued Ipc Active | 1.392 | inst/cycle | 0.000 | 5 |
SM Busy | 34.796 | % | 0.075 | 5 |
Memory Throughput | 2173159598790.414 | byte/second | 248020572224845447168.000 | 5 |
Mem Busy | 64.846 | % | 0.247 | 5 |
Max Bandwidth | 69.162 | % | 0.223 | 5 |
L1/TEX Hit Rate | 55.160 | % | 0.000 | 5 |
L2 Hit Rate | 72.454 | % | 0.001 | 5 |
Mem Pipes Busy | 29.830 | % | 0.076 | 5 |
Warp Cycles Per Issued Instruction | 38.018 | cycle | 0.036 | 5 |
Warp Cycles Per Executed Instruction | 38.076 | cycle | 0.036 | 5 |
Avg. Active Threads Per Warp | 32.000 | 0.000 | 5 | |
Avg. Not Predicated Off Threads Per Warp | 28.510 | 0.000 | 5 | |
Max Active Clusters | 0.000 | cluster | 0.000 | 5 |
Max Cluster Size | 8.000 | block | 0.000 | 5 |
Overall GPU Occupancy | 0.000 | % | 0.000 | 5 |
Cluster Occupancy | 0.000 | % | 0.000 | 5 |
Block Limit SM | 32.000 | block | 0.000 | 5 |
Block Limit Registers | 8.000 | block | 0.000 | 5 |
Block Limit Shared Mem | 16.000 | block | 0.000 | 5 |
Block Limit Warps | 8.000 | block | 0.000 | 5 |
Theoretical Active Warps per SM | 64.000 | warp | 0.000 | 5 |
Theoretical Occupancy | 100.000 | % | 0.000 | 5 |
Achieved Occupancy | 82.936 | % | 0.009 | 5 |
Achieved Active Warps Per SM | 53.078 | warp | 0.004 | 5 |
Rule | Description |
---|---|
WRN HighPipeUtilization | All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details. |
INF CPIStall | Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason. |
WRN Occupancy | This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (82.9%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy. |
Operation / Metric | Value | Unit |
---|---|---|
aten::fill_ | ||
CPU Time | 3118785.43 | μs |
Device Time | 428964.29 | μs |
Self CPU Time | 19154.59 | μs |
Self Device Time | 428964.29 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::conv3d | ||
CPU Time | 1119890.08 | μs |
Device Time | 3807159.51 | μs |
Self CPU Time | 12332.96 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::convolution | ||
CPU Time | 1107557.12 | μs |
Device Time | 3807159.51 | μs |
Self CPU Time | 13970.94 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_convolution | ||
CPU Time | 1093586.18 | μs |
Device Time | 3807159.51 | μs |
Self CPU Time | 28463.31 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::cudnn_convolution | ||
CPU Time | 587090.86 | μs |
Device Time | 3305144.94 | μs |
Self CPU Time | 186891.48 | μs |
Self Device Time | 3305144.94 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
sm80_xmma_fprop_implicit_gemm_indexed_f32f32_f32f32_f32_nchwkcrs_nchw_tilesize32x32x8_stage3_warpsize1x2x1_g1_ffma_aligna4_alignc4_execute_kernel__5x_cudnn | ||
CPU Time | 0.00 | μs |
Device Time | 3305143.43 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 3305143.43 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaLaunchKernel | ||
CPU Time | 4128212.95 | μs |
Device Time | 99115.02 | μs |
Self CPU Time | 4128212.95 | μs |
Self Device Time | 99115.02 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::zero_ | ||
CPU Time | 3132124.66 | μs |
Device Time | 428964.29 | μs |
Self CPU Time | 13356.64 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45288 warnings generated when compiling for host. Suppressed 45323 warnings (45276 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.