← Back to Leaderboard

The AI CUDA Engineer 👷

2_Standard_matrix_multiplication_coalesced_hybrid_matmul_base_base

Level 1 • Task 2
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(A, B):
    """
    performs a single general matrix multiplication (C = A * B).

    Args:
        A: Input tensor of shape (M, K).
        B: Input tensor of shape (K, N).

    Returns:
        Output tensor of shape (M, N).
    """
    return torch.matmul(A, B)


class Model(nn.Module):
    """
    Simple model that performs a single matrix multiplication (C = A * B)
    """

    def __init__(self):
        super(Model, self).__init__()

    def forward(self, A: torch.Tensor, B: torch.Tensor, fn=module_fn) -> torch.Tensor:
        return fn(A, B)


M = 1024
K = 4096
N = 2048


def get_inputs():
    A = torch.randn(M, K)
    B = torch.randn(K, N)
    return [A, B]


def get_init_inputs():
    return []  # No special initialization inputs needed
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Simple model that performs a single matrix multiplication (C = A * B)
    """
    def __init__(self):
        super(Model, self).__init__()
    
    def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
        """
        Performs matrix multiplication.

        Args:
            A: Input tensor of shape (M, K).
            B: Input tensor of shape (K, N).

        Returns:
            Output tensor of shape (M, N).
        """
        return torch.matmul(A, B)

M = 1024
K = 4096
N = 2048

def get_inputs():
    A = torch.randn(M, K)
    B = torch.randn(K, N)
    return [A, B]

def get_init_inputs():
    return []  # No special initialization inputs needed

Kernel Information

Related Kernels (Level 1, Task 2 • 2_Standard_matrix_multiplication_)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 optimized_tiled_matmul_base 0.43 1.00 1.08
🥇 hybrid_matmul_base 0.43 1.00 1.08
🥇 hybrid_regtiled_base 0.43 1.00 1.08
🥇 double_buffered_matmul_base 0.43 1.00 1.08
5 warp_optimized_matmul_base_base 0.43 1.00 1.08
5 coalesced_hybrid_matmul_base_base 0.43 1.00 1.08
5 strided_tiled_matmul_base 0.43 1.00 1.08
8 hybrid_matmul_base 0.43 0.99 1.07
8 aligned_tiled_matmul_base_base 0.43 0.99 1.07
10 unrolled_hybrid_matmul_base 0.43 0.99 1.07
11 unrolled_hybrid_matmul_base_base 0.43 0.98 1.06
11 dynamic_blocksize_matmul_base 0.43 0.98 1.06
13 doublebuffer_tiled_matmul_base 0.43 0.98 1.06
13 optimized_single_stream_matmul_base 0.43 0.98 1.06
15 hybrid_tiled_cublas_base 0.43 0.98 1.06
16 constant_hybrid_matmul_base_base 0.45 0.96 1.03
17 streamed_pipelined_matmul_base_base 0.45 0.95 1.02
18 tiled_regtile_base 1.26 0.34 0.36
19 optimized_sync_matrix_mult_edit_1 1.92 0.22 0.24
20 divergence_free_matrix_mult_base 1.93 0.22 0.24
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>

#define BLOCK_SIZE 32
#define VECTOR_SIZE 4  // Use vector loads for better memory coalescing

#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

static cublasHandle_t handle = nullptr;

// Vectorized load type for better memory coalescing
typedef float4 vector_t;

__global__ void coalesced_matmul_kernel(const float* __restrict__ A,
                                       const float* __restrict__ B,
                                       float* __restrict__ C,
                                       const int M, const int N, const int K) {
    __shared__ float As[BLOCK_SIZE][BLOCK_SIZE];
    __shared__ float Bs[BLOCK_SIZE][BLOCK_SIZE];

    // Block indices
    const int bx = blockIdx.x;
    const int by = blockIdx.y;
    
    // Thread indices
    const int tx = threadIdx.x;
    const int ty = threadIdx.y;

    // Compute base indices for coalesced memory access
    const int row = by * BLOCK_SIZE + ty;
    const int col = bx * BLOCK_SIZE + tx;

    float sum = 0.0f;

    // Loop over tiles with vectorized loads
    for (int tile = 0; tile < (K + BLOCK_SIZE - 1) / BLOCK_SIZE; ++tile) {
        // Compute aligned addresses for vectorized loads
        const int baseIdxA = row * K + tile * BLOCK_SIZE;
        const int baseIdxB = (tile * BLOCK_SIZE) * N + col;
        
        // Load A tile with vectorized reads where possible
        if (row < M && (tile * BLOCK_SIZE + tx) < K) {
            if ((baseIdxA + tx) % VECTOR_SIZE == 0 && tx + VECTOR_SIZE <= BLOCK_SIZE) {
                vector_t v = *reinterpret_cast<const vector_t*>(&A[baseIdxA + tx]);
                As[ty][tx] = v.x;
                if (tx + 1 < BLOCK_SIZE) As[ty][tx + 1] = v.y;
                if (tx + 2 < BLOCK_SIZE) As[ty][tx + 2] = v.z;
                if (tx + 3 < BLOCK_SIZE) As[ty][tx + 3] = v.w;
            } else {
                As[ty][tx] = A[baseIdxA + tx];
            }
        } else {
            As[ty][tx] = 0.0f;
        }

        // Load B tile with vectorized reads where possible
        if ((tile * BLOCK_SIZE + ty) < K && col < N) {
            if ((baseIdxB + ty * N) % VECTOR_SIZE == 0) {
                vector_t v = *reinterpret_cast<const vector_t*>(&B[baseIdxB + ty * N]);
                Bs[ty][tx] = v.x;
            } else {
                Bs[ty][tx] = B[baseIdxB + ty * N];
            }
        } else {
            Bs[ty][tx] = 0.0f;
        }

        __syncthreads();

        // Compute partial dot product for this tile
        #pragma unroll
        for (int k = 0; k < BLOCK_SIZE; ++k) {
            sum += As[ty][k] * Bs[k][tx];
        }

        __syncthreads();
    }

    // Write result with coalesced access
    if (row < M && col < N) {
        C[row * N + col] = sum;
    }
}

void matrix_multiply_cuda(const torch::Tensor &A, const torch::Tensor &B, torch::Tensor &C) {
    CHECK_INPUT(A);
    CHECK_INPUT(B);
    CHECK_INPUT(C);

    const int M = A.size(0);
    const int K = A.size(1);
    const int N = B.size(1);

    const float* d_A = A.data_ptr<float>();
    const float* d_B = B.data_ptr<float>();
    float* d_C = C.data_ptr<float>();

    if (M <= 128 && N <= 128 && K <= 128) {
        dim3 threads(BLOCK_SIZE, BLOCK_SIZE);
        dim3 blocks((N + BLOCK_SIZE - 1) / BLOCK_SIZE,
                   (M + BLOCK_SIZE - 1) / BLOCK_SIZE);

        coalesced_matmul_kernel<<<blocks, threads>>>(d_A, d_B, d_C, M, N, K);
    } else {
        if (handle == nullptr) {
            cublasCreate(&handle);
            cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH);
        }

        const float alpha = 1.0f;
        const float beta = 0.0f;

        cublasSgemm(handle,
                    CUBLAS_OP_N, CUBLAS_OP_N,
                    N, M, K,
                    &alpha,
                    d_B, N,
                    d_A, K,
                    &beta,
                    d_C, N);
    }
}

torch::Tensor forward(torch::Tensor A, torch::Tensor B) {
    CHECK_INPUT(A);
    CHECK_INPUT(B);

    const int M = A.size(0);
    const int N = B.size(1);

    auto options = torch::TensorOptions()
                      .dtype(A.dtype())
                      .device(A.device())
                      .requires_grad(false);
    
    torch::Tensor C = torch::empty({M, N}, options);
    matrix_multiply_cuda(A, B, C);
    return C;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Coalesced matrix multiplication (CUDA)");
}
Performance Metrics
Metric Value Unit Variance Samples
Analysis Rules
Rule Description
Operation / Metric Value Unit
aten::to
CPU Time 329233.55 μs
Device Time 5073.89 μs
Self CPU Time 43.31 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaStreamGetCaptureInfo
CPU Time 8062.59 μs
Device Time 38109.67 μs
Self CPU Time 8062.59 μs
Self Device Time 38109.67 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaMemsetAsync
CPU Time 343678.55 μs
Device Time 36389.25 μs
Self CPU Time 343678.55 μs
Self Device Time 36389.25 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
sm80_xmma_gemm_f32f32_f32f32_f32_nn_n_tilesize64x64x8_stage3_warpsize1x4x1_ffma_aligna4_alignc4_execute_kernel__51_cublas
CPU Time 0.00 μs
Device Time 2746798.49 μs
Self CPU Time 0.00 μs
Self Device Time 2746798.49 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 2601379.12 μs
Device Time 504503.24 μs
Self CPU Time 13566.13 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::fill_
CPU Time 2587830.55 μs
Device Time 504503.24 μs
Self CPU Time 17040.73 μs
Self Device Time 504503.24 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 2570789.82 μs
Device Time 0.00 μs
Self CPU Time 2570789.82 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>)
CPU Time 0.00 μs
Device Time 504503.24 μs
Self CPU Time 0.00 μs
Self Device Time 504503.24 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45293 warnings generated when compiling for host.
Suppressed 45326 warnings (45279 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b9_s3_coalesced_hybrid_matmul_base/base/base.cu:9:35 bugprone-macro-parentheses
9 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
| ^
| ()
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b9_s3_coalesced_hybrid_matmul_base/base/base.cu:10:41: warning: macro argument should be enclosed in parentheses [bugprone-macro-parentheses]
10 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
| ^
| ()
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b9_s3_coalesced_hybrid_matmul_base/base/base.cu:18:41: warning: 2 adjacent parameters of 'coalesced_matmul_kernel' of similar type ('const float *__restrict') are easily swapped by mistake [bugprone-easily-swappable-parameters]
18 | __global__ void coalesced_matmul_kernel(const float* __restrict__ A,
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~
19 | const float* __restrict__ B,
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b9_s3_coalesced_hybrid_matmul_base/base/base.cu:18:67: note: the first parameter in the range is 'A'
18 | __global__ void coalesced_matmul_kernel(const float* __restrict__ A,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b9_s3_coalesced_hybrid_matmul_base/base/base.cu:19:66: note: the last parameter in the range is 'B'
19 | const float* __restrict__ B,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b9_s3_coalesced_hybrid_matmul_base/base/base.cu:26:20: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
26 | const int bx = blockIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b9_s3_coalesced_hybrid_matmul_base/base/base.cu:27:20: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
27 | const int by = blockIdx.y;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b9_s3_coalesced_hybrid_matmul_base/base/base.cu:30:20: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
30 | const int tx = threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b9_s3_coalesced_hybrid_matmul_base/base/base.cu:31:20: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
31 | const int ty = threadIdx.y;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b9_s3_coalesced_hybrid_matmul_base/base/base.cu:94:19: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
94 | const int M = A.size(0);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b9_s3_coalesced_hybrid_matmul_base/base/base.cu:95:19: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
95 | const int K = A.size(1);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b9_s3_coalesced_hybrid_matmul_base/base/base.cu:96:19: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
96 | const int N = B.size(1);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b9_s3_coalesced_hybrid_matmul_base/base/base.cu:128:37: warning: the parameter 'A' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
128 | torch::Tensor forward(torch::Tensor A, torch::Tensor B) {
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b9_s3_coalesced_hybrid_matmul_base/base/base.cu:128:54: warning: the parameter 'B' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
128 | torch::Tensor forward(torch::Tensor A, torch::Tensor B) {
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b9_s3_coalesced_hybrid_matmul_base/base/base.cu:132:19: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
132 | const int M = A.size(0);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b9_s3_coalesced_hybrid_matmul_base/base/base.cu:133:19: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
133 | const int N = B.size(1);
| ^