← Back to Leaderboard

The AI CUDA Engineer 👷

9_Tall_skinny_matrix_multiplication_predicated_tile_loading_unrolled_base

Level 1 • Task 9
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(A, B):
    """
    Performs a single matrix multiplication (C = A * B) where one of the matrices is tall and skinny (M >> N or N >> M).

    Args:
        A (torch.Tensor): Input matrix of shape (M, K) or (K, M) where M >> N or N >> M.
        B (torch.Tensor): Input matrix of shape (K, N) or (N, K) where M >> N or N >> M.

    Returns:
        torch.Tensor: Output matrix of shape (M, N) or (N, M)
    """
    return torch.matmul(A, B)


class Model(nn.Module):
    """
    Simple model that performs a single matrix multiplication (C = A * B) where one of the matrices is tall and skinny (M >> N or N >> M)
    """

    def __init__(self):
        super(Model, self).__init__()

    def forward(self, A, B, fn=module_fn):
        return fn(A, B)


M = 16384
N = 16


def get_inputs():
    A = torch.randn(M, N)
    B = torch.randn(N, M)
    return [A, B]


def get_init_inputs():
    return []  # No special initialization inputs needed
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Simple model that performs a single matrix multiplication (C = A * B) where one of the matrices is tall and skinny (M >> N or N >> M)
    """
    def __init__(self):
        super(Model, self).__init__()
    
    def forward(self, A, B):
        """
        Performs the matrix multiplication.

        Args:
            A (torch.Tensor): Input matrix of shape (M, K) or (K, M) where M >> N or N >> M.
            B (torch.Tensor): Input matrix of shape (K, N) or (N, K) where M >> N or N >> M.

        Returns:
            torch.Tensor: Output matrix of shape (M, N) or (N, M)
        """
        return torch.matmul(A, B)

M = 16384
N = 16

def get_inputs():
    A = torch.randn(M, N)
    B = torch.randn(N, M)
    return [A, B]

def get_init_inputs():
    return []  # No special initialization inputs needed

Kernel Information

Related Kernels (Level 1, Task 9 • 9_Tall_skinny_matrix_multiplication_)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 unrolled_loop_matmul_base 0.68 0.78 0.59
🥈 constant_mem_matmul_base_base 0.69 0.78 0.58
🥉 unrolled_matmul_kernel_base 0.69 0.77 0.58
4 balanced_workload_matmul_base_base 0.71 0.75 0.56
4 multi_tile_mapping_base 0.71 0.75 0.56
6 optimized_tiled_gemm_base 0.71 0.75 0.56
6 optimized_matmul_kernel_base 0.71 0.75 0.56
8 streamed_balanced_matmul_base 0.75 0.71 0.53
9 streamed_balanced_matmul_base 0.75 0.71 0.53
9 streamed_pipelined_matmul_base 0.75 0.71 0.53
11 predicated_tile_loading_unrolled_edit_1 1.26 0.42 0.32
11 unrolled_loop_optimization_base 1.26 0.42 0.32
11 unrolled_loop_optimization_edit_1 1.26 0.42 0.32
11 modular_device_functions_edit_1 1.26 0.42 0.32
15 uniform_flow_matmul_base 1.26 0.42 0.32
15 warp_optimized_reduction_edit_1 1.26 0.42 0.32
17 predicated_tile_loading_unrolled_base 1.26 0.42 0.32
18 modular_device_functions_base 1.26 0.42 0.32
19 warp_divergence_optimized_base_base 1.27 0.42 0.32
20 coalesced_memory_access_base_base 1.27 0.42 0.32
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

#define BLOCK_SIZE 16

__device__ float get_element(const float* __restrict__ matrix, int row, int col, int ld, bool transpose) {
    return transpose ? matrix[col * ld + row] : matrix[row * ld + col];
}

__device__ void load_tiles(const float* __restrict__ A, const float* __restrict__ B,
                           float As[BLOCK_SIZE][BLOCK_SIZE], float Bs[BLOCK_SIZE][BLOCK_SIZE],
                           int row, int col, int t, int M, int N, int K, int lda, int ldb,
                           bool transA, bool transB) {
    const int k_a = t * BLOCK_SIZE + threadIdx.x;
    const int k_b = t * BLOCK_SIZE + threadIdx.y;
    
    const bool valid_a = (row < M) && (k_a < K);
    const bool valid_b = (col < N) && (k_b < K);
    
    As[threadIdx.y][threadIdx.x] = valid_a ? get_element(A, row, k_a, lda, transA) : 0.0f;
    Bs[threadIdx.y][threadIdx.x] = valid_b ? get_element(B, k_b, col, ldb, transB) : 0.0f;
}

__device__ float compute_partial_product(float As[BLOCK_SIZE][BLOCK_SIZE], float Bs[BLOCK_SIZE][BLOCK_SIZE]) {
    float sum = 0.0f;
    #pragma unroll
    for (int k = 0; k < BLOCK_SIZE; ++k) {
        sum += As[threadIdx.y][k] * Bs[k][threadIdx.x];
    }
    return sum;
}

__global__ void matmul_kernel(const float* __restrict__ A,
                              const float* __restrict__ B,
                              float* __restrict__ C,
                              int M, int N, int K,
                              int lda, int ldb, int ldc,
                              bool transA, bool transB) {
    const int row = blockIdx.y * blockDim.y + threadIdx.y;
    const int col = blockIdx.x * blockDim.x + threadIdx.x;

    __shared__ float As[BLOCK_SIZE][BLOCK_SIZE];
    __shared__ float Bs[BLOCK_SIZE][BLOCK_SIZE];
    
    float acc = 0.0f;

    for (int t = 0; t < (K + BLOCK_SIZE - 1) / BLOCK_SIZE; ++t) {
        load_tiles(A, B, As, Bs, row, col, t, M, N, K, lda, ldb, transA, transB);
        __syncthreads();
        
        acc += compute_partial_product(As, Bs);
        __syncthreads();
    }

    if (row < M && col < N) {
        C[row * ldc + col] = acc;
    }
}

torch::Tensor matmul_cuda(torch::Tensor A, torch::Tensor B) {
    if (!A.is_cuda() || !B.is_cuda())
        throw std::invalid_argument("Inputs must be CUDA tensors");
    if (A.dim() != 2 || B.dim() != 2)
        throw std::invalid_argument("Inputs must be 2D");

    int64_t M, N, K;
    bool transA = false, transB = false;
    int lda, ldb, ldc;

    const auto A_rows = A.size(0), A_cols = A.size(1);
    const auto B_rows = B.size(0), B_cols = B.size(1);

    if (A_rows >= A_cols && B_rows == A_cols) {
        M = A_rows; K = A_cols; N = B_cols;
        lda = A.stride(0); ldb = B.stride(0);
    } else if (A_cols > A_rows && B_rows == A_rows) {
        transA = true; M = A_cols; K = A_rows; N = B_cols;
        lda = A.stride(1); ldb = B.stride(0);
    } else if (A_rows >= A_cols && B_cols == A_cols) {
        transB = true; M = A_rows; K = A_cols; N = B_rows;
        lda = A.stride(0); ldb = B.stride(1);
    } else if (A_cols > A_rows && B_cols == A_rows) {
        transA = transB = true; M = A_cols; K = A_rows; N = B_rows;
        lda = A.stride(1); ldb = B.stride(1);
    } else {
        throw std::invalid_argument("Dimensions mismatch");
    }

    auto C = torch::empty({M, N}, A.options());
    ldc = N;

    const dim3 block(BLOCK_SIZE, BLOCK_SIZE);
    const dim3 grid((N + BLOCK_SIZE - 1) / BLOCK_SIZE, (M + BLOCK_SIZE - 1) / BLOCK_SIZE);

    matmul_kernel<<<grid, block>>>(A.data_ptr<float>(), B.data_ptr<float>(), C.data_ptr<float>(),
                                  M, N, K, lda, ldb, ldc, transA, transB);
    
    cudaDeviceSynchronize();
    return C;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &matmul_cuda, "Optimized tall-skinny matmul with predicated loading");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 3.100 inst/cycle 0.000 5
Executed Ipc Elapsed 3.090 inst/cycle 0.000 5
Issue Slots Busy 77.418 % 0.000 5
Issued Ipc Active 3.100 inst/cycle 0.000 5
SM Busy 77.418 % 0.000 5
Memory Throughput 686498742439.206 byte/second 75198003007403136.000 5
Mem Busy 92.880 % 0.000 5
Max Bandwidth 72.796 % 0.000 5
L1/TEX Hit Rate 30.402 % 0.000 5
L2 Hit Rate 99.278 % 0.004 5
Mem Pipes Busy 72.760 % 0.000 5
Warp Cycles Per Issued Instruction 18.490 cycle 0.000 5
Warp Cycles Per Executed Instruction 18.490 cycle 0.000 5
Avg. Active Threads Per Warp 32.000 0.000 5
Avg. Not Predicated Off Threads Per Warp 31.190 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 8.000 block 0.000 5
Block Limit Shared Mem 21.000 block 0.000 5
Block Limit Warps 8.000 block 0.000 5
Theoretical Active Warps per SM 64.000 warp 0.000 5
Theoretical Occupancy 100.000 % 0.000 5
Achieved Occupancy 90.248 % 0.000 5
Achieved Active Warps Per SM 57.760 warp 0.000 5
Analysis Rules
Rule Description
INF HighPipeUtilization ALU is the highest-utilized pipeline (35.1%) based on active cycles, taking into account the rates of its different instructions. It executes integer and logic operations. It is well-utilized, but should not be a bottleneck.
INF Occupancy This kernel's theoretical occupancy is not impacted by any block limit.
Operation / Metric Value Unit
aten::to
CPU Time 211086.61 μs
Device Time 79.81 μs
Self CPU Time 41.47 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_to_copy
CPU Time 211045.14 μs
Device Time 79.81 μs
Self CPU Time 132.16 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::empty_strided
CPU Time 210466.14 μs
Device Time 0.00 μs
Self CPU Time 104.17 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaDeviceGetStreamPriorityRange
CPU Time 210159.19 μs
Device Time 0.00 μs
Self CPU Time 210159.19 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
matmul_kernel(float const*, float const*, float*, int, int, int, int, int, int, bool, bool)
CPU Time 0.00 μs
Device Time 4389576.87 μs
Self CPU Time 0.00 μs
Self Device Time 4389576.87 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaDeviceSynchronize
CPU Time 4622223.49 μs
Device Time 393.53 μs
Self CPU Time 4622223.49 μs
Self Device Time 393.53 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaEventRecord
CPU Time 8663.27 μs
Device Time 49872.11 μs
Self CPU Time 8663.27 μs
Self Device Time 49872.11 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 39260.87 μs
Device Time 267132.85 μs
Self CPU Time 7663.40 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::fill_
CPU Time 31599.59 μs
Device Time 267132.85 μs
Self CPU Time 9935.22 μs
Self Device Time 267132.85 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>)
CPU Time 0.00 μs
Device Time 267132.85 μs
Self CPU Time 0.00 μs
Self Device Time 267132.85 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45296 warnings generated when compiling for host.
Suppressed 45322 warnings (45275 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250211_optimize_b5_s4_e1_v2/level_1/task_9/b4_s3_predicated_tile_loading_unrolled/base/base.cu:13:28 bugprone-easily-swappable-parameters
13 | int row, int col, int t, int M, int N, int K, int lda, int ldb,
| ^~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250211_optimize_b5_s4_e1_v2/level_1/task_9/b4_s3_predicated_tile_loading_unrolled/base/base.cu:13:32: note: the first parameter in the range is 'row'
13 | int row, int col, int t, int M, int N, int K, int lda, int ldb,
| ^~~
/home/robert_sakana_ai/llm_cuda/experiments/20250211_optimize_b5_s4_e1_v2/level_1/task_9/b4_s3_predicated_tile_loading_unrolled/base/base.cu:13:50: note: the last parameter in the range is 't'
13 | int row, int col, int t, int M, int N, int K, int lda, int ldb,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250211_optimize_b5_s4_e1_v2/level_1/task_9/b4_s3_predicated_tile_loading_unrolled/base/base.cu:13:53: warning: 2 adjacent parameters of 'load_tiles' of similar type ('int') are easily swapped by mistake [bugprone-easily-swappable-parameters]
13 | int row, int col, int t, int M, int N, int K, int lda, int ldb,
| ^~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250211_optimize_b5_s4_e1_v2/level_1/task_9/b4_s3_predicated_tile_loading_unrolled/base/base.cu:13:57: note: the first parameter in the range is 'M'
13 | int row, int col, int t, int M, int N, int K, int lda, int ldb,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250211_optimize_b5_s4_e1_v2/level_1/task_9/b4_s3_predicated_tile_loading_unrolled/base/base.cu:13:64: note: the last parameter in the range is 'N'
13 | int row, int col, int t, int M, int N, int K, int lda, int ldb,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250211_optimize_b5_s4_e1_v2/level_1/task_9/b4_s3_predicated_tile_loading_unrolled/base/base.cu:13:67: warning: 2 adjacent parameters of 'load_tiles' of similar type ('int') are easily swapped by mistake [bugprone-easily-swappable-parameters]
13 | int row, int col, int t, int M, int N, int K, int lda, int ldb,
| ^~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250211_optimize_b5_s4_e1_v2/level_1/task_9/b4_s3_predicated_tile_loading_unrolled/base/base.cu:13:71: note: the first parameter in the range is 'K'
13 | int row, int col, int t, int M, int N, int K, int lda, int ldb,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250211_optimize_b5_s4_e1_v2/level_1/task_9/b4_s3_predicated_tile_loading_unrolled/base/base.cu:13:78: note: the last parameter in the range is 'lda'
13 | int row, int col, int t, int M, int N, int K, int lda, int ldb,
| ^~~
/home/robert_sakana_ai/llm_cuda/experiments/20250211_optimize_b5_s4_e1_v2/level_1/task_9/b4_s3_predicated_tile_loading_unrolled/base/base.cu:15:21: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
15 | const int k_a = t * BLOCK_SIZE + threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250211_optimize_b5_s4_e1_v2/level_1/task_9/b4_s3_predicated_tile_loading_unrolled/base/base.cu:16:21: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
16 | const int k_b = t * BLOCK_SIZE + threadIdx.y;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250211_optimize_b5_s4_e1_v2/level_1/task_9/b4_s3_predicated_tile_loading_unrolled/base/base.cu:40:21: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
40 | const int row = blockIdx.y * blockDim.y + threadIdx.y;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250211_optimize_b5_s4_e1_v2/level_1/task_9/b4_s3_predicated_tile_loading_unrolled/base/base.cu:41:21: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
41 | const int col = blockIdx.x * blockDim.x + threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250211_optimize_b5_s4_e1_v2/level_1/task_9/b4_s3_predicated_tile_loading_unrolled/base/base.cu:61:41: warning: the parameter 'A' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
61 | torch::Tensor matmul_cuda(torch::Tensor A, torch::Tensor B) {
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250211_optimize_b5_s4_e1_v2/level_1/task_9/b4_s3_predicated_tile_loading_unrolled/base/base.cu:61:58: warning: the parameter 'B' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
61 | torch::Tensor matmul_cuda(torch::Tensor A, torch::Tensor B) {
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250211_optimize_b5_s4_e1_v2/level_1/task_9/b4_s3_predicated_tile_loading_unrolled/base/base.cu:76:15: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
76 | lda = A.stride(0); ldb = B.stride(0);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250211_optimize_b5_s4_e1_v2/level_1/task_9/b4_s3_predicated_tile_loading_unrolled/base/base.cu:76:34: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
76 | lda = A.stride(0); ldb = B.stride(0);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250211_optimize_b5_s4_e1_v2/level_1/task_9/b4_s3_predicated_tile_loading_unrolled/base/base.cu:79:15: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
79 | lda = A.stride(1); ldb = B.stride(0);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250211_optimize_b5_s4_e1_v2/level_1/task_9/b4_s3_predicated_tile_loading_unrolled/base/base.cu:79:34: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
79 | lda = A.stride(1); ldb = B.stride(0);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250211_optimize_b5_s4_e1_v2/level_1/task_9/b4_s3_predicated_tile_loading_unrolled/base/base.cu:82:15: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
82 | lda = A.stride(0); ldb = B.stride(1);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250211_optimize_b5_s4_e1_v2/level_1/task_9/b4_s3_predicated_tile_loading_unrolled/base/base.cu:82:34: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
82 | lda = A.stride(0); ldb = B.stride(1);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250211_optimize_b5_s4_e1_v2/level_1/task_9/b4_s3_predicated_tile_loading_unrolled/base/base.cu:85:15: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
85 | lda = A.stride(1); ldb = B.stride(1);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250211_optimize_b5_s4_e1_v2/level_1/task_9/b4_s3_predicated_tile_loading_unrolled/base/base.cu:85:34: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
85 | lda = A.stride(1); ldb = B.stride(1);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250211_optimize_b5_s4_e1_v2/level_1/task_9/b4_s3_predicated_tile_loading_unrolled/base/base.cu:91:11: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
91 | ldc = N;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250211_optimize_b5_s4_e1_v2/level_1/task_9/b4_s3_predicated_tile_loading_unrolled/base/base.cu:97:35: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
97 | M, N, K, lda, ldb, ldc, transA, transB);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250211_optimize_b5_s4_e1_v2/level_1/task_9/b4_s3_predicated_tile_loading_unrolled/base/base.cu:97:38: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
97 | M, N, K, lda, ldb, ldc, transA, transB);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250211_optimize_b5_s4_e1_v2/level_1/task_9/b4_s3_predicated_tile_loading_unrolled/base/base.cu:97:41: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
97 | M, N, K, lda, ldb, ldc, transA, transB);
| ^