← Back to Leaderboard

The AI CUDA Engineer 👷

18_Matmul_with_transposed_bothoptimized_matmul_transpose_base

Level 1 • Task 18
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
    """
    Performs a single matrix multiplication with transposed A and B (C = A.T * B.T).

    Args:
        A: Input tensor of shape (K, M).
        B: Input tensor of shape (N, K).

    Returns:
        Output tensor of shape (M, N).
    """
    return torch.matmul(A.T, B.T)


class Model(nn.Module):
    """
    Simple model that performs a single matrix multiplication (C = A * B)
    """

    def __init__(self):
        super(Model, self).__init__()

    def forward(self, A: torch.Tensor, B: torch.Tensor, fn=module_fn) -> torch.Tensor:
        return fn(A, B)


M = 1024
K = 4096
N = 2048


def get_inputs():
    A = torch.randn(K, M)
    B = torch.randn(N, K)
    return [A, B]


def get_init_inputs():
    return []  # No special initialization inputs needed
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Simple model that performs a single matrix multiplication (C = A * B)
    """
    def __init__(self):
        super(Model, self).__init__()
    
    def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
        """
        Performs matrix multiplication.

        Args:
            A: Input tensor of shape (M, K).
            B: Input tensor of shape (K, N).

        Returns:
            Output tensor of shape (M, N).
        """
        return torch.matmul(A.T, B.T)

M = 1024
K = 4096
N = 2048

def get_inputs():
    A = torch.randn(K, M)
    B = torch.randn(N, K)
    return [A, B]

def get_init_inputs():
    return []  # No special initialization inputs needed

Kernel Information

Related Kernels (Level 1, Task 18 • 18_Matmul_with_transposed_both)

#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

// Define tile sizes
#define BLOCK_SIZE_M 32  // Output tile height (each block computes 32 rows)
#define BLOCK_SIZE_N 16  // Output tile width (each block computes 16 columns)
#define BLOCK_SIZE_K 16  // Reduction tile depth
#define UNROLL_FACTOR 4

// Kernel: each thread computes a 2x1 sub-tile (2 rows, 1 column) of C
// A is (K x M): element A[k, m] = A[k * M + m]
// B is (N x K): element B[n, k] = B[n * K + k]
// C is (M x N): element C[m, n] = C[m * N + n]

template <typename scalar_t>
__global__ void optimized_matmul_transpose_kernel(
    const scalar_t* __restrict__ A,
    const scalar_t* __restrict__ B,
    scalar_t* __restrict__ C,
    int M,
    int N,
    int K) {

    // Determine the starting indices for this block's tile in C
    int m_start = blockIdx.y * BLOCK_SIZE_M;  // row start in C
    int n_start = blockIdx.x * BLOCK_SIZE_N;  // col start in C

    // Thread indices within the block
    int tx = threadIdx.x; // Expected range: [0, 15]
    int ty = threadIdx.y; // Expected range: [0, 15]

    // Each thread computes two rows: row0 and row1
    int row0 = m_start + tx;             // first row computed by this thread
    int row1 = row0 + (BLOCK_SIZE_M / 2);  // second row computed (offset by 16)
    int col = n_start + ty;              // column index in C

    // Accumulators for the two output elements
    scalar_t acc0 = 0;
    scalar_t acc1 = 0;

    // Declare shared memory tiles
    __shared__ scalar_t A_tile[BLOCK_SIZE_K][BLOCK_SIZE_M]; // Size: 16 x 32
    __shared__ scalar_t B_tile[BLOCK_SIZE_N][BLOCK_SIZE_K];   // Size: 16 x 16

    // Total threads in a block
    int tId = threadIdx.y * blockDim.x + threadIdx.x; // Range: 0 to 255
    int blockSize = blockDim.x * blockDim.y;            // = 256

    int numTiles = (K + BLOCK_SIZE_K - 1) / BLOCK_SIZE_K;
    for (int tile = 0; tile < numTiles; tile++) {
        // Load A tile into shared memory with unrolling
        int totalAElements = BLOCK_SIZE_K * BLOCK_SIZE_M; // 512
        for (int idx = tId; idx < totalAElements; idx += blockSize) {
            int kd = idx / BLOCK_SIZE_M;  // k-index within the tile
            int md = idx % BLOCK_SIZE_M;  // m-index within the tile
            int global_m = m_start + md;  // global m index
            int global_k = tile * BLOCK_SIZE_K + kd;  // global k index
            if (global_m < M && global_k < K)
                A_tile[kd][md] = __ldg(&A[global_k * M + global_m]);
            else
                A_tile[kd][md] = 0;
        }

        // Load B tile into shared memory with 128-bit loading
        int totalBElements = BLOCK_SIZE_N * BLOCK_SIZE_K; // 256
        for (int idx = tId; idx < totalBElements; idx += blockSize) {
            int nd = idx / BLOCK_SIZE_K;  // n-index within the tile
            int kd = idx % BLOCK_SIZE_K;  // k-index within the tile
            int global_n = n_start + nd;  // global n index
            int global_k = tile * BLOCK_SIZE_K + kd;  // global k index
            if (global_n < N && global_k < K)
                B_tile[nd][kd] = __ldg(&B[global_n * K + global_k]);
            else
                B_tile[nd][kd] = 0;
        }

        __syncthreads();

        // Simultaneous loading with unrolling for compute
        // Compute the partial results for this tile
        for (int k = 0; k < BLOCK_SIZE_K; k += UNROLL_FACTOR) {
            scalar_t a0[UNROLL_FACTOR], a1[UNROLL_FACTOR];
            scalar_t b[UNROLL_FACTOR];
            #pragma unroll
            for (int u = 0; u < UNROLL_FACTOR; u++) {
                if (k + u < BLOCK_SIZE_K) {
                    a0[u] = A_tile[k + u][tx];
                    a1[u] = A_tile[k + u][tx + (BLOCK_SIZE_M / 2)];
                    b[u] = B_tile[ty][k + u];
                }
            }

            #pragma unroll
            for (int u = 0; u < UNROLL_FACTOR; u++) {
                if (k + u < BLOCK_SIZE_K) {
                    acc0 = __fmaf_rn(a0[u], b[u], acc0);
                    acc1 = __fmaf_rn(a1[u], b[u], acc1);
                }
            }
        }
        __syncthreads();
    }

    // Write the results to global memory
    if (row0 < M && col < N) {
        C[row0 * N + col] = acc0;
    }
    if (row1 < M && col < N) {
        C[row1 * N + col] = acc1;
    }
}


// PyTorch binding

torch::Tensor matmul_transpose_cuda(torch::Tensor A, torch::Tensor B) {
    // Dimensions:
    // A: (K x M), B: (N x K), therefore C: (M x N)
    int K = A.size(0);
    int M = A.size(1);
    int N = B.size(0);

    auto C = torch::empty({M, N}, A.options());

    // Define block dimensions: use 16x16 threads per block
    dim3 threads(16, 16);
    // Grid dimensions based on tile sizes
    dim3 blocks((N + BLOCK_SIZE_N - 1) / BLOCK_SIZE_N, (M + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M);

    AT_DISPATCH_FLOATING_TYPES(A.scalar_type(), "optimized_matmul_transpose_kernel", ([&] {
        optimized_matmul_transpose_kernel<scalar_t><<<blocks, threads>>>(
            A.data_ptr<scalar_t>(),
            B.data_ptr<scalar_t>(),
            C.data_ptr<scalar_t>(),
            M, N, K);
    }));

    return C;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &matmul_transpose_cuda, "Optimized matrix multiplication with transpose (CUDA)");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 2.580 inst/cycle 0.000 5
Executed Ipc Elapsed 2.548 inst/cycle 0.000 5
Issue Slots Busy 64.520 % 0.007 5
Issued Ipc Active 2.580 inst/cycle 0.000 5
SM Busy 64.520 % 0.007 5
Memory Throughput 70771506548.074 byte/second 4437835498168402944.000 5
Mem Busy 83.462 % 0.015 5
Max Bandwidth 75.020 % 0.012 5
L1/TEX Hit Rate 1.680 % 0.002 5
L2 Hit Rate 86.760 % 0.272 5
Mem Pipes Busy 73.288 % 0.012 5
Warp Cycles Per Issued Instruction 22.540 cycle 0.001 5
Warp Cycles Per Executed Instruction 22.540 cycle 0.001 5
Avg. Active Threads Per Warp 32.000 0.000 5
Avg. Not Predicated Off Threads Per Warp 31.120 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 8.000 block 0.000 5
Block Limit Shared Mem 16.000 block 0.000 5
Block Limit Warps 8.000 block 0.000 5
Theoretical Active Warps per SM 64.000 warp 0.000 5
Theoretical Occupancy 100.000 % 0.000 5
Achieved Occupancy 91.002 % 0.007 5
Achieved Active Warps Per SM 58.242 warp 0.003 5
Analysis Rules
Rule Description
INF HighPipeUtilization ALU is the highest-utilized pipeline (37.1%) based on active cycles, taking into account the rates of its different instructions. It executes integer and logic operations. It is well-utilized, but should not be a bottleneck.
INF CPIStall Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason.
INF Occupancy This kernel's theoretical occupancy is not impacted by any block limit.
Operation / Metric Value Unit
aten::to
CPU Time 608072.15 μs
Device Time 5403.39 μs
Self CPU Time 47.34 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_to_copy
CPU Time 608024.81 μs
Device Time 5403.39 μs
Self CPU Time 127.41 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 5345400.52 μs
Device Time 7516.30 μs
Self CPU Time 5345400.52 μs
Self Device Time 7516.30 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void optimized_matmul_transpose_kernel<float>(float const*, float const*, float*, int, int, int)
CPU Time 0.00 μs
Device Time 5706786.33 μs
Self CPU Time 0.00 μs
Self Device Time 5706786.33 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaEventRecord
CPU Time 10442.63 μs
Device Time 14805.71 μs
Self CPU Time 10442.63 μs
Self Device Time 14805.71 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 5190384.16 μs
Device Time 231938.39 μs
Self CPU Time 6448.82 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::fill_
CPU Time 5183936.96 μs
Device Time 231938.39 μs
Self CPU Time 8951.10 μs
Self Device Time 231938.39 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>)
CPU Time 0.00 μs
Device Time 231938.39 μs
Self CPU Time 0.00 μs
Self Device Time 231938.39 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45288 warnings generated when compiling for host.
Suppressed 45322 warnings (45275 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_18/b8_s3_optimized_matmul_transpose/base/base.cu:18:5 bugprone-easily-swappable-parameters
18 | const scalar_t* __restrict__ A,
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
19 | const scalar_t* __restrict__ B,
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_18/b8_s3_optimized_matmul_transpose/base/base.cu:18:34: note: the first parameter in the range is 'A'
18 | const scalar_t* __restrict__ A,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_18/b8_s3_optimized_matmul_transpose/base/base.cu:19:34: note: the last parameter in the range is 'B'
19 | const scalar_t* __restrict__ B,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_18/b8_s3_optimized_matmul_transpose/base/base.cu:26:19: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
26 | int m_start = blockIdx.y * BLOCK_SIZE_M; // row start in C
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_18/b8_s3_optimized_matmul_transpose/base/base.cu:27:19: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
27 | int n_start = blockIdx.x * BLOCK_SIZE_N; // col start in C
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_18/b8_s3_optimized_matmul_transpose/base/base.cu:30:14: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
30 | int tx = threadIdx.x; // Expected range: [0, 15]
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_18/b8_s3_optimized_matmul_transpose/base/base.cu:31:14: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
31 | int ty = threadIdx.y; // Expected range: [0, 15]
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_18/b8_s3_optimized_matmul_transpose/base/base.cu:47:15: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
47 | int tId = threadIdx.y * blockDim.x + threadIdx.x; // Range: 0 to 255
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_18/b8_s3_optimized_matmul_transpose/base/base.cu:48:21: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
48 | int blockSize = blockDim.x * blockDim.y; // = 256
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_18/b8_s3_optimized_matmul_transpose/base/base.cu:120:13: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
120 | int K = A.size(0);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_18/b8_s3_optimized_matmul_transpose/base/base.cu:121:13: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
121 | int M = A.size(1);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_18/b8_s3_optimized_matmul_transpose/base/base.cu:122:13: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
122 | int N = B.size(0);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_18/b8_s3_optimized_matmul_transpose/base/base.cu:131:5: warning: inside a lambda, '__func__' expands to the name of the function call operator; consider capturing the name of the enclosing function explicitly [bugprone-lambda-function-name]
131 | AT_DISPATCH_FLOATING_TYPES(A.scalar_type(), "optimized_matmul_transpose_kernel", ([&] {
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:237:34: note: expanded from macro 'AT_DISPATCH_FLOATING_TYPES'
237 | AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:233:3: note: expanded from macro 'AT_DISPATCH_CASE_FLOATING_TYPES'
233 | AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:74:3: note: expanded from macro 'AT_DISPATCH_CASE'
74 | AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
| ^
note: (skipping 1 expansions in backtrace; use -fmacro-backtrace-limit=0 to see all)
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:58:7: note: expanded from macro 'AT_PRIVATE_CHECK_SELECTIVE_BUILD'
58 | AT_ERROR( \
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/c10/util/Exception.h:711:32: note: expanded from macro 'AT_ERROR'
711 | C10_EXPAND_MSVC_WORKAROUND(TORCH_CHECK(false, ::c10::str(__VA_ARGS__))); \
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/c10/util/Exception.h:536:9: note: expanded from macro 'TORCH_CHECK'
536 | __func__, \
| ^