← Back to Leaderboard

The AI CUDA Engineer 👷

36_LTSMHncombined_unroll_base

Level 3 • Task 36
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import _VF
from typing import List


def module_fn(
    x: torch.Tensor,
    lstm_weights_ih: List[torch.Tensor],
    lstm_weights_hh: List[torch.Tensor],
    lstm_biases_ih: List[torch.Tensor],
    lstm_biases_hh: List[torch.Tensor],
    h0: torch.Tensor,
    c0: torch.Tensor,
    is_training: bool,
) -> torch.Tensor:
    """
    Functional implementation of LSTM with Hn

    Args:
        x: Input tensor of shape (batch_size, sequence_length, input_size)
        lstm_weights_ih: List of input-hidden weight tensors for each LSTM layer
        lstm_weights_hh: List of hidden-hidden weight tensors for each LSTM layer
        lstm_biases_ih: List of input-hidden bias tensors for each LSTM layer
        lstm_biases_hh: List of hidden-hidden bias tensors for each LSTM layer
        h0: Initial hidden state
        c0: Initial cell state
        is_training: Whether in training mode

    Returns:
        Final hidden state tensor
    """
    h0 = h0.to(x.device)
    c0 = c0.to(x.device)

    # Run LSTM layers
    out = x
    hn = h0
    cn = c0

    for i in range(len(lstm_weights_ih)):
        params = (
            lstm_weights_ih[i],
            lstm_weights_hh[i],
            lstm_biases_ih[i],
            lstm_biases_hh[i],
        )
        result = _VF.lstm(
            out,
            (hn[i : i + 1], cn[i : i + 1]),
            params,
            True,  # has_biases
            1,  # num_layers
            0.0 if not is_training else dropout,  # dropout
            is_training,  # training
            False,  # bidirectional
            True,
        )  # batch_first

        out = result[0]
        # Update the corresponding layer's hidden state
        hn = hn.clone()
        cn = cn.clone()
        hn[i : i + 1] = result[1]
        cn[i : i + 1] = result[2]

    return hn


class Model(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.0):
        """
        Initialize the LSTM model.

        :param input_size: The number of expected features in the input `x`
        :param hidden_size: The number of features in the hidden state `h`
        :param num_layers: Number of recurrent layers
        :param output_size: The number of output features
        :param dropout: If non-zero, introduces a Dropout layer on the outputs of each LSTM layer except the last layer
        """
        super(Model, self).__init__()

        # Initialize hidden states
        self.h0 = torch.randn((num_layers, batch_size, hidden_size))
        self.c0 = torch.randn((num_layers, batch_size, hidden_size))

        # Extract LSTM parameters
        lstm = nn.LSTM(
            input_size,
            hidden_size,
            num_layers,
            batch_first=True,
            dropout=dropout,
            bidirectional=False,
        )

        # Get weights and biases for each layer
        self.lstm_weights_ih = nn.ParameterList()
        self.lstm_weights_hh = nn.ParameterList()
        self.lstm_biases_ih = nn.ParameterList()
        self.lstm_biases_hh = nn.ParameterList()

        for i in range(num_layers):
            self.lstm_weights_ih.append(
                nn.Parameter(getattr(lstm, f"weight_ih_l{i}").data.clone())
            )
            self.lstm_weights_hh.append(
                nn.Parameter(getattr(lstm, f"weight_hh_l{i}").data.clone())
            )
            self.lstm_biases_ih.append(
                nn.Parameter(getattr(lstm, f"bias_ih_l{i}").data.clone())
            )
            self.lstm_biases_hh.append(
                nn.Parameter(getattr(lstm, f"bias_hh_l{i}").data.clone())
            )

        # Extract linear layer parameters
        fc = nn.Linear(hidden_size, output_size)
        self.fc_weight = nn.Parameter(fc.weight.data.clone())
        self.fc_bias = nn.Parameter(fc.bias.data.clone())

    def forward(self, x, fn=module_fn):
        return fn(
            x,
            self.lstm_weights_ih,
            self.lstm_weights_hh,
            self.lstm_biases_ih,
            self.lstm_biases_hh,
            self.h0,
            self.c0,
            self.training,
        )


# Test code
batch_size = 10
sequence_length = 512
input_size = 128
hidden_size = 256
num_layers = 6
output_size = 10
dropout = 0.0


def get_inputs():
    return [torch.randn(batch_size, sequence_length, input_size)]


def get_init_inputs():
    return [input_size, hidden_size, num_layers, output_size, dropout]
import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.0):
        """
        Initialize the LSTM model.

        :param input_size: The number of expected features in the input `x`
        :param hidden_size: The number of features in the hidden state `h`
        :param num_layers: Number of recurrent layers
        :param output_size: The number of output features
        :param dropout: If non-zero, introduces a Dropout layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to `dropout`
        """
        super(Model, self).__init__()
        # Initialize hidden state with random values
        self.h0 = torch.randn((num_layers, batch_size, hidden_size))
        self.c0 = torch.randn((num_layers, batch_size, hidden_size))
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout, bidirectional=False)
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        """
        Forward pass through the LSTM model.

        :param x: The input tensor, shape (batch_size, sequence_length, input_size)
        :return: The output tensor, shape (batch_size, sequence_length, output_size)
        """
        self.h0 = self.h0.to(x.device)
        self.c0 = self.h0.to(x.device)
        
        # Forward propagate LSTM
        out, state = self.lstm(x, (self.h0, self.c0))  # out: tensor of shape (batch_size, seq_length, hidden_size)
        
        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])  # out: tensor of shape (batch_size, output_size)
        
        return state[0]

# Test code
batch_size = 10
sequence_length = 512
input_size = 128
hidden_size = 256
num_layers = 6
output_size = 10
dropout = 0.0

def get_inputs():
    return [torch.randn(batch_size, sequence_length, input_size)]

def get_init_inputs():
    return [input_size, hidden_size, num_layers, output_size, dropout]

Kernel Information

Related Kernels (Level 3, Task 36 • 36_LTSMHn)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 optimized_lstm_base 36.44 0.76 1.59
🥈 unrolled_lstm_optimized_base 37.00 0.75 1.56
🥉 36_ltsmh_n_modular_base 37.61 0.74 1.54
4 optimized_lstm_forward_base 37.82 0.73 1.53
5 36_LTSMHn 37.86 0.73 1.53
6 combined_unroll_base 38.18 0.73 1.52
7 36_LTSMHn_unrolled_base 38.19 0.73 1.51
8 optimized_ltsmh_coalesced_base 38.38 0.72 1.51
9 warp_divergence_optimized_lstm_base 41.31 0.67 1.40
10 fused_lstm_edit_1 49.73 0.56 1.16
11 fused_lstm_base 49.76 0.56 1.16
12 36_ltsmhn_coalesced_mem_edit_1 49.92 0.56 1.16
13 36_ltsmhn_warp_aligned_base 50.08 0.55 1.15
14 36_ltsmhn_coalesced_mem_base 50.09 0.55 1.15
15 optimized_lstm_forward_base 50.47 0.55 1.15
16 fused_lstm_sync_opt_edit_1 813.07 0.03 0.07
#include <torch/extension.h>
#include <torch/torch.h>
#include <vector>

// A combined and optimized LSTM forward kernel
torch::Tensor forward(
    torch::Tensor x,
    std::vector<torch::Tensor> lstm_weights_ih,
    std::vector<torch::Tensor> lstm_weights_hh,
    std::vector<torch::Tensor> lstm_biases_ih,
    std::vector<torch::Tensor> lstm_biases_hh,
    torch::Tensor h0,
    torch::Tensor c0,
    bool is_training) {

    // Ensure h0 and c0 are on the same device as x
    auto device = x.device();
    h0 = h0.to(device);
    c0 = c0.to(device);

    auto out = x;
    auto hn = h0.clone();
    auto cn = c0.clone();

    const size_t num_layers = lstm_weights_ih.size();

    // Lambda to process a single LSTM layer
    auto process_layer = [&](size_t i) {
        // Get the weights and biases for this layer
        auto weight_ih = lstm_weights_ih[i];
        auto weight_hh = lstm_weights_hh[i];
        auto bias_ih = lstm_biases_ih[i];
        auto bias_hh = lstm_biases_hh[i];

        int64_t input_size = weight_ih.size(1);
        int64_t hidden_size = weight_hh.size(1);

        // Instantiate a single-layer LSTM model
        torch::nn::LSTM lstm_model(
            torch::nn::LSTMOptions(input_size, hidden_size)
                .num_layers(1)
                .batch_first(true)
                .bidirectional(false));
        lstm_model->to(device);

        // Copy weights and biases using a pragma unroll to help compile-time unrolling
        #pragma unroll
        {
            lstm_model->named_parameters()["weight_ih_l0"].copy_(weight_ih);
            lstm_model->named_parameters()["weight_hh_l0"].copy_(weight_hh);
            lstm_model->named_parameters()["bias_ih_l0"].copy_(bias_ih);
            lstm_model->named_parameters()["bias_hh_l0"].copy_(bias_hh);
        }

        // Get the corresponding hidden and cell state slice for this layer
        auto h_slice = hn.narrow(0, i, 1);
        auto c_slice = cn.narrow(0, i, 1);
        std::tuple<torch::Tensor, torch::Tensor> state_tuple = std::make_tuple(h_slice, c_slice);

        // Set training mode accordingly
        lstm_model->train(is_training);

        // Forward pass through the current LSTM layer
        auto output_and_state = lstm_model->forward(out, state_tuple);
        auto output = std::get<0>(output_and_state);
        auto h_n_c_n = std::get<1>(output_and_state);

        // Update hidden and cell state slices
        #pragma unroll
        {
            auto h_n = std::get<0>(h_n_c_n);
            auto c_n = std::get<1>(h_n_c_n);
            hn.narrow(0, i, 1).copy_(h_n);
            cn.narrow(0, i, 1).copy_(c_n);
        }
        
        // Update output for the next layer
        out = output;
    };

    // Explicitly unroll the first 4 layers for common cases, then loop for remaining layers
    if (num_layers > 0) process_layer(0);
    if (num_layers > 1) process_layer(1);
    if (num_layers > 2) process_layer(2);
    if (num_layers > 3) process_layer(3);
    for (size_t i = 4; i < num_layers; ++i) {
        process_layer(i);
    }

    return hn;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Optimized LSTM forward (CUDA)");
}
Operation / Metric Value Unit
aten::to
CPU Time 1343407.45 μs
Device Time 101070.75 μs
Self CPU Time 5583.61 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::copy_
CPU Time 912453.80 μs
Device Time 143350.48 μs
Self CPU Time 187753.88 μs
Self Device Time 143350.48 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::uniform_
CPU Time 2601594.71 μs
Device Time 0.00 μs
Self CPU Time 2601594.71 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 4338114.07 μs
Device Time 0.00 μs
Self CPU Time 4338114.07 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::lstm
CPU Time 6485475.72 μs
Device Time 5991149.20 μs
Self CPU Time 17956.59 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_cudnn_rnn
CPU Time 6457852.83 μs
Device Time 5991149.20 μs
Self CPU Time 1904403.61 μs
Self Device Time 5989981.61 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void cutlass::Kernel<cutlass_80_tensorop_s1688gemm_64x64_32x6_tn_align4>(cutlass_80_tensorop_s1688gemm_64x64_32x6_tn_align4::Params)
CPU Time 0.00 μs
Device Time 3856986.07 μs
Self CPU Time 0.00 μs
Self Device Time 3856986.07 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void elemWiseRNNcell<float, float, float, (cudnnRNNMode_t)2, (cudnnRNNBiasMode_t)2>(int, int, int, int, int, bool, float const*, float const*, float const*, float const*, float const*, float const*, float const*, float*, float*, float*, float*, float*, cudnnRNNClipMode_t, cudnnNanPropagation_t, float, float)
CPU Time 0.00 μs
Device Time 2132995.55 μs
Self CPU Time 0.00 μs
Self Device Time 2132995.55 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B