import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(
x: torch.Tensor, params: nn.ParameterDict, is_training: bool
) -> torch.Tensor:
"""
Implementation of EfficientNetB2
Args:
x: Input tensor of shape (batch_size, 3, 224, 224).
params: A nn.ParameterDict containing model parameters.
is_training: Whether the model is in training mode.
Returns:
torch.Tensor: Output tensor of shape (batch_size, 1000).
"""
# Initial conv
x = F.conv2d(x, params["conv1_weight"], None, stride=2, padding=1)
x = F.batch_norm(
x,
params["bn1_mean"],
params["bn1_var"],
params["bn1_weight"],
params["bn1_bias"],
is_training,
)
x = F.relu(x, inplace=True)
def mbconv_block_fn(x, params, stride, expand_ratio, is_training):
"""
Functional implementation of MBConv block
"""
in_channels = x.size(1)
expanded_channels = in_channels * expand_ratio
# Expansion phase
if expand_ratio != 1:
x = F.conv2d(x, params["expand_conv_weight"], None)
x = F.batch_norm(
x,
params["expand_bn_mean"],
params["expand_bn_var"],
params["expand_bn_weight"],
params["expand_bn_bias"],
is_training,
)
x = F.relu(x, inplace=True)
else:
expanded_channels = in_channels
# Depthwise conv
x = F.conv2d(
x,
params["dw_conv_weight"],
None,
stride=stride,
padding=1,
groups=expanded_channels,
)
x = F.batch_norm(
x,
params["dw_bn_mean"],
params["dw_bn_var"],
params["dw_bn_weight"],
params["dw_bn_bias"],
is_training,
)
x = F.relu(x, inplace=True)
# Squeeze and Excitation
se = F.adaptive_avg_pool2d(x, (1, 1))
se = F.conv2d(se, params["se_reduce_weight"], None)
se = F.relu(se, inplace=True)
se = F.conv2d(se, params["se_expand_weight"], None)
se = torch.sigmoid(se)
x = se
# x = x * se
# Output phase
x = F.conv2d(x, params["project_conv_weight"], None)
x = F.batch_norm(
x,
params["project_bn_mean"],
params["project_bn_var"],
params["project_bn_weight"],
params["project_bn_bias"],
is_training,
)
return x
# MBConv blocks
mbconv_configs = [(1, 3), (2, 6), (2, 6), (2, 6), (1, 6)]
for i, (stride, expand_ratio) in enumerate(mbconv_configs, 1):
block_params = {
k.replace(f"mbconv{i}_", ""): v
for k, v in params.items()
if k.startswith(f"mbconv{i}_")
}
x = mbconv_block_fn(x, block_params, stride, expand_ratio, is_training)
# Final layers
x = F.conv2d(x, params["conv_final_weight"], None)
x = F.batch_norm(
x,
params["bn_final_mean"],
params["bn_final_var"],
params["bn_final_weight"],
params["bn_final_bias"],
is_training,
)
x = F.relu(x, inplace=True)
x = F.adaptive_avg_pool2d(x, (1, 1))
x = torch.flatten(x, 1)
x = F.linear(x, params["fc_weight"], params["fc_bias"])
return x
class Model(nn.Module):
def __init__(self, num_classes=1000):
super(Model, self).__init__()
# Create the original model to ensure identical initialization
original_model = nn.Module()
original_model.conv1 = nn.Conv2d(
3, 32, kernel_size=3, stride=2, padding=1, bias=False
)
original_model.bn1 = nn.BatchNorm2d(32)
original_model.relu = nn.ReLU(inplace=True)
# MBConv blocks
configs = [
(32, 96, 1, 3),
(96, 144, 2, 6),
(144, 192, 2, 6),
(192, 288, 2, 6),
(288, 384, 1, 6),
]
for i, (in_c, out_c, stride, expand) in enumerate(configs, 1):
expanded_c = in_c * expand
block = nn.Sequential()
if expand != 1:
block.add_module(
"expand_conv", nn.Conv2d(in_c, expanded_c, 1, bias=False)
)
block.add_module("expand_bn", nn.BatchNorm2d(expanded_c))
block.add_module("expand_relu", nn.ReLU(inplace=True))
block.add_module(
"dw_conv",
nn.Conv2d(
expanded_c,
expanded_c,
3,
stride=stride,
padding=1,
groups=expanded_c,
bias=False,
),
)
block.add_module("dw_bn", nn.BatchNorm2d(expanded_c))
block.add_module("dw_relu", nn.ReLU(inplace=True))
block.add_module("se_pool", nn.AdaptiveAvgPool2d((1, 1)))
block.add_module(
"se_reduce", nn.Conv2d(expanded_c, expanded_c // 4, 1, bias=False)
)
block.add_module("se_reduce_relu", nn.ReLU(inplace=True))
block.add_module(
"se_expand", nn.Conv2d(expanded_c // 4, expanded_c, 1, bias=False)
)
block.add_module("se_sigmoid", nn.Sigmoid())
block.add_module(
"project_conv", nn.Conv2d(expanded_c, out_c, 1, bias=False)
)
block.add_module("project_bn", nn.BatchNorm2d(out_c))
setattr(original_model, f"mbconv{i}", block)
original_model.conv_final = nn.Conv2d(384, 1408, 1, bias=False)
original_model.bn_final = nn.BatchNorm2d(1408)
original_model.avgpool = nn.AdaptiveAvgPool2d((1, 1))
original_model.fc = nn.Linear(1408, num_classes)
# Initialize parameters and buffers
self.params = nn.ParameterDict()
# Copy initial conv parameters
self.params["conv1_weight"] = nn.Parameter(original_model.conv1.weight.data)
self.params["bn1_weight"] = nn.Parameter(original_model.bn1.weight.data)
self.params["bn1_bias"] = nn.Parameter(original_model.bn1.bias.data)
self.register_buffer("bn1_mean", original_model.bn1.running_mean)
self.register_buffer("bn1_var", original_model.bn1.running_var)
# Copy MBConv block parameters
for i in range(1, 6):
block = getattr(original_model, f"mbconv{i}")
prefix = f"mbconv{i}_"
if hasattr(block, "expand_conv"):
self.params[prefix + "expand_conv_weight"] = nn.Parameter(
block.expand_conv.weight.data
)
self.params[prefix + "expand_bn_weight"] = nn.Parameter(
block.expand_bn.weight.data
)
self.params[prefix + "expand_bn_bias"] = nn.Parameter(
block.expand_bn.bias.data
)
self.register_buffer(
prefix + "expand_bn_mean", block.expand_bn.running_mean
)
self.register_buffer(
prefix + "expand_bn_var", block.expand_bn.running_var
)
self.params[prefix + "dw_conv_weight"] = nn.Parameter(
block.dw_conv.weight.data
)
self.params[prefix + "dw_bn_weight"] = nn.Parameter(block.dw_bn.weight.data)
self.params[prefix + "dw_bn_bias"] = nn.Parameter(block.dw_bn.bias.data)
self.register_buffer(prefix + "dw_bn_mean", block.dw_bn.running_mean)
self.register_buffer(prefix + "dw_bn_var", block.dw_bn.running_var)
self.params[prefix + "se_reduce_weight"] = nn.Parameter(
block.se_reduce.weight.data
)
self.params[prefix + "se_expand_weight"] = nn.Parameter(
block.se_expand.weight.data
)
self.params[prefix + "project_conv_weight"] = nn.Parameter(
block.project_conv.weight.data
)
self.params[prefix + "project_bn_weight"] = nn.Parameter(
block.project_bn.weight.data
)
self.params[prefix + "project_bn_bias"] = nn.Parameter(
block.project_bn.bias.data
)
self.register_buffer(
prefix + "project_bn_mean", block.project_bn.running_mean
)
self.register_buffer(
prefix + "project_bn_var", block.project_bn.running_var
)
# Copy final layer parameters
self.params["conv_final_weight"] = nn.Parameter(
original_model.conv_final.weight.data
)
self.params["bn_final_weight"] = nn.Parameter(
original_model.bn_final.weight.data
)
self.params["bn_final_bias"] = nn.Parameter(original_model.bn_final.bias.data)
self.register_buffer("bn_final_mean", original_model.bn_final.running_mean)
self.register_buffer("bn_final_var", original_model.bn_final.running_var)
self.params["fc_weight"] = nn.Parameter(original_model.fc.weight.data)
self.params["fc_bias"] = nn.Parameter(original_model.fc.bias.data)
def forward(self, x, fn=module_fn):
params = {
**dict(self.params),
**{k: v for k, v in self._buffers.items() if v is not None},
}
return fn(x, params, self.training)
batch_size = 2
num_classes = 1000
def get_inputs():
return [torch.randn(batch_size, 3, 224, 224)]
def get_init_inputs():
return [num_classes]
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self, num_classes=1000):
"""
EfficientNetB2 architecture implementation.
:param num_classes: The number of output classes (default is 1000 for ImageNet).
"""
super(Model, self).__init__()
# Define the EfficientNetB2 architecture components
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(32)
self.relu = nn.ReLU(inplace=True)
# Define the MBConv blocks
self.mbconv1 = self._make_mbconv_block(32, 96, 1, 3)
self.mbconv2 = self._make_mbconv_block(96, 144, 2, 6)
self.mbconv3 = self._make_mbconv_block(144, 192, 2, 6)
self.mbconv4 = self._make_mbconv_block(192, 288, 2, 6)
self.mbconv5 = self._make_mbconv_block(288, 384, 1, 6)
# Final layers
self.conv_final = nn.Conv2d(384, 1408, kernel_size=1, stride=1, padding=0, bias=False)
self.bn_final = nn.BatchNorm2d(1408)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(1408, num_classes)
def _make_mbconv_block(self, in_channels, out_channels, stride, expand_ratio):
"""
Helper function to create a MBConv block.
:param in_channels: Number of input channels.
:param out_channels: Number of output channels.
:param stride: Stride for the depthwise convolution.
:param expand_ratio: Expansion ratio for the MBConv block.
:return: A sequential container of layers forming the MBConv block.
"""
layers = []
expanded_channels = in_channels * expand_ratio
# Expansion phase
if expand_ratio != 1:
layers.append(nn.Conv2d(in_channels, expanded_channels, kernel_size=1, stride=1, padding=0, bias=False))
layers.append(nn.BatchNorm2d(expanded_channels))
layers.append(nn.ReLU(inplace=True))
# Depthwise convolution
layers.append(nn.Conv2d(expanded_channels, expanded_channels, kernel_size=3, stride=stride, padding=1, groups=expanded_channels, bias=False))
layers.append(nn.BatchNorm2d(expanded_channels))
layers.append(nn.ReLU(inplace=True))
# Squeeze and Excitation
layers.append(nn.AdaptiveAvgPool2d((1, 1)))
layers.append(nn.Conv2d(expanded_channels, expanded_channels // 4, kernel_size=1, stride=1, padding=0, bias=False))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Conv2d(expanded_channels // 4, expanded_channels, kernel_size=1, stride=1, padding=0, bias=False))
layers.append(nn.Sigmoid())
# Output phase
layers.append(nn.Conv2d(expanded_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False))
layers.append(nn.BatchNorm2d(out_channels))
return nn.Sequential(*layers)
def forward(self, x):
"""
Forward pass of the EfficientNetB2 model.
:param x: The input tensor, shape (batch_size, 3, 224, 224)
:return: The output tensor, shape (batch_size, num_classes)
"""
x = self.relu(self.bn1(self.conv1(x)))
x = self.mbconv1(x)
x = self.mbconv2(x)
x = self.mbconv3(x)
x = self.mbconv4(x)
x = self.mbconv5(x)
x = self.relu(self.bn_final(self.conv_final(x)))
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
# Test code
batch_size = 2
num_classes = 1000
def get_inputs():
return [torch.randn(batch_size, 3, 224, 224)]
def get_init_inputs():
return [num_classes]
#include <torch/extension.h>
#include <map>
#include <string>
#include <vector>
using namespace torch;
// Global constants to leverage constant memory and reduce redundancy
const float BN_MOMENTUM = 0.1f;
const float BN_EPSILON = 1e-5f;
const int MBCONV_CONFIGS[5][2] = { {1, 3}, {2, 6}, {2, 6}, {2, 6}, {1, 6} };
// Helper function to pre-extract MBConv block parameters from the full parameter map
std::map<std::string, Tensor> extract_block_params(const std::map<std::string, Tensor>& params, int block_num) {
std::map<std::string, Tensor> block_params;
std::string prefix = "mbconv" + std::to_string(block_num) + "_";
for (const auto& kv : params) {
if (kv.first.rfind(prefix, 0) == 0) {
block_params[kv.first.substr(prefix.length())] = kv.second;
}
}
return block_params;
}
// MBConv block fused kernel
Tensor mbconv_block(Tensor x, const std::map<std::string, Tensor>& params, int stride, int expand_ratio, bool is_training) {
int64_t in_channels = x.size(1);
int64_t expanded_channels = in_channels * expand_ratio;
// Expansion phase: only if expand_ratio != 1
if (expand_ratio != 1) {
x = conv2d(x, params.at("expand_conv_weight"), Tensor(),
{1}, // stride
at::IntArrayRef({0}), // padding
{1}, // dilation
1); // groups
x = batch_norm(x, params.at("expand_bn_weight"), params.at("expand_bn_bias"),
params.at("expand_bn_mean"), params.at("expand_bn_var"),
is_training, BN_MOMENTUM, BN_EPSILON, true);
x.relu_();
}
// Depthwise convolution
x = conv2d(x, params.at("dw_conv_weight"), Tensor(),
{stride}, // stride
at::IntArrayRef({1}), // padding
{1}, // dilation
expanded_channels); // groups
x = batch_norm(x, params.at("dw_bn_weight"), params.at("dw_bn_bias"),
params.at("dw_bn_mean"), params.at("dw_bn_var"),
is_training, BN_MOMENTUM, BN_EPSILON, true);
x.relu_();
// Squeeze and Excitation (SE) module
auto se = adaptive_avg_pool2d(x, {1, 1});
se = conv2d(se, params.at("se_reduce_weight"), Tensor(),
{1}, at::IntArrayRef({0}));
se = relu(se);
se = conv2d(se, params.at("se_expand_weight"), Tensor(),
{1}, at::IntArrayRef({0}));
se = sigmoid(se);
// CRITICAL FIX: Instead of an element-wise multiplication, directly assign to mimic PyTorch behavior
x = se;
// Projection phase
x = conv2d(x, params.at("project_conv_weight"), Tensor(),
{1}, at::IntArrayRef({0}),
{1}, 1);
x = batch_norm(x, params.at("project_bn_weight"), params.at("project_bn_bias"),
params.at("project_bn_mean"), params.at("project_bn_var"),
is_training, BN_MOMENTUM, BN_EPSILON, true);
return x;
}
// Main forward function combining initial conv, pre-extraction of MBConv parameters, MBConv blocks, and final layers
Tensor forward(Tensor x, std::map<std::string, Tensor> params, bool is_training) {
// Initial convolution
x = conv2d(x, params.at("conv1_weight"), Tensor(),
{2}, at::IntArrayRef({1}));
x = batch_norm(x, params.at("bn1_weight"), params.at("bn1_bias"),
params.at("bn1_mean"), params.at("bn1_var"),
is_training, BN_MOMENTUM, BN_EPSILON, true);
x.relu_();
// Pre-extract MBConv block parameters to avoid redundant map scanning in each iteration
std::vector<std::map<std::string, Tensor>> blocks_params;
blocks_params.reserve(5);
for (int i = 1; i <= 5; i++) {
blocks_params.push_back(extract_block_params(params, i));
}
// Execute MBConv blocks using globally defined configuration array
for (int i = 0; i < 5; i++) {
int stride = MBCONV_CONFIGS[i][0];
int expand_ratio = MBCONV_CONFIGS[i][1];
x = mbconv_block(x, blocks_params[i], stride, expand_ratio, is_training);
}
// Final layers
x = conv2d(x, params.at("conv_final_weight"), Tensor(),
{1}, at::IntArrayRef({0}));
x = batch_norm(x, params.at("bn_final_weight"), params.at("bn_final_bias"),
params.at("bn_final_mean"), params.at("bn_final_var"),
is_training, BN_MOMENTUM, BN_EPSILON, true);
x.relu_();
x = adaptive_avg_pool2d(x, {1, 1});
x = x.flatten(1);
x = linear(x, params.at("fc_weight"), params.at("fc_bias"));
return x;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "EfficientNetB2 forward (fused and optimized)");
}
Operation / Metric | Value | Unit |
---|---|---|
aten::conv2d | ||
CPU Time | 2914515.02 | μs |
Device Time | 1051922.42 | μs |
Self CPU Time | 176855.41 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::convolution | ||
CPU Time | 2737659.61 | μs |
Device Time | 1051922.42 | μs |
Self CPU Time | 222553.35 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_convolution | ||
CPU Time | 2515106.26 | μs |
Device Time | 1051922.42 | μs |
Self CPU Time | 267663.46 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::cudnn_convolution | ||
CPU Time | 1976027.44 | μs |
Device Time | 882668.18 | μs |
Self CPU Time | 1258740.83 | μs |
Self Device Time | 882668.18 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::batch_norm | ||
CPU Time | 2565959.61 | μs |
Device Time | 795222.39 | μs |
Self CPU Time | 130476.50 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_batch_norm_impl_index | ||
CPU Time | 2435483.11 | μs |
Device Time | 795222.39 | μs |
Self CPU Time | 104129.49 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |