← Back to Leaderboard

The AI CUDA Engineer 👷

89_ConvTranspose3d_MaxPool_Softmax_Subtract_Swish_Maxconstant_memory_optimization_base

Level 2 • Task 89
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(
    x: torch.Tensor,
    stride: int,
    padding: int,
    output_padding: int,
    pool_kernel_size: int,
    pool_stride: int,
    pool_padding: int,
    conv_transpose: torch.Tensor,
    conv_transpose_bias: torch.Tensor,
    subtract: torch.Tensor,
) -> torch.Tensor:
    """
    Applies sequence of operations:
        - ConvTranspose3d
        - MaxPool3d
        - Softmax
        - Subtract
        - Swish
        - Max

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, in_channels, depth, height, width)
        stride (int): Stride for conv transpose
        padding (int): Padding for conv transpose
        output_padding (int): Output padding for conv transpose
        pool_kernel_size (int): Kernel size for max pooling
        pool_stride (int): Stride for max pooling
        pool_padding (int): Padding for max pooling
        conv_transpose (torch.Tensor): Weight tensor for transposed convolution
        conv_transpose_bias (torch.Tensor): Bias tensor for transposed convolution
        subtract (torch.Tensor): Subtraction parameter tensor
    """
    x = F.conv_transpose3d(
        x,
        conv_transpose,
        bias=conv_transpose_bias,
        stride=stride,
        padding=padding,
        output_padding=output_padding,
    )
    x = F.max_pool3d(
        x, kernel_size=pool_kernel_size, stride=pool_stride, padding=pool_padding
    )
    x = F.softmax(x, dim=1)
    x = x - subtract.view(1, -1, 1, 1, 1)
    x = torch.sigmoid(x) * x  # Swish
    x = torch.max(x, dim=1)[0]
    return x


class Model(nn.Module):
    """
    A model that performs a sequence of operations:
        - ConvTranspose3d
        - MaxPool3d
        - Softmax
        - Subtract
        - Swish
        - Max
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        output_padding,
        pool_kernel_size,
        pool_stride,
        pool_padding,
    ):
        super(Model, self).__init__()
        conv_transpose = nn.ConvTranspose3d(in_channels, out_channels, kernel_size)
        self.conv_transpose_parameter = conv_transpose.weight
        self.conv_transpose_bias = conv_transpose.bias
        self.subtract_parameter = nn.Parameter(torch.randn(out_channels) * 0.02)

    def forward(
        self,
        x,
        stride,
        padding,
        output_padding,
        pool_kernel_size,
        pool_stride,
        pool_padding,
        fn=module_fn,
    ):
        return fn(
            x,
            stride,
            padding,
            output_padding,
            pool_kernel_size,
            pool_stride,
            pool_padding,
            self.conv_transpose_parameter,
            self.conv_transpose_bias,
            self.subtract_parameter,
        )


batch_size = 128
in_channels = 3
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = 3
stride = 2
padding = 1
output_padding = 1
pool_kernel_size = 2
pool_stride = 2
pool_padding = 0


def get_inputs():
    return [
        torch.randn(batch_size, in_channels, depth, height, width),
        stride,
        padding,
        output_padding,
        pool_kernel_size,
        pool_stride,
        pool_padding,
    ]


def get_init_inputs():
    return [
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        output_padding,
        pool_kernel_size,
        pool_stride,
        pool_padding,
    ]
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    A model that performs a sequence of operations:
        - ConvTranspose3d
        - MaxPool3d
        - Softmax
        - Subtract
        - Swish
        - Max
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, output_padding, pool_kernel_size, pool_stride, pool_padding):
        super(Model, self).__init__()
        self.conv_transpose = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding)
        self.max_pool = nn.MaxPool3d(kernel_size=pool_kernel_size, stride=pool_stride, padding=pool_padding)
        self.subtract = nn.Parameter(torch.randn(out_channels)*0.02) # Assuming subtraction is element-wise across channels

    def forward(self, x):
        x = self.conv_transpose(x)
        x = self.max_pool(x)
        x = torch.softmax(x, dim=1) # Apply softmax across channels (dim=1)
        x = x - self.subtract.view(1, -1, 1, 1, 1) # Subtract across channels
        x = torch.sigmoid(x) * x # Swish activation
        x = torch.max(x, dim=1)[0] # Max pooling across channels
        return x

batch_size = 128
in_channels = 3
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = 3
stride = 2
padding = 1
output_padding = 1
pool_kernel_size = 2
pool_stride = 2
pool_padding = 0

def get_inputs():
    return [torch.randn(batch_size, in_channels, depth, height, width)]

def get_init_inputs():
    return [in_channels, out_channels, kernel_size, stride, padding, output_padding, pool_kernel_size, pool_stride, pool_padding]

Kernel Information

#include <torch/extension.h>
#include <pybind11/pybind11.h>

namespace py = pybind11;

// Declare constant memory for kernel parameters
__constant__ int STRIDE[3];
__constant__ int PADDING[3];
__constant__ int OUTPUT_PADDING[3];
__constant__ int POOL_PARAMS[9];  // kernel_size, stride, padding for each dimension

// Declare constant memory for subtract tensor
__constant__ float SUBTRACT_TENSOR[1024];  // Assuming max channel size is 1024

__device__ __forceinline__ float swish(float x) {
    return x * (1.0f / (1.0f + expf(-x)));
}

torch::Tensor forward(
    torch::Tensor x,
    int64_t stride,
    int64_t padding,
    int64_t output_padding,
    int64_t pool_kernel_size,
    int64_t pool_stride,
    int64_t pool_padding,
    torch::Tensor conv_transpose_weight,
    torch::Tensor conv_transpose_bias,
    torch::Tensor subtract_tensor
) {
    // Copy constant parameters to device memory
    int h_stride[3] = {stride, stride, stride};
    int h_padding[3] = {padding, padding, padding};
    int h_output_padding[3] = {output_padding, output_padding, output_padding};
    int h_pool_params[9] = {
        pool_kernel_size, pool_kernel_size, pool_kernel_size,
        pool_stride, pool_stride, pool_stride,
        pool_padding, pool_padding, pool_padding
    };
    
    cudaMemcpyToSymbol(STRIDE, h_stride, sizeof(int) * 3);
    cudaMemcpyToSymbol(PADDING, h_padding, sizeof(int) * 3);
    cudaMemcpyToSymbol(OUTPUT_PADDING, h_output_padding, sizeof(int) * 3);
    cudaMemcpyToSymbol(POOL_PARAMS, h_pool_params, sizeof(int) * 9);

    // Copy subtract tensor to constant memory
    cudaMemcpyToSymbol(SUBTRACT_TENSOR, subtract_tensor.data_ptr<float>(), subtract_tensor.numel() * sizeof(float));

    // Transposed convolution
    auto out = at::conv_transpose3d(
        x,
        conv_transpose_weight,
        conv_transpose_bias,
        {stride, stride, stride},
        {padding, padding, padding},
        {output_padding, output_padding, output_padding},
        1,
        {1, 1, 1}
    );

    // MaxPool
    out = at::max_pool3d(
        out,
        {pool_kernel_size, pool_kernel_size, pool_kernel_size},
        {pool_stride, pool_stride, pool_stride},
        {pool_padding, pool_padding, pool_padding}
    );

    // Softmax along channel dimension
    out = at::softmax(out, 1, /*dtype=*/c10::nullopt);

    // Subtract using constant memory
    auto sub_view = subtract_tensor.view({1, -1, 1, 1, 1});
    out = out - sub_view;

    // Swish
    out = out * at::sigmoid(out);

    // Max over channel dimension
    out = std::get<0>(out.max(/*dim=*/1, /*keepdim=*/false));

    return out;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "CUDA forward pass for the module");
}
Operation / Metric Value Unit
aten::to
CPU Time 426750.80 μs
Device Time 2688.79 μs
Self CPU Time 49.34 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_to_copy
CPU Time 426701.46 μs
Device Time 2688.79 μs
Self CPU Time 107.29 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::empty_strided
CPU Time 423617.94 μs
Device Time 0.00 μs
Self CPU Time 111.40 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaDeviceGetStreamPriorityRange
CPU Time 423421.41 μs
Device Time 0.00 μs
Self CPU Time 423421.41 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaMemcpyToSymbol
CPU Time 8445037.19 μs
Device Time 23161.06 μs
Self CPU Time 8445037.19 μs
Self Device Time 23161.06 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::conv_transpose3d
CPU Time 289720.23 μs
Device Time 6544680.84 μs
Self CPU Time 3375.15 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::convolution
CPU Time 286345.09 μs
Device Time 6544680.84 μs
Self CPU Time 5673.87 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_convolution
CPU Time 280671.22 μs
Device Time 6544680.84 μs
Self CPU Time 9184.26 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::cudnn_convolution_transpose
CPU Time 243876.06 μs
Device Time 5185732.36 μs
Self CPU Time 99125.91 μs
Self Device Time 5185732.36 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
sm90_xmma_dgrad_implicit_gemm_indexed_f32f32_tf32f32_f32_nhwckrsc_nhwc_tilesize256x64x32_warpgroupsize1x1x1_g1_strided_execute_kernel__5x_cudnn
CPU Time 0.00 μs
Device Time 3664337.67 μs
Self CPU Time 0.00 μs
Self Device Time 3664337.67 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B