import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import _VF
from typing import List
def module_fn(
x: torch.Tensor,
lstm_weights_ih: List[torch.Tensor],
lstm_weights_hh: List[torch.Tensor],
lstm_biases_ih: List[torch.Tensor],
lstm_biases_hh: List[torch.Tensor],
h0: torch.Tensor,
c0: torch.Tensor,
is_training: bool,
) -> torch.Tensor:
"""
Functional implementation of LSTM with Hn
Args:
x: Input tensor of shape (batch_size, sequence_length, input_size)
lstm_weights_ih: List of input-hidden weight tensors for each LSTM layer
lstm_weights_hh: List of hidden-hidden weight tensors for each LSTM layer
lstm_biases_ih: List of input-hidden bias tensors for each LSTM layer
lstm_biases_hh: List of hidden-hidden bias tensors for each LSTM layer
h0: Initial hidden state
c0: Initial cell state
is_training: Whether in training mode
Returns:
Final hidden state tensor
"""
h0 = h0.to(x.device)
c0 = c0.to(x.device)
# Run LSTM layers
out = x
hn = h0
cn = c0
for i in range(len(lstm_weights_ih)):
params = (
lstm_weights_ih[i],
lstm_weights_hh[i],
lstm_biases_ih[i],
lstm_biases_hh[i],
)
result = _VF.lstm(
out,
(hn[i : i + 1], cn[i : i + 1]),
params,
True, # has_biases
1, # num_layers
0.0 if not is_training else dropout, # dropout
is_training, # training
False, # bidirectional
True,
) # batch_first
out = result[0]
# Update the corresponding layer's hidden state
hn = hn.clone()
cn = cn.clone()
hn[i : i + 1] = result[1]
cn[i : i + 1] = result[2]
return hn
class Model(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.0):
"""
Initialize the LSTM model.
:param input_size: The number of expected features in the input `x`
:param hidden_size: The number of features in the hidden state `h`
:param num_layers: Number of recurrent layers
:param output_size: The number of output features
:param dropout: If non-zero, introduces a Dropout layer on the outputs of each LSTM layer except the last layer
"""
super(Model, self).__init__()
# Initialize hidden states
self.h0 = torch.randn((num_layers, batch_size, hidden_size))
self.c0 = torch.randn((num_layers, batch_size, hidden_size))
# Extract LSTM parameters
lstm = nn.LSTM(
input_size,
hidden_size,
num_layers,
batch_first=True,
dropout=dropout,
bidirectional=False,
)
# Get weights and biases for each layer
self.lstm_weights_ih = nn.ParameterList()
self.lstm_weights_hh = nn.ParameterList()
self.lstm_biases_ih = nn.ParameterList()
self.lstm_biases_hh = nn.ParameterList()
for i in range(num_layers):
self.lstm_weights_ih.append(
nn.Parameter(getattr(lstm, f"weight_ih_l{i}").data.clone())
)
self.lstm_weights_hh.append(
nn.Parameter(getattr(lstm, f"weight_hh_l{i}").data.clone())
)
self.lstm_biases_ih.append(
nn.Parameter(getattr(lstm, f"bias_ih_l{i}").data.clone())
)
self.lstm_biases_hh.append(
nn.Parameter(getattr(lstm, f"bias_hh_l{i}").data.clone())
)
# Extract linear layer parameters
fc = nn.Linear(hidden_size, output_size)
self.fc_weight = nn.Parameter(fc.weight.data.clone())
self.fc_bias = nn.Parameter(fc.bias.data.clone())
def forward(self, x, fn=module_fn):
return fn(
x,
self.lstm_weights_ih,
self.lstm_weights_hh,
self.lstm_biases_ih,
self.lstm_biases_hh,
self.h0,
self.c0,
self.training,
)
# Test code
batch_size = 10
sequence_length = 512
input_size = 128
hidden_size = 256
num_layers = 6
output_size = 10
dropout = 0.0
def get_inputs():
return [torch.randn(batch_size, sequence_length, input_size)]
def get_init_inputs():
return [input_size, hidden_size, num_layers, output_size, dropout]
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.0):
"""
Initialize the LSTM model.
:param input_size: The number of expected features in the input `x`
:param hidden_size: The number of features in the hidden state `h`
:param num_layers: Number of recurrent layers
:param output_size: The number of output features
:param dropout: If non-zero, introduces a Dropout layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to `dropout`
"""
super(Model, self).__init__()
# Initialize hidden state with random values
self.h0 = torch.randn((num_layers, batch_size, hidden_size))
self.c0 = torch.randn((num_layers, batch_size, hidden_size))
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout, bidirectional=False)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
"""
Forward pass through the LSTM model.
:param x: The input tensor, shape (batch_size, sequence_length, input_size)
:return: The output tensor, shape (batch_size, sequence_length, output_size)
"""
self.h0 = self.h0.to(x.device)
self.c0 = self.h0.to(x.device)
# Forward propagate LSTM
out, state = self.lstm(x, (self.h0, self.c0)) # out: tensor of shape (batch_size, seq_length, hidden_size)
# Decode the hidden state of the last time step
out = self.fc(out[:, -1, :]) # out: tensor of shape (batch_size, output_size)
return state[0]
# Test code
batch_size = 10
sequence_length = 512
input_size = 128
hidden_size = 256
num_layers = 6
output_size = 10
dropout = 0.0
def get_inputs():
return [torch.randn(batch_size, sequence_length, input_size)]
def get_init_inputs():
return [input_size, hidden_size, num_layers, output_size, dropout]
#include <torch/extension.h>
#include <torch/torch.h>
#include <vector>
#include <c10/cuda/CUDAGuard.h>
torch::Tensor forward(
torch::Tensor x,
std::vector<torch::Tensor> lstm_weights_ih,
std::vector<torch::Tensor> lstm_weights_hh,
std::vector<torch::Tensor> lstm_biases_ih,
std::vector<torch::Tensor> lstm_biases_hh,
torch::Tensor h0,
torch::Tensor c0,
bool is_training
) {
// Ensure we're operating on the correct device
const at::cuda::CUDAGuard device_guard(x.device());
// Pre-allocate contiguous tensors and ensure alignment
h0 = h0.to(x.device()).contiguous();
c0 = c0.to(x.device()).contiguous();
x = x.contiguous();
auto out = x;
auto hn = h0.clone();
auto cn = c0.clone();
const size_t num_layers = lstm_weights_ih.size();
// Pre-allocate and prepare all layer parameters in contiguous memory
std::vector<torch::Tensor> weights_ih_contiguous;
std::vector<torch::Tensor> weights_hh_contiguous;
std::vector<torch::nn::LSTM> lstm_layers;
weights_ih_contiguous.reserve(num_layers);
weights_hh_contiguous.reserve(num_layers);
lstm_layers.reserve(num_layers);
// Prepare aligned and contiguous weights
for (size_t i = 0; i < num_layers; ++i) {
weights_ih_contiguous.push_back(lstm_weights_ih[i].contiguous());
weights_hh_contiguous.push_back(lstm_weights_hh[i].contiguous());
int64_t input_size = weights_ih_contiguous[i].size(1);
int64_t hidden_size = weights_hh_contiguous[i].size(1);
lstm_layers.emplace_back(
torch::nn::LSTMOptions(input_size, hidden_size)
.num_layers(1)
.batch_first(true)
.bidirectional(false)
);
auto& layer = lstm_layers[i];
layer->to(x.device());
// Assign pre-aligned weights and biases
layer->named_parameters()["weight_ih_l0"].copy_(weights_ih_contiguous[i]);
layer->named_parameters()["weight_hh_l0"].copy_(weights_hh_contiguous[i]);
layer->named_parameters()["bias_ih_l0"].copy_(lstm_biases_ih[i]);
layer->named_parameters()["bias_hh_l0"].copy_(lstm_biases_hh[i]);
layer->train(is_training);
}
// Pre-allocate output tensor for all layers
auto batch_size = x.size(0);
auto seq_length = x.size(1);
auto hidden_size = lstm_weights_hh[0].size(1);
// Cache for intermediate states to avoid recomputation
std::vector<std::tuple<torch::Tensor, torch::Tensor>> cached_states;
cached_states.reserve(num_layers);
// Process layers with coalesced memory access and state caching
for (size_t i = 0; i < num_layers; ++i) {
// Create cached states for this layer
auto h_slice = hn.narrow(0, i, 1).contiguous();
auto c_slice = cn.narrow(0, i, 1).contiguous();
cached_states.emplace_back(h_slice, c_slice);
// Ensure input is contiguous before LSTM operation
out = out.contiguous();
// Use cached states for forward pass
auto output_and_state = lstm_layers[i]->forward(out, cached_states[i]);
out = std::get<0>(output_and_state);
// Update cached states
auto& new_state = std::get<1>(output_and_state);
hn.narrow(0, i, 1).copy_(std::get<0>(new_state));
cn.narrow(0, i, 1).copy_(std::get<1>(new_state));
}
return hn.contiguous();
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "LSTM forward (CUDA)");
}
Operation / Metric | Value | Unit |
---|---|---|
aten::to | ||
CPU Time | 1831424.15 | μs |
Device Time | 84583.93 | μs |
Self CPU Time | 4340.24 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::copy_ | ||
CPU Time | 1778057.70 | μs |
Device Time | 129457.46 | μs |
Self CPU Time | 180959.57 | μs |
Self Device Time | 129457.46 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::uniform_ | ||
CPU Time | 2186798.07 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 2186798.07 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaLaunchKernel | ||
CPU Time | 4174910.87 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 4174910.87 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::lstm | ||
CPU Time | 6043725.33 | μs |
Device Time | 5048531.51 | μs |
Self CPU Time | 19115.52 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_cudnn_rnn | ||
CPU Time | 6016545.98 | μs |
Device Time | 5048531.51 | μs |
Self CPU Time | 1461063.65 | μs |
Self Device Time | 5042849.30 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void cutlass::Kernel<cutlass_80_tensorop_s1688gemm_64x64_32x6_tn_align4>(cutlass_80_tensorop_s1688gemm_64x64_32x6_tn_align4::Params) | ||
CPU Time | 0.00 | μs |
Device Time | 3251807.56 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 3251807.56 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void elemWiseRNNcell<float, float, float, (cudnnRNNMode_t)2, (cudnnRNNBiasMode_t)2>(int, int, int, int, int, bool, float const*, float const*, float const*, float const*, float const*, float const*, float const*, float*, float*, float*, float*, float*, cudnnRNNClipMode_t, cudnnNanPropagation_t, float, float) | ||
CPU Time | 0.00 | μs |
Device Time | 1791047.85 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 1791047.85 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |