from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import _VF
def module_fn(
x: torch.Tensor,
gru_weights_ih: List[torch.Tensor],
gru_weights_hh: List[torch.Tensor],
gru_biases_ih: List[torch.Tensor],
gru_biases_hh: List[torch.Tensor],
h0: torch.Tensor,
is_training: bool,
) -> torch.Tensor:
"""
Functional implementation of GRU
Args:
x: Input tensor of shape (seq_len, batch_size, input_size) if batch_first=False
gru_weights_ih: List of input-hidden weight tensors for each GRU layer
gru_weights_hh: List of hidden-hidden weight tensors for each GRU layer
gru_biases_ih: List of input-hidden bias tensors for each GRU layer
gru_biases_hh: List of hidden-hidden bias tensors for each GRU layer
h0: Initial hidden state
is_training: Whether in training mode
Returns:
output tensor of shape (seq_len, batch_size, hidden_size)
"""
h0 = h0.to(x.device)
# Run single GRU with all layers at once
output, _ = _VF.gru(
x,
h0,
[
w
for layer in zip(
gru_weights_ih, gru_weights_hh, gru_biases_ih, gru_biases_hh
)
for w in layer
],
True, # has_biases
len(gru_weights_ih), # num_layers
0.0, # dropout
is_training, # training
False, # bidirectional
False,
) # batch_first
return output
class Model(nn.Module):
def __init__(
self, input_size, hidden_size, num_layers=3, bias=True, batch_first=False
):
"""
:param input_size: The number of expected features in the input x
:param hidden_size: The number of features in the hidden state h
:param num_layers: Number of recurrent layers (default: 1)
:param bias: If False, then the layer does not use bias weights b_ih and b_hh (default: True)
:param batch_first: If True, then the input and output tensors are provided as (batch, seq, feature) (default: False)
"""
super(Model, self).__init__()
# Create GRU and extract its parameters
gru = nn.GRU(
input_size,
hidden_size,
num_layers,
bias=bias,
batch_first=batch_first,
dropout=0,
bidirectional=False,
)
# Initialize h0 exactly as in original code
self.h0 = torch.randn((num_layers, batch_size, hidden_size))
# Extract and store GRU parameters
self.gru_weights_ih = nn.ParameterList()
self.gru_weights_hh = nn.ParameterList()
self.gru_biases_ih = nn.ParameterList()
self.gru_biases_hh = nn.ParameterList()
for i in range(num_layers):
self.gru_weights_ih.append(getattr(gru, f"weight_ih_l{i}"))
self.gru_weights_hh.append(getattr(gru, f"weight_hh_l{i}"))
if bias:
self.gru_biases_ih.append(getattr(gru, f"bias_ih_l{i}"))
self.gru_biases_hh.append(getattr(gru, f"bias_hh_l{i}"))
else:
self.gru_biases_ih.append(None)
self.gru_biases_hh.append(None)
def forward(self, x, fn=module_fn):
return fn(
x,
self.gru_weights_ih,
self.gru_weights_hh,
self.gru_biases_ih,
self.gru_biases_hh,
self.h0,
self.training,
)
# Test code
batch_size = 10
seq_len = 512
input_size = 128
hidden_size = 256
num_layers = 6
def get_inputs():
return [torch.randn(seq_len, batch_size, input_size)]
def get_init_inputs():
return [input_size, hidden_size, num_layers]
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self, input_size, hidden_size, num_layers=3, bias=True, batch_first=False):
"""
:param input_size: The number of expected features in the input x
:param hidden_size: The number of features in the hidden state h
:param num_layers: Number of recurrent layers (default: 1)
:param bias: If False, then the layer does not use bias weights b_ih and b_hh (default: True)
:param batch_first: If True, then the input and output tensors are provided as (batch, seq, feature) (default: False)
"""
super(Model, self).__init__()
self.gru = nn.GRU(input_size, hidden_size, num_layers, bias, batch_first, dropout=0, bidirectional=False)
self.h0 = torch.randn((num_layers, batch_size, hidden_size))
def forward(self, x):
"""
:param x: The input tensor, shape (seq_len, batch_size, input_size) if batch_first=False, otherwise (batch_size, seq_len, input_size)
:param h_0: The initial hidden state for the input sequence, shape (num_layers * num_directions, batch_size, hidden_size) (default: None)
:return: output, h_n
- output: The output features (h_t) from the last layer of the GRU, for each t, shape (seq_len, batch_size, num_directions * hidden_size) if batch_first=False, otherwise (batch_size, seq_len, num_directions * hidden_size)
- h_n: The hidden state for t = seq_len, shape (num_layers * num_directions, batch_size, hidden_size)
"""
self.h0 = self.h0.to(x.device)
output, h_n = self.gru(x, self.h0)
return output
# Test code
batch_size = 10
seq_len = 512
input_size = 128
hidden_size = 256
num_layers = 6
def get_inputs():
return [torch.randn(seq_len, batch_size, input_size)]
def get_init_inputs():
return [input_size, hidden_size, num_layers]
#include <torch/extension.h>
#include <torch/torch.h>
#include <vector>
__constant__ float ih_consts[2048];
__constant__ float hh_consts[2048];
torch::Tensor forward(
torch::Tensor x,
std::vector<torch::Tensor> gru_weights_ih,
std::vector<torch::Tensor> gru_weights_hh,
std::vector<torch::Tensor> gru_biases_ih,
std::vector<torch::Tensor> gru_biases_hh,
torch::Tensor h0,
bool is_training) {
h0 = h0.to(x.device());
// Ensure inputs are contiguous for better memory access
x = x.contiguous();
h0 = h0.contiguous();
size_t num_layers = gru_weights_ih.size();
int64_t input_size = x.size(2);
int64_t hidden_size = gru_weights_hh[0].size(1);
int64_t seq_length = x.size(0);
int64_t batch_size = x.size(1);
// Pre-allocate output tensor with optimal memory layout
auto output = torch::empty({seq_length, batch_size, hidden_size},
x.options().layout(torch::kStrided)
.memory_format(torch::MemoryFormat::Contiguous));
// Create GRU options
torch::nn::GRUOptions gru_options(input_size, hidden_size);
gru_options.num_layers(num_layers);
gru_options.bidirectional(false);
gru_options.batch_first(false);
auto gru = torch::nn::GRU(gru_options);
gru->to(x.device());
gru->train(is_training);
// Pre-process weights and biases for better memory access
for (size_t l = 0; l < num_layers; ++l) {
std::string layer_str = std::to_string(l);
// Ensure weights are contiguous and properly aligned
gru_weights_ih[l] = gru_weights_ih[l].contiguous();
gru_weights_hh[l] = gru_weights_hh[l].contiguous();
gru_biases_ih[l] = gru_biases_ih[l].contiguous();
gru_biases_hh[l] = gru_biases_hh[l].contiguous();
auto params = gru->named_parameters();
// Copy weights into constant memory if small enough
if (gru_weights_ih[l].numel() <= 2048 && gru_weights_hh[l].numel() <= 2048) {
cudaMemcpyToSymbol(ih_consts + l * 2048, gru_weights_ih[l].data_ptr<float>(), gru_weights_ih[l].numel() * sizeof(float));
cudaMemcpyToSymbol(hh_consts, gru_weights_hh[l].data_ptr<float>(), gru_weights_hh[l].numel() * sizeof(float));
} else {
params["weight_ih_l" + layer_str].copy_(gru_weights_ih[l]);
params["weight_hh_l" + layer_str].copy_(gru_weights_hh[l]);
}
params["bias_ih_l" + layer_str].copy_(gru_biases_ih[l]);
params["bias_hh_l" + layer_str].copy_(gru_biases_hh[l]);
}
// Reshape h0 with optimal memory layout
h0 = h0.view({static_cast<int64_t>(num_layers), batch_size, hidden_size});
// Forward pass with optimized memory access
auto result = gru->forward(x, h0);
output.copy_(std::get<0>(result));
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "GRU forward (CUDA)");
}
Operation / Metric | Value | Unit |
---|---|---|
aten::to | ||
CPU Time | 689366.61 | μs |
Device Time | 93798.97 | μs |
Self CPU Time | 3996.37 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::copy_ | ||
CPU Time | 483390.03 | μs |
Device Time | 129522.46 | μs |
Self CPU Time | 37730.57 | μs |
Self Device Time | 129522.46 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::uniform_ | ||
CPU Time | 2348770.75 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 2348770.75 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaLaunchKernel | ||
CPU Time | 4873244.50 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 4873244.50 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::gru | ||
CPU Time | 7170233.75 | μs |
Device Time | 7323989.16 | μs |
Self CPU Time | 5926.05 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_cudnn_rnn | ||
CPU Time | 7162350.26 | μs |
Device Time | 7323989.16 | μs |
Self CPU Time | 1783121.55 | μs |
Self Device Time | 7323989.16 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void cutlass::Kernel<cutlass_80_tensorop_s1688gemm_64x64_32x6_tn_align4>(cutlass_80_tensorop_s1688gemm_64x64_32x6_tn_align4::Params) | ||
CPU Time | 0.00 | μs |
Device Time | 4708602.33 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 4708602.33 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void elemWiseRNNcell<float, float, float, (cudnnRNNMode_t)3, (cudnnRNNBiasMode_t)2>(int, int, int, int, int, bool, float const*, float const*, float const*, float const*, float const*, float const*, float const*, float*, float*, float*, float*, float*, cudnnRNNClipMode_t, cudnnNanPropagation_t, float, float) | ||
CPU Time | 0.00 | μs |
Device Time | 2615390.26 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 2615390.26 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |