← Back to Leaderboard

The AI CUDA Engineer 👷

12_Matmul_with_diagonal_matrices_coalesced_diag_matmul_base

Level 1 • Task 12
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(A, B):
    """
    Performs a matrix multiplication of a diagonal matrix with another matrix.

    Args:
        A (torch.Tensor): A 1D tensor representing the diagonal of the diagonal matrix. Shape: (N,).
        B (torch.Tensor): A 2D tensor representing the second matrix. Shape: (N, M).

    Returns:
        torch.Tensor: The result of the matrix multiplication. Shape: (N, M).
    """
    return torch.diag(A) @ B


class Model(nn.Module):
    """
    Simple model that performs a matrix multiplication of a diagonal matrix with another matrix.
    C = diag(A) * B
    """

    def __init__(self):
        super(Model, self).__init__()

    def forward(self, A, B, fn=module_fn):
        return fn(A, B)


M = 4096
N = 4096


def get_inputs():
    A = torch.randn(N)
    B = torch.randn(N, M)
    return [A, B]


def get_init_inputs():
    return []  # No special initialization inputs needed
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Simple model that performs a matrix multiplication of a diagonal matrix with another matrix.
    C = diag(A) * B
    """
    def __init__(self):
        super(Model, self).__init__()
    
    def forward(self, A, B):
        """
        Performs the matrix multiplication.

        Args:
            A (torch.Tensor): A 1D tensor representing the diagonal of the diagonal matrix. Shape: (N,).
            B (torch.Tensor): A 2D tensor representing the second matrix. Shape: (N, M).

        Returns:
            torch.Tensor: The result of the matrix multiplication. Shape: (N, M).
        """
        return torch.diag(A) @ B

M = 4096
N = 4096

def get_inputs():
    A = torch.randn(N)
    B = torch.randn(N, M)
    return [A, B]

def get_init_inputs():
    return []  # No special initialization inputs needed

Kernel Information

Related Kernels (Level 1, Task 12 • 12_Matmul_with_diagonal_matrices_)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 flat_no_atomic_diag_matmul_base 0.05 54.40 55.46
🥇 shared_mem_diag_matmul_base_base 0.05 54.40 55.46
🥇 hybrid_diag_matmul_base 0.05 54.40 55.46
🥇 adaptive_diag_matmul_base 0.05 54.40 55.46
🥇 optimized_block_size_diag_matmul_base 0.05 54.40 55.46
🥇 hybrid_diag_matmul_base 0.05 54.40 55.46
🥇 adaptive_diag_matmul_base 0.05 54.40 55.46
🥇 diag_matmul_modular_edit_1 0.05 54.40 55.46
🥇 diag_matmul_warp_sync_base 0.05 54.40 55.46
🥇 diag_matmul_shared_min_sync_edit_1 0.05 54.40 55.46
🥇 diag_matmul_readonly_base 0.05 54.40 55.46
🥇 diag_matmul_modular_base 0.05 54.40 55.46
🥇 diag_matmul_shared_min_sync_base 0.05 54.40 55.46
🥇 stride_loop_diag_matmul_base 0.05 54.40 55.46
🥇 shared_memory_optimized_diag_matmul_base 0.05 54.40 55.46
🥇 unified_diag_matmul_base 0.05 54.40 55.46
🥇 diag_matmul_modular_base 0.05 54.40 55.46
🥇 adaptive_diag_matmul_edit_1 0.05 54.40 55.46
19 coalesced_diag_matmul_base 0.05 53.35 54.39
19 diag_matmul_readonly_edit_1 0.05 53.35 54.39
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

// Kernel for when M is divisible by 4: use float4 for coalesced memory accesses
__global__ void coalesced_diag_matmul_vec_kernel(
    const float* __restrict__ A,
    const float* __restrict__ B,
    float* __restrict__ C,
    const int64_t M
) {
    int row = blockIdx.x; // each block processes one row
    float a_val = A[row];

    // Number of float4 elements per row
    int vec_cols = M / 4;

    // Cast row pointers to float4 pointers
    const float4* B_row = reinterpret_cast<const float4*>(B + row * M);
    float4* C_row = reinterpret_cast<float4*>(C + row * M);

    // Each thread processes several consecutive float4 elements
    for (int v = threadIdx.x; v < vec_cols; v += blockDim.x) {
        float4 b_val = B_row[v];
        float4 c_val;
        c_val.x = a_val * b_val.x;
        c_val.y = a_val * b_val.y;
        c_val.z = a_val * b_val.z;
        c_val.w = a_val * b_val.w;
        C_row[v] = c_val;
    }
}

// Fallback kernel for when M is not divisible by 4: process element-wise
__global__ void coalesced_diag_matmul_kernel(
    const float* __restrict__ A,
    const float* __restrict__ B,
    float* __restrict__ C,
    const int64_t M
) {
    int row = blockIdx.x; // each block processes one row
    float a_val = A[row];
    int offset = row * M;
    
    // Each thread processes elements in the row with a fixed stride
    for (int j = threadIdx.x; j < M; j += blockDim.x) {
        C[offset + j] = a_val * B[offset + j];
    }
}

// Forward function that dispatches the appropriate kernel
at::Tensor forward(at::Tensor A, at::Tensor B) {
    TORCH_CHECK(A.dim() == 1, "A must be a 1D tensor");
    TORCH_CHECK(B.dim() == 2, "B must be a 2D tensor");
    TORCH_CHECK(A.size(0) == B.size(0), "Dimension mismatch between A and B");

    A = A.contiguous();
    B = B.contiguous();

    int64_t N = A.size(0);
    int64_t M = B.size(1);

    auto C = torch::empty({N, M}, B.options());

    // If M is divisible by 4 and large enough, use the vectorized kernel
    if (M >= 4 && (M % 4 == 0)) {
        // Use one block per row. Choose thread count based on number of float4 elements
        int threads = (M / 4) < 256 ? (int)(M / 4) : 256;
        dim3 grid(N);
        coalesced_diag_matmul_vec_kernel<<<grid, threads>>>(
            A.data_ptr<float>(),
            B.data_ptr<float>(),
            C.data_ptr<float>(),
            M
        );
    } else {
        // Fallback kernel: one block per row, processing element-wise
        int threads = M < 256 ? M : 256;
        dim3 grid(N);
        coalesced_diag_matmul_kernel<<<grid, threads>>>(
            A.data_ptr<float>(),
            B.data_ptr<float>(),
            C.data_ptr<float>(),
            M
        );
    }

    return C;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Diagonal matrix multiplication with coalesced memory accesses");
}
Performance Metrics
Metric Value Unit Variance Samples
Analysis Rules
Rule Description
Operation / Metric Value Unit
aten::to
CPU Time 357159.90 μs
Device Time 7154.12 μs
Self CPU Time 49.35 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_to_copy
CPU Time 357110.55 μs
Device Time 7154.12 μs
Self CPU Time 107.76 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::empty_strided
CPU Time 349611.14 μs
Device Time 0.00 μs
Self CPU Time 81.71 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaDeviceGetStreamPriorityRange
CPU Time 349160.63 μs
Device Time 0.00 μs
Self CPU Time 349160.63 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 843033.96 μs
Device Time 22841.96 μs
Self CPU Time 843033.96 μs
Self Device Time 22841.96 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
coalesced_diag_matmul_vec_kernel(float const*, float const*, float*, long)
CPU Time 0.00 μs
Device Time 392205.76 μs
Self CPU Time 0.00 μs
Self Device Time 392205.76 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaEventRecord
CPU Time 17828.03 μs
Device Time 42322.72 μs
Self CPU Time 17828.03 μs
Self Device Time 42322.72 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 348241.41 μs
Device Time 632369.54 μs
Self CPU Time 14177.97 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::fill_
CPU Time 334065.16 μs
Device Time 632369.54 μs
Self CPU Time 18803.31 μs
Self Device Time 632369.54 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>)
CPU Time 0.00 μs
Device Time 632369.54 μs
Self CPU Time 0.00 μs
Self Device Time 632369.54 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45286 warnings generated when compiling for host.
Suppressed 45322 warnings (45275 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s2_coalesced_diag_matmul/base/base.cu:7:5 bugprone-easily-swappable-parameters
7 | const float* __restrict__ A,
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~
8 | const float* __restrict__ B,
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s2_coalesced_diag_matmul/base/base.cu:7:31: note: the first parameter in the range is 'A'
7 | const float* __restrict__ A,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s2_coalesced_diag_matmul/base/base.cu:8:31: note: the last parameter in the range is 'B'
8 | const float* __restrict__ B,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s2_coalesced_diag_matmul/base/base.cu:12:15: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
12 | int row = blockIdx.x; // each block processes one row
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s2_coalesced_diag_matmul/base/base.cu:16:20: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
16 | int vec_cols = M / 4;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s2_coalesced_diag_matmul/base/base.cu:23:18: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
23 | for (int v = threadIdx.x; v < vec_cols; v += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s2_coalesced_diag_matmul/base/base.cu:23:50: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
23 | for (int v = threadIdx.x; v < vec_cols; v += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s2_coalesced_diag_matmul/base/base.cu:36:5: warning: 2 adjacent parameters of 'coalesced_diag_matmul_kernel' of similar type ('const float *__restrict') are easily swapped by mistake [bugprone-easily-swappable-parameters]
36 | const float* __restrict__ A,
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~
37 | const float* __restrict__ B,
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s2_coalesced_diag_matmul/base/base.cu:36:31: note: the first parameter in the range is 'A'
36 | const float* __restrict__ A,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s2_coalesced_diag_matmul/base/base.cu:37:31: note: the last parameter in the range is 'B'
37 | const float* __restrict__ B,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s2_coalesced_diag_matmul/base/base.cu:41:15: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
41 | int row = blockIdx.x; // each block processes one row
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s2_coalesced_diag_matmul/base/base.cu:43:18: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
43 | int offset = row * M;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s2_coalesced_diag_matmul/base/base.cu:46:18: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
46 | for (int j = threadIdx.x; j < M; j += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s2_coalesced_diag_matmul/base/base.cu:46:43: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
46 | for (int j = threadIdx.x; j < M; j += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s2_coalesced_diag_matmul/base/base.cu:78:33: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
78 | int threads = M < 256 ? M : 256;
| ^