← Back to Leaderboard

The AI CUDA Engineer 👷

2_Standard_matrix_multiplication_warp_optimized_matmul_base_base

Level 1 • Task 2
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(A, B):
    """
    performs a single general matrix multiplication (C = A * B).

    Args:
        A: Input tensor of shape (M, K).
        B: Input tensor of shape (K, N).

    Returns:
        Output tensor of shape (M, N).
    """
    return torch.matmul(A, B)


class Model(nn.Module):
    """
    Simple model that performs a single matrix multiplication (C = A * B)
    """

    def __init__(self):
        super(Model, self).__init__()

    def forward(self, A: torch.Tensor, B: torch.Tensor, fn=module_fn) -> torch.Tensor:
        return fn(A, B)


M = 1024
K = 4096
N = 2048


def get_inputs():
    A = torch.randn(M, K)
    B = torch.randn(K, N)
    return [A, B]


def get_init_inputs():
    return []  # No special initialization inputs needed
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Simple model that performs a single matrix multiplication (C = A * B)
    """
    def __init__(self):
        super(Model, self).__init__()
    
    def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
        """
        Performs matrix multiplication.

        Args:
            A: Input tensor of shape (M, K).
            B: Input tensor of shape (K, N).

        Returns:
            Output tensor of shape (M, N).
        """
        return torch.matmul(A, B)

M = 1024
K = 4096
N = 2048

def get_inputs():
    A = torch.randn(M, K)
    B = torch.randn(K, N)
    return [A, B]

def get_init_inputs():
    return []  # No special initialization inputs needed

Kernel Information

Related Kernels (Level 1, Task 2 • 2_Standard_matrix_multiplication_)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 optimized_tiled_matmul_base 0.43 1.00 1.08
🥇 hybrid_matmul_base 0.43 1.00 1.08
🥇 hybrid_regtiled_base 0.43 1.00 1.08
🥇 double_buffered_matmul_base 0.43 1.00 1.08
5 warp_optimized_matmul_base_base 0.43 1.00 1.08
5 coalesced_hybrid_matmul_base_base 0.43 1.00 1.08
5 strided_tiled_matmul_base 0.43 1.00 1.08
8 hybrid_matmul_base 0.43 0.99 1.07
8 aligned_tiled_matmul_base_base 0.43 0.99 1.07
10 unrolled_hybrid_matmul_base 0.43 0.99 1.07
11 unrolled_hybrid_matmul_base_base 0.43 0.98 1.06
11 dynamic_blocksize_matmul_base 0.43 0.98 1.06
13 doublebuffer_tiled_matmul_base 0.43 0.98 1.06
13 optimized_single_stream_matmul_base 0.43 0.98 1.06
15 hybrid_tiled_cublas_base 0.43 0.98 1.06
16 constant_hybrid_matmul_base_base 0.45 0.96 1.03
17 streamed_pipelined_matmul_base_base 0.45 0.95 1.02
18 tiled_regtile_base 1.26 0.34 0.36
19 optimized_sync_matrix_mult_edit_1 1.92 0.22 0.24
20 divergence_free_matrix_mult_base 1.93 0.22 0.24
#include <torch/extension.h>
#include <cuda_runtime.h>
#include <cuda.h>
#include <cublas_v2.h>

#define WARP_SIZE 32
#define BLOCK_SIZE 128
#define WARPS_PER_BLOCK (BLOCK_SIZE/WARP_SIZE)
#define TILE_SIZE 32

#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

static cublasHandle_t handle = nullptr;

__global__ void warp_optimized_matmul_kernel(const float* __restrict__ A,
                                            const float* __restrict__ B,
                                            float* __restrict__ C,
                                            const int M, const int N, const int K) {
    const int row = blockIdx.y * TILE_SIZE + (threadIdx.x / WARP_SIZE) * (TILE_SIZE/WARPS_PER_BLOCK) + (threadIdx.x % (TILE_SIZE/WARPS_PER_BLOCK));
    const int col = blockIdx.x * TILE_SIZE + threadIdx.y;
    
    float sum = 0.0f;
    
    const int lane = threadIdx.x % WARP_SIZE;
    const int warp_id = threadIdx.x / WARP_SIZE;
    
    __shared__ float shared_data[WARPS_PER_BLOCK][WARP_SIZE];
    
    for (int tile = 0; tile < K; tile += WARP_SIZE) {
        float a_reg = (row < M && (tile + lane) < K) ? A[row * K + tile + lane] : 0.0f;
        float b_reg = ((tile + lane) < K && col < N) ? B[(tile + lane) * N + col] : 0.0f;
        
        #pragma unroll
        for (int k = 0; k < WARP_SIZE; ++k) {
            float a_bc = __shfl_sync(0xffffffff, a_reg, k);
            sum += a_bc * b_reg;
            b_reg = __shfl_up_sync(0xffffffff, b_reg, 1);
        }
    }
    
    if (lane < WARP_SIZE) {
        shared_data[warp_id][lane] = sum;
    }
    __syncthreads();
    
    if (lane < WARPS_PER_BLOCK) {
        float warp_sum = 0.0f;
        #pragma unroll
        for (int i = 0; i < WARP_SIZE; ++i) {
            warp_sum += shared_data[i][lane];
        }
        
        if (row < M && col < N) {
            C[row * N + col] = warp_sum;
        }
    }
}

void matrix_multiply_cuda(const torch::Tensor &A, const torch::Tensor &B, torch::Tensor &C) {
    CHECK_INPUT(A);
    CHECK_INPUT(B);
    CHECK_INPUT(C);

    const int M = A.size(0);
    const int K = A.size(1);
    const int N = B.size(1);

    const float* d_A = A.data_ptr<float>();
    const float* d_B = B.data_ptr<float>();
    float* d_C = C.data_ptr<float>();

    if (M <= 128 && N <= 128 && K <= 128) {
        dim3 block(BLOCK_SIZE);
        dim3 grid((N + TILE_SIZE - 1) / TILE_SIZE, (M + TILE_SIZE - 1) / TILE_SIZE);
        warp_optimized_matmul_kernel<<<grid, block>>>(d_A, d_B, d_C, M, N, K);
    } else {
        if (handle == nullptr) {
            cublasCreate(&handle);
            cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH);
        }

        const float alpha = 1.0f;
        const float beta = 0.0f;
        cublasSgemm(handle,
                    CUBLAS_OP_N, CUBLAS_OP_N,
                    N, M, K,
                    &alpha,
                    d_B, N,
                    d_A, K,
                    &beta,
                    d_C, N);
    }
}

torch::Tensor forward(torch::Tensor A, torch::Tensor B) {
    CHECK_INPUT(A);
    CHECK_INPUT(B);

    const int M = A.size(0);
    const int N = B.size(1);

    auto options = torch::TensorOptions()
                      .dtype(A.dtype())
                      .device(A.device())
                      .requires_grad(false);
    
    torch::Tensor C = torch::empty({M, N}, options);
    matrix_multiply_cuda(A, B, C);
    return C;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Warp-optimized matrix multiplication (CUDA)");
}
Performance Metrics
Metric Value Unit Variance Samples
Analysis Rules
Rule Description
Operation / Metric Value Unit
aten::to
CPU Time 219314.98 μs
Device Time 5199.04 μs
Self CPU Time 46.87 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaStreamGetCaptureInfo
CPU Time 6941.69 μs
Device Time 40654.40 μs
Self CPU Time 6941.69 μs
Self Device Time 40654.40 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaMemsetAsync
CPU Time 359954.28 μs
Device Time 38783.48 μs
Self CPU Time 359954.28 μs
Self Device Time 38783.48 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
sm80_xmma_gemm_f32f32_f32f32_f32_nn_n_tilesize64x64x8_stage3_warpsize1x4x1_ffma_aligna4_alignc4_execute_kernel__51_cublas
CPU Time 0.00 μs
Device Time 2924502.68 μs
Self CPU Time 0.00 μs
Self Device Time 2924502.68 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 2769483.28 μs
Device Time 536755.97 μs
Self CPU Time 14189.99 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::fill_
CPU Time 2755294.94 μs
Device Time 536755.97 μs
Self CPU Time 18902.65 μs
Self Device Time 536755.97 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 2736392.29 μs
Device Time 0.00 μs
Self CPU Time 2736392.29 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>)
CPU Time 0.00 μs
Device Time 536755.97 μs
Self CPU Time 0.00 μs
Self Device Time 536755.97 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45293 warnings generated when compiling for host.
Suppressed 45326 warnings (45279 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b10_s3_warp_optimized_matmul_base/base/base.cu:11:35 bugprone-macro-parentheses
11 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
| ^
| ()
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b10_s3_warp_optimized_matmul_base/base/base.cu:12:41: warning: macro argument should be enclosed in parentheses [bugprone-macro-parentheses]
12 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
| ^
| ()
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b10_s3_warp_optimized_matmul_base/base/base.cu:17:46: warning: 2 adjacent parameters of 'warp_optimized_matmul_kernel' of similar type ('const float *__restrict') are easily swapped by mistake [bugprone-easily-swappable-parameters]
17 | __global__ void warp_optimized_matmul_kernel(const float* __restrict__ A,
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~
18 | const float* __restrict__ B,
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b10_s3_warp_optimized_matmul_base/base/base.cu:17:72: note: the first parameter in the range is 'A'
17 | __global__ void warp_optimized_matmul_kernel(const float* __restrict__ A,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b10_s3_warp_optimized_matmul_base/base/base.cu:18:71: note: the last parameter in the range is 'B'
18 | const float* __restrict__ B,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b10_s3_warp_optimized_matmul_base/base/base.cu:21:21: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
21 | const int row = blockIdx.y * TILE_SIZE + (threadIdx.x / WARP_SIZE) * (TILE_SIZE/WARPS_PER_BLOCK) + (threadIdx.x % (TILE_SIZE/WARPS_PER_BLOCK));
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b10_s3_warp_optimized_matmul_base/base/base.cu:22:21: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
22 | const int col = blockIdx.x * TILE_SIZE + threadIdx.y;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b10_s3_warp_optimized_matmul_base/base/base.cu:26:22: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
26 | const int lane = threadIdx.x % WARP_SIZE;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b10_s3_warp_optimized_matmul_base/base/base.cu:27:25: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
27 | const int warp_id = threadIdx.x / WARP_SIZE;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b10_s3_warp_optimized_matmul_base/base/base.cu:66:19: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
66 | const int M = A.size(0);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b10_s3_warp_optimized_matmul_base/base/base.cu:67:19: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
67 | const int K = A.size(1);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b10_s3_warp_optimized_matmul_base/base/base.cu:68:19: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
68 | const int N = B.size(1);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b10_s3_warp_optimized_matmul_base/base/base.cu:97:37: warning: the parameter 'A' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
97 | torch::Tensor forward(torch::Tensor A, torch::Tensor B) {
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b10_s3_warp_optimized_matmul_base/base/base.cu:97:54: warning: the parameter 'B' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
97 | torch::Tensor forward(torch::Tensor A, torch::Tensor B) {
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b10_s3_warp_optimized_matmul_base/base/base.cu:101:19: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
101 | const int M = A.size(0);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_2/b10_s3_warp_optimized_matmul_base/base/base.cu:102:19: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
102 | const int N = B.size(1);
| ^