import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import _VF
from typing import List
def module_fn(
x: torch.Tensor,
lstm_weights_ih: List[torch.Tensor],
lstm_weights_hh: List[torch.Tensor],
lstm_biases_ih: List[torch.Tensor],
lstm_biases_hh: List[torch.Tensor],
h0: torch.Tensor,
c0: torch.Tensor,
is_training: bool,
) -> torch.Tensor:
"""
Functional implementation of LSTM with Hn
Args:
x: Input tensor of shape (batch_size, sequence_length, input_size)
lstm_weights_ih: List of input-hidden weight tensors for each LSTM layer
lstm_weights_hh: List of hidden-hidden weight tensors for each LSTM layer
lstm_biases_ih: List of input-hidden bias tensors for each LSTM layer
lstm_biases_hh: List of hidden-hidden bias tensors for each LSTM layer
h0: Initial hidden state
c0: Initial cell state
is_training: Whether in training mode
Returns:
Final hidden state tensor
"""
h0 = h0.to(x.device)
c0 = c0.to(x.device)
# Run LSTM layers
out = x
hn = h0
cn = c0
for i in range(len(lstm_weights_ih)):
params = (
lstm_weights_ih[i],
lstm_weights_hh[i],
lstm_biases_ih[i],
lstm_biases_hh[i],
)
result = _VF.lstm(
out,
(hn[i : i + 1], cn[i : i + 1]),
params,
True, # has_biases
1, # num_layers
0.0 if not is_training else dropout, # dropout
is_training, # training
False, # bidirectional
True,
) # batch_first
out = result[0]
# Update the corresponding layer's hidden state
hn = hn.clone()
cn = cn.clone()
hn[i : i + 1] = result[1]
cn[i : i + 1] = result[2]
return hn
class Model(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.0):
"""
Initialize the LSTM model.
:param input_size: The number of expected features in the input `x`
:param hidden_size: The number of features in the hidden state `h`
:param num_layers: Number of recurrent layers
:param output_size: The number of output features
:param dropout: If non-zero, introduces a Dropout layer on the outputs of each LSTM layer except the last layer
"""
super(Model, self).__init__()
# Initialize hidden states
self.h0 = torch.randn((num_layers, batch_size, hidden_size))
self.c0 = torch.randn((num_layers, batch_size, hidden_size))
# Extract LSTM parameters
lstm = nn.LSTM(
input_size,
hidden_size,
num_layers,
batch_first=True,
dropout=dropout,
bidirectional=False,
)
# Get weights and biases for each layer
self.lstm_weights_ih = nn.ParameterList()
self.lstm_weights_hh = nn.ParameterList()
self.lstm_biases_ih = nn.ParameterList()
self.lstm_biases_hh = nn.ParameterList()
for i in range(num_layers):
self.lstm_weights_ih.append(
nn.Parameter(getattr(lstm, f"weight_ih_l{i}").data.clone())
)
self.lstm_weights_hh.append(
nn.Parameter(getattr(lstm, f"weight_hh_l{i}").data.clone())
)
self.lstm_biases_ih.append(
nn.Parameter(getattr(lstm, f"bias_ih_l{i}").data.clone())
)
self.lstm_biases_hh.append(
nn.Parameter(getattr(lstm, f"bias_hh_l{i}").data.clone())
)
# Extract linear layer parameters
fc = nn.Linear(hidden_size, output_size)
self.fc_weight = nn.Parameter(fc.weight.data.clone())
self.fc_bias = nn.Parameter(fc.bias.data.clone())
def forward(self, x, fn=module_fn):
return fn(
x,
self.lstm_weights_ih,
self.lstm_weights_hh,
self.lstm_biases_ih,
self.lstm_biases_hh,
self.h0,
self.c0,
self.training,
)
# Test code
batch_size = 10
sequence_length = 512
input_size = 128
hidden_size = 256
num_layers = 6
output_size = 10
dropout = 0.0
def get_inputs():
return [torch.randn(batch_size, sequence_length, input_size)]
def get_init_inputs():
return [input_size, hidden_size, num_layers, output_size, dropout]
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.0):
"""
Initialize the LSTM model.
:param input_size: The number of expected features in the input `x`
:param hidden_size: The number of features in the hidden state `h`
:param num_layers: Number of recurrent layers
:param output_size: The number of output features
:param dropout: If non-zero, introduces a Dropout layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to `dropout`
"""
super(Model, self).__init__()
# Initialize hidden state with random values
self.h0 = torch.randn((num_layers, batch_size, hidden_size))
self.c0 = torch.randn((num_layers, batch_size, hidden_size))
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout, bidirectional=False)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
"""
Forward pass through the LSTM model.
:param x: The input tensor, shape (batch_size, sequence_length, input_size)
:return: The output tensor, shape (batch_size, sequence_length, output_size)
"""
self.h0 = self.h0.to(x.device)
self.c0 = self.h0.to(x.device)
# Forward propagate LSTM
out, state = self.lstm(x, (self.h0, self.c0)) # out: tensor of shape (batch_size, seq_length, hidden_size)
# Decode the hidden state of the last time step
out = self.fc(out[:, -1, :]) # out: tensor of shape (batch_size, output_size)
return state[0]
# Test code
batch_size = 10
sequence_length = 512
input_size = 128
hidden_size = 256
num_layers = 6
output_size = 10
dropout = 0.0
def get_inputs():
return [torch.randn(batch_size, sequence_length, input_size)]
def get_init_inputs():
return [input_size, hidden_size, num_layers, output_size, dropout]
#include <torch/extension.h>
#include <torch/torch.h>
#include <vector>
// A combined and optimized LSTM forward kernel
torch::Tensor forward(
torch::Tensor x,
std::vector<torch::Tensor> lstm_weights_ih,
std::vector<torch::Tensor> lstm_weights_hh,
std::vector<torch::Tensor> lstm_biases_ih,
std::vector<torch::Tensor> lstm_biases_hh,
torch::Tensor h0,
torch::Tensor c0,
bool is_training) {
// Ensure h0 and c0 are on the same device as x
auto device = x.device();
h0 = h0.to(device);
c0 = c0.to(device);
auto out = x;
auto hn = h0.clone();
auto cn = c0.clone();
const size_t num_layers = lstm_weights_ih.size();
// Lambda to process a single LSTM layer
auto process_layer = [&](size_t i) {
// Get the weights and biases for this layer
auto weight_ih = lstm_weights_ih[i];
auto weight_hh = lstm_weights_hh[i];
auto bias_ih = lstm_biases_ih[i];
auto bias_hh = lstm_biases_hh[i];
int64_t input_size = weight_ih.size(1);
int64_t hidden_size = weight_hh.size(1);
// Instantiate a single-layer LSTM model
torch::nn::LSTM lstm_model(
torch::nn::LSTMOptions(input_size, hidden_size)
.num_layers(1)
.batch_first(true)
.bidirectional(false));
lstm_model->to(device);
// Copy weights and biases using a pragma unroll to help compile-time unrolling
#pragma unroll
{
lstm_model->named_parameters()["weight_ih_l0"].copy_(weight_ih);
lstm_model->named_parameters()["weight_hh_l0"].copy_(weight_hh);
lstm_model->named_parameters()["bias_ih_l0"].copy_(bias_ih);
lstm_model->named_parameters()["bias_hh_l0"].copy_(bias_hh);
}
// Get the corresponding hidden and cell state slice for this layer
auto h_slice = hn.narrow(0, i, 1);
auto c_slice = cn.narrow(0, i, 1);
std::tuple<torch::Tensor, torch::Tensor> state_tuple = std::make_tuple(h_slice, c_slice);
// Set training mode accordingly
lstm_model->train(is_training);
// Forward pass through the current LSTM layer
auto output_and_state = lstm_model->forward(out, state_tuple);
auto output = std::get<0>(output_and_state);
auto h_n_c_n = std::get<1>(output_and_state);
// Update hidden and cell state slices
#pragma unroll
{
auto h_n = std::get<0>(h_n_c_n);
auto c_n = std::get<1>(h_n_c_n);
hn.narrow(0, i, 1).copy_(h_n);
cn.narrow(0, i, 1).copy_(c_n);
}
// Update output for the next layer
out = output;
};
// Explicitly unroll the first 4 layers for common cases, then loop for remaining layers
if (num_layers > 0) process_layer(0);
if (num_layers > 1) process_layer(1);
if (num_layers > 2) process_layer(2);
if (num_layers > 3) process_layer(3);
for (size_t i = 4; i < num_layers; ++i) {
process_layer(i);
}
return hn;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Optimized LSTM forward (CUDA)");
}
Operation / Metric | Value | Unit |
---|---|---|
aten::to | ||
CPU Time | 1343407.45 | μs |
Device Time | 101070.75 | μs |
Self CPU Time | 5583.61 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::copy_ | ||
CPU Time | 912453.80 | μs |
Device Time | 143350.48 | μs |
Self CPU Time | 187753.88 | μs |
Self Device Time | 143350.48 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::uniform_ | ||
CPU Time | 2601594.71 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 2601594.71 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaLaunchKernel | ||
CPU Time | 4338114.07 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 4338114.07 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::lstm | ||
CPU Time | 6485475.72 | μs |
Device Time | 5991149.20 | μs |
Self CPU Time | 17956.59 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_cudnn_rnn | ||
CPU Time | 6457852.83 | μs |
Device Time | 5991149.20 | μs |
Self CPU Time | 1904403.61 | μs |
Self Device Time | 5989981.61 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void cutlass::Kernel<cutlass_80_tensorop_s1688gemm_64x64_32x6_tn_align4>(cutlass_80_tensorop_s1688gemm_64x64_32x6_tn_align4::Params) | ||
CPU Time | 0.00 | μs |
Device Time | 3856986.07 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 3856986.07 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void elemWiseRNNcell<float, float, float, (cudnnRNNMode_t)2, (cudnnRNNBiasMode_t)2>(int, int, int, int, int, bool, float const*, float const*, float const*, float const*, float const*, float const*, float const*, float*, float*, float*, float*, float*, cudnnRNNClipMode_t, cudnnNanPropagation_t, float, float) | ||
CPU Time | 0.00 | μs |
Device Time | 2132995.55 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 2132995.55 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |