← Back to Leaderboard

The AI CUDA Engineer 👷

24_EfficientNetB224_efficientnetb2_const_mem_edit_1

Level 3 • Task 24
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(
    x: torch.Tensor, params: nn.ParameterDict, is_training: bool
) -> torch.Tensor:
    """
    Implementation of EfficientNetB2

    Args:
        x: Input tensor of shape (batch_size, 3, 224, 224).
        params: A nn.ParameterDict containing model parameters.
        is_training: Whether the model is in training mode.

    Returns:
        torch.Tensor: Output tensor of shape (batch_size, 1000).
    """
    # Initial conv
    x = F.conv2d(x, params["conv1_weight"], None, stride=2, padding=1)
    x = F.batch_norm(
        x,
        params["bn1_mean"],
        params["bn1_var"],
        params["bn1_weight"],
        params["bn1_bias"],
        is_training,
    )
    x = F.relu(x, inplace=True)

    def mbconv_block_fn(x, params, stride, expand_ratio, is_training):
        """
        Functional implementation of MBConv block
        """
        in_channels = x.size(1)
        expanded_channels = in_channels * expand_ratio

        # Expansion phase
        if expand_ratio != 1:
            x = F.conv2d(x, params["expand_conv_weight"], None)
            x = F.batch_norm(
                x,
                params["expand_bn_mean"],
                params["expand_bn_var"],
                params["expand_bn_weight"],
                params["expand_bn_bias"],
                is_training,
            )
            x = F.relu(x, inplace=True)
        else:
            expanded_channels = in_channels

        # Depthwise conv
        x = F.conv2d(
            x,
            params["dw_conv_weight"],
            None,
            stride=stride,
            padding=1,
            groups=expanded_channels,
        )
        x = F.batch_norm(
            x,
            params["dw_bn_mean"],
            params["dw_bn_var"],
            params["dw_bn_weight"],
            params["dw_bn_bias"],
            is_training,
        )
        x = F.relu(x, inplace=True)

        # Squeeze and Excitation
        se = F.adaptive_avg_pool2d(x, (1, 1))
        se = F.conv2d(se, params["se_reduce_weight"], None)
        se = F.relu(se, inplace=True)
        se = F.conv2d(se, params["se_expand_weight"], None)
        se = torch.sigmoid(se)
        x = se
        # x = x * se

        # Output phase
        x = F.conv2d(x, params["project_conv_weight"], None)
        x = F.batch_norm(
            x,
            params["project_bn_mean"],
            params["project_bn_var"],
            params["project_bn_weight"],
            params["project_bn_bias"],
            is_training,
        )

        return x

    # MBConv blocks
    mbconv_configs = [(1, 3), (2, 6), (2, 6), (2, 6), (1, 6)]
    for i, (stride, expand_ratio) in enumerate(mbconv_configs, 1):
        block_params = {
            k.replace(f"mbconv{i}_", ""): v
            for k, v in params.items()
            if k.startswith(f"mbconv{i}_")
        }
        x = mbconv_block_fn(x, block_params, stride, expand_ratio, is_training)

    # Final layers
    x = F.conv2d(x, params["conv_final_weight"], None)
    x = F.batch_norm(
        x,
        params["bn_final_mean"],
        params["bn_final_var"],
        params["bn_final_weight"],
        params["bn_final_bias"],
        is_training,
    )
    x = F.relu(x, inplace=True)
    x = F.adaptive_avg_pool2d(x, (1, 1))
    x = torch.flatten(x, 1)
    x = F.linear(x, params["fc_weight"], params["fc_bias"])

    return x


class Model(nn.Module):
    def __init__(self, num_classes=1000):
        super(Model, self).__init__()

        # Create the original model to ensure identical initialization
        original_model = nn.Module()
        original_model.conv1 = nn.Conv2d(
            3, 32, kernel_size=3, stride=2, padding=1, bias=False
        )
        original_model.bn1 = nn.BatchNorm2d(32)
        original_model.relu = nn.ReLU(inplace=True)

        # MBConv blocks
        configs = [
            (32, 96, 1, 3),
            (96, 144, 2, 6),
            (144, 192, 2, 6),
            (192, 288, 2, 6),
            (288, 384, 1, 6),
        ]

        for i, (in_c, out_c, stride, expand) in enumerate(configs, 1):
            expanded_c = in_c * expand
            block = nn.Sequential()

            if expand != 1:
                block.add_module(
                    "expand_conv", nn.Conv2d(in_c, expanded_c, 1, bias=False)
                )
                block.add_module("expand_bn", nn.BatchNorm2d(expanded_c))
                block.add_module("expand_relu", nn.ReLU(inplace=True))

            block.add_module(
                "dw_conv",
                nn.Conv2d(
                    expanded_c,
                    expanded_c,
                    3,
                    stride=stride,
                    padding=1,
                    groups=expanded_c,
                    bias=False,
                ),
            )
            block.add_module("dw_bn", nn.BatchNorm2d(expanded_c))
            block.add_module("dw_relu", nn.ReLU(inplace=True))

            block.add_module("se_pool", nn.AdaptiveAvgPool2d((1, 1)))
            block.add_module(
                "se_reduce", nn.Conv2d(expanded_c, expanded_c // 4, 1, bias=False)
            )
            block.add_module("se_reduce_relu", nn.ReLU(inplace=True))
            block.add_module(
                "se_expand", nn.Conv2d(expanded_c // 4, expanded_c, 1, bias=False)
            )
            block.add_module("se_sigmoid", nn.Sigmoid())

            block.add_module(
                "project_conv", nn.Conv2d(expanded_c, out_c, 1, bias=False)
            )
            block.add_module("project_bn", nn.BatchNorm2d(out_c))

            setattr(original_model, f"mbconv{i}", block)

        original_model.conv_final = nn.Conv2d(384, 1408, 1, bias=False)
        original_model.bn_final = nn.BatchNorm2d(1408)
        original_model.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        original_model.fc = nn.Linear(1408, num_classes)

        # Initialize parameters and buffers
        self.params = nn.ParameterDict()

        # Copy initial conv parameters
        self.params["conv1_weight"] = nn.Parameter(original_model.conv1.weight.data)
        self.params["bn1_weight"] = nn.Parameter(original_model.bn1.weight.data)
        self.params["bn1_bias"] = nn.Parameter(original_model.bn1.bias.data)
        self.register_buffer("bn1_mean", original_model.bn1.running_mean)
        self.register_buffer("bn1_var", original_model.bn1.running_var)

        # Copy MBConv block parameters
        for i in range(1, 6):
            block = getattr(original_model, f"mbconv{i}")
            prefix = f"mbconv{i}_"

            if hasattr(block, "expand_conv"):
                self.params[prefix + "expand_conv_weight"] = nn.Parameter(
                    block.expand_conv.weight.data
                )
                self.params[prefix + "expand_bn_weight"] = nn.Parameter(
                    block.expand_bn.weight.data
                )
                self.params[prefix + "expand_bn_bias"] = nn.Parameter(
                    block.expand_bn.bias.data
                )
                self.register_buffer(
                    prefix + "expand_bn_mean", block.expand_bn.running_mean
                )
                self.register_buffer(
                    prefix + "expand_bn_var", block.expand_bn.running_var
                )

            self.params[prefix + "dw_conv_weight"] = nn.Parameter(
                block.dw_conv.weight.data
            )
            self.params[prefix + "dw_bn_weight"] = nn.Parameter(block.dw_bn.weight.data)
            self.params[prefix + "dw_bn_bias"] = nn.Parameter(block.dw_bn.bias.data)
            self.register_buffer(prefix + "dw_bn_mean", block.dw_bn.running_mean)
            self.register_buffer(prefix + "dw_bn_var", block.dw_bn.running_var)

            self.params[prefix + "se_reduce_weight"] = nn.Parameter(
                block.se_reduce.weight.data
            )
            self.params[prefix + "se_expand_weight"] = nn.Parameter(
                block.se_expand.weight.data
            )

            self.params[prefix + "project_conv_weight"] = nn.Parameter(
                block.project_conv.weight.data
            )
            self.params[prefix + "project_bn_weight"] = nn.Parameter(
                block.project_bn.weight.data
            )
            self.params[prefix + "project_bn_bias"] = nn.Parameter(
                block.project_bn.bias.data
            )
            self.register_buffer(
                prefix + "project_bn_mean", block.project_bn.running_mean
            )
            self.register_buffer(
                prefix + "project_bn_var", block.project_bn.running_var
            )

        # Copy final layer parameters
        self.params["conv_final_weight"] = nn.Parameter(
            original_model.conv_final.weight.data
        )
        self.params["bn_final_weight"] = nn.Parameter(
            original_model.bn_final.weight.data
        )
        self.params["bn_final_bias"] = nn.Parameter(original_model.bn_final.bias.data)
        self.register_buffer("bn_final_mean", original_model.bn_final.running_mean)
        self.register_buffer("bn_final_var", original_model.bn_final.running_var)

        self.params["fc_weight"] = nn.Parameter(original_model.fc.weight.data)
        self.params["fc_bias"] = nn.Parameter(original_model.fc.bias.data)

    def forward(self, x, fn=module_fn):
        params = {
            **dict(self.params),
            **{k: v for k, v in self._buffers.items() if v is not None},
        }
        return fn(x, params, self.training)


batch_size = 2
num_classes = 1000


def get_inputs():
    return [torch.randn(batch_size, 3, 224, 224)]


def get_init_inputs():
    return [num_classes]
import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self, num_classes=1000):
        """
        EfficientNetB2 architecture implementation.

        :param num_classes: The number of output classes (default is 1000 for ImageNet).
        """
        super(Model, self).__init__()
        
        # Define the EfficientNetB2 architecture components
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(inplace=True)
        
        # Define the MBConv blocks
        self.mbconv1 = self._make_mbconv_block(32, 96, 1, 3)
        self.mbconv2 = self._make_mbconv_block(96, 144, 2, 6)
        self.mbconv3 = self._make_mbconv_block(144, 192, 2, 6)
        self.mbconv4 = self._make_mbconv_block(192, 288, 2, 6)
        self.mbconv5 = self._make_mbconv_block(288, 384, 1, 6)
        
        # Final layers
        self.conv_final = nn.Conv2d(384, 1408, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn_final = nn.BatchNorm2d(1408)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(1408, num_classes)
    
    def _make_mbconv_block(self, in_channels, out_channels, stride, expand_ratio):
        """
        Helper function to create a MBConv block.

        :param in_channels: Number of input channels.
        :param out_channels: Number of output channels.
        :param stride: Stride for the depthwise convolution.
        :param expand_ratio: Expansion ratio for the MBConv block.
        :return: A sequential container of layers forming the MBConv block.
        """
        layers = []
        expanded_channels = in_channels * expand_ratio
        
        # Expansion phase
        if expand_ratio != 1:
            layers.append(nn.Conv2d(in_channels, expanded_channels, kernel_size=1, stride=1, padding=0, bias=False))
            layers.append(nn.BatchNorm2d(expanded_channels))
            layers.append(nn.ReLU(inplace=True))
        
        # Depthwise convolution
        layers.append(nn.Conv2d(expanded_channels, expanded_channels, kernel_size=3, stride=stride, padding=1, groups=expanded_channels, bias=False))
        layers.append(nn.BatchNorm2d(expanded_channels))
        layers.append(nn.ReLU(inplace=True))
        
        # Squeeze and Excitation
        layers.append(nn.AdaptiveAvgPool2d((1, 1)))
        layers.append(nn.Conv2d(expanded_channels, expanded_channels // 4, kernel_size=1, stride=1, padding=0, bias=False))
        layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(expanded_channels // 4, expanded_channels, kernel_size=1, stride=1, padding=0, bias=False))
        layers.append(nn.Sigmoid())
        
        # Output phase
        layers.append(nn.Conv2d(expanded_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False))
        layers.append(nn.BatchNorm2d(out_channels))
        
        return nn.Sequential(*layers)
    
    def forward(self, x):
        """
        Forward pass of the EfficientNetB2 model.

        :param x: The input tensor, shape (batch_size, 3, 224, 224)
        :return: The output tensor, shape (batch_size, num_classes)
        """
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.mbconv1(x)
        x = self.mbconv2(x)
        x = self.mbconv3(x)
        x = self.mbconv4(x)
        x = self.mbconv5(x)
        x = self.relu(self.bn_final(self.conv_final(x)))
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

# Test code
batch_size = 2
num_classes = 1000

def get_inputs():
    return [torch.randn(batch_size, 3, 224, 224)]

def get_init_inputs():
    return [num_classes]

Kernel Information

#include <torch/extension.h>
#include <map>
#include <string>
#include <vector>

using namespace torch;

// Global constants for frequently accessed parameters
const float BN_MOMENTUM = 0.1f;
const float BN_EPSILON = 1e-5f;
const int MBCONV_CONFIGS[5][2] = {{1,3}, {2,6}, {2,6}, {2,6}, {1,6}};

Tensor mbconv_block(Tensor x, std::map<std::string, Tensor>& params, int stride, int expand_ratio, bool is_training) {
    int64_t in_channels = x.size(1);
    int64_t expanded_channels = in_channels * expand_ratio;

    // Expansion phase
    if (expand_ratio != 1) {
        auto expand_conv_weight = params["expand_conv_weight"];
        x = conv2d(x, expand_conv_weight, Tensor(), 
                  {1},  // stride
                  at::IntArrayRef({0}),  // padding
                  {1},  // dilation
                  1);  // groups
        x = batch_norm(
            x, params["expand_bn_weight"], params["expand_bn_bias"],
            params["expand_bn_mean"], params["expand_bn_var"],
            is_training, BN_MOMENTUM, BN_EPSILON, true
        );
        x = relu(x);
    }

    // Depthwise conv
    auto dw_conv_weight = params["dw_conv_weight"];
    x = conv2d(x, dw_conv_weight, Tensor(), 
              {stride},  // stride
              at::IntArrayRef({1}),  // padding
              {1},  // dilation
              expanded_channels);  // groups
    x = batch_norm(
        x, params["dw_bn_weight"], params["dw_bn_bias"],
        params["dw_bn_mean"], params["dw_bn_var"],
        is_training, BN_MOMENTUM, BN_EPSILON, true
    );
    x = relu(x);

    // Squeeze and Excitation
    auto se = adaptive_avg_pool2d(x, {1, 1});
    se = conv2d(se, params["se_reduce_weight"], Tensor(),
               {1},  // stride
               at::IntArrayRef({0}));  // padding
    se = relu(se);
    se = conv2d(se, params["se_expand_weight"], Tensor(),
               {1},  // stride
               at::IntArrayRef({0}));  // padding
    se = sigmoid(se);
    x = se;

    // Projection phase
    auto project_conv_weight = params["project_conv_weight"];
    x = conv2d(x, project_conv_weight, Tensor(),
              {1},  // stride
              at::IntArrayRef({0}),  // padding
              {1},  // dilation
              1);  // groups
    x = batch_norm(
        x, params["project_bn_weight"], params["project_bn_bias"],
        params["project_bn_mean"], params["project_bn_var"],
        is_training, BN_MOMENTUM, BN_EPSILON, true
    );

    return x;
}

Tensor forward(Tensor x, std::map<std::string, Tensor> params, bool is_training) {
    // Initial conv
    x = conv2d(x, params["conv1_weight"], Tensor(),
              {2},  // stride
              at::IntArrayRef({1}));  // padding
    x = batch_norm(
        x, params["bn1_weight"], params["bn1_bias"],
        params["bn1_mean"], params["bn1_var"],
        is_training, BN_MOMENTUM, BN_EPSILON, true
    );
    x = relu(x);

    // MBConv blocks using constant memory configs
    for (int i = 0; i < 5; i++) {
        int block_num = i + 1;
        int stride = MBCONV_CONFIGS[i][0];
        int expand_ratio = MBCONV_CONFIGS[i][1];
        
        std::map<std::string, Tensor> block_params;
        std::string prefix = "mbconv" + std::to_string(block_num) + "_";
        
        for (const auto& pair : params) {
            if (pair.first.rfind(prefix, 0) == 0) {
                std::string key = pair.first.substr(prefix.length());
                block_params[key] = pair.second;
            }
        }
        
        x = mbconv_block(x, block_params, stride, expand_ratio, is_training);
    }

    // Final layers
    x = conv2d(x, params["conv_final_weight"], Tensor(),
              {1},  // stride
              at::IntArrayRef({0}));  // padding
    x = batch_norm(
        x, params["bn_final_weight"], params["bn_final_bias"],
        params["bn_final_mean"], params["bn_final_var"],
        is_training, BN_MOMENTUM, BN_EPSILON, true
    );
    x = relu(x);
    x = adaptive_avg_pool2d(x, {1, 1});
    x = x.flatten(1);
    x = linear(x, params["fc_weight"], params["fc_bias"]);

    return x;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "EfficientNetB2 forward");
}
Operation / Metric Value Unit
aten::conv2d
CPU Time 2870956.95 μs
Device Time 1092273.19 μs
Self CPU Time 169922.64 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::convolution
CPU Time 2701034.31 μs
Device Time 1092273.19 μs
Self CPU Time 212201.43 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_convolution
CPU Time 2488832.88 μs
Device Time 1092273.19 μs
Self CPU Time 260757.35 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::cudnn_convolution
CPU Time 1950735.93 μs
Device Time 916464.00 μs
Self CPU Time 1225831.53 μs
Self Device Time 916464.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::batch_norm
CPU Time 2555070.23 μs
Device Time 833682.71 μs
Self CPU Time 123124.69 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_batch_norm_impl_index
CPU Time 2431945.53 μs
Device Time 833682.71 μs
Self CPU Time 102329.20 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B