import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(
x: torch.Tensor, params: nn.ParameterDict, is_training: bool
) -> torch.Tensor:
"""
Implements the ResNet101 module.
Args:
x (torch.Tensor): Input tensor, shape (batch_size, in_channels, height, width)
params (nn.ParameterDict): Dictionary of parameters
is_training (bool): Whether to use training mode
Returns:
torch.Tensor: Output tensor, shape (batch_size, num_classes)
"""
# Initial layers
x = F.conv2d(x, params["conv1_w"].to(x.device), bias=None, stride=2, padding=3)
x = F.batch_norm(
x,
params["bn1_m"].to(x.device),
params["bn1_v"].to(x.device),
params["bn1_w"].to(x.device),
params["bn1_b"].to(x.device),
training=is_training,
)
x = F.relu(x)
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
def bottleneck_fn(
x,
conv1_w,
conv2_w,
conv3_w,
bn1_w,
bn1_b,
bn1_m,
bn1_v,
bn2_w,
bn2_b,
bn2_m,
bn2_v,
bn3_w,
bn3_b,
bn3_m,
bn3_v,
downsample_conv_w=None,
downsample_bn_w=None,
downsample_bn_b=None,
downsample_bn_m=None,
downsample_bn_v=None,
stride=1,
is_training=True,
):
identity = x
out = F.conv2d(x, conv1_w.to(x.device), bias=None)
out = F.batch_norm(
out,
bn1_m.to(x.device),
bn1_v.to(x.device),
bn1_w.to(x.device),
bn1_b.to(x.device),
training=is_training,
)
out = F.relu(out)
out = F.conv2d(out, conv2_w.to(x.device), bias=None, stride=stride, padding=1)
out = F.batch_norm(
out,
bn2_m.to(x.device),
bn2_v.to(x.device),
bn2_w.to(x.device),
bn2_b.to(x.device),
training=is_training,
)
out = F.relu(out)
out = F.conv2d(out, conv3_w.to(x.device), bias=None)
out = F.batch_norm(
out,
bn3_m.to(x.device),
bn3_v.to(x.device),
bn3_w.to(x.device),
bn3_b.to(x.device),
training=is_training,
)
if downsample_conv_w is not None:
identity = F.conv2d(
x, downsample_conv_w.to(x.device), bias=None, stride=stride
)
identity = F.batch_norm(
identity,
downsample_bn_m.to(x.device),
downsample_bn_v.to(x.device),
downsample_bn_w.to(x.device),
downsample_bn_b.to(x.device),
training=is_training,
)
out += identity
out = F.relu(out)
return out
# Layer 1-4
for layer_idx in range(1, 5):
blocks = params[f"layer{layer_idx}_blocks"]
for block_idx in range(len(blocks)):
block_params = blocks[block_idx]
downsample_params = None
if "downsample_conv_w" in block_params:
downsample_params = [
block_params["downsample_conv_w"],
block_params["downsample_bn_w"],
block_params["downsample_bn_b"],
block_params["downsample_bn_m"],
block_params["downsample_bn_v"],
]
x = bottleneck_fn(
x,
block_params["conv1_w"],
block_params["conv2_w"],
block_params["conv3_w"],
block_params["bn1_w"],
block_params["bn1_b"],
block_params["bn1_m"],
block_params["bn1_v"],
block_params["bn2_w"],
block_params["bn2_b"],
block_params["bn2_m"],
block_params["bn2_v"],
block_params["bn3_w"],
block_params["bn3_b"],
block_params["bn3_m"],
block_params["bn3_v"],
*(downsample_params if downsample_params else [None] * 5),
stride=2 if block_idx == 0 and layer_idx > 1 else 1,
is_training=is_training,
)
x = F.adaptive_avg_pool2d(x, (1, 1))
x = torch.flatten(x, 1)
x = F.linear(x, params["fc_w"].to(x.device), params["fc_b"].to(x.device))
return x
class Model(nn.Module):
def __init__(self, layers, num_classes=1000):
super(Model, self).__init__()
self.params = nn.ParameterDict()
in_channels = 64
expansion = 4
# Initial layers
conv1 = nn.Conv2d(
3, in_channels, kernel_size=7, stride=2, padding=3, bias=False
)
bn1 = nn.BatchNorm2d(in_channels)
self.params["conv1_w"] = nn.Parameter(conv1.weight.data.clone())
self.params["bn1_w"] = nn.Parameter(bn1.weight.data.clone())
self.params["bn1_b"] = nn.Parameter(bn1.bias.data.clone())
self.params["bn1_m"] = nn.Parameter(bn1.running_mean.data.clone())
self.params["bn1_v"] = nn.Parameter(bn1.running_var.data.clone())
# Layers 1-4
channels = [64, 128, 256, 512]
for layer_idx, (out_channels, num_blocks) in enumerate(
zip(channels, layers), 1
):
layer_blocks = []
for block_idx in range(num_blocks):
block_in_channels = (
in_channels if block_idx == 0 else out_channels * expansion
)
# Create block parameters
block_params = {}
# First block may have downsample
if block_idx == 0 and (
layer_idx > 1 or block_in_channels != out_channels * expansion
):
downsample_conv = nn.Conv2d(
block_in_channels,
out_channels * expansion,
kernel_size=1,
stride=2 if layer_idx > 1 else 1,
bias=False,
)
downsample_bn = nn.BatchNorm2d(out_channels * expansion)
block_params["downsample_conv_w"] = nn.Parameter(
downsample_conv.weight.data.clone()
)
block_params["downsample_bn_w"] = nn.Parameter(
downsample_bn.weight.data.clone()
)
block_params["downsample_bn_b"] = nn.Parameter(
downsample_bn.bias.data.clone()
)
block_params["downsample_bn_m"] = nn.Parameter(
downsample_bn.running_mean.data.clone()
)
block_params["downsample_bn_v"] = nn.Parameter(
downsample_bn.running_var.data.clone()
)
conv1 = nn.Conv2d(
block_in_channels, out_channels, kernel_size=1, bias=False
)
bn1 = nn.BatchNorm2d(out_channels)
conv2 = nn.Conv2d(
out_channels,
out_channels,
kernel_size=3,
stride=2 if block_idx == 0 and layer_idx > 1 else 1,
padding=1,
bias=False,
)
bn2 = nn.BatchNorm2d(out_channels)
conv3 = nn.Conv2d(
out_channels, out_channels * expansion, kernel_size=1, bias=False
)
bn3 = nn.BatchNorm2d(out_channels * expansion)
block_params["conv1_w"] = nn.Parameter(conv1.weight.data.clone())
block_params["bn1_w"] = nn.Parameter(bn1.weight.data.clone())
block_params["bn1_b"] = nn.Parameter(bn1.bias.data.clone())
block_params["bn1_m"] = nn.Parameter(bn1.running_mean.data.clone())
block_params["bn1_v"] = nn.Parameter(bn1.running_var.data.clone())
block_params["conv2_w"] = nn.Parameter(conv2.weight.data.clone())
block_params["bn2_w"] = nn.Parameter(bn2.weight.data.clone())
block_params["bn2_b"] = nn.Parameter(bn2.bias.data.clone())
block_params["bn2_m"] = nn.Parameter(bn2.running_mean.data.clone())
block_params["bn2_v"] = nn.Parameter(bn2.running_var.data.clone())
block_params["conv3_w"] = nn.Parameter(conv3.weight.data.clone())
block_params["bn3_w"] = nn.Parameter(bn3.weight.data.clone())
block_params["bn3_b"] = nn.Parameter(bn3.bias.data.clone())
block_params["bn3_m"] = nn.Parameter(bn3.running_mean.data.clone())
block_params["bn3_v"] = nn.Parameter(bn3.running_var.data.clone())
layer_blocks.append(block_params)
self.params[f"layer{layer_idx}_blocks"] = layer_blocks
in_channels = out_channels * expansion
# Final FC layer
fc = nn.Linear(512 * expansion, num_classes)
self.params["fc_w"] = nn.Parameter(fc.weight.data.clone())
self.params["fc_b"] = nn.Parameter(fc.bias.data.clone())
def forward(self, x, fn=module_fn):
return fn(x, self.params, self.training)
# Test configurations
batch_size = 10
height = 224
width = 224
layers = [3, 4, 23, 3]
num_classes = 1000
def get_inputs():
return [torch.randn(batch_size, 3, height, width)]
def get_init_inputs():
return [layers, num_classes]
import torch
import torch.nn as nn
import torch.nn.functional as F
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
"""
:param in_channels: Number of input channels
:param out_channels: Number of output channels
:param stride: Stride for the first convolutional layer
:param downsample: Downsample layer for the shortcut connection
"""
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
"""
:param x: Input tensor, shape (batch_size, in_channels, height, width)
:return: Output tensor, shape (batch_size, out_channels * expansion, height, width)
"""
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Model(nn.Module):
def __init__(self, layers, num_classes=1000):
"""
:param block: Type of block to use (BasicBlock or Bottleneck)
:param layers: List of integers specifying the number of blocks in each layer
:param num_classes: Number of output classes
"""
super(Model, self).__init__()
self.in_channels = 64
self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(self.in_channels)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
block = Bottleneck
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
def _make_layer(self, block, out_channels, blocks, stride=1):
downsample = None
if stride != 1 or self.in_channels != out_channels * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.in_channels, out_channels * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels * block.expansion),
)
layers = []
layers.append(block(self.in_channels, out_channels, stride, downsample))
self.in_channels = out_channels * block.expansion
for _ in range(1, blocks):
layers.append(block(self.in_channels, out_channels))
return nn.Sequential(*layers)
def forward(self, x):
"""
:param x: Input tensor, shape (batch_size, 3, height, width)
:return: Output tensor, shape (batch_size, num_classes)
"""
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
# Test code
batch_size = 10
height = 224
width = 224
layers = [3, 4, 23, 3]
num_classes = 1000
def get_inputs():
return [torch.randn(batch_size, 3, height, width)]
def get_init_inputs():
return [layers, num_classes]
#include <torch/extension.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace py = pybind11;
// Separate downsample path computation to reduce divergent branching
inline torch::Tensor compute_downsample(
const torch::Tensor& x,
const torch::Tensor& conv_w,
const torch::Tensor& bn_w,
const torch::Tensor& bn_b,
const torch::Tensor& bn_m,
const torch::Tensor& bn_v,
int64_t stride,
bool is_training
) {
auto out = torch::conv2d(x, conv_w, /*bias=*/torch::Tensor(), stride)
.to(x.dtype(), /*non_blocking=*/true, /*copy=*/false, torch::MemoryFormat::Contiguous);
return torch::batch_norm(out, bn_w, bn_b, bn_m, bn_v, is_training, 0.1, 1e-5, true);
}
// Separate main path computation
inline torch::Tensor compute_main_path(
const torch::Tensor& x,
const torch::Tensor& conv1_w,
const torch::Tensor& conv2_w,
const torch::Tensor& conv3_w,
const torch::Tensor& bn1_w,
const torch::Tensor& bn1_b,
const torch::Tensor& bn1_m,
const torch::Tensor& bn1_v,
const torch::Tensor& bn2_w,
const torch::Tensor& bn2_b,
const torch::Tensor& bn2_m,
const torch::Tensor& bn2_v,
const torch::Tensor& bn3_w,
const torch::Tensor& bn3_b,
const torch::Tensor& bn3_m,
const torch::Tensor& bn3_v,
int64_t stride,
bool is_training
) {
auto out = torch::conv2d(x, conv1_w, /*bias=*/torch::Tensor())
.to(x.dtype(), /*non_blocking=*/true, /*copy=*/false, torch::MemoryFormat::Contiguous);
out = torch::batch_norm(out, bn1_w, bn1_b, bn1_m, bn1_v, is_training, 0.1, 1e-5, true);
out = torch::relu(out);
out = torch::conv2d(out, conv2_w, /*bias=*/torch::Tensor(), stride, /*padding=*/1)
.to(x.dtype(), /*non_blocking=*/true, /*copy=*/false, torch::MemoryFormat::Contiguous);
out = torch::batch_norm(out, bn2_w, bn2_b, bn2_m, bn2_v, is_training, 0.1, 1e-5, true);
out = torch::relu(out);
out = torch::conv2d(out, conv3_w, /*bias=*/torch::Tensor())
.to(x.dtype(), /*non_blocking=*/true, /*copy=*/false, torch::MemoryFormat::Contiguous);
return torch::batch_norm(out, bn3_w, bn3_b, bn3_m, bn3_v, is_training, 0.1, 1e-5, true);
}
torch::Tensor forward(
torch::Tensor x,
py::object params,
bool is_training
) {
auto device = x.device();
// Pre-fetch and prepare all parameters at once
struct BlockParams {
std::vector<torch::Tensor> main_path;
std::vector<torch::Tensor> downsample;
bool has_downsample;
};
std::vector<std::vector<BlockParams>> all_layer_params;
// Initial stem parameters
auto conv1_w = params.attr("get")("conv1_w").cast<torch::Tensor>().contiguous().to(device, true);
auto bn1_w = params.attr("get")("bn1_w").cast<torch::Tensor>().contiguous().to(device, true);
auto bn1_b = params.attr("get")("bn1_b").cast<torch::Tensor>().contiguous().to(device, true);
auto bn1_m = params.attr("get")("bn1_m").cast<torch::Tensor>().contiguous().to(device, true);
auto bn1_v = params.attr("get")("bn1_v").cast<torch::Tensor>().contiguous().to(device, true);
// Pre-fetch all layer parameters
for (int layer_idx = 1; layer_idx <= 4; ++layer_idx) {
std::string key = "layer" + std::to_string(layer_idx) + "_blocks";
py::list blocks = params.attr("get")(py::str(key)).cast<py::list>();
std::vector<BlockParams> layer_blocks;
for (auto block : blocks) {
py::object bp = block.cast<py::object>();
BlockParams params;
// Main path parameters
const char* main_names[] = {
"conv1_w", "conv2_w", "conv3_w",
"bn1_w", "bn1_b", "bn1_m", "bn1_v",
"bn2_w", "bn2_b", "bn2_m", "bn2_v",
"bn3_w", "bn3_b", "bn3_m", "bn3_v"
};
for (const char* name : main_names) {
params.main_path.push_back(
bp.attr("get")(py::str(name)).cast<torch::Tensor>().contiguous().to(device, true)
);
}
// Downsample parameters
params.has_downsample = py::bool_(bp.attr("__contains__")("downsample_conv_w"));
if (params.has_downsample) {
const char* ds_names[] = {
"downsample_conv_w", "downsample_bn_w",
"downsample_bn_b", "downsample_bn_m", "downsample_bn_v"
};
for (const char* name : ds_names) {
params.downsample.push_back(
bp.attr("get")(py::str(name)).cast<torch::Tensor>().contiguous().to(device, true)
);
}
}
layer_blocks.push_back(std::move(params));
}
all_layer_params.push_back(std::move(layer_blocks));
}
// Initial convolution and pooling
x = torch::conv2d(x, conv1_w, /*bias=*/torch::Tensor(), 2, 3)
.to(x.dtype(), /*non_blocking=*/true, /*copy=*/false, torch::MemoryFormat::Contiguous);
x = torch::batch_norm(x, bn1_w, bn1_b, bn1_m, bn1_v, is_training, 0.1, 1e-5, true);
x = torch::relu(x);
x = torch::max_pool2d(x, 3, 2, 1);
// Process all layers with uniform control flow
for (int layer_idx = 0; layer_idx < 4; ++layer_idx) {
for (size_t block_idx = 0; block_idx < all_layer_params[layer_idx].size(); ++block_idx) {
const auto& block_params = all_layer_params[layer_idx][block_idx];
int64_t stride = (block_idx == 0 && layer_idx > 0) ? 2 : 1;
// Compute main path
auto main_out = compute_main_path(
x,
block_params.main_path[0], block_params.main_path[1], block_params.main_path[2],
block_params.main_path[3], block_params.main_path[4], block_params.main_path[5], block_params.main_path[6],
block_params.main_path[7], block_params.main_path[8], block_params.main_path[9], block_params.main_path[10],
block_params.main_path[11], block_params.main_path[12], block_params.main_path[13], block_params.main_path[14],
stride, is_training
);
// Identity path (with or without downsample)
torch::Tensor identity = block_params.has_downsample ?
compute_downsample(x,
block_params.downsample[0], block_params.downsample[1],
block_params.downsample[2], block_params.downsample[3],
block_params.downsample[4], stride, is_training) :
x.to(main_out.dtype());
x = torch::relu(main_out + identity);
}
}
x = torch::adaptive_avg_pool2d(x, {1, 1}).contiguous();
x = x.view({x.size(0), -1});
auto fc_w = params.attr("get")("fc_w").cast<torch::Tensor>().contiguous().to(device, true);
auto fc_b = params.attr("get")("fc_b").cast<torch::Tensor>().contiguous().to(device, true);
return torch::linear(x, fc_w, fc_b);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "ResNet101 forward with uniform control flow");
}
Operation / Metric | Value | Unit |
---|---|---|
aten::to | ||
CPU Time | 7262960.79 | μs |
Device Time | 3550442.01 | μs |
Self CPU Time | 125261.71 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_to_copy | ||
CPU Time | 7137699.08 | μs |
Device Time | 3550442.01 | μs |
Self CPU Time | 389398.97 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::copy_ | ||
CPU Time | 7081593.85 | μs |
Device Time | 3550442.01 | μs |
Self CPU Time | 1411973.55 | μs |
Self Device Time | 3550442.01 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaMemcpyAsync | ||
CPU Time | 5669510.54 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 5669510.54 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
Memcpy HtoD (Pageable -> Device) | ||
CPU Time | 0.00 | μs |
Device Time | 3550443.81 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 3550443.81 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::conv2d | ||
CPU Time | 1136735.84 | μs |
Device Time | 780718.13 | μs |
Self CPU Time | 43253.81 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |