← Back to Leaderboard

The AI CUDA Engineer 👷

12_Matmul_with_diagonal_matrices_flat_no_atomic_diag_matmul_base

Level 1 • Task 12
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(A, B):
    """
    Performs a matrix multiplication of a diagonal matrix with another matrix.

    Args:
        A (torch.Tensor): A 1D tensor representing the diagonal of the diagonal matrix. Shape: (N,).
        B (torch.Tensor): A 2D tensor representing the second matrix. Shape: (N, M).

    Returns:
        torch.Tensor: The result of the matrix multiplication. Shape: (N, M).
    """
    return torch.diag(A) @ B


class Model(nn.Module):
    """
    Simple model that performs a matrix multiplication of a diagonal matrix with another matrix.
    C = diag(A) * B
    """

    def __init__(self):
        super(Model, self).__init__()

    def forward(self, A, B, fn=module_fn):
        return fn(A, B)


M = 4096
N = 4096


def get_inputs():
    A = torch.randn(N)
    B = torch.randn(N, M)
    return [A, B]


def get_init_inputs():
    return []  # No special initialization inputs needed
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Simple model that performs a matrix multiplication of a diagonal matrix with another matrix.
    C = diag(A) * B
    """
    def __init__(self):
        super(Model, self).__init__()
    
    def forward(self, A, B):
        """
        Performs the matrix multiplication.

        Args:
            A (torch.Tensor): A 1D tensor representing the diagonal of the diagonal matrix. Shape: (N,).
            B (torch.Tensor): A 2D tensor representing the second matrix. Shape: (N, M).

        Returns:
            torch.Tensor: The result of the matrix multiplication. Shape: (N, M).
        """
        return torch.diag(A) @ B

M = 4096
N = 4096

def get_inputs():
    A = torch.randn(N)
    B = torch.randn(N, M)
    return [A, B]

def get_init_inputs():
    return []  # No special initialization inputs needed

Kernel Information

Related Kernels (Level 1, Task 12 • 12_Matmul_with_diagonal_matrices_)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 flat_no_atomic_diag_matmul_base 0.05 54.40 55.46
🥇 shared_mem_diag_matmul_base_base 0.05 54.40 55.46
🥇 hybrid_diag_matmul_base 0.05 54.40 55.46
🥇 adaptive_diag_matmul_base 0.05 54.40 55.46
🥇 optimized_block_size_diag_matmul_base 0.05 54.40 55.46
🥇 hybrid_diag_matmul_base 0.05 54.40 55.46
🥇 adaptive_diag_matmul_base 0.05 54.40 55.46
🥇 diag_matmul_modular_edit_1 0.05 54.40 55.46
🥇 diag_matmul_warp_sync_base 0.05 54.40 55.46
🥇 diag_matmul_shared_min_sync_edit_1 0.05 54.40 55.46
🥇 diag_matmul_readonly_base 0.05 54.40 55.46
🥇 diag_matmul_modular_base 0.05 54.40 55.46
🥇 diag_matmul_shared_min_sync_base 0.05 54.40 55.46
🥇 stride_loop_diag_matmul_base 0.05 54.40 55.46
🥇 shared_memory_optimized_diag_matmul_base 0.05 54.40 55.46
🥇 unified_diag_matmul_base 0.05 54.40 55.46
🥇 diag_matmul_modular_base 0.05 54.40 55.46
🥇 adaptive_diag_matmul_edit_1 0.05 54.40 55.46
19 coalesced_diag_matmul_base 0.05 53.35 54.39
19 diag_matmul_readonly_edit_1 0.05 53.35 54.39
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

// This kernel is used when the number of columns (M) is divisible by 4.
// It uses vectorized loads/stores (float4) for improved memory throughput.
// Note: No atomic operations are used, as each thread computes a unique output element.
__global__ void flat_vectorized_diag_matmul_kernel(
    const float* __restrict__ A,
    const float* __restrict__ B,
    float* __restrict__ C,
    const int64_t N,
    const int64_t M,
    const int64_t vec_total) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = blockDim.x * gridDim.x;
  const float4* B_vec = reinterpret_cast<const float4*>(B);
  float4* C_vec = reinterpret_cast<float4*>(C);
  
  for (; idx < vec_total; idx += stride) {
    int base_idx = idx * 4;  // Corresponding index in the original array
    int row = base_idx / M;
    float a_val = A[row];

    float4 b_val = B_vec[idx];
    float4 c_val;
    c_val.x = a_val * b_val.x;
    c_val.y = a_val * b_val.y;
    c_val.z = a_val * b_val.z;
    c_val.w = a_val * b_val.w;

    C_vec[idx] = c_val;
  }
}

// This kernel is used when vectorized access is not possible (i.e., M is not divisible by 4).
// Each thread computes a unique output element using a flat grid-stride loop.
// Atomic operations are not needed since there is a one-to-one mapping between threads and output elements.
__global__ void flat_scalar_diag_matmul_kernel(
    const float* __restrict__ A,
    const float* __restrict__ B,
    float* __restrict__ C,
    const int64_t N,
    const int64_t M,
    const int64_t total) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = blockDim.x * gridDim.x;
  for (; idx < total; idx += stride) {
    int row = idx / M;
    C[idx] = A[row] * B[idx];
  }
}

at::Tensor forward(at::Tensor A, at::Tensor B) {
  TORCH_CHECK(A.dim() == 1, "A must be a 1D tensor");
  TORCH_CHECK(B.dim() == 2, "B must be a 2D tensor");
  TORCH_CHECK(A.size(0) == B.size(0), "Dimension mismatch: A.size(0) must match B.size(0)");

  A = A.contiguous();
  B = B.contiguous();

  int64_t N = A.size(0);
  int64_t M = B.size(1);
  int64_t total = N * M;
  auto C = torch::empty({N, M}, B.options());

  int threads = 256;
  
  // If M is divisible by 4, use the vectorized kernel for improved throughput
  if (M % 4 == 0) {
    int64_t vec_total = total / 4;
    int blocks = (vec_total + threads - 1) / threads;
    flat_vectorized_diag_matmul_kernel<<<blocks, threads>>>(
        A.data_ptr<float>(), B.data_ptr<float>(), C.data_ptr<float>(), N, M, vec_total);
  } else {
    int blocks = (total + threads - 1) / threads;
    flat_scalar_diag_matmul_kernel<<<blocks, threads>>>(
        A.data_ptr<float>(), B.data_ptr<float>(), C.data_ptr<float>(), N, M, total);
  }

  return C;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &forward, "Flat diagonal matrix multiplication without unnecessary atomic operations");
}
Performance Metrics
Metric Value Unit Variance Samples
Analysis Rules
Rule Description
Operation / Metric Value Unit
aten::to
CPU Time 398134.46 μs
Device Time 7079.00 μs
Self CPU Time 44.74 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_to_copy
CPU Time 398089.72 μs
Device Time 7079.00 μs
Self CPU Time 102.73 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::empty_strided
CPU Time 390667.25 μs
Device Time 0.00 μs
Self CPU Time 98.22 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaDeviceGetStreamPriorityRange
CPU Time 390211.26 μs
Device Time 0.00 μs
Self CPU Time 390211.26 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 771482.40 μs
Device Time 19689.85 μs
Self CPU Time 771482.40 μs
Self Device Time 19689.85 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
flat_vectorized_diag_matmul_kernel(float const*, float const*, float*, long, long, long)
CPU Time 0.00 μs
Device Time 355308.11 μs
Self CPU Time 0.00 μs
Self Device Time 355308.11 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaEventRecord
CPU Time 16512.41 μs
Device Time 39285.87 μs
Self CPU Time 16512.41 μs
Self Device Time 39285.87 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 313526.36 μs
Device Time 586576.03 μs
Self CPU Time 12857.71 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::fill_
CPU Time 300672.30 μs
Device Time 586576.03 μs
Self CPU Time 15785.41 μs
Self Device Time 586576.03 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>)
CPU Time 0.00 μs
Device Time 586576.03 μs
Self CPU Time 0.00 μs
Self Device Time 586576.03 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45286 warnings generated when compiling for host.
Suppressed 45322 warnings (45275 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s3_flat_no_atomic_diag_matmul/base/base.cu:9:5 bugprone-easily-swappable-parameters
9 | const float* __restrict__ A,
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~
10 | const float* __restrict__ B,
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s3_flat_no_atomic_diag_matmul/base/base.cu:9:31: note: the first parameter in the range is 'A'
9 | const float* __restrict__ A,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s3_flat_no_atomic_diag_matmul/base/base.cu:10:31: note: the last parameter in the range is 'B'
10 | const float* __restrict__ B,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s3_flat_no_atomic_diag_matmul/base/base.cu:12:5: warning: 3 adjacent parameters of 'flat_vectorized_diag_matmul_kernel' of similar type ('const int64_t') are easily swapped by mistake [bugprone-easily-swappable-parameters]
12 | const int64_t N,
| ^~~~~~~~~~~~~~~~
13 | const int64_t M,
| ~~~~~~~~~~~~~~~~
14 | const int64_t vec_total) {
| ~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s3_flat_no_atomic_diag_matmul/base/base.cu:12:19: note: the first parameter in the range is 'N'
12 | const int64_t N,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s3_flat_no_atomic_diag_matmul/base/base.cu:14:19: note: the last parameter in the range is 'vec_total'
14 | const int64_t vec_total) {
| ^~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s3_flat_no_atomic_diag_matmul/base/base.cu:15:13: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
15 | int idx = blockIdx.x * blockDim.x + threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s3_flat_no_atomic_diag_matmul/base/base.cu:16:16: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
16 | int stride = blockDim.x * gridDim.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s3_flat_no_atomic_diag_matmul/base/base.cu:22:15: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
22 | int row = base_idx / M;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s3_flat_no_atomic_diag_matmul/base/base.cu:43:5: warning: 3 adjacent parameters of 'flat_scalar_diag_matmul_kernel' of similar type ('const int64_t') are easily swapped by mistake [bugprone-easily-swappable-parameters]
43 | const int64_t N,
| ^~~~~~~~~~~~~~~~
44 | const int64_t M,
| ~~~~~~~~~~~~~~~~
45 | const int64_t total) {
| ~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s3_flat_no_atomic_diag_matmul/base/base.cu:43:19: note: the first parameter in the range is 'N'
43 | const int64_t N,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s3_flat_no_atomic_diag_matmul/base/base.cu:45:19: note: the last parameter in the range is 'total'
45 | const int64_t total) {
| ^~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s3_flat_no_atomic_diag_matmul/base/base.cu:46:13: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
46 | int idx = blockIdx.x * blockDim.x + threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s3_flat_no_atomic_diag_matmul/base/base.cu:47:16: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
47 | int stride = blockDim.x * gridDim.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s3_flat_no_atomic_diag_matmul/base/base.cu:49:15: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
49 | int row = idx / M;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s3_flat_no_atomic_diag_matmul/base/base.cu:72:18: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
72 | int blocks = (vec_total + threads - 1) / threads;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b6_s3_flat_no_atomic_diag_matmul/base/base.cu:76:18: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
76 | int blocks = (total + threads - 1) / threads;
| ^