18_Matmul_with_transposed_both
• matmul_transpose_ldg_optimization_base
import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
"""
Performs a single matrix multiplication with transposed A and B (C = A.T * B.T).
Args:
A: Input tensor of shape (K, M).
B: Input tensor of shape (N, K).
Returns:
Output tensor of shape (M, N).
"""
return torch.matmul(A.T, B.T)
class Model(nn.Module):
"""
Simple model that performs a single matrix multiplication (C = A * B)
"""
def __init__(self):
super(Model, self).__init__()
def forward(self, A: torch.Tensor, B: torch.Tensor, fn=module_fn) -> torch.Tensor:
return fn(A, B)
M = 1024
K = 4096
N = 2048
def get_inputs():
A = torch.randn(K, M)
B = torch.randn(N, K)
return [A, B]
def get_init_inputs():
return [] # No special initialization inputs needed
import torch
import torch.nn as nn
class Model(nn.Module):
"""
Simple model that performs a single matrix multiplication (C = A * B)
"""
def __init__(self):
super(Model, self).__init__()
def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
"""
Performs matrix multiplication.
Args:
A: Input tensor of shape (M, K).
B: Input tensor of shape (K, N).
Returns:
Output tensor of shape (M, N).
"""
return torch.matmul(A.T, B.T)
M = 1024
K = 4096
N = 2048
def get_inputs():
A = torch.randn(K, M)
B = torch.randn(N, K)
return [A, B]
def get_init_inputs():
return [] # No special initialization inputs needed
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
// Define tile sizes
#define BLOCK_SIZE_M 32 // Output tile height (each block computes 32 rows)
#define BLOCK_SIZE_N 16 // Output tile width (each block computes 16 columns)
#define BLOCK_SIZE_K 16 // Reduction tile depth
// Kernel: each thread computes a 2x1 sub-tile (2 rows, 1 column) of C
// A is (K x M): element A[k, m] = A[k * M + m]
// B is (N x K): element B[n, k] = B[n * K + k]
// C is (M x N): element C[m, n] = C[m * N + n]
template <typename scalar_t>
__global__ void matmul_transpose_ldg_optimization_kernel(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
scalar_t* __restrict__ C,
int M,
int N,
int K) {
// Determine the starting indices for this block's tile in C
int m_start = blockIdx.y * BLOCK_SIZE_M; // row start in C
int n_start = blockIdx.x * BLOCK_SIZE_N; // col start in C
// Thread indices within the block
int tx = threadIdx.x; // Expected range: [0, 15]
int ty = threadIdx.y; // Expected range: [0, 15]
// Each thread computes two rows: row0 and row1
int row0 = m_start + tx; // first row computed by this thread
int row1 = row0 + (BLOCK_SIZE_M / 2); // second row computed (offset by 16)
int col = n_start + ty; // column index in C
// Accumulators for the two output elements
scalar_t acc0 = 0;
scalar_t acc1 = 0;
// Declare shared memory tiles
__shared__ scalar_t A_tile[BLOCK_SIZE_K][BLOCK_SIZE_M]; // Size: 16 x 32
__shared__ scalar_t B_tile[BLOCK_SIZE_N][BLOCK_SIZE_K]; // Size: 16 x 16
// Total threads in a block
int tId = threadIdx.y * blockDim.x + threadIdx.x; // Range: 0 to 255
int blockSize = blockDim.x * blockDim.y; // = 256
int numTiles = (K + BLOCK_SIZE_K - 1) / BLOCK_SIZE_K;
for (int tile = 0; tile < numTiles; tile++) {
// Load A tile into shared memory
// A tile dimensions: BLOCK_SIZE_K x BLOCK_SIZE_M (16 x 32 = 512 elements)
int totalAElements = BLOCK_SIZE_K * BLOCK_SIZE_M; // 512
for (int idx = tId; idx < totalAElements; idx += blockSize) {
int kd = idx / BLOCK_SIZE_M; // k-index within the tile
int md = idx % BLOCK_SIZE_M; // m-index within the tile
int global_m = m_start + md; // global m index
int global_k = tile * BLOCK_SIZE_K + kd; // global k index
if (global_m < M && global_k < K)
A_tile[kd][md] = __ldg(&A[global_k * M + global_m]);
else
A_tile[kd][md] = 0;
}
// Load B tile into shared memory
// B tile dimensions: BLOCK_SIZE_N x BLOCK_SIZE_K (16 x 16 = 256 elements)
int totalBElements = BLOCK_SIZE_N * BLOCK_SIZE_K; // 256
for (int idx = tId; idx < totalBElements; idx += blockSize) {
int nd = idx / BLOCK_SIZE_K; // n-index within the tile
int kd = idx % BLOCK_SIZE_K; // k-index within the tile
int global_n = n_start + nd; // global n index
int global_k = tile * BLOCK_SIZE_K + kd; // global k index
if (global_n < N && global_k < K)
B_tile[nd][kd] = __ldg(&B[global_n * K + global_k]);
else
B_tile[nd][kd] = 0;
}
__syncthreads();
// Compute the partial results for this tile
for (int k = 0; k < BLOCK_SIZE_K; k++) {
scalar_t a_val0 = A_tile[k][tx]; // for row0
scalar_t a_val1 = A_tile[k][tx + (BLOCK_SIZE_M / 2)]; // for row1
scalar_t b_val = B_tile[ty][k];
acc0 += a_val0 * b_val;
acc1 += a_val1 * b_val;
}
__syncthreads();
}
// Write the results to global memory
if (row0 < M && col < N) {
C[row0 * N + col] = acc0;
}
if (row1 < M && col < N) {
C[row1 * N + col] = acc1;
}
}
// PyTorch binding
torch::Tensor matmul_transpose_cuda(torch::Tensor A, torch::Tensor B) {
// Dimensions:
// A: (K x M), B: (N x K), therefore C: (M x N)
int K = A.size(0);
int M = A.size(1);
int N = B.size(0);
auto C = torch::empty({M, N}, A.options());
// Define block dimensions: use 16x16 threads per block
dim3 threads(16, 16);
// Grid dimensions based on tile sizes
dim3 blocks((N + BLOCK_SIZE_N - 1) / BLOCK_SIZE_N, (M + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M);
AT_DISPATCH_FLOATING_TYPES(A.scalar_type(), "matmul_transpose_ldg_optimization_kernel", ([&] {
matmul_transpose_ldg_optimization_kernel<scalar_t><<<blocks, threads>>>(
A.data_ptr<scalar_t>(),
B.data_ptr<scalar_t>(),
C.data_ptr<scalar_t>(),
M, N, K);
}));
return C;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &matmul_transpose_cuda, "Matrix multiplication with transposed inputs using multi-output kernel and ldg optimization (CUDA)");
}
Metric | Value | Unit | Variance | Samples |
---|---|---|---|---|
Executed Ipc Active | 2.582 | inst/cycle | 0.000 | 5 |
Executed Ipc Elapsed | 2.554 | inst/cycle | 0.000 | 5 |
Issue Slots Busy | 64.582 | % | 0.008 | 5 |
Issued Ipc Active | 2.582 | inst/cycle | 0.000 | 5 |
SM Busy | 64.582 | % | 0.008 | 5 |
Memory Throughput | 71068218101.104 | byte/second | 687096220118865664.000 | 5 |
Mem Busy | 83.582 | % | 0.014 | 5 |
Max Bandwidth | 75.126 | % | 0.011 | 5 |
L1/TEX Hit Rate | 1.692 | % | 0.003 | 5 |
L2 Hit Rate | 86.284 | % | 0.077 | 5 |
Mem Pipes Busy | 73.392 | % | 0.010 | 5 |
Warp Cycles Per Issued Instruction | 22.552 | cycle | 0.000 | 5 |
Warp Cycles Per Executed Instruction | 22.554 | cycle | 0.000 | 5 |
Avg. Active Threads Per Warp | 32.000 | 0.000 | 5 | |
Avg. Not Predicated Off Threads Per Warp | 31.120 | 0.000 | 5 | |
Max Active Clusters | 0.000 | cluster | 0.000 | 5 |
Max Cluster Size | 8.000 | block | 0.000 | 5 |
Overall GPU Occupancy | 0.000 | % | 0.000 | 5 |
Cluster Occupancy | 0.000 | % | 0.000 | 5 |
Block Limit SM | 32.000 | block | 0.000 | 5 |
Block Limit Registers | 8.000 | block | 0.000 | 5 |
Block Limit Shared Mem | 16.000 | block | 0.000 | 5 |
Block Limit Warps | 8.000 | block | 0.000 | 5 |
Theoretical Active Warps per SM | 64.000 | warp | 0.000 | 5 |
Theoretical Occupancy | 100.000 | % | 0.000 | 5 |
Achieved Occupancy | 91.072 | % | 0.004 | 5 |
Achieved Active Warps Per SM | 58.288 | warp | 0.002 | 5 |
Rule | Description |
---|---|
INF HighPipeUtilization | ALU is the highest-utilized pipeline (37.1%) based on active cycles, taking into account the rates of its different instructions. It executes integer and logic operations. It is well-utilized, but should not be a bottleneck. |
INF Occupancy | This kernel's theoretical occupancy is not impacted by any block limit. |
INF CPIStall | Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason. |
Operation / Metric | Value | Unit |
---|---|---|
aten::to | ||
CPU Time | 265075.92 | μs |
Device Time | 6060.28 | μs |
Self CPU Time | 50.09 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaLaunchKernel | ||
CPU Time | 5825758.27 | μs |
Device Time | 8842.08 | μs |
Self CPU Time | 5825758.27 | μs |
Self Device Time | 8842.08 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void matmul_transpose_ldg_optimization_kernel<float>(float const*, float const*, float*, int, int, int) | ||
CPU Time | 0.00 | μs |
Device Time | 6179881.65 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 6179881.65 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaDeviceSynchronize | ||
CPU Time | 520382.65 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 520382.65 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaEventRecord | ||
CPU Time | 11244.42 | μs |
Device Time | 16206.47 | μs |
Self CPU Time | 11244.42 | μs |
Self Device Time | 16206.47 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::zero_ | ||
CPU Time | 5658662.74 | μs |
Device Time | 251169.60 | μs |
Self CPU Time | 7484.48 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::fill_ | ||
CPU Time | 5651180.15 | μs |
Device Time | 251169.60 | μs |
Self CPU Time | 11331.09 | μs |
Self Device Time | 251169.60 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>) | ||
CPU Time | 0.00 | μs |
Device Time | 251169.60 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 251169.60 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45288 warnings generated when compiling for host. Suppressed 45322 warnings (45275 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.