← Back to Leaderboard

The AI CUDA Engineer 👷

39_GRU39_gru_constant_memory_edit_1

Level 3 • Task 39
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import _VF


def module_fn(
    x: torch.Tensor,
    gru_weights_ih: List[torch.Tensor],
    gru_weights_hh: List[torch.Tensor],
    gru_biases_ih: List[torch.Tensor],
    gru_biases_hh: List[torch.Tensor],
    h0: torch.Tensor,
    is_training: bool,
) -> torch.Tensor:
    """
    Functional implementation of GRU

    Args:
        x: Input tensor of shape (seq_len, batch_size, input_size) if batch_first=False
        gru_weights_ih: List of input-hidden weight tensors for each GRU layer
        gru_weights_hh: List of hidden-hidden weight tensors for each GRU layer
        gru_biases_ih: List of input-hidden bias tensors for each GRU layer
        gru_biases_hh: List of hidden-hidden bias tensors for each GRU layer
        h0: Initial hidden state
        is_training: Whether in training mode

    Returns:
        output tensor of shape (seq_len, batch_size, hidden_size)
    """
    h0 = h0.to(x.device)

    # Run single GRU with all layers at once
    output, _ = _VF.gru(
        x,
        h0,
        [
            w
            for layer in zip(
                gru_weights_ih, gru_weights_hh, gru_biases_ih, gru_biases_hh
            )
            for w in layer
        ],
        True,  # has_biases
        len(gru_weights_ih),  # num_layers
        0.0,  # dropout
        is_training,  # training
        False,  # bidirectional
        False,
    )  # batch_first

    return output


class Model(nn.Module):
    def __init__(
        self, input_size, hidden_size, num_layers=3, bias=True, batch_first=False
    ):
        """
        :param input_size: The number of expected features in the input x
        :param hidden_size: The number of features in the hidden state h
        :param num_layers: Number of recurrent layers (default: 1)
        :param bias: If False, then the layer does not use bias weights b_ih and b_hh (default: True)
        :param batch_first: If True, then the input and output tensors are provided as (batch, seq, feature) (default: False)
        """
        super(Model, self).__init__()

        # Create GRU and extract its parameters
        gru = nn.GRU(
            input_size,
            hidden_size,
            num_layers,
            bias=bias,
            batch_first=batch_first,
            dropout=0,
            bidirectional=False,
        )

        # Initialize h0 exactly as in original code
        self.h0 = torch.randn((num_layers, batch_size, hidden_size))

        # Extract and store GRU parameters
        self.gru_weights_ih = nn.ParameterList()
        self.gru_weights_hh = nn.ParameterList()
        self.gru_biases_ih = nn.ParameterList()
        self.gru_biases_hh = nn.ParameterList()

        for i in range(num_layers):
            self.gru_weights_ih.append(getattr(gru, f"weight_ih_l{i}"))
            self.gru_weights_hh.append(getattr(gru, f"weight_hh_l{i}"))
            if bias:
                self.gru_biases_ih.append(getattr(gru, f"bias_ih_l{i}"))
                self.gru_biases_hh.append(getattr(gru, f"bias_hh_l{i}"))
            else:
                self.gru_biases_ih.append(None)
                self.gru_biases_hh.append(None)

    def forward(self, x, fn=module_fn):
        return fn(
            x,
            self.gru_weights_ih,
            self.gru_weights_hh,
            self.gru_biases_ih,
            self.gru_biases_hh,
            self.h0,
            self.training,
        )


# Test code
batch_size = 10
seq_len = 512
input_size = 128
hidden_size = 256
num_layers = 6


def get_inputs():
    return [torch.randn(seq_len, batch_size, input_size)]


def get_init_inputs():
    return [input_size, hidden_size, num_layers]
import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=3, bias=True, batch_first=False):
        """
        :param input_size: The number of expected features in the input x
        :param hidden_size: The number of features in the hidden state h
        :param num_layers: Number of recurrent layers (default: 1)
        :param bias: If False, then the layer does not use bias weights b_ih and b_hh (default: True)
        :param batch_first: If True, then the input and output tensors are provided as (batch, seq, feature) (default: False)
        """
        super(Model, self).__init__()
        
        self.gru = nn.GRU(input_size, hidden_size, num_layers, bias, batch_first, dropout=0, bidirectional=False)
        self.h0 = torch.randn((num_layers, batch_size, hidden_size))
    
    def forward(self, x):
        """
        :param x: The input tensor, shape (seq_len, batch_size, input_size) if batch_first=False, otherwise (batch_size, seq_len, input_size)
        :param h_0: The initial hidden state for the input sequence, shape (num_layers * num_directions, batch_size, hidden_size) (default: None)
        :return: output, h_n
            - output: The output features (h_t) from the last layer of the GRU, for each t, shape (seq_len, batch_size, num_directions * hidden_size) if batch_first=False, otherwise (batch_size, seq_len, num_directions * hidden_size)
            - h_n: The hidden state for t = seq_len, shape (num_layers * num_directions, batch_size, hidden_size)
        """
        self.h0 = self.h0.to(x.device)
        output, h_n = self.gru(x, self.h0)
        return output

# Test code
batch_size = 10
seq_len = 512
input_size = 128
hidden_size = 256
num_layers = 6

def get_inputs():
    return [torch.randn(seq_len, batch_size, input_size)]

def get_init_inputs():
    return [input_size, hidden_size, num_layers]

Kernel Information

Related Kernels (Level 3, Task 39 • 39_GRU)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 39_gru_constant_memory_edit_1 27.75 1.24 1.83
🥈 gru_with_stride_loops_opt_base 28.00 1.23 1.81
🥉 39_gru_coalesced_base 28.18 1.22 1.80
4 gru_with_memory_coalescing_base_base 28.21 1.22 1.80
5 optimized_gru_forward_base 28.36 1.21 1.79
6 39_gru_constant_memory_base 29.09 1.18 1.74
7 gru_with_cuda_streams_base 30.00 1.15 1.69
8 gru_with_memory_access_optimizations_base 30.08 1.14 1.68
9 optimized_copy_base 30.08 1.14 1.68
10 gru_with_optimal_block_size_base 30.13 1.14 1.68
11 gru_3d_indexing_optim_base 30.23 1.14 1.68
12 gru_with_load_balancing_base_base_base 30.23 1.14 1.68
13 gru_with_ldg_and_alignment_base_base_base 30.27 1.14 1.67
14 gru_with_unrolled_loops_base_base 30.33 1.13 1.67
15 gru_pipeline_overlap_base 30.34 1.13 1.67
16 gru_with_uniform_control_flow_base 30.37 1.13 1.67
17 gru_loop_unroll_base 30.38 1.13 1.67
18 gru_with_minimized_warp_divergence_base_base 30.39 1.13 1.67
19 optimized_gru_stream_unroll_base 30.39 1.13 1.67
20 gru_optimized_thread_block_indexing_base_base 30.43 1.13 1.67
#include <torch/extension.h>
#include <torch/torch.h>
#include <vector>

__constant__ float ih_consts[2048];
__constant__ float hh_consts[2048];

torch::Tensor forward(
    torch::Tensor x,
    std::vector<torch::Tensor> gru_weights_ih,
    std::vector<torch::Tensor> gru_weights_hh,
    std::vector<torch::Tensor> gru_biases_ih,
    std::vector<torch::Tensor> gru_biases_hh,
    torch::Tensor h0,
    bool is_training) {
    
    h0 = h0.to(x.device());
    
    // Ensure inputs are contiguous for better memory access
    x = x.contiguous();
    h0 = h0.contiguous();
    
    size_t num_layers = gru_weights_ih.size();
    int64_t input_size = x.size(2);
    int64_t hidden_size = gru_weights_hh[0].size(1);
    int64_t seq_length = x.size(0);
    int64_t batch_size = x.size(1);
    
    // Pre-allocate output tensor with optimal memory layout
    auto output = torch::empty({seq_length, batch_size, hidden_size}, 
                             x.options().layout(torch::kStrided)
                             .memory_format(torch::MemoryFormat::Contiguous));
    
    // Create GRU options
    torch::nn::GRUOptions gru_options(input_size, hidden_size);
    gru_options.num_layers(num_layers);
    gru_options.bidirectional(false);
    gru_options.batch_first(false);
    
    auto gru = torch::nn::GRU(gru_options);
    gru->to(x.device());
    gru->train(is_training);
    
    // Pre-process weights and biases for better memory access
    for (size_t l = 0; l < num_layers; ++l) {
        std::string layer_str = std::to_string(l);
        
        // Ensure weights are contiguous and properly aligned
        gru_weights_ih[l] = gru_weights_ih[l].contiguous();
        gru_weights_hh[l] = gru_weights_hh[l].contiguous();
        gru_biases_ih[l] = gru_biases_ih[l].contiguous();
        gru_biases_hh[l] = gru_biases_hh[l].contiguous();
        
        auto params = gru->named_parameters();

        // Copy weights into constant memory if small enough
        if (gru_weights_ih[l].numel() <= 2048 && gru_weights_hh[l].numel() <= 2048) {
            cudaMemcpyToSymbol(ih_consts + l * 2048, gru_weights_ih[l].data_ptr<float>(), gru_weights_ih[l].numel() * sizeof(float));
            cudaMemcpyToSymbol(hh_consts, gru_weights_hh[l].data_ptr<float>(), gru_weights_hh[l].numel() * sizeof(float));
        } else {
            params["weight_ih_l" + layer_str].copy_(gru_weights_ih[l]);
            params["weight_hh_l" + layer_str].copy_(gru_weights_hh[l]);
        }

        params["bias_ih_l" + layer_str].copy_(gru_biases_ih[l]);
        params["bias_hh_l" + layer_str].copy_(gru_biases_hh[l]);
    }
    
    // Reshape h0 with optimal memory layout
    h0 = h0.view({static_cast<int64_t>(num_layers), batch_size, hidden_size});
    
    // Forward pass with optimized memory access
    auto result = gru->forward(x, h0);
    output.copy_(std::get<0>(result));
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "GRU forward (CUDA)");
}
Operation / Metric Value Unit
aten::to
CPU Time 689366.61 μs
Device Time 93798.97 μs
Self CPU Time 3996.37 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::copy_
CPU Time 483390.03 μs
Device Time 129522.46 μs
Self CPU Time 37730.57 μs
Self Device Time 129522.46 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::uniform_
CPU Time 2348770.75 μs
Device Time 0.00 μs
Self CPU Time 2348770.75 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 4873244.50 μs
Device Time 0.00 μs
Self CPU Time 4873244.50 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::gru
CPU Time 7170233.75 μs
Device Time 7323989.16 μs
Self CPU Time 5926.05 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_cudnn_rnn
CPU Time 7162350.26 μs
Device Time 7323989.16 μs
Self CPU Time 1783121.55 μs
Self Device Time 7323989.16 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void cutlass::Kernel<cutlass_80_tensorop_s1688gemm_64x64_32x6_tn_align4>(cutlass_80_tensorop_s1688gemm_64x64_32x6_tn_align4::Params)
CPU Time 0.00 μs
Device Time 4708602.33 μs
Self CPU Time 0.00 μs
Self Device Time 4708602.33 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void elemWiseRNNcell<float, float, float, (cudnnRNNMode_t)3, (cudnnRNNBiasMode_t)2>(int, int, int, int, int, bool, float const*, float const*, float const*, float const*, float const*, float const*, float const*, float*, float*, float*, float*, float*, cudnnRNNClipMode_t, cudnnNanPropagation_t, float, float)
CPU Time 0.00 μs
Device Time 2615390.26 μs
Self CPU Time 0.00 μs
Self Device Time 2615390.26 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B