← Back to Leaderboard

The AI CUDA Engineer 👷

12_Matmul_with_diagonal_matrices_stride_loop_diag_matmul_base

Level 1 • Task 12
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(A, B):
    """
    Performs a matrix multiplication of a diagonal matrix with another matrix.

    Args:
        A (torch.Tensor): A 1D tensor representing the diagonal of the diagonal matrix. Shape: (N,).
        B (torch.Tensor): A 2D tensor representing the second matrix. Shape: (N, M).

    Returns:
        torch.Tensor: The result of the matrix multiplication. Shape: (N, M).
    """
    return torch.diag(A) @ B


class Model(nn.Module):
    """
    Simple model that performs a matrix multiplication of a diagonal matrix with another matrix.
    C = diag(A) * B
    """

    def __init__(self):
        super(Model, self).__init__()

    def forward(self, A, B, fn=module_fn):
        return fn(A, B)


M = 4096
N = 4096


def get_inputs():
    A = torch.randn(N)
    B = torch.randn(N, M)
    return [A, B]


def get_init_inputs():
    return []  # No special initialization inputs needed
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Simple model that performs a matrix multiplication of a diagonal matrix with another matrix.
    C = diag(A) * B
    """
    def __init__(self):
        super(Model, self).__init__()
    
    def forward(self, A, B):
        """
        Performs the matrix multiplication.

        Args:
            A (torch.Tensor): A 1D tensor representing the diagonal of the diagonal matrix. Shape: (N,).
            B (torch.Tensor): A 2D tensor representing the second matrix. Shape: (N, M).

        Returns:
            torch.Tensor: The result of the matrix multiplication. Shape: (N, M).
        """
        return torch.diag(A) @ B

M = 4096
N = 4096

def get_inputs():
    A = torch.randn(N)
    B = torch.randn(N, M)
    return [A, B]

def get_init_inputs():
    return []  # No special initialization inputs needed

Kernel Information

Related Kernels (Level 1, Task 12 • 12_Matmul_with_diagonal_matrices_)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 flat_no_atomic_diag_matmul_base 0.05 54.40 55.46
🥇 shared_mem_diag_matmul_base_base 0.05 54.40 55.46
🥇 hybrid_diag_matmul_base 0.05 54.40 55.46
🥇 adaptive_diag_matmul_base 0.05 54.40 55.46
🥇 optimized_block_size_diag_matmul_base 0.05 54.40 55.46
🥇 hybrid_diag_matmul_base 0.05 54.40 55.46
🥇 adaptive_diag_matmul_base 0.05 54.40 55.46
🥇 diag_matmul_modular_edit_1 0.05 54.40 55.46
🥇 diag_matmul_warp_sync_base 0.05 54.40 55.46
🥇 diag_matmul_shared_min_sync_edit_1 0.05 54.40 55.46
🥇 diag_matmul_readonly_base 0.05 54.40 55.46
🥇 diag_matmul_modular_base 0.05 54.40 55.46
🥇 diag_matmul_shared_min_sync_base 0.05 54.40 55.46
🥇 stride_loop_diag_matmul_base 0.05 54.40 55.46
🥇 shared_memory_optimized_diag_matmul_base 0.05 54.40 55.46
🥇 unified_diag_matmul_base 0.05 54.40 55.46
🥇 diag_matmul_modular_base 0.05 54.40 55.46
🥇 adaptive_diag_matmul_edit_1 0.05 54.40 55.46
19 coalesced_diag_matmul_base 0.05 53.35 54.39
19 diag_matmul_readonly_edit_1 0.05 53.35 54.39
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

// Vectorized kernel using grid-stride loops for matrices where M is divisible by 4
__global__ void stride_loop_vectorized_diag_matmul_kernel(
    const float* __restrict__ A,
    const float* __restrict__ B,
    float* __restrict__ C,
    const int64_t N,
    const int64_t M,
    const int64_t vec_total) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = blockDim.x * gridDim.x;

  // reinterpret B and C as float4 pointers
  const float4* B_vec = reinterpret_cast<const float4*>(B);
  float4* C_vec = reinterpret_cast<float4*>(C);

  // Process elements in grid-stride loop
  for (int i = idx; i < vec_total; i += stride) {
    // Compute the corresponding base index in the scalar array
    int base_idx = i * 4;
    // Boundary check (should always pass if total is divisible by 4)
    if (base_idx < N * M) {
      // Determine row by dividing the base index by number of columns
      int row = base_idx / M;
      float a_val = A[row];
      
      // Load 4 consecutive floats from B
      float4 b_val = B_vec[i];
      float4 c_val;
      c_val.x = a_val * b_val.x;
      c_val.y = a_val * b_val.y;
      c_val.z = a_val * b_val.z;
      c_val.w = a_val * b_val.w;
      
      C_vec[i] = c_val;
    }
  }
}

// Scalar kernel using grid-stride loops
__global__ void stride_loop_scalar_diag_matmul_kernel(
    const float* __restrict__ A,
    const float* __restrict__ B,
    float* __restrict__ C,
    const int64_t total,
    const int64_t M) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = blockDim.x * gridDim.x;
  
  // Each thread processes multiple elements in a grid-stride loop
  for (int i = idx; i < total; i += stride) {
    int row = i / M;
    C[i] = A[row] * B[i];
  }
}

at::Tensor forward(at::Tensor A, at::Tensor B) {
  TORCH_CHECK(A.dim() == 1, "A must be a 1D tensor");
  TORCH_CHECK(B.dim() == 2, "B must be a 2D tensor");
  TORCH_CHECK(A.size(0) == B.size(0), "Dimension mismatch: A.size(0) must match B.size(0)");

  A = A.contiguous();
  B = B.contiguous();

  int64_t N = A.size(0);
  int64_t M = B.size(1);
  int64_t total = N * M;

  auto C = torch::empty({N, M}, B.options());
  
  int threads = 256;
  
  // Use vectorized kernel if M is divisible by 4
  if (M % 4 == 0) {
    int64_t vec_total = total / 4;  // total number of float4 elements
    int blocks = (vec_total + threads - 1) / threads;
    stride_loop_vectorized_diag_matmul_kernel<<<blocks, threads>>>(
        A.data_ptr<float>(), B.data_ptr<float>(), C.data_ptr<float>(), N, M, vec_total);
  } else {
    int blocks = (total + threads - 1) / threads;
    stride_loop_scalar_diag_matmul_kernel<<<blocks, threads>>>(
        A.data_ptr<float>(), B.data_ptr<float>(), C.data_ptr<float>(), total, M);
  }
  
  return C;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &forward, "Stride loop diagonal matrix multiplication");
}
Performance Metrics
Metric Value Unit Variance Samples
Analysis Rules
Rule Description
Operation / Metric Value Unit
aten::to
CPU Time 262277.90 μs
Device Time 7419.80 μs
Self CPU Time 59.49 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_to_copy
CPU Time 262218.41 μs
Device Time 7419.80 μs
Self CPU Time 160.06 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 822846.04 μs
Device Time 22934.23 μs
Self CPU Time 822846.04 μs
Self Device Time 22934.23 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
stride_loop_vectorized_diag_matmul_kernel(float const*, float const*, float*, long, long, long)
CPU Time 0.00 μs
Device Time 384043.89 μs
Self CPU Time 0.00 μs
Self Device Time 384043.89 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaEventRecord
CPU Time 22224.20 μs
Device Time 42327.89 μs
Self CPU Time 22224.20 μs
Self Device Time 42327.89 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 336098.95 μs
Device Time 633135.57 μs
Self CPU Time 14348.66 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::fill_
CPU Time 321756.18 μs
Device Time 633135.57 μs
Self CPU Time 18967.26 μs
Self Device Time 633135.57 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>)
CPU Time 0.00 μs
Device Time 633135.57 μs
Self CPU Time 0.00 μs
Self Device Time 633135.57 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45286 warnings generated when compiling for host.
Suppressed 45322 warnings (45275 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b10_s3_stride_loop_diag_matmul/base/base.cu:7:5 bugprone-easily-swappable-parameters
7 | const float* __restrict__ A,
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~
8 | const float* __restrict__ B,
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b10_s3_stride_loop_diag_matmul/base/base.cu:7:31: note: the first parameter in the range is 'A'
7 | const float* __restrict__ A,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b10_s3_stride_loop_diag_matmul/base/base.cu:8:31: note: the last parameter in the range is 'B'
8 | const float* __restrict__ B,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b10_s3_stride_loop_diag_matmul/base/base.cu:11:5: warning: 2 adjacent parameters of 'stride_loop_vectorized_diag_matmul_kernel' of similar type ('const int64_t') are easily swapped by mistake [bugprone-easily-swappable-parameters]
11 | const int64_t M,
| ^~~~~~~~~~~~~~~~
12 | const int64_t vec_total) {
| ~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b10_s3_stride_loop_diag_matmul/base/base.cu:11:19: note: the first parameter in the range is 'M'
11 | const int64_t M,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b10_s3_stride_loop_diag_matmul/base/base.cu:12:19: note: the last parameter in the range is 'vec_total'
12 | const int64_t vec_total) {
| ^~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b10_s3_stride_loop_diag_matmul/base/base.cu:13:13: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
13 | int idx = blockIdx.x * blockDim.x + threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b10_s3_stride_loop_diag_matmul/base/base.cu:14:16: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
14 | int stride = blockDim.x * gridDim.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b10_s3_stride_loop_diag_matmul/base/base.cu:27:17: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
27 | int row = base_idx / M;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b10_s3_stride_loop_diag_matmul/base/base.cu:48:5: warning: 2 adjacent parameters of 'stride_loop_scalar_diag_matmul_kernel' of similar type ('const int64_t') are easily swapped by mistake [bugprone-easily-swappable-parameters]
48 | const int64_t total,
| ^~~~~~~~~~~~~~~~~~~~
49 | const int64_t M) {
| ~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b10_s3_stride_loop_diag_matmul/base/base.cu:48:19: note: the first parameter in the range is 'total'
48 | const int64_t total,
| ^~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b10_s3_stride_loop_diag_matmul/base/base.cu:49:19: note: the last parameter in the range is 'M'
49 | const int64_t M) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b10_s3_stride_loop_diag_matmul/base/base.cu:50:13: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
50 | int idx = blockIdx.x * blockDim.x + threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b10_s3_stride_loop_diag_matmul/base/base.cu:51:16: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
51 | int stride = blockDim.x * gridDim.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b10_s3_stride_loop_diag_matmul/base/base.cu:55:15: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
55 | int row = i / M;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b10_s3_stride_loop_diag_matmul/base/base.cu:79:18: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
79 | int blocks = (vec_total + threads - 1) / threads;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b10_s3_stride_loop_diag_matmul/base/base.cu:83:18: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
83 | int blocks = (total + threads - 1) / threads;
| ^