16_Matmul_with_transposed_A
• streams_partitioned_matmul_base
import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
"""
Performs a single matrix multiplication with transposed A (C = A.T * B).
Args:
A: Input tensor of shape (K, M).
B: Input tensor of shape (K, N).
Returns:
Output tensor of shape (M, N).
"""
return torch.matmul(A.T, B)
class Model(nn.Module):
"""
Simple model that performs a single matrix multiplication (C = A * B)
"""
def __init__(self):
super(Model, self).__init__()
def forward(self, A: torch.Tensor, B: torch.Tensor, fn=module_fn) -> torch.Tensor:
return fn(A, B)
M = 1024
K = 4096
N = 2048
def get_inputs():
A = torch.randn(K, M)
B = torch.randn(K, N)
return [A, B]
def get_init_inputs():
return [] # No special initialization inputs needed
import torch
import torch.nn as nn
class Model(nn.Module):
"""
Simple model that performs a single matrix multiplication (C = A * B)
"""
def __init__(self):
super(Model, self).__init__()
def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
"""
Performs matrix multiplication.
Args:
A: Input tensor of shape (M, K).
B: Input tensor of shape (K, N).
Returns:
Output tensor of shape (M, N).
"""
return torch.matmul(A.T, B)
M = 1024
K = 4096
N = 2048
def get_inputs():
A = torch.randn(K, M)
B = torch.randn(K, N)
return [A, B]
def get_init_inputs():
return [] # No special initialization inputs needed
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdexcept>
#include <vector>
#include <algorithm>
#define TILE_SIZE 32
// Kernel that computes C = A.T * B for a partition of rows of C.
// A: shape (K, M) stored in row-major order
// B: shape (K, N) stored in row-major order
// C: shape (M, N) stored in row-major order, where each element C[i,j] = sum_{k=0}^{K-1} A[k*M + i] * B[k*N + j]
// row_offset: the starting row index (i) for this partition
__global__ void matMulKernelPartition(const float* __restrict__ A,
const float* __restrict__ B,
float* __restrict__ C,
int K, int M, int N,
int row_offset) {
// local row index within this partition
int local_row = blockIdx.x * TILE_SIZE + threadIdx.y;
// global row index in C (and column index in A)
int global_row = row_offset + local_row;
int col = blockIdx.y * TILE_SIZE + threadIdx.x;
float sum = 0.0f;
__shared__ float tileA[TILE_SIZE][TILE_SIZE];
__shared__ float tileB[TILE_SIZE][TILE_SIZE];
// Number of tiles needed to cover the K dimension
int numTiles = (K + TILE_SIZE - 1) / TILE_SIZE;
for (int t = 0; t < numTiles; t++) {
// Each thread loads one element for tileA
int aIndex = t * TILE_SIZE + threadIdx.x; // k index for A
if (global_row < M && aIndex < K) {
// Note: A is stored as (K, M), so element for A.T at (global_row, aIndex) comes from A[aIndex * M + global_row]
tileA[threadIdx.y][threadIdx.x] = A[aIndex * M + global_row];
} else {
tileA[threadIdx.y][threadIdx.x] = 0.0f;
}
// Each thread loads one element for tileB
int bIndex = t * TILE_SIZE + threadIdx.y; // k index for B
if (bIndex < K && col < N) {
tileB[threadIdx.y][threadIdx.x] = B[bIndex * N + col];
} else {
tileB[threadIdx.y][threadIdx.x] = 0.0f;
}
__syncthreads();
// Compute partial dot product for the tile
#pragma unroll
for (int k_inner = 0; k_inner < TILE_SIZE; k_inner++) {
sum += tileA[threadIdx.y][k_inner] * tileB[k_inner][threadIdx.x];
}
__syncthreads();
}
if (global_row < M && col < N) {
C[global_row * N + col] = sum;
}
}
// The forward function exposed via PyBind11.
// This version partitions the output matrix along the M dimension and launches concurrent kernel streams.
// Overlapping kernel execution and memory operations among streams can hide some latency and improve throughput.
// Input A: Tensor with shape (K, M), B: Tensor with shape (K, N).
// Output: Tensor C with shape (M, N) computed as C = A.T * B.
torch::Tensor forward(torch::Tensor A, torch::Tensor B) {
// Ensure inputs are CUDA tensors of type float32
TORCH_CHECK(A.is_cuda(), "Input A must be a CUDA tensor");
TORCH_CHECK(B.is_cuda(), "Input B must be a CUDA tensor");
TORCH_CHECK(A.dtype() == torch::kFloat32, "Input A must be float32");
TORCH_CHECK(B.dtype() == torch::kFloat32, "Input B must be float32");
int K = A.size(0);
int M = A.size(1);
TORCH_CHECK(B.size(0) == K, "Dimension mismatch: A and B must have the same first dimension (K)");
int N = B.size(1);
// Allocate output tensor using torch::empty to avoid the cost of zero initialization
auto C = torch::empty({M, N}, torch::device(A.device()).dtype(A.dtype()));
// Use multiple CUDA streams to partition the work and overlap memory operations with computation.
const int num_streams = 2; // Can be tuned further
std::vector<cudaStream_t> streams(num_streams);
for (int i = 0; i < num_streams; i++) {
cudaStreamCreate(&streams[i]);
}
// Partition the M dimension (rows of C) among the available streams.
int rows_per_partition = (M + num_streams - 1) / num_streams;
// Launch the kernel for each partition on its own stream
for (int s = 0; s < num_streams; s++) {
int row_offset = s * rows_per_partition;
int rows_in_partition = std::min(rows_per_partition, M - row_offset);
if (rows_in_partition <= 0) continue;
// Compute grid dimensions based on the number of rows in this partition and full N
dim3 blockDim(TILE_SIZE, TILE_SIZE);
dim3 gridDim((rows_in_partition + TILE_SIZE - 1) / TILE_SIZE, (N + TILE_SIZE - 1) / TILE_SIZE);
matMulKernelPartition<<<gridDim, blockDim, 0, streams[s]>>>(
A.data_ptr<float>(),
B.data_ptr<float>(),
C.data_ptr<float>(),
K, M, N,
row_offset
);
}
// Synchronize and destroy streams
for (int i = 0; i < num_streams; i++) {
cudaStreamSynchronize(streams[i]);
cudaStreamDestroy(streams[i]);
}
return C;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Compute C = A.T * B using partitioned kernel launches with CUDA streams to overlap computation and memory transfers");
}
Metric | Value | Unit | Variance | Samples |
---|---|---|---|---|
Executed Ipc Active | 1.060 | inst/cycle | 0.000 | 5 |
Executed Ipc Elapsed | 1.020 | inst/cycle | 0.000 | 5 |
Issue Slots Busy | 26.448 | % | 0.000 | 5 |
Issued Ipc Active | 1.060 | inst/cycle | 0.000 | 5 |
SM Busy | 26.670 | % | 0.000 | 5 |
Memory Throughput | 25126659597.696 | byte/second | 9731810717913500.000 | 5 |
Mem Busy | 89.306 | % | 0.001 | 5 |
Max Bandwidth | 61.032 | % | 0.000 | 5 |
L1/TEX Hit Rate | 77.710 | % | 0.000 | 5 |
L2 Hit Rate | 92.906 | % | 0.020 | 5 |
Mem Pipes Busy | 47.236 | % | 0.000 | 5 |
Warp Cycles Per Issued Instruction | 57.244 | cycle | 0.000 | 5 |
Warp Cycles Per Executed Instruction | 57.244 | cycle | 0.000 | 5 |
Avg. Active Threads Per Warp | 32.000 | 0.000 | 5 | |
Avg. Not Predicated Off Threads Per Warp | 31.990 | 0.000 | 5 | |
Max Active Clusters | 0.000 | cluster | 0.000 | 5 |
Max Cluster Size | 8.000 | block | 0.000 | 5 |
Overall GPU Occupancy | 0.000 | % | 0.000 | 5 |
Cluster Occupancy | 0.000 | % | 0.000 | 5 |
Block Limit SM | 32.000 | block | 0.000 | 5 |
Block Limit Registers | 2.000 | block | 0.000 | 5 |
Block Limit Shared Mem | 3.000 | block | 0.000 | 5 |
Block Limit Warps | 2.000 | block | 0.000 | 5 |
Theoretical Active Warps per SM | 64.000 | warp | 0.000 | 5 |
Theoretical Occupancy | 100.000 | % | 0.000 | 5 |
Achieved Occupancy | 94.608 | % | 0.000 | 5 |
Achieved Active Warps Per SM | 60.550 | warp | 0.000 | 5 |
Rule | Description |
---|---|
WRN HighPipeUtilization | All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details. |
INF CPIStall | Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason. |
INF Occupancy | This kernel's theoretical occupancy is not impacted by any block limit. |
Operation / Metric | Value | Unit |
---|---|---|
aten::to | ||
CPU Time | 255917.52 | μs |
Device Time | 5103.43 | μs |
Self CPU Time | 50.56 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_to_copy | ||
CPU Time | 255866.96 | μs |
Device Time | 5103.43 | μs |
Self CPU Time | 139.47 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::empty_strided | ||
CPU Time | 250232.87 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 106.67 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaDeviceGetStreamPriorityRange | ||
CPU Time | 249739.96 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 249739.96 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaStreamSynchronize | ||
CPU Time | 7057990.55 | μs |
Device Time | 13967.01 | μs |
Self CPU Time | 7057990.55 | μs |
Self Device Time | 13967.01 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
matMulKernelPartition(float const*, float const*, float*, int, int, int, int) | ||
CPU Time | 0.00 | μs |
Device Time | 7799841.75 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 7799841.75 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::zero_ | ||
CPU Time | 33929.43 | μs |
Device Time | 180878.59 | μs |
Self CPU Time | 5591.51 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::fill_ | ||
CPU Time | 28340.00 | μs |
Device Time | 180878.59 | μs |
Self CPU Time | 9824.76 | μs |
Self Device Time | 180878.59 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>) | ||
CPU Time | 0.00 | μs |
Device Time | 180878.59 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 180878.59 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45294 warnings generated when compiling for host. Suppressed 45330 warnings (45283 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.