import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import _VF
from typing import List
def module_fn(
x: torch.Tensor,
lstm_weights_ih: List[torch.Tensor],
lstm_weights_hh: List[torch.Tensor],
lstm_biases_ih: List[torch.Tensor],
lstm_biases_hh: List[torch.Tensor],
h0: torch.Tensor,
c0: torch.Tensor,
is_training: bool,
) -> torch.Tensor:
"""
Functional implementation of LSTM with Hn
Args:
x: Input tensor of shape (batch_size, sequence_length, input_size)
lstm_weights_ih: List of input-hidden weight tensors for each LSTM layer
lstm_weights_hh: List of hidden-hidden weight tensors for each LSTM layer
lstm_biases_ih: List of input-hidden bias tensors for each LSTM layer
lstm_biases_hh: List of hidden-hidden bias tensors for each LSTM layer
h0: Initial hidden state
c0: Initial cell state
is_training: Whether in training mode
Returns:
Final hidden state tensor
"""
h0 = h0.to(x.device)
c0 = c0.to(x.device)
# Run LSTM layers
out = x
hn = h0
cn = c0
for i in range(len(lstm_weights_ih)):
params = (
lstm_weights_ih[i],
lstm_weights_hh[i],
lstm_biases_ih[i],
lstm_biases_hh[i],
)
result = _VF.lstm(
out,
(hn[i : i + 1], cn[i : i + 1]),
params,
True, # has_biases
1, # num_layers
0.0 if not is_training else dropout, # dropout
is_training, # training
False, # bidirectional
True,
) # batch_first
out = result[0]
# Update the corresponding layer's hidden state
hn = hn.clone()
cn = cn.clone()
hn[i : i + 1] = result[1]
cn[i : i + 1] = result[2]
return hn
class Model(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.0):
"""
Initialize the LSTM model.
:param input_size: The number of expected features in the input `x`
:param hidden_size: The number of features in the hidden state `h`
:param num_layers: Number of recurrent layers
:param output_size: The number of output features
:param dropout: If non-zero, introduces a Dropout layer on the outputs of each LSTM layer except the last layer
"""
super(Model, self).__init__()
# Initialize hidden states
self.h0 = torch.randn((num_layers, batch_size, hidden_size))
self.c0 = torch.randn((num_layers, batch_size, hidden_size))
# Extract LSTM parameters
lstm = nn.LSTM(
input_size,
hidden_size,
num_layers,
batch_first=True,
dropout=dropout,
bidirectional=False,
)
# Get weights and biases for each layer
self.lstm_weights_ih = nn.ParameterList()
self.lstm_weights_hh = nn.ParameterList()
self.lstm_biases_ih = nn.ParameterList()
self.lstm_biases_hh = nn.ParameterList()
for i in range(num_layers):
self.lstm_weights_ih.append(
nn.Parameter(getattr(lstm, f"weight_ih_l{i}").data.clone())
)
self.lstm_weights_hh.append(
nn.Parameter(getattr(lstm, f"weight_hh_l{i}").data.clone())
)
self.lstm_biases_ih.append(
nn.Parameter(getattr(lstm, f"bias_ih_l{i}").data.clone())
)
self.lstm_biases_hh.append(
nn.Parameter(getattr(lstm, f"bias_hh_l{i}").data.clone())
)
# Extract linear layer parameters
fc = nn.Linear(hidden_size, output_size)
self.fc_weight = nn.Parameter(fc.weight.data.clone())
self.fc_bias = nn.Parameter(fc.bias.data.clone())
def forward(self, x, fn=module_fn):
return fn(
x,
self.lstm_weights_ih,
self.lstm_weights_hh,
self.lstm_biases_ih,
self.lstm_biases_hh,
self.h0,
self.c0,
self.training,
)
# Test code
batch_size = 10
sequence_length = 512
input_size = 128
hidden_size = 256
num_layers = 6
output_size = 10
dropout = 0.0
def get_inputs():
return [torch.randn(batch_size, sequence_length, input_size)]
def get_init_inputs():
return [input_size, hidden_size, num_layers, output_size, dropout]
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.0):
"""
Initialize the LSTM model.
:param input_size: The number of expected features in the input `x`
:param hidden_size: The number of features in the hidden state `h`
:param num_layers: Number of recurrent layers
:param output_size: The number of output features
:param dropout: If non-zero, introduces a Dropout layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to `dropout`
"""
super(Model, self).__init__()
# Initialize hidden state with random values
self.h0 = torch.randn((num_layers, batch_size, hidden_size))
self.c0 = torch.randn((num_layers, batch_size, hidden_size))
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout, bidirectional=False)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
"""
Forward pass through the LSTM model.
:param x: The input tensor, shape (batch_size, sequence_length, input_size)
:return: The output tensor, shape (batch_size, sequence_length, output_size)
"""
self.h0 = self.h0.to(x.device)
self.c0 = self.h0.to(x.device)
# Forward propagate LSTM
out, state = self.lstm(x, (self.h0, self.c0)) # out: tensor of shape (batch_size, seq_length, hidden_size)
# Decode the hidden state of the last time step
out = self.fc(out[:, -1, :]) # out: tensor of shape (batch_size, output_size)
return state[0]
# Test code
batch_size = 10
sequence_length = 512
input_size = 128
hidden_size = 256
num_layers = 6
output_size = 10
dropout = 0.0
def get_inputs():
return [torch.randn(batch_size, sequence_length, input_size)]
def get_init_inputs():
return [input_size, hidden_size, num_layers, output_size, dropout]
#include <torch/extension.h>
#include <torch/torch.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <vector>
#include <tuple>
// CUDA kernel to perform coalesced copy
template <typename scalar_t>
__global__ void coalesced_copy_kernel(const scalar_t* __restrict__ src, scalar_t* __restrict__ dst, size_t numel) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel) {
dst[idx] = src[idx];
}
}
// Helper function to launch the custom copy kernel ensuring memory coalescing
void coalesced_copy_tensor(const torch::Tensor &src, torch::Tensor &dst) {
auto numel = src.numel();
int threads = 256;
int blocks = (numel + threads - 1) / threads;
AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "coalesced_copy", ([&] {
coalesced_copy_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
src.data_ptr<scalar_t>(),
dst.data_ptr<scalar_t>(),
numel
);
}));
}
// Main forward function for the optimized LTSMHn kernel with memory coalescing
torch::Tensor forward(
torch::Tensor x,
std::vector<torch::Tensor> lstm_weights_ih,
std::vector<torch::Tensor> lstm_weights_hh,
std::vector<torch::Tensor> lstm_biases_ih,
std::vector<torch::Tensor> lstm_biases_hh,
torch::Tensor h0,
torch::Tensor c0,
bool is_training
) {
// Ensure all tensors are on the proper device and are contiguous
auto device = x.device();
x = x.to(device).contiguous();
h0 = h0.to(device).contiguous();
c0 = c0.to(device).contiguous();
// Clone initial states for manipulation
torch::Tensor out = x;
torch::Tensor hn = h0.clone();
torch::Tensor cn = c0.clone();
const size_t num_layers = lstm_weights_ih.size();
// Lambda to process each LSTM layer
auto process_layer = [&](size_t i) {
// Extract parameters for layer i
auto weight_ih = lstm_weights_ih[i];
auto weight_hh = lstm_weights_hh[i];
auto bias_ih = lstm_biases_ih[i];
auto bias_hh = lstm_biases_hh[i];
int64_t input_size = weight_ih.size(1);
int64_t hidden_size = weight_hh.size(1);
// Create a one-layer LSTM module
torch::nn::LSTM lstm_model(
torch::nn::LSTMOptions(input_size, hidden_size)
.num_layers(1)
.batch_first(true)
.bidirectional(false)
);
lstm_model->to(device);
// Copy the layer parameters (assumed to be stored contiguously)
lstm_model->named_parameters()["weight_ih_l0"].copy_(weight_ih);
lstm_model->named_parameters()["weight_hh_l0"].copy_(weight_hh);
lstm_model->named_parameters()["bias_ih_l0"].copy_(bias_ih);
lstm_model->named_parameters()["bias_hh_l0"].copy_(bias_hh);
// Extract the corresponding hidden and cell state slices
torch::Tensor h_slice = hn.narrow(0, i, 1);
torch::Tensor c_slice = cn.narrow(0, i, 1);
std::tuple<torch::Tensor, torch::Tensor> state_tuple = std::make_tuple(h_slice, c_slice);
lstm_model->train(is_training);
// Forward pass through the current LSTM layer
auto output_and_state = lstm_model->forward(out, state_tuple);
auto output = std::get<0>(output_and_state);
auto state = std::get<1>(output_and_state);
auto h_n = std::get<0>(state);
auto c_n = std::get<1>(state);
// Ensure outputs are contiguous for coalesced memory access
h_n = h_n.contiguous();
c_n = c_n.contiguous();
// Update the global hidden and cell states using the coalesced copy kernel
torch::Tensor h_target = hn.narrow(0, i, 1);
torch::Tensor c_target = cn.narrow(0, i, 1);
coalesced_copy_tensor(h_n, h_target);
coalesced_copy_tensor(c_n, c_target);
// The output of the current layer becomes the input for the next layer
out = output;
};
// Explicitly unroll the first four layers if present
if (num_layers > 0) process_layer(0);
if (num_layers > 1) process_layer(1);
if (num_layers > 2) process_layer(2);
if (num_layers > 3) process_layer(3);
for (size_t i = 4; i < num_layers; ++i) {
process_layer(i);
}
return hn;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Optimized LSTM forward (CUDA) with memory coalescing");
}
Metric | Value | Unit | Variance | Samples |
---|---|---|---|---|
Executed Ipc Active | 0.106 | inst/cycle | 0.000 | 5 |
Executed Ipc Elapsed | 0.000 | inst/cycle | 0.000 | 5 |
Issue Slots Busy | 2.960 | % | 0.009 | 5 |
Issued Ipc Active | 0.116 | inst/cycle | 0.000 | 5 |
SM Busy | 2.960 | % | 0.009 | 5 |
Memory Throughput | 4069562026.356 | byte/second | 18606954756246856.000 | 5 |
Mem Busy | 10.170 | % | 0.076 | 5 |
Max Bandwidth | 5.234 | % | 0.030 | 5 |
L1/TEX Hit Rate | 0.000 | % | 0.000 | 5 |
L2 Hit Rate | 100.204 | % | 0.112 | 5 |
Mem Pipes Busy | 0.124 | % | 0.000 | 5 |
Warp Cycles Per Issued Instruction | 64.882 | cycle | 8.751 | 5 |
Warp Cycles Per Executed Instruction | 71.062 | cycle | 10.484 | 5 |
Avg. Active Threads Per Warp | 32.000 | 0.000 | 5 | |
Avg. Not Predicated Off Threads Per Warp | 30.480 | 0.000 | 5 | |
Max Active Clusters | 0.000 | cluster | 0.000 | 5 |
Max Cluster Size | 8.000 | block | 0.000 | 5 |
Overall GPU Occupancy | 0.000 | % | 0.000 | 5 |
Cluster Occupancy | 0.000 | % | 0.000 | 5 |
Block Limit SM | 32.000 | block | 0.000 | 5 |
Block Limit Registers | 16.000 | block | 0.000 | 5 |
Block Limit Shared Mem | 32.000 | block | 0.000 | 5 |
Block Limit Warps | 8.000 | block | 0.000 | 5 |
Theoretical Active Warps per SM | 64.000 | warp | 0.000 | 5 |
Theoretical Occupancy | 100.000 | % | 0.000 | 5 |
Achieved Occupancy | 11.996 | % | 0.042 | 5 |
Achieved Active Warps Per SM | 7.676 | warp | 0.018 | 5 |
Rule | Description |
---|---|
WRN HighPipeUtilization | All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details. |
INF CPIStall | Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason. |
WRN Occupancy | This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (12.1%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy. |
Operation / Metric | Value | Unit |
---|---|---|
aten::to | ||
CPU Time | 1146603.73 | μs |
Device Time | 98820.92 | μs |
Self CPU Time | 5229.82 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::copy_ | ||
CPU Time | 900234.71 | μs |
Device Time | 134630.59 | μs |
Self CPU Time | 207029.68 | μs |
Self Device Time | 134630.59 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::uniform_ | ||
CPU Time | 2615948.65 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 2615948.65 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaLaunchKernel | ||
CPU Time | 4441291.76 | μs |
Device Time | 5068.97 | μs |
Self CPU Time | 4441291.76 | μs |
Self Device Time | 5068.97 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::lstm | ||
CPU Time | 6577211.30 | μs |
Device Time | 5951440.73 | μs |
Self CPU Time | 17393.20 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_cudnn_rnn | ||
CPU Time | 6551467.20 | μs |
Device Time | 5951440.73 | μs |
Self CPU Time | 1678223.07 | μs |
Self Device Time | 5950284.63 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void cutlass::Kernel<cutlass_80_tensorop_s1688gemm_64x64_32x6_tn_align4>(cutlass_80_tensorop_s1688gemm_64x64_32x6_tn_align4::Params) | ||
CPU Time | 0.00 | μs |
Device Time | 3819032.22 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 3819032.22 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void elemWiseRNNcell<float, float, float, (cudnnRNNMode_t)2, (cudnnRNNBiasMode_t)2>(int, int, int, int, int, bool, float const*, float const*, float const*, float const*, float const*, float const*, float const*, float*, float*, float*, float*, float*, cudnnRNNClipMode_t, cudnnNanPropagation_t, float, float) | ||
CPU Time | 0.00 | μs |
Device Time | 2131255.94 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 2131255.94 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45312 warnings generated when compiling for host. Suppressed 45346 warnings (45299 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.