import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(
x: torch.Tensor,
i2h_weight: torch.Tensor,
i2h_bias: torch.Tensor,
h2o_weight: torch.Tensor,
h2o_bias: torch.Tensor,
hidden: torch.Tensor,
) -> torch.Tensor:
"""
Vanilla RNN forward pass
Args:
x: Input tensor of shape (batch_size, input_size)
i2h_weight: Weight tensor for input-to-hidden layer
i2h_bias: Bias tensor for input-to-hidden layer
h2o_weight: Weight tensor for hidden-to-output layer
h2o_bias: Bias tensor for hidden-to-output layer
hidden: Hidden state tensor
Returns:
Output tensor of shape (batch_size, output_size)
"""
hidden = hidden.to(x.device)
combined = torch.cat((x, hidden), dim=1)
hidden = torch.tanh(F.linear(combined, i2h_weight, i2h_bias))
output = F.linear(hidden, h2o_weight, h2o_bias)
return output
class Model(nn.Module):
def __init__(self, input_size: int, hidden_size: int, output_size: int):
"""
Initialize the Vanilla RNN model.
:param input_size: The number of input features (int).
:param hidden_size: The size of the hidden state (int).
:param output_size: The number of output features (int).
"""
super(Model, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.hidden = nn.Parameter(torch.randn((batch_size, hidden_size)))
# Extract parameters from linear layers
i2h = nn.Linear(input_size + hidden_size, hidden_size)
self.i2h_weight = nn.Parameter(i2h.weight.data.clone())
self.i2h_bias = nn.Parameter(i2h.bias.data.clone())
h2o = nn.Linear(hidden_size, output_size)
self.h2o_weight = nn.Parameter(h2o.weight.data.clone())
self.h2o_bias = nn.Parameter(h2o.bias.data.clone())
def forward(self, x: torch.Tensor, fn=module_fn) -> torch.Tensor:
return fn(
x,
self.i2h_weight,
self.i2h_bias,
self.h2o_weight,
self.h2o_bias,
self.hidden,
)
batch_size = 8
input_size = 1024
hidden_size = 256
output_size = 128
sequence_length = 256
def get_inputs():
return [torch.randn(batch_size, input_size)]
def get_init_inputs():
return [input_size, hidden_size, output_size]
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self, input_size: int, hidden_size: int, output_size: int):
"""
Initialize the Vanilla RNN model.
:param input_size: The number of input features (int).
:param hidden_size: The size of the hidden state (int).
:param output_size: The number of output features (int).
"""
super(Model, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.hidden = torch.randn((batch_size, hidden_size))
# Define the RNN cell components (input to hidden, hidden to hidden, and hidden to output)
self.i2h = nn.Linear(input_size + hidden_size, hidden_size) # Input to hidden
self.h2o = nn.Linear(hidden_size, output_size) # Hidden to output
self.tanh = nn.Tanh() # Activation function for hidden state
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the Vanilla RNN.
:param x: Input tensor of shape (batch_size, input_size).
:param hidden: Hidden state tensor of shape (batch_size, hidden_size).
:return: Output tensor of shape (batch_size, output_size), and the new hidden state.
"""
self.hidden = self.hidden.to(x.device)
combined = torch.cat((x, self.hidden), dim=1) # Concatenate input and hidden state
self.hidden = self.tanh(self.i2h(combined)) # Update hidden state
output = self.h2o(self.hidden) # Compute output
return output
batch_size = 8
input_size = 1024
hidden_size = 256
output_size = 128
sequence_length = 256
def get_inputs():
return [torch.randn(batch_size, input_size)]
def get_init_inputs():
return [input_size, hidden_size, output_size]
/*
Optimized CUDA kernel for concatenating two tensors with vectorized memory accesses.
It copies the first tensor (x) and the second tensor (hidden) into a combined tensor along the column dimension.
The bulk of the work is performed using float4 vectorized loads/stores for improved memory throughput,
with a fallback loop to handle remainder elements that do not form a complete vector of 4.
*/
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <algorithm>
// Optimized concatenation kernel which first processes the bulk vectorized elements
// then handles any remainder elements if (x_size+hidden_size) is not a multiple of 4.
__global__ void concat_kernel_optimized(
const float* __restrict__ x,
const float* __restrict__ hidden,
float* __restrict__ combined,
const int batch_size,
const int x_size,
const int hidden_size
) {
// Total columns after concatenation
const int total_width = x_size + hidden_size;
const int vector_width = 4;
// Number of complete float4 groups per row
const int total_vec_count = total_width / vector_width; // floor division
// For vectorized access in each tensor, only complete groups can be processed.
const int x_vec_count = x_size / vector_width;
const int h_vec_count = hidden_size / vector_width;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
// Create vectorized pointers
const float4* x4 = reinterpret_cast<const float4*>(x);
const float4* hidden4 = reinterpret_cast<const float4*>(hidden);
float4* combined4 = reinterpret_cast<float4*>(combined);
// Process the bulk vectorized portion
int total_vectorized_elements = batch_size * total_vec_count;
for (int idx = tid; idx < total_vectorized_elements; idx += stride) {
int row = idx / total_vec_count;
int vec_idx = idx % total_vec_count; // which float4 block within the combined row
int base_col = vec_idx * vector_width; // starting column index for this vector block
// Determine whether this block falls in the x part or the hidden part
if (base_col < x_size) {
// There is a possibility this block is fully within x or partially spilling over
if (base_col + vector_width <= x_size) {
// Entire block lies within x
int x_vec_index = base_col / vector_width;
combined4[row * total_vec_count + vec_idx] = x4[row * x_vec_count + x_vec_index];
} else {
// The block partially spans x and hidden. Fallback to scalar copy.
float temp[vector_width];
for (int i = 0; i < vector_width; i++) {
int col = base_col + i;
if (col < x_size) {
temp[i] = x[row * x_size + col];
} else {
temp[i] = hidden[row * hidden_size + (col - x_size)];
}
}
float4 vec;
vec.x = temp[0]; vec.y = temp[1]; vec.z = temp[2]; vec.w = temp[3];
combined4[row * total_vec_count + vec_idx] = vec;
}
} else {
// This block lies entirely in the hidden portion
int h_base = base_col - x_size;
if (h_base + vector_width <= hidden_size) {
int h_vec_index = h_base / vector_width;
combined4[row * total_vec_count + vec_idx] = hidden4[row * h_vec_count + h_vec_index];
} else {
float temp[vector_width];
for (int i = 0; i < vector_width; i++) {
int col = base_col + i;
temp[i] = hidden[row * hidden_size + (col - x_size)];
}
float4 vec;
vec.x = temp[0]; vec.y = temp[1]; vec.z = temp[2]; vec.w = temp[3];
combined4[row * total_vec_count + vec_idx] = vec;
}
}
}
// Process any remaining elements that do not constitute a full float4
int remainder = total_width - (total_vec_count * vector_width);
int total_remainder_elements = batch_size * remainder;
for (int idx = tid; idx < total_remainder_elements; idx += stride) {
int row = idx / remainder;
int col_offset = idx % remainder;
int col = total_vec_count * vector_width + col_offset;
if (col < x_size) {
combined[row * total_width + col] = x[row * x_size + col];
} else {
combined[row * total_width + col] = hidden[row * hidden_size + (col - x_size)];
}
}
}
// Host wrapper function: prepares tensors, launches the optimized concatenation kernel,
// and proceeds with the subsequent linear operations and activation.
torch::Tensor module_fn_cuda(
torch::Tensor x,
torch::Tensor i2h_weight,
torch::Tensor i2h_bias,
torch::Tensor h2o_weight,
torch::Tensor h2o_bias,
torch::Tensor hidden
) {
// Ensure all tensors are contiguous and on CUDA
x = x.contiguous();
i2h_weight = i2h_weight.contiguous();
i2h_bias = i2h_bias.contiguous();
h2o_weight = h2o_weight.contiguous();
h2o_bias = h2o_bias.contiguous();
hidden = hidden.contiguous();
const int batch_size = x.size(0);
const int x_size = x.size(1);
const int hidden_size = hidden.size(1);
int total_width = x_size + hidden_size;
auto options = torch::TensorOptions().dtype(x.dtype()).device(x.device());
auto combined = torch::empty({batch_size, total_width}, options);
// Set up grid and block dimensions
const int threads = 256;
// Grid size for vectorized part
int total_vec_count = total_width / 4; // number of full 4-element blocks per row
const int vec_elements = batch_size * total_vec_count;
int blocks = std::min(65535, (vec_elements + threads - 1) / threads);
// Launch the optimized concatenation kernel
concat_kernel_optimized<<<blocks, threads>>>(
x.data_ptr<float>(),
hidden.data_ptr<float>(),
combined.data_ptr<float>(),
batch_size,
x_size,
hidden_size
);
// Compute hidden_new = tanh(i2h_bias + combined * i2h_weight^T)
auto hidden_new = torch::tanh(torch::addmm(i2h_bias, combined, i2h_weight.t()));
// Compute output = h2o_bias + hidden_new * h2o_weight^T
auto output = torch::addmm(h2o_bias, hidden_new, h2o_weight.t());
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &module_fn_cuda, "Optimized module forward (CUDA)");
}
Metric | Value | Unit | Variance | Samples |
---|---|---|---|---|
Executed Ipc Active | 0.370 | inst/cycle | 0.000 | 5 |
Executed Ipc Elapsed | 0.010 | inst/cycle | 0.000 | 5 |
Issue Slots Busy | 9.862 | % | 0.237 | 5 |
Issued Ipc Active | 0.394 | inst/cycle | 0.000 | 5 |
SM Busy | 9.862 | % | 0.237 | 5 |
Memory Throughput | 15090953909.342 | byte/second | 92264231076721104.000 | 5 |
Mem Busy | 9.998 | % | 0.041 | 5 |
Max Bandwidth | 5.416 | % | 0.014 | 5 |
L1/TEX Hit Rate | 0.000 | % | 0.000 | 5 |
L2 Hit Rate | 99.760 | % | 0.020 | 5 |
Mem Pipes Busy | 0.214 | % | 0.000 | 5 |
Warp Cycles Per Issued Instruction | 19.990 | cycle | 0.131 | 5 |
Warp Cycles Per Executed Instruction | 21.450 | cycle | 0.151 | 5 |
Avg. Active Threads Per Warp | 32.000 | 0.000 | 5 | |
Avg. Not Predicated Off Threads Per Warp | 28.260 | 0.000 | 5 | |
Max Active Clusters | 0.000 | cluster | 0.000 | 5 |
Max Cluster Size | 8.000 | block | 0.000 | 5 |
Overall GPU Occupancy | 0.000 | % | 0.000 | 5 |
Cluster Occupancy | 0.000 | % | 0.000 | 5 |
Block Limit SM | 32.000 | block | 0.000 | 5 |
Block Limit Registers | 8.000 | block | 0.000 | 5 |
Block Limit Shared Mem | 32.000 | block | 0.000 | 5 |
Block Limit Warps | 8.000 | block | 0.000 | 5 |
Theoretical Active Warps per SM | 64.000 | warp | 0.000 | 5 |
Theoretical Occupancy | 100.000 | % | 0.000 | 5 |
Achieved Occupancy | 11.716 | % | 0.059 | 5 |
Achieved Active Warps Per SM | 7.502 | warp | 0.025 | 5 |
Rule | Description |
---|---|
WRN HighPipeUtilization | All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details. |
INF CPIStall | Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason. |
WRN Occupancy | This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (11.5%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy. |
Operation / Metric | Value | Unit |
---|---|---|
aten::to | ||
CPU Time | 679603.77 | μs |
Device Time | 66.53 | μs |
Self CPU Time | 61.00 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_to_copy | ||
CPU Time | 679542.77 | μs |
Device Time | 66.53 | μs |
Self CPU Time | 133.56 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::empty_strided | ||
CPU Time | 679016.04 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 141.82 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaDeviceGetStreamPriorityRange | ||
CPU Time | 660831.71 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 660831.71 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::addmm | ||
CPU Time | 960691.95 | μs |
Device Time | 359753.97 | μs |
Self CPU Time | 462793.01 | μs |
Self Device Time | 359753.97 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
sm80_xmma_gemm_f32f32_f32f32_f32_tn_n_tilesize32x32x8_stage3_warpsize1x2x1_ffma_aligna4_alignc4_execute_kernel__51_cublas | ||
CPU Time | 0.00 | μs |
Device Time | 172972.74 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 172972.74 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::zero_ | ||
CPU Time | 111899.65 | μs |
Device Time | 1145344.83 | μs |
Self CPU Time | 23897.79 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::fill_ | ||
CPU Time | 88003.28 | μs |
Device Time | 1145344.83 | μs |
Self CPU Time | 34120.74 | μs |
Self Device Time | 1145344.83 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>) | ||
CPU Time | 0.00 | μs |
Device Time | 1145344.83 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 1145344.83 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45288 warnings generated when compiling for host. Suppressed 45328 warnings (45281 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.