44_ConvTranspose2d_Multiply_GlobalAvgPool_GlobalAvgPool_Mean
• optimized_spatial_reduction_base
import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(
x: torch.Tensor,
stride: int,
padding: int,
output_padding: int,
conv_transpose: torch.Tensor,
conv_transpose_bias: torch.Tensor,
multiplier: float,
) -> torch.Tensor:
"""
Applies transposed convolution, scalar multiplication, and multiple global average pooling operations.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width)
stride (int): Stride of the transposed convolution
padding (int): Padding of the transposed convolution
output_padding (int): Additional size added to output shape
conv_transpose (torch.Tensor): Transposed convolution weight tensor
conv_transpose_bias (torch.Tensor): Bias tensor for transposed convolution
multiplier (float): Scalar multiplier value
Returns:
torch.Tensor: Scalar output after applying operations
"""
x = F.conv_transpose2d(
x,
conv_transpose,
bias=conv_transpose_bias,
stride=stride,
padding=padding,
output_padding=output_padding,
)
x = x * multiplier
x = torch.mean(x, dim=[2, 3], keepdim=True)
x = torch.mean(x, dim=[2, 3], keepdim=True)
x = torch.mean(x)
return x
class Model(nn.Module):
"""
Model that performs a transposed convolution, multiplies by a scalar, applies global average pooling,
another global average pooling, and then calculates the mean.
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
output_padding,
multiplier,
):
super(Model, self).__init__()
conv = nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
)
self.conv_transpose_parameter = nn.Parameter(conv.weight)
self.conv_transpose_bias = nn.Parameter(
conv.bias
+ torch.randn(
conv.bias.shape, device=conv.bias.device, dtype=conv.bias.dtype
)
* 0.02
)
self.multiplier = multiplier
def forward(self, x, stride, padding, output_padding, fn=module_fn):
return fn(
x,
stride,
padding,
output_padding,
self.conv_transpose_parameter,
self.conv_transpose_bias,
self.multiplier,
)
batch_size = 128
in_channels = 3
out_channels = 16
height, width = 32, 32
kernel_size = 3
stride = 2
padding = 1
output_padding = 1
multiplier = 0.5
def get_inputs():
return [
torch.randn(batch_size, in_channels, height, width),
stride,
padding,
output_padding,
]
def get_init_inputs():
return [
in_channels,
out_channels,
kernel_size,
stride,
padding,
output_padding,
multiplier,
]
import torch
import torch.nn as nn
class Model(nn.Module):
"""
Model that performs a transposed convolution, multiplies by a scalar, applies global average pooling,
another global average pooling, and then calculates the mean.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, output_padding, multiplier):
super(Model, self).__init__()
self.conv_transpose = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding)
self.conv_transpose.bias = nn.Parameter(self.conv_transpose.bias + torch.randn(self.conv_transpose.bias.shape, device=self.conv_transpose.bias.device, dtype=self.conv_transpose.bias.dtype) * 0.02)
self.multiplier = multiplier
def forward(self, x):
x = self.conv_transpose(x)
x = x * self.multiplier
x = torch.mean(x, dim=[2, 3], keepdim=True) # First global average pooling
x = torch.mean(x, dim=[2, 3], keepdim=True) # Second global average pooling
x = torch.mean(x)
return x
batch_size = 128
in_channels = 3
out_channels = 16
height, width = 32, 32
kernel_size = 3
stride = 2
padding = 1
output_padding = 1
multiplier = 0.5
def get_inputs():
return [torch.randn(batch_size, in_channels, height, width)]
def get_init_inputs():
return [in_channels, out_channels, kernel_size, stride, padding, output_padding, multiplier]
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <cuda_runtime.h>
template<int BLOCK_SIZE>
__global__ void optimized_reduction_kernel(
const float* __restrict__ input,
float* __restrict__ output,
const int N,
const int C,
const int H,
const int W,
const float multiplier
) {
extern __shared__ float sdata[];
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int batch_idx = bid / C;
const int channel_idx = bid % C;
const int spatial_size = H * W;
// Calculate input offset for this block
const float* block_input = input + (batch_idx * C * spatial_size) + (channel_idx * spatial_size);
// Initialize accumulator
float sum = 0.0f;
// Process multiple elements per thread with stride pattern and apply multiplier
#pragma unroll 8 // Increased unroll factor
for (int i = tid; i < spatial_size; i += BLOCK_SIZE) {
sum += block_input[i] * multiplier;
}
// Store in shared memory
sdata[tid] = sum;
__syncthreads();
// Two-phase reduction for better performance
if (BLOCK_SIZE >= 1024) {
if (tid < 512) sdata[tid] += sdata[tid + 512];
__syncthreads();
}
if (BLOCK_SIZE >= 512) {
if (tid < 256) sdata[tid] += sdata[tid + 256];
__syncthreads();
}
if (BLOCK_SIZE >= 256) {
if (tid < 128) sdata[tid] += sdata[tid + 128];
__syncthreads();
}
if (BLOCK_SIZE >= 128) {
if (tid < 64) sdata[tid] += sdata[tid + 64];
__syncthreads();
}
// Warp-level reduction (no sync needed within a warp)
if (tid < 32) {
volatile float* vmem = sdata;
if (BLOCK_SIZE >= 64) vmem[tid] += vmem[tid + 32];
if (BLOCK_SIZE >= 32) vmem[tid] += vmem[tid + 16];
if (BLOCK_SIZE >= 16) vmem[tid] += vmem[tid + 8];
if (BLOCK_SIZE >= 8) vmem[tid] += vmem[tid + 4];
if (BLOCK_SIZE >= 4) vmem[tid] += vmem[tid + 2];
if (BLOCK_SIZE >= 2) vmem[tid] += vmem[tid + 1];
}
// First thread writes result
if (tid == 0) {
output[bid] = sdata[0] / (spatial_size); // Normalize during reduction
}
}
at::Tensor module_fn(
at::Tensor x,
int64_t stride,
int64_t padding,
int64_t output_padding,
at::Tensor conv_transpose,
at::Tensor conv_transpose_bias,
double multiplier
) {
// Apply transposed convolution
at::Tensor y = at::conv_transpose2d(
x,
conv_transpose,
conv_transpose_bias,
{stride, stride},
{padding, padding},
{output_padding, output_padding},
1,
{1, 1}
);
// Prepare output tensor
auto options = torch::TensorOptions().device(y.device()).dtype(y.dtype());
auto dims = y.sizes();
at::Tensor output = torch::zeros({dims[0], dims[1]}, options);
// Launch kernel with optimized configuration
constexpr int BLOCK_SIZE = 512; // Optimized block size
const int blocks = dims[0] * dims[1];
const int shared_mem_size = BLOCK_SIZE * sizeof(float);
optimized_reduction_kernel<BLOCK_SIZE><<<blocks, BLOCK_SIZE, shared_mem_size>>>(
y.data_ptr<float>(),
output.data_ptr<float>(),
dims[0], dims[1], dims[2], dims[3],
static_cast<float>(multiplier)
);
// Compute final mean
return output.mean();
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &module_fn, "Module function");
}
Metric | Value | Unit | Variance | Samples |
---|---|---|---|---|
Executed Ipc Active | 1.540 | inst/cycle | 0.000 | 5 |
Executed Ipc Elapsed | 1.208 | inst/cycle | 0.001 | 5 |
Issue Slots Busy | 38.652 | % | 0.031 | 5 |
Issued Ipc Active | 1.546 | inst/cycle | 0.000 | 5 |
SM Busy | 38.652 | % | 0.031 | 5 |
Memory Throughput | 2350054044898.164 | byte/second | 3791160778347151622144.000 | 5 |
Mem Busy | 39.506 | % | 1.114 | 5 |
Max Bandwidth | 70.234 | % | 3.497 | 5 |
L1/TEX Hit Rate | 0.010 | % | 0.000 | 5 |
L2 Hit Rate | 2.926 | % | 0.000 | 5 |
Mem Pipes Busy | 21.460 | % | 0.347 | 5 |
Warp Cycles Per Issued Instruction | 34.664 | cycle | 0.004 | 5 |
Warp Cycles Per Executed Instruction | 34.820 | cycle | 0.004 | 5 |
Avg. Active Threads Per Warp | 31.860 | 0.000 | 5 | |
Avg. Not Predicated Off Threads Per Warp | 27.510 | 0.000 | 5 | |
Max Active Clusters | 0.000 | cluster | 0.000 | 5 |
Max Cluster Size | 8.000 | block | 0.000 | 5 |
Overall GPU Occupancy | 0.000 | % | 0.000 | 5 |
Cluster Occupancy | 0.000 | % | 0.000 | 5 |
Block Limit SM | 32.000 | block | 0.000 | 5 |
Block Limit Registers | 5.000 | block | 0.000 | 5 |
Block Limit Shared Mem | 10.000 | block | 0.000 | 5 |
Block Limit Warps | 4.000 | block | 0.000 | 5 |
Theoretical Active Warps per SM | 64.000 | warp | 0.000 | 5 |
Theoretical Occupancy | 100.000 | % | 0.000 | 5 |
Achieved Occupancy | 84.340 | % | 0.019 | 5 |
Achieved Active Warps Per SM | 53.978 | warp | 0.007 | 5 |
Rule | Description |
---|---|
INF HighPipeUtilization | ALU is the highest-utilized pipeline (22.3%) based on active cycles, taking into account the rates of its different instructions. It executes integer and logic operations. It is well-utilized, but should not be a bottleneck. |
INF CPIStall | Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason. |
WRN Occupancy | This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (84.5%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy. |
Operation / Metric | Value | Unit |
---|---|---|
aten::conv_transpose2d | ||
CPU Time | 6596246.32 | μs |
Device Time | 5333382.66 | μs |
Self CPU Time | 54605.28 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::convolution | ||
CPU Time | 6541641.04 | μs |
Device Time | 5333382.66 | μs |
Self CPU Time | 72271.19 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_convolution | ||
CPU Time | 6469369.85 | μs |
Device Time | 5333382.66 | μs |
Self CPU Time | 151283.79 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::cudnn_convolution_transpose | ||
CPU Time | 3965701.33 | μs |
Device Time | 4338178.22 | μs |
Self CPU Time | 739830.54 | μs |
Self Device Time | 4338178.22 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaLaunchKernel | ||
CPU Time | 3929087.15 | μs |
Device Time | 1244.28 | μs |
Self CPU Time | 3929087.15 | μs |
Self Device Time | 1244.28 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::zero_ | ||
CPU Time | 510290.40 | μs |
Device Time | 2619305.19 | μs |
Self CPU Time | 110567.20 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45294 warnings generated when compiling for host. Suppressed 45327 warnings (45280 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.