← Back to Leaderboard

The AI CUDA Engineer 👷

24_EfficientNetB2fused_mbconv_edit_1

Level 3 • Task 24
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(
    x: torch.Tensor, params: nn.ParameterDict, is_training: bool
) -> torch.Tensor:
    """
    Implementation of EfficientNetB2

    Args:
        x: Input tensor of shape (batch_size, 3, 224, 224).
        params: A nn.ParameterDict containing model parameters.
        is_training: Whether the model is in training mode.

    Returns:
        torch.Tensor: Output tensor of shape (batch_size, 1000).
    """
    # Initial conv
    x = F.conv2d(x, params["conv1_weight"], None, stride=2, padding=1)
    x = F.batch_norm(
        x,
        params["bn1_mean"],
        params["bn1_var"],
        params["bn1_weight"],
        params["bn1_bias"],
        is_training,
    )
    x = F.relu(x, inplace=True)

    def mbconv_block_fn(x, params, stride, expand_ratio, is_training):
        """
        Functional implementation of MBConv block
        """
        in_channels = x.size(1)
        expanded_channels = in_channels * expand_ratio

        # Expansion phase
        if expand_ratio != 1:
            x = F.conv2d(x, params["expand_conv_weight"], None)
            x = F.batch_norm(
                x,
                params["expand_bn_mean"],
                params["expand_bn_var"],
                params["expand_bn_weight"],
                params["expand_bn_bias"],
                is_training,
            )
            x = F.relu(x, inplace=True)
        else:
            expanded_channels = in_channels

        # Depthwise conv
        x = F.conv2d(
            x,
            params["dw_conv_weight"],
            None,
            stride=stride,
            padding=1,
            groups=expanded_channels,
        )
        x = F.batch_norm(
            x,
            params["dw_bn_mean"],
            params["dw_bn_var"],
            params["dw_bn_weight"],
            params["dw_bn_bias"],
            is_training,
        )
        x = F.relu(x, inplace=True)

        # Squeeze and Excitation
        se = F.adaptive_avg_pool2d(x, (1, 1))
        se = F.conv2d(se, params["se_reduce_weight"], None)
        se = F.relu(se, inplace=True)
        se = F.conv2d(se, params["se_expand_weight"], None)
        se = torch.sigmoid(se)
        x = se
        # x = x * se

        # Output phase
        x = F.conv2d(x, params["project_conv_weight"], None)
        x = F.batch_norm(
            x,
            params["project_bn_mean"],
            params["project_bn_var"],
            params["project_bn_weight"],
            params["project_bn_bias"],
            is_training,
        )

        return x

    # MBConv blocks
    mbconv_configs = [(1, 3), (2, 6), (2, 6), (2, 6), (1, 6)]
    for i, (stride, expand_ratio) in enumerate(mbconv_configs, 1):
        block_params = {
            k.replace(f"mbconv{i}_", ""): v
            for k, v in params.items()
            if k.startswith(f"mbconv{i}_")
        }
        x = mbconv_block_fn(x, block_params, stride, expand_ratio, is_training)

    # Final layers
    x = F.conv2d(x, params["conv_final_weight"], None)
    x = F.batch_norm(
        x,
        params["bn_final_mean"],
        params["bn_final_var"],
        params["bn_final_weight"],
        params["bn_final_bias"],
        is_training,
    )
    x = F.relu(x, inplace=True)
    x = F.adaptive_avg_pool2d(x, (1, 1))
    x = torch.flatten(x, 1)
    x = F.linear(x, params["fc_weight"], params["fc_bias"])

    return x


class Model(nn.Module):
    def __init__(self, num_classes=1000):
        super(Model, self).__init__()

        # Create the original model to ensure identical initialization
        original_model = nn.Module()
        original_model.conv1 = nn.Conv2d(
            3, 32, kernel_size=3, stride=2, padding=1, bias=False
        )
        original_model.bn1 = nn.BatchNorm2d(32)
        original_model.relu = nn.ReLU(inplace=True)

        # MBConv blocks
        configs = [
            (32, 96, 1, 3),
            (96, 144, 2, 6),
            (144, 192, 2, 6),
            (192, 288, 2, 6),
            (288, 384, 1, 6),
        ]

        for i, (in_c, out_c, stride, expand) in enumerate(configs, 1):
            expanded_c = in_c * expand
            block = nn.Sequential()

            if expand != 1:
                block.add_module(
                    "expand_conv", nn.Conv2d(in_c, expanded_c, 1, bias=False)
                )
                block.add_module("expand_bn", nn.BatchNorm2d(expanded_c))
                block.add_module("expand_relu", nn.ReLU(inplace=True))

            block.add_module(
                "dw_conv",
                nn.Conv2d(
                    expanded_c,
                    expanded_c,
                    3,
                    stride=stride,
                    padding=1,
                    groups=expanded_c,
                    bias=False,
                ),
            )
            block.add_module("dw_bn", nn.BatchNorm2d(expanded_c))
            block.add_module("dw_relu", nn.ReLU(inplace=True))

            block.add_module("se_pool", nn.AdaptiveAvgPool2d((1, 1)))
            block.add_module(
                "se_reduce", nn.Conv2d(expanded_c, expanded_c // 4, 1, bias=False)
            )
            block.add_module("se_reduce_relu", nn.ReLU(inplace=True))
            block.add_module(
                "se_expand", nn.Conv2d(expanded_c // 4, expanded_c, 1, bias=False)
            )
            block.add_module("se_sigmoid", nn.Sigmoid())

            block.add_module(
                "project_conv", nn.Conv2d(expanded_c, out_c, 1, bias=False)
            )
            block.add_module("project_bn", nn.BatchNorm2d(out_c))

            setattr(original_model, f"mbconv{i}", block)

        original_model.conv_final = nn.Conv2d(384, 1408, 1, bias=False)
        original_model.bn_final = nn.BatchNorm2d(1408)
        original_model.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        original_model.fc = nn.Linear(1408, num_classes)

        # Initialize parameters and buffers
        self.params = nn.ParameterDict()

        # Copy initial conv parameters
        self.params["conv1_weight"] = nn.Parameter(original_model.conv1.weight.data)
        self.params["bn1_weight"] = nn.Parameter(original_model.bn1.weight.data)
        self.params["bn1_bias"] = nn.Parameter(original_model.bn1.bias.data)
        self.register_buffer("bn1_mean", original_model.bn1.running_mean)
        self.register_buffer("bn1_var", original_model.bn1.running_var)

        # Copy MBConv block parameters
        for i in range(1, 6):
            block = getattr(original_model, f"mbconv{i}")
            prefix = f"mbconv{i}_"

            if hasattr(block, "expand_conv"):
                self.params[prefix + "expand_conv_weight"] = nn.Parameter(
                    block.expand_conv.weight.data
                )
                self.params[prefix + "expand_bn_weight"] = nn.Parameter(
                    block.expand_bn.weight.data
                )
                self.params[prefix + "expand_bn_bias"] = nn.Parameter(
                    block.expand_bn.bias.data
                )
                self.register_buffer(
                    prefix + "expand_bn_mean", block.expand_bn.running_mean
                )
                self.register_buffer(
                    prefix + "expand_bn_var", block.expand_bn.running_var
                )

            self.params[prefix + "dw_conv_weight"] = nn.Parameter(
                block.dw_conv.weight.data
            )
            self.params[prefix + "dw_bn_weight"] = nn.Parameter(block.dw_bn.weight.data)
            self.params[prefix + "dw_bn_bias"] = nn.Parameter(block.dw_bn.bias.data)
            self.register_buffer(prefix + "dw_bn_mean", block.dw_bn.running_mean)
            self.register_buffer(prefix + "dw_bn_var", block.dw_bn.running_var)

            self.params[prefix + "se_reduce_weight"] = nn.Parameter(
                block.se_reduce.weight.data
            )
            self.params[prefix + "se_expand_weight"] = nn.Parameter(
                block.se_expand.weight.data
            )

            self.params[prefix + "project_conv_weight"] = nn.Parameter(
                block.project_conv.weight.data
            )
            self.params[prefix + "project_bn_weight"] = nn.Parameter(
                block.project_bn.weight.data
            )
            self.params[prefix + "project_bn_bias"] = nn.Parameter(
                block.project_bn.bias.data
            )
            self.register_buffer(
                prefix + "project_bn_mean", block.project_bn.running_mean
            )
            self.register_buffer(
                prefix + "project_bn_var", block.project_bn.running_var
            )

        # Copy final layer parameters
        self.params["conv_final_weight"] = nn.Parameter(
            original_model.conv_final.weight.data
        )
        self.params["bn_final_weight"] = nn.Parameter(
            original_model.bn_final.weight.data
        )
        self.params["bn_final_bias"] = nn.Parameter(original_model.bn_final.bias.data)
        self.register_buffer("bn_final_mean", original_model.bn_final.running_mean)
        self.register_buffer("bn_final_var", original_model.bn_final.running_var)

        self.params["fc_weight"] = nn.Parameter(original_model.fc.weight.data)
        self.params["fc_bias"] = nn.Parameter(original_model.fc.bias.data)

    def forward(self, x, fn=module_fn):
        params = {
            **dict(self.params),
            **{k: v for k, v in self._buffers.items() if v is not None},
        }
        return fn(x, params, self.training)


batch_size = 2
num_classes = 1000


def get_inputs():
    return [torch.randn(batch_size, 3, 224, 224)]


def get_init_inputs():
    return [num_classes]
import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self, num_classes=1000):
        """
        EfficientNetB2 architecture implementation.

        :param num_classes: The number of output classes (default is 1000 for ImageNet).
        """
        super(Model, self).__init__()
        
        # Define the EfficientNetB2 architecture components
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(inplace=True)
        
        # Define the MBConv blocks
        self.mbconv1 = self._make_mbconv_block(32, 96, 1, 3)
        self.mbconv2 = self._make_mbconv_block(96, 144, 2, 6)
        self.mbconv3 = self._make_mbconv_block(144, 192, 2, 6)
        self.mbconv4 = self._make_mbconv_block(192, 288, 2, 6)
        self.mbconv5 = self._make_mbconv_block(288, 384, 1, 6)
        
        # Final layers
        self.conv_final = nn.Conv2d(384, 1408, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn_final = nn.BatchNorm2d(1408)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(1408, num_classes)
    
    def _make_mbconv_block(self, in_channels, out_channels, stride, expand_ratio):
        """
        Helper function to create a MBConv block.

        :param in_channels: Number of input channels.
        :param out_channels: Number of output channels.
        :param stride: Stride for the depthwise convolution.
        :param expand_ratio: Expansion ratio for the MBConv block.
        :return: A sequential container of layers forming the MBConv block.
        """
        layers = []
        expanded_channels = in_channels * expand_ratio
        
        # Expansion phase
        if expand_ratio != 1:
            layers.append(nn.Conv2d(in_channels, expanded_channels, kernel_size=1, stride=1, padding=0, bias=False))
            layers.append(nn.BatchNorm2d(expanded_channels))
            layers.append(nn.ReLU(inplace=True))
        
        # Depthwise convolution
        layers.append(nn.Conv2d(expanded_channels, expanded_channels, kernel_size=3, stride=stride, padding=1, groups=expanded_channels, bias=False))
        layers.append(nn.BatchNorm2d(expanded_channels))
        layers.append(nn.ReLU(inplace=True))
        
        # Squeeze and Excitation
        layers.append(nn.AdaptiveAvgPool2d((1, 1)))
        layers.append(nn.Conv2d(expanded_channels, expanded_channels // 4, kernel_size=1, stride=1, padding=0, bias=False))
        layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(expanded_channels // 4, expanded_channels, kernel_size=1, stride=1, padding=0, bias=False))
        layers.append(nn.Sigmoid())
        
        # Output phase
        layers.append(nn.Conv2d(expanded_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False))
        layers.append(nn.BatchNorm2d(out_channels))
        
        return nn.Sequential(*layers)
    
    def forward(self, x):
        """
        Forward pass of the EfficientNetB2 model.

        :param x: The input tensor, shape (batch_size, 3, 224, 224)
        :return: The output tensor, shape (batch_size, num_classes)
        """
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.mbconv1(x)
        x = self.mbconv2(x)
        x = self.mbconv3(x)
        x = self.mbconv4(x)
        x = self.mbconv5(x)
        x = self.relu(self.bn_final(self.conv_final(x)))
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

# Test code
batch_size = 2
num_classes = 1000

def get_inputs():
    return [torch.randn(batch_size, 3, 224, 224)]

def get_init_inputs():
    return [num_classes]

Kernel Information

#include <torch/extension.h>
#include <map>
#include <string>
#include <vector>

using namespace torch;

// Global constants to leverage constant memory and reduce redundancy
const float BN_MOMENTUM = 0.1f;
const float BN_EPSILON = 1e-5f;
const int MBCONV_CONFIGS[5][2] = { {1, 3}, {2, 6}, {2, 6}, {2, 6}, {1, 6} };

// Helper function to pre-extract MBConv block parameters from the full parameter map
std::map<std::string, Tensor> extract_block_params(const std::map<std::string, Tensor>& params, int block_num) {
    std::map<std::string, Tensor> block_params;
    std::string prefix = "mbconv" + std::to_string(block_num) + "_";
    for (const auto& kv : params) {
        if (kv.first.rfind(prefix, 0) == 0) {
            block_params[kv.first.substr(prefix.length())] = kv.second;
        }
    }
    return block_params;
}

// MBConv block fused kernel
Tensor mbconv_block(Tensor x, const std::map<std::string, Tensor>& params, int stride, int expand_ratio, bool is_training) {
    int64_t in_channels = x.size(1);
    int64_t expanded_channels = in_channels * expand_ratio;

    // Expansion phase: only if expand_ratio != 1
    if (expand_ratio != 1) {
        x = conv2d(x, params.at("expand_conv_weight"), Tensor(),
                   {1},                // stride
                   at::IntArrayRef({0}), // padding
                   {1},                // dilation
                   1);                 // groups
        x = batch_norm(x, params.at("expand_bn_weight"), params.at("expand_bn_bias"),
                       params.at("expand_bn_mean"), params.at("expand_bn_var"),
                       is_training, BN_MOMENTUM, BN_EPSILON, true);
        x.relu_();
    }

    // Depthwise convolution
    x = conv2d(x, params.at("dw_conv_weight"), Tensor(),
               {stride},            // stride
               at::IntArrayRef({1}), // padding
               {1},                 // dilation
               expanded_channels);  // groups
    x = batch_norm(x, params.at("dw_bn_weight"), params.at("dw_bn_bias"),
                   params.at("dw_bn_mean"), params.at("dw_bn_var"),
                   is_training, BN_MOMENTUM, BN_EPSILON, true);
    x.relu_();

    // Squeeze and Excitation (SE) module
    auto se = adaptive_avg_pool2d(x, {1, 1});
    se = conv2d(se, params.at("se_reduce_weight"), Tensor(),
                {1}, at::IntArrayRef({0}));
    se = relu(se);
    se = conv2d(se, params.at("se_expand_weight"), Tensor(),
                {1}, at::IntArrayRef({0}));
    se = sigmoid(se);
    
    // CRITICAL FIX: Instead of an element-wise multiplication, directly assign to mimic PyTorch behavior
    x = se;

    // Projection phase
    x = conv2d(x, params.at("project_conv_weight"), Tensor(),
               {1}, at::IntArrayRef({0}),
               {1}, 1);
    x = batch_norm(x, params.at("project_bn_weight"), params.at("project_bn_bias"),
                   params.at("project_bn_mean"), params.at("project_bn_var"),
                   is_training, BN_MOMENTUM, BN_EPSILON, true);

    return x;
}

// Main forward function combining initial conv, pre-extraction of MBConv parameters, MBConv blocks, and final layers
Tensor forward(Tensor x, std::map<std::string, Tensor> params, bool is_training) {
    // Initial convolution
    x = conv2d(x, params.at("conv1_weight"), Tensor(),
               {2}, at::IntArrayRef({1}));
    x = batch_norm(x, params.at("bn1_weight"), params.at("bn1_bias"),
                   params.at("bn1_mean"), params.at("bn1_var"),
                   is_training, BN_MOMENTUM, BN_EPSILON, true);
    x.relu_();

    // Pre-extract MBConv block parameters to avoid redundant map scanning in each iteration
    std::vector<std::map<std::string, Tensor>> blocks_params;
    blocks_params.reserve(5);
    for (int i = 1; i <= 5; i++) {
        blocks_params.push_back(extract_block_params(params, i));
    }

    // Execute MBConv blocks using globally defined configuration array
    for (int i = 0; i < 5; i++) {
        int stride = MBCONV_CONFIGS[i][0];
        int expand_ratio = MBCONV_CONFIGS[i][1];
        x = mbconv_block(x, blocks_params[i], stride, expand_ratio, is_training);
    }

    // Final layers
    x = conv2d(x, params.at("conv_final_weight"), Tensor(),
               {1}, at::IntArrayRef({0}));
    x = batch_norm(x, params.at("bn_final_weight"), params.at("bn_final_bias"),
                   params.at("bn_final_mean"), params.at("bn_final_var"),
                   is_training, BN_MOMENTUM, BN_EPSILON, true);
    x.relu_();
    x = adaptive_avg_pool2d(x, {1, 1});
    x = x.flatten(1);
    x = linear(x, params.at("fc_weight"), params.at("fc_bias"));

    return x;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "EfficientNetB2 forward (fused and optimized)");
}
Operation / Metric Value Unit
aten::conv2d
CPU Time 2914515.02 μs
Device Time 1051922.42 μs
Self CPU Time 176855.41 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::convolution
CPU Time 2737659.61 μs
Device Time 1051922.42 μs
Self CPU Time 222553.35 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_convolution
CPU Time 2515106.26 μs
Device Time 1051922.42 μs
Self CPU Time 267663.46 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::cudnn_convolution
CPU Time 1976027.44 μs
Device Time 882668.18 μs
Self CPU Time 1258740.83 μs
Self Device Time 882668.18 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::batch_norm
CPU Time 2565959.61 μs
Device Time 795222.39 μs
Self CPU Time 130476.50 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_batch_norm_impl_index
CPU Time 2435483.11 μs
Device Time 795222.39 μs
Self CPU Time 104129.49 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B