import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(x, params, is_training):
"""
Functional version of Model forward pass
"""
x = F.conv2d(x, params["features_conv_weight"], bias=None, stride=2, padding=3)
x = F.batch_norm(
x,
params["features_bn_mean"],
params["features_bn_var"],
params["features_bn_weight"],
params["features_bn_bias"],
training=is_training,
)
x = F.relu(x, inplace=True)
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
def dense_layer_fn(
x, bn_weight, bn_bias, bn_mean, bn_var, conv_weight, is_training
):
"""
Functional version of a single dense layer
"""
x = F.batch_norm(x, bn_mean, bn_var, bn_weight, bn_bias, training=is_training)
x = F.relu(x, inplace=True)
x = F.conv2d(x, conv_weight, bias=None, padding=1)
x = F.dropout(x, p=0.0, training=is_training)
return x
def dense_block_fn(x, layer_params, is_training):
"""
Functional version of DenseBlock
"""
features = [x]
for params in layer_params:
new_feature = dense_layer_fn(x, *params, is_training)
features.append(new_feature)
x = torch.cat(features, 1)
return x
def transition_layer_fn(
x, bn_weight, bn_bias, bn_mean, bn_var, conv_weight, is_training
):
"""
Functional version of TransitionLayer
"""
x = F.batch_norm(x, bn_mean, bn_var, bn_weight, bn_bias, training=is_training)
x = F.relu(x, inplace=True)
x = F.conv2d(x, conv_weight, bias=None) # Removed kernel_size parameter
x = F.avg_pool2d(x, kernel_size=2, stride=2)
return x
# Dense blocks and transitions
for i in range(len(params["dense_blocks"])):
x = dense_block_fn(x, params["dense_blocks"][i], is_training)
if i != len(params["dense_blocks"]) - 1:
x = transition_layer_fn(x, *params["transition_layers"][i], is_training)
x = F.batch_norm(
x,
params["final_bn_mean"],
params["final_bn_var"],
params["final_bn_weight"],
params["final_bn_bias"],
training=is_training,
)
x = F.relu(x, inplace=True)
x = F.adaptive_avg_pool2d(x, (1, 1)).view(x.size(0), -1)
x = F.linear(x, params["classifier_weight"], params["classifier_bias"])
return x
class Model(nn.Module):
def __init__(self, growth_rate=32, num_classes=1000):
super(Model, self).__init__()
self.params = nn.ParameterDict()
num_features = 64
block_layers = [6, 12, 48, 32]
device = "cuda"
# Extract initial features parameters
conv = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
bn = nn.BatchNorm2d(64)
self.params["features_conv_weight"] = nn.Parameter(conv.weight.data.clone()).to(
device
)
self.params["features_bn_weight"] = nn.Parameter(bn.weight.data.clone()).to(
device
)
self.params["features_bn_bias"] = nn.Parameter(bn.bias.data.clone()).to(device)
self.params["features_bn_mean"] = nn.Parameter(bn.running_mean.data.clone()).to(
device
)
self.params["features_bn_var"] = nn.Parameter(bn.running_var.data.clone()).to(
device
)
# Extract dense blocks parameters
self.params["dense_blocks"] = []
for num_layers in block_layers:
block_params = []
for i in range(num_layers):
in_features = num_features + i * growth_rate
bn = nn.BatchNorm2d(in_features)
conv = nn.Conv2d(
in_features, growth_rate, kernel_size=3, padding=1, bias=False
)
layer_params = [
nn.Parameter(bn.weight.data.clone()).to(device),
nn.Parameter(bn.bias.data.clone()).to(device),
nn.Parameter(bn.running_mean.data.clone()).to(device),
nn.Parameter(bn.running_var.data.clone()).to(device),
nn.Parameter(conv.weight.data.clone()).to(device),
]
block_params.append(layer_params)
self.params["dense_blocks"].append(block_params)
num_features = num_features + num_layers * growth_rate
# Extract transition layer parameters if not last block
if len(self.params.get("transition_layers", [])) < len(block_layers) - 1:
bn = nn.BatchNorm2d(num_features)
conv = nn.Conv2d(
num_features, num_features // 2, kernel_size=1, bias=False
)
if "transition_layers" not in self.params:
self.params["transition_layers"] = []
self.params["transition_layers"].append(
[
nn.Parameter(bn.weight.data.clone()).to(device),
nn.Parameter(bn.bias.data.clone()).to(device),
nn.Parameter(bn.running_mean.data.clone()).to(device),
nn.Parameter(bn.running_var.data.clone()).to(device),
nn.Parameter(conv.weight.data.clone()).to(device),
]
)
num_features = num_features // 2
# Extract final layers parameters
bn = nn.BatchNorm2d(num_features)
self.params["final_bn_weight"] = nn.Parameter(bn.weight.data.clone()).to(device)
self.params["final_bn_bias"] = nn.Parameter(bn.bias.data.clone()).to(device)
self.params["final_bn_mean"] = nn.Parameter(bn.running_mean.data.clone()).to(
device
)
self.params["final_bn_var"] = nn.Parameter(bn.running_var.data.clone()).to(
device
)
linear = nn.Linear(num_features, num_classes)
self.params["classifier_weight"] = nn.Parameter(linear.weight.data.clone()).to(
device
)
self.params["classifier_bias"] = nn.Parameter(linear.bias.data.clone()).to(
device
)
def forward(self, x, fn=module_fn):
return fn(x, self.params, self.training)
batch_size = 10
num_classes = 10
height, width = 224, 224
def get_inputs():
return [torch.randn(batch_size, 3, height, width)]
def get_init_inputs():
return [32, num_classes]
import torch
import torch.nn as nn
import torch.nn.functional as F
class DenseBlock(nn.Module):
def __init__(self, num_layers: int, num_input_features: int, growth_rate: int):
"""
:param num_layers: The number of layers in the dense block
:param num_input_features: The number of input feature maps
:param growth_rate: The growth rate for the dense block (new features added per layer)
"""
super(DenseBlock, self).__init__()
layers = []
for i in range(num_layers):
layers.append(self._make_layer(num_input_features + i * growth_rate, growth_rate))
self.layers = nn.ModuleList(layers)
def _make_layer(self, in_features: int, growth_rate: int):
"""
Creates a single layer with BatchNorm, ReLU, Conv2D, and Dropout.
"""
return nn.Sequential(
nn.BatchNorm2d(in_features),
nn.ReLU(inplace=True),
nn.Conv2d(in_features, growth_rate, kernel_size=3, padding=1, bias=False),
nn.Dropout(0.0)
)
def forward(self, x):
"""
:param x: Input tensor of shape (batch_size, num_input_features, height, width)
:return: Concatenated output tensor with shape (batch_size, num_output_features, height, width)
"""
features = [x]
for layer in self.layers:
new_feature = layer(x)
features.append(new_feature)
x = torch.cat(features, 1) # Concatenate along channel axis
return x
class TransitionLayer(nn.Module):
def __init__(self, num_input_features: int, num_output_features: int):
"""
:param num_input_features: The number of input feature maps
:param num_output_features: The number of output feature maps
"""
super(TransitionLayer, self).__init__()
self.transition = nn.Sequential(
nn.BatchNorm2d(num_input_features),
nn.ReLU(inplace=True),
nn.Conv2d(num_input_features, num_output_features, kernel_size=1, bias=False),
nn.AvgPool2d(kernel_size=2, stride=2)
)
def forward(self, x):
"""
:param x: Input tensor of shape (batch_size, num_input_features, height, width)
:return: Downsampled tensor with reduced number of feature maps
"""
return self.transition(x)
class Model(nn.Module):
def __init__(self, growth_rate: int = 32, num_classes: int = 1000):
"""
:param growth_rate: The growth rate of the DenseNet (new features added per layer)
:param num_classes: The number of output classes for classification
"""
super(Model, self).__init__()
# Initial convolution and pooling
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
# Each dense block is followed by a transition layer, except the last one
num_features = 64
block_layers = [6, 12, 48, 32] # Corresponding layers in DenseNet201
self.dense_blocks = nn.ModuleList()
self.transition_layers = nn.ModuleList()
for i, num_layers in enumerate(block_layers):
block = DenseBlock(num_layers=num_layers, num_input_features=num_features, growth_rate=growth_rate)
self.dense_blocks.append(block)
num_features = num_features + num_layers * growth_rate
if i != len(block_layers) - 1:
transition = TransitionLayer(num_input_features=num_features, num_output_features=num_features // 2)
self.transition_layers.append(transition)
num_features = num_features // 2
# Final batch norm and classifier
self.final_bn = nn.BatchNorm2d(num_features)
self.classifier = nn.Linear(num_features, num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
:param x: Input tensor of shape (batch_size, 3, height, width)
:return: Output tensor of shape (batch_size, num_classes)
"""
x = self.features(x)
for i, block in enumerate(self.dense_blocks):
x = block(x)
if i != len(self.dense_blocks) - 1:
x = self.transition_layers[i](x)
x = self.final_bn(x)
x = F.relu(x, inplace=True)
x = F.adaptive_avg_pool2d(x, (1, 1)).view(x.size(0), -1)
x = self.classifier(x)
return x
# Testing the DenseNet201 model
batch_size = 10
num_classes = 10
height, width = 224, 224 # Standard input size for DenseNet
def get_inputs():
return [torch.randn(batch_size, 3, height, width)]
def get_init_inputs():
return [32, num_classes]
#include <torch/extension.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <vector>
#include <functional>
#include <cuda_runtime.h>
#include <cmath>
namespace py = pybind11;
// Fused BatchNorm + ReLU CUDA kernel for inference mode
__global__ void fused_bn_relu_kernel(const float* __restrict__ input,
float* __restrict__ output,
const float* __restrict__ bn_weight,
const float* __restrict__ bn_bias,
const float* __restrict__ bn_mean,
const float* __restrict__ bn_var,
float eps,
int n, int c, int h, int w) {
int total = n * c * h * w;
int stride = blockDim.x * gridDim.x;
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < total; index += stride) {
int hw = h * w;
int channel = (index / hw) % c;
float val = input[index];
float norm = (val - bn_mean[channel]) / sqrtf(bn_var[channel] + eps);
float bnval = bn_weight[channel] * norm + bn_bias[channel];
// Fused ReLU activation
output[index] = bnval > 0.0f ? bnval : 0.0f;
}
}
// Host function to launch the fused BN+ReLU kernel
torch::Tensor fused_bn_relu(torch::Tensor x,
torch::Tensor bn_weight,
torch::Tensor bn_bias,
torch::Tensor bn_mean,
torch::Tensor bn_var,
float eps) {
// Ensure the input tensor is contiguous
x = x.contiguous();
auto output = torch::empty_like(x);
int n = x.size(0);
int c = x.size(1);
int h = x.size(2);
int w = x.size(3);
int total = n * c * h * w;
int threads = 256;
int blocks = (total + threads - 1) / threads;
fused_bn_relu_kernel<<<blocks, threads>>>(x.data_ptr<float>(), output.data_ptr<float>(),
bn_weight.data_ptr<float>(), bn_bias.data_ptr<float>(),
bn_mean.data_ptr<float>(), bn_var.data_ptr<float>(),
eps, n, c, h, w);
cudaDeviceSynchronize();
return output;
}
// Dense layer: applies a BN+ReLU then a convolution and dropout
torch::Tensor dense_layer_fn(
torch::Tensor x,
torch::Tensor bn_weight, // scale (gamma)
torch::Tensor bn_bias, // shift (beta)
torch::Tensor bn_mean, // running mean
torch::Tensor bn_var, // running variance
torch::Tensor conv_weight,
bool is_training) {
if (!is_training) {
// Use the fused kernel in inference mode
x = fused_bn_relu(x, bn_weight, bn_bias, bn_mean, bn_var, 1e-5);
} else {
x = at::batch_norm(x, bn_weight, bn_bias, bn_mean, bn_var, is_training, 0.1, 1e-5, true);
x = at::relu(x);
}
x = at::conv2d(x,
conv_weight,
c10::nullopt,
at::IntArrayRef(std::vector<int64_t>{1, 1}),
at::IntArrayRef(std::vector<int64_t>{1, 1}));
x = at::dropout(x, 0.0, is_training);
return x;
}
// Dense block: sequentially applies dense layers and concatenates the outputs
torch::Tensor dense_block_fn(torch::Tensor x, py::list layer_params, bool is_training) {
std::vector<torch::Tensor> features;
features.push_back(x);
for (ssize_t i = 0; i < layer_params.size(); i++) {
auto params_tuple = layer_params[i].cast<py::tuple>();
if (params_tuple.size() != 5) {
throw std::runtime_error("Each dense layer parameter set must have 5 elements.");
}
torch::Tensor bn_weight = params_tuple[0].cast<torch::Tensor>();
torch::Tensor bn_bias = params_tuple[1].cast<torch::Tensor>();
torch::Tensor bn_mean = params_tuple[2].cast<torch::Tensor>();
torch::Tensor bn_var = params_tuple[3].cast<torch::Tensor>();
torch::Tensor conv_weight = params_tuple[4].cast<torch::Tensor>();
torch::Tensor new_feature = dense_layer_fn(x, bn_weight, bn_bias, bn_mean, bn_var, conv_weight, is_training);
features.push_back(new_feature);
x = at::cat(features, 1);
}
return x;
}
// Transition layer: applies BN+ReLU, convolution and average pooling
torch::Tensor transition_layer_fn(
torch::Tensor x,
torch::Tensor bn_weight, // scale (gamma)
torch::Tensor bn_bias, // shift (beta)
torch::Tensor bn_mean, // running mean
torch::Tensor bn_var, // running variance
torch::Tensor conv_weight,
bool is_training) {
if (!is_training) {
x = fused_bn_relu(x, bn_weight, bn_bias, bn_mean, bn_var, 1e-5);
} else {
x = at::batch_norm(x, bn_weight, bn_bias, bn_mean, bn_var, is_training, 0.1, 1e-5, true);
x = at::relu(x);
}
x = at::conv2d(x,
conv_weight,
c10::nullopt,
at::IntArrayRef(std::vector<int64_t>{1, 1}),
at::IntArrayRef(std::vector<int64_t>{0, 0}));
x = at::avg_pool2d(x,
at::IntArrayRef(std::vector<int64_t>{2, 2}),
at::IntArrayRef(std::vector<int64_t>{2, 2}));
return x;
}
// Forward function orchestrates the DenseNet201 forward pass
torch::Tensor forward(torch::Tensor x, py::object params_obj, bool is_training) {
py::dict params = params_obj.cast<py::dict>();
// Initial convolution and BN+ReLU block
torch::Tensor features_conv_weight = params["features_conv_weight"].cast<torch::Tensor>();
torch::Tensor features_bn_mean = params["features_bn_mean"].cast<torch::Tensor>();
torch::Tensor features_bn_var = params["features_bn_var"].cast<torch::Tensor>();
torch::Tensor features_bn_weight = params["features_bn_weight"].cast<torch::Tensor>();
torch::Tensor features_bn_bias = params["features_bn_bias"].cast<torch::Tensor>();
x = at::conv2d(x,
features_conv_weight,
c10::nullopt,
at::IntArrayRef(std::vector<int64_t>{2, 2}),
at::IntArrayRef(std::vector<int64_t>{3, 3}));
if (!is_training) {
x = fused_bn_relu(x, features_bn_weight, features_bn_bias, features_bn_mean, features_bn_var, 1e-5);
} else {
x = at::batch_norm(x, features_bn_weight, features_bn_bias, features_bn_mean, features_bn_var, is_training, 0.1, 1e-5, true);
x = at::relu(x);
}
x = at::max_pool2d(x,
at::IntArrayRef(std::vector<int64_t>{3, 3}),
at::IntArrayRef(std::vector<int64_t>{2, 2}),
at::IntArrayRef(std::vector<int64_t>{1, 1}));
// Dense blocks and transition layers
py::list dense_blocks = params["dense_blocks"].cast<py::list>();
py::list transition_layers = params["transition_layers"].cast<py::list>();
int num_dense_blocks = dense_blocks.size();
// Precompute transition functions to avoid divergent branching
std::vector<std::function<torch::Tensor(torch::Tensor, bool)>> transition_funcs;
transition_funcs.reserve(num_dense_blocks);
for (int i = 0; i < num_dense_blocks; i++) {
if (i < num_dense_blocks - 1) {
auto trans_tuple = transition_layers[i].cast<py::tuple>();
if (trans_tuple.size() != 5) {
throw std::runtime_error("Each transition layer parameter set must have 5 elements.");
}
torch::Tensor t_bn_weight = trans_tuple[0].cast<torch::Tensor>();
torch::Tensor t_bn_bias = trans_tuple[1].cast<torch::Tensor>();
torch::Tensor t_bn_mean = trans_tuple[2].cast<torch::Tensor>();
torch::Tensor t_bn_var = trans_tuple[3].cast<torch::Tensor>();
torch::Tensor t_conv_weight = trans_tuple[4].cast<torch::Tensor>();
transition_funcs.push_back([
t_bn_weight, t_bn_bias, t_bn_mean, t_bn_var, t_conv_weight
](torch::Tensor inp, bool is_train) -> torch::Tensor {
return transition_layer_fn(inp, t_bn_weight, t_bn_bias, t_bn_mean, t_bn_var, t_conv_weight, is_train);
});
} else {
// Identity function for the final dense block
transition_funcs.push_back([](torch::Tensor inp, bool) -> torch::Tensor {
return inp;
});
}
}
// Process each dense block and apply the corresponding transition uniformly
for (int i = 0; i < num_dense_blocks; i++) {
py::list block_params = dense_blocks[i].cast<py::list>();
x = dense_block_fn(x, block_params, is_training);
x = transition_funcs[i](x, is_training);
}
// Final classifier block with fused BN+ReLU
torch::Tensor final_bn_mean = params["final_bn_mean"].cast<torch::Tensor>();
torch::Tensor final_bn_var = params["final_bn_var"].cast<torch::Tensor>();
torch::Tensor final_bn_weight = params["final_bn_weight"].cast<torch::Tensor>();
torch::Tensor final_bn_bias = params["final_bn_bias"].cast<torch::Tensor>();
if (!is_training) {
x = fused_bn_relu(x, final_bn_weight, final_bn_bias, final_bn_mean, final_bn_var, 1e-5);
} else {
x = at::batch_norm(x, final_bn_weight, final_bn_bias, final_bn_mean, final_bn_var, is_training, 0.1, 1e-5, true);
x = at::relu(x);
}
x = at::adaptive_avg_pool2d(x, at::IntArrayRef(std::vector<int64_t>{1, 1}));
x = x.view({x.size(0), -1});
torch::Tensor classifier_weight = params["classifier_weight"].cast<torch::Tensor>();
torch::Tensor classifier_bias = params["classifier_bias"].cast<torch::Tensor>();
x = at::linear(x, classifier_weight, classifier_bias);
return x;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Custom CUDA forward function with fused BN and ReLU for reduced warp divergence");
}
Metric | Value | Unit | Variance | Samples |
---|
Rule | Description |
---|
Operation / Metric | Value | Unit |
---|---|---|
aten::conv2d | ||
CPU Time | 3778946.68 | μs |
Device Time | 3553785.81 | μs |
Self CPU Time | 157241.89 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::convolution | ||
CPU Time | 3621704.79 | μs |
Device Time | 3553785.81 | μs |
Self CPU Time | 178641.83 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_convolution | ||
CPU Time | 3443062.96 | μs |
Device Time | 3553785.81 | μs |
Self CPU Time | 215029.79 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::cudnn_convolution | ||
CPU Time | 3228033.17 | μs |
Device Time | 3553785.81 | μs |
Self CPU Time | 1588469.00 | μs |
Self Device Time | 3553785.81 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::batch_norm | ||
CPU Time | 3369741.17 | μs |
Device Time | 1751852.51 | μs |
Self CPU Time | 173409.64 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
sm80_xmma_fprop_implicit_gemm_tf32f32_tf32f32_f32_nhwckrsc_nchw_tilesize64x32x64_stage5_warpsize2x2x1_g1_tensor16x8x8_alignc4_execute_kernel__5x_cudnn | ||
CPU Time | 0.00 | μs |
Device Time | 1820285.72 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 1820285.72 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45296 warnings generated when compiling for host. Suppressed 45323 warnings (45276 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.