← Back to Leaderboard

The AI CUDA Engineer 👷

2_Standard_matrix_multiplication_hybrid_matmul_base

Level 1 • Task 2
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(A, B):
    """
    performs a single general matrix multiplication (C = A * B).

    Args:
        A: Input tensor of shape (M, K).
        B: Input tensor of shape (K, N).

    Returns:
        Output tensor of shape (M, N).
    """
    return torch.matmul(A, B)


class Model(nn.Module):
    """
    Simple model that performs a single matrix multiplication (C = A * B)
    """

    def __init__(self):
        super(Model, self).__init__()

    def forward(self, A: torch.Tensor, B: torch.Tensor, fn=module_fn) -> torch.Tensor:
        return fn(A, B)


M = 1024
K = 4096
N = 2048


def get_inputs():
    A = torch.randn(M, K)
    B = torch.randn(K, N)
    return [A, B]


def get_init_inputs():
    return []  # No special initialization inputs needed
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Simple model that performs a single matrix multiplication (C = A * B)
    """
    def __init__(self):
        super(Model, self).__init__()
    
    def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
        """
        Performs matrix multiplication.

        Args:
            A: Input tensor of shape (M, K).
            B: Input tensor of shape (K, N).

        Returns:
            Output tensor of shape (M, N).
        """
        return torch.matmul(A, B)

M = 1024
K = 4096
N = 2048

def get_inputs():
    A = torch.randn(M, K)
    B = torch.randn(K, N)
    return [A, B]

def get_init_inputs():
    return []  # No special initialization inputs needed

Kernel Information

Related Kernels (Level 1, Task 2 • 2_Standard_matrix_multiplication_)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 optimized_tiled_matmul_base 0.43 1.00 1.08
🥇 hybrid_matmul_base 0.43 1.00 1.08
🥇 hybrid_regtiled_base 0.43 1.00 1.08
🥇 double_buffered_matmul_base 0.43 1.00 1.08
5 warp_optimized_matmul_base_base 0.43 1.00 1.08
5 coalesced_hybrid_matmul_base_base 0.43 1.00 1.08
5 strided_tiled_matmul_base 0.43 1.00 1.08
8 hybrid_matmul_base 0.43 0.99 1.07
8 aligned_tiled_matmul_base_base 0.43 0.99 1.07
10 unrolled_hybrid_matmul_base 0.43 0.99 1.07
11 unrolled_hybrid_matmul_base_base 0.43 0.98 1.06
11 dynamic_blocksize_matmul_base 0.43 0.98 1.06
13 doublebuffer_tiled_matmul_base 0.43 0.98 1.06
13 optimized_single_stream_matmul_base 0.43 0.98 1.06
15 hybrid_tiled_cublas_base 0.43 0.98 1.06
16 constant_hybrid_matmul_base_base 0.45 0.96 1.03
17 streamed_pipelined_matmul_base_base 0.45 0.95 1.02
18 tiled_regtile_base 1.26 0.34 0.36
19 optimized_sync_matrix_mult_edit_1 1.92 0.22 0.24
20 divergence_free_matrix_mult_base 1.93 0.22 0.24
/*
Hybrid Matrix Multiplication Extension
This implementation combines a custom tiled CUDA kernel for small matrices and cuBLAS for larger matrices.
For small matrix sizes (e.g. <= 128x128), the custom kernel minimizes launch overhead.
For larger matrices, cuBLAS leverages highly optimized libraries and GPU tensor cores.
*/

#include <torch/extension.h>
#include <cuda_runtime.h>
#include <cuda.h>
#include <cublas_v2.h>

#define TILE_SIZE 32

#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

// Static cuBLAS handle to avoid recreation overhead
static cublasHandle_t handle = nullptr;

// Custom tiled matrix multiplication kernel
__global__ void matmul_kernel_2d(const float* __restrict__ A,
                                 const float* __restrict__ B,
                                 float* __restrict__ C,
                                 const int M, const int N, const int K) {
    __shared__ float As[TILE_SIZE][TILE_SIZE];
    __shared__ float Bs[TILE_SIZE][TILE_SIZE];

    // Block indices
    const int bx = blockIdx.x;
    const int by = blockIdx.y;
    // Thread indices
    const int tx = threadIdx.x;
    const int ty = threadIdx.y;

    // Compute row and col for C
    const int row = by * TILE_SIZE + ty;
    const int col = bx * TILE_SIZE + tx;

    float sum = 0.0f;

    // Loop over tiles
    for (int tile = 0; tile < (K + TILE_SIZE - 1) / TILE_SIZE; ++tile) {
        // Load A tile
        if (row < M && tile * TILE_SIZE + tx < K) {
            As[ty][tx] = A[row * K + tile * TILE_SIZE + tx];
        } else {
            As[ty][tx] = 0.0f;
        }

        // Load B tile
        if (tile * TILE_SIZE + ty < K && col < N) {
            Bs[ty][tx] = B[(tile * TILE_SIZE + ty) * N + col];
        } else {
            Bs[ty][tx] = 0.0f;
        }

        __syncthreads();

        // Compute partial dot product using the tile
        #pragma unroll
        for (int k = 0; k < TILE_SIZE; ++k) {
            sum += As[ty][k] * Bs[k][tx];
        }

        __syncthreads();
    }

    // Write the result
    if (row < M && col < N) {
        C[row * N + col] = sum;
    }
}

// Hybrid matrix multiplication: chooses custom kernel for small matrices, cuBLAS for larger ones
void matrix_multiply_cuda(const torch::Tensor &A, const torch::Tensor &B, torch::Tensor &C) {
    CHECK_INPUT(A);
    CHECK_INPUT(B);
    CHECK_INPUT(C);

    const int M = A.size(0);
    const int K = A.size(1);
    const int N = B.size(1);

    const float* d_A = A.data_ptr<float>();
    const float* d_B = B.data_ptr<float>();
    float* d_C = C.data_ptr<float>();

    // Heuristic: use custom kernel for small matrices, cuBLAS otherwise.
    if (M <= 128 && N <= 128 && K <= 128) {
        // Launch custom tiled kernel
        dim3 threadsPerBlock(TILE_SIZE, TILE_SIZE);
        dim3 numBlocks((N + TILE_SIZE - 1) / TILE_SIZE, (M + TILE_SIZE - 1) / TILE_SIZE);
        matmul_kernel_2d<<<numBlocks, threadsPerBlock>>>(d_A, d_B, d_C, M, N, K);
    } else {
        // Initialize cuBLAS handle if needed
        if (handle == nullptr) {
            cublasCreate(&handle);
            // Optionally, set math mode to use Tensor Cores if available
            cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH);
        }

        const float alpha = 1.0f;
        const float beta = 0.0f;

        // Note: cuBLAS assumes column-major order. Here we use arguments in a way that allows using row-major data.
        // We swap A and B pointers so that C = A*B is computed correctly.
        cublasSgemm(handle,
                    CUBLAS_OP_N, CUBLAS_OP_N,
                    N, M, K,
                    &alpha,
                    d_B, N,  // B's leading dimension
                    d_A, K,  // A's leading dimension
                    &beta,
                    d_C, N); // C's leading dimension
    }
}

// PyTorch forward interface
torch::Tensor forward(torch::Tensor A, torch::Tensor B) {
    CHECK_INPUT(A);
    CHECK_INPUT(B);

    const int M = A.size(0);
    const int N = B.size(1);

    auto options = torch::TensorOptions()
                       .dtype(A.dtype())
                       .device(A.device())
                       .requires_grad(false);
    
    torch::Tensor C = torch::empty({M, N}, options);
    matrix_multiply_cuda(A, B, C);
    return C;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Hybrid matrix multiplication (CUDA): custom kernel for small matrices and cuBLAS for large matrices");
}
Performance Metrics
Metric Value Unit Variance Samples
Analysis Rules
Rule Description
Operation / Metric Value Unit
aten::to
CPU Time 456977.30 μs
Device Time 5148.42 μs
Self CPU Time 40.91 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_to_copy
CPU Time 456936.40 μs
Device Time 5148.42 μs
Self CPU Time 115.13 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaStreamGetCaptureInfo
CPU Time 7632.97 μs
Device Time 44076.86 μs
Self CPU Time 7632.97 μs
Self Device Time 44076.86 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
sm80_xmma_gemm_f32f32_f32f32_f32_nn_n_tilesize64x64x8_stage3_warpsize1x4x1_ffma_aligna4_alignc4_execute_kernel__51_cublas
CPU Time 0.00 μs
Device Time 3162336.16 μs
Self CPU Time 0.00 μs
Self Device Time 3162336.16 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 3005967.15 μs
Device Time 580578.90 μs
Self CPU Time 15081.78 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::fill_
CPU Time 2990889.05 μs
Device Time 580578.90 μs
Self CPU Time 19870.42 μs
Self Device Time 580578.90 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 2971018.63 μs
Device Time 0.00 μs
Self CPU Time 2971018.63 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>)
CPU Time 0.00 μs
Device Time 580578.90 μs
Self CPU Time 0.00 μs
Self Device Time 580578.90 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45293 warnings generated when compiling for host.
Suppressed 45326 warnings (45279 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b4_s3_hybrid_matmul/base/base.cu:15:35 bugprone-macro-parentheses
15 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
| ^
| ()
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b4_s3_hybrid_matmul/base/base.cu:16:41: warning: macro argument should be enclosed in parentheses [bugprone-macro-parentheses]
16 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
| ^
| ()
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b4_s3_hybrid_matmul/base/base.cu:23:34: warning: 2 adjacent parameters of 'matmul_kernel_2d' of similar type ('const float *__restrict') are easily swapped by mistake [bugprone-easily-swappable-parameters]
23 | __global__ void matmul_kernel_2d(const float* __restrict__ A,
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~
24 | const float* __restrict__ B,
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b4_s3_hybrid_matmul/base/base.cu:23:60: note: the first parameter in the range is 'A'
23 | __global__ void matmul_kernel_2d(const float* __restrict__ A,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b4_s3_hybrid_matmul/base/base.cu:24:60: note: the last parameter in the range is 'B'
24 | const float* __restrict__ B,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b4_s3_hybrid_matmul/base/base.cu:31:20: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
31 | const int bx = blockIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b4_s3_hybrid_matmul/base/base.cu:32:20: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
32 | const int by = blockIdx.y;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b4_s3_hybrid_matmul/base/base.cu:34:20: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
34 | const int tx = threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b4_s3_hybrid_matmul/base/base.cu:35:20: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
35 | const int ty = threadIdx.y;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b4_s3_hybrid_matmul/base/base.cu:82:19: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
82 | const int M = A.size(0);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b4_s3_hybrid_matmul/base/base.cu:83:19: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
83 | const int K = A.size(1);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b4_s3_hybrid_matmul/base/base.cu:84:19: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
84 | const int N = B.size(1);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b4_s3_hybrid_matmul/base/base.cu:121:37: warning: the parameter 'A' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
121 | torch::Tensor forward(torch::Tensor A, torch::Tensor B) {
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b4_s3_hybrid_matmul/base/base.cu:121:54: warning: the parameter 'B' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
121 | torch::Tensor forward(torch::Tensor A, torch::Tensor B) {
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b4_s3_hybrid_matmul/base/base.cu:125:19: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
125 | const int M = A.size(0);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b4_s3_hybrid_matmul/base/base.cu:126:19: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
126 | const int N = B.size(1);
| ^