← Back to Leaderboard

The AI CUDA Engineer 👷

12_Matmul_with_diagonal_matrices_unified_diag_matmul_base

Level 1 • Task 12
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(A, B):
    """
    Performs a matrix multiplication of a diagonal matrix with another matrix.

    Args:
        A (torch.Tensor): A 1D tensor representing the diagonal of the diagonal matrix. Shape: (N,).
        B (torch.Tensor): A 2D tensor representing the second matrix. Shape: (N, M).

    Returns:
        torch.Tensor: The result of the matrix multiplication. Shape: (N, M).
    """
    return torch.diag(A) @ B


class Model(nn.Module):
    """
    Simple model that performs a matrix multiplication of a diagonal matrix with another matrix.
    C = diag(A) * B
    """

    def __init__(self):
        super(Model, self).__init__()

    def forward(self, A, B, fn=module_fn):
        return fn(A, B)


M = 4096
N = 4096


def get_inputs():
    A = torch.randn(N)
    B = torch.randn(N, M)
    return [A, B]


def get_init_inputs():
    return []  # No special initialization inputs needed
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Simple model that performs a matrix multiplication of a diagonal matrix with another matrix.
    C = diag(A) * B
    """
    def __init__(self):
        super(Model, self).__init__()
    
    def forward(self, A, B):
        """
        Performs the matrix multiplication.

        Args:
            A (torch.Tensor): A 1D tensor representing the diagonal of the diagonal matrix. Shape: (N,).
            B (torch.Tensor): A 2D tensor representing the second matrix. Shape: (N, M).

        Returns:
            torch.Tensor: The result of the matrix multiplication. Shape: (N, M).
        """
        return torch.diag(A) @ B

M = 4096
N = 4096

def get_inputs():
    A = torch.randn(N)
    B = torch.randn(N, M)
    return [A, B]

def get_init_inputs():
    return []  # No special initialization inputs needed

Kernel Information

Related Kernels (Level 1, Task 12 • 12_Matmul_with_diagonal_matrices_)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 flat_no_atomic_diag_matmul_base 0.05 54.40 55.46
🥇 shared_mem_diag_matmul_base_base 0.05 54.40 55.46
🥇 hybrid_diag_matmul_base 0.05 54.40 55.46
🥇 adaptive_diag_matmul_base 0.05 54.40 55.46
🥇 optimized_block_size_diag_matmul_base 0.05 54.40 55.46
🥇 hybrid_diag_matmul_base 0.05 54.40 55.46
🥇 adaptive_diag_matmul_base 0.05 54.40 55.46
🥇 diag_matmul_modular_edit_1 0.05 54.40 55.46
🥇 diag_matmul_warp_sync_base 0.05 54.40 55.46
🥇 diag_matmul_shared_min_sync_edit_1 0.05 54.40 55.46
🥇 diag_matmul_readonly_base 0.05 54.40 55.46
🥇 diag_matmul_modular_base 0.05 54.40 55.46
🥇 diag_matmul_shared_min_sync_base 0.05 54.40 55.46
🥇 stride_loop_diag_matmul_base 0.05 54.40 55.46
🥇 shared_memory_optimized_diag_matmul_base 0.05 54.40 55.46
🥇 unified_diag_matmul_base 0.05 54.40 55.46
🥇 diag_matmul_modular_base 0.05 54.40 55.46
🥇 adaptive_diag_matmul_edit_1 0.05 54.40 55.46
19 coalesced_diag_matmul_base 0.05 53.35 54.39
19 diag_matmul_readonly_edit_1 0.05 53.35 54.39
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

// Unified kernel that selects between vectorized and row-based scalar approaches
__global__ void unified_diag_matmul_kernel(
    const float* __restrict__ A,
    const float* __restrict__ B,
    float* __restrict__ C,
    const int64_t N,
    const int64_t M,
    const bool use_vectorized
) {
    if (use_vectorized) {
        // Vectorized branch: works when each row's length M is divisible by 4
        int idx = blockIdx.x * blockDim.x + threadIdx.x;
        int stride = blockDim.x * gridDim.x;
        // Total number of elements in C
        int64_t total = N * M;
        // Each float4 covers 4 consecutive floats
        int64_t vec_total = total / 4;

        // Cast B and C pointers to float4
        const float4* B_vec = reinterpret_cast<const float4*>(B);
        float4* C_vec = reinterpret_cast<float4*>(C);

        for (; idx < vec_total; idx += stride) {
            int base_idx = idx * 4;  // Corresponding starting index in the original array
            int row = base_idx / M;  // Determine the row based on the flat index
            float a_val = A[row];
            
            float4 b_val = B_vec[idx];
            float4 c_val;
            c_val.x = a_val * b_val.x;
            c_val.y = a_val * b_val.y;
            c_val.z = a_val * b_val.z;
            c_val.w = a_val * b_val.w;
            
            C_vec[idx] = c_val;
        }
    } else {
        // Scalar row-based branch using grid-stride loop over rows.
        // Each block will iterate over rows, and threads in the block will collaborate on processing
        // columns within a row for improved memory coalescing.
        for (int row = blockIdx.x; row < N; row += gridDim.x) {
            float a_val = A[row];
            int row_offset = row * M;
            for (int col = threadIdx.x; col < M; col += blockDim.x) {
                int idx = row_offset + col;
                C[idx] = a_val * B[idx];
            }
        }
    }
}

at::Tensor forward(at::Tensor A, at::Tensor B) {
    TORCH_CHECK(A.dim() == 1, "A must be a 1D tensor");
    TORCH_CHECK(B.dim() == 2, "B must be a 2D tensor");
    TORCH_CHECK(A.size(0) == B.size(0), "Dimension mismatch: A.size(0) must match B.size(0)");

    A = A.contiguous();
    B = B.contiguous();

    const int64_t N = A.size(0);
    const int64_t M = B.size(1);
    auto C = torch::empty({N, M}, B.options());

    // Decide which approach to use:
    // Use the vectorized method if M is divisible by 4 and sufficiently large (e.g., M >= 512) 
    // to better harness memory throughput.
    bool use_vectorized = (M % 4 == 0) && (M >= 512);

    if (use_vectorized) {
        const int threads = 256;
        int64_t total = N * M;
        int64_t vec_total = total / 4;
        int blocks = (vec_total + threads - 1) / threads;
        // Clamp grid dimension to hardware limits (max 65535 in x dimension)
        blocks = (blocks > 65535) ? 65535 : blocks;
        unified_diag_matmul_kernel<<<blocks, threads>>>(
            A.data_ptr<float>(), B.data_ptr<float>(), C.data_ptr<float>(),
            N, M, true);
    } else {
        // For the scalar branch, use a grid-stride loop over rows for improved coalescing
        int threads = (M < 256) ? (((M + 31) / 32) * 32) : 256;
        int blocks = (N < 256) ? N : 256;
        unified_diag_matmul_kernel<<<blocks, threads>>>(
            A.data_ptr<float>(), B.data_ptr<float>(), C.data_ptr<float>(),
            N, M, false);
    }

    return C;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Unified diagonal matrix multiplication using vectorized and row-based kernels");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 1.354 inst/cycle 0.000 5
Executed Ipc Elapsed 1.246 inst/cycle 0.000 5
Issue Slots Busy 33.932 % 0.094 5
Issued Ipc Active 1.354 inst/cycle 0.000 5
SM Busy 33.932 % 0.094 5
Memory Throughput 2674143214459.126 byte/second 194593542367397838848.000 5
Mem Busy 46.738 % 0.057 5
Max Bandwidth 79.844 % 0.173 5
L1/TEX Hit Rate 2.700 % 0.000 5
L2 Hit Rate 49.920 % 0.017 5
Mem Pipes Busy 8.882 % 0.002 5
Warp Cycles Per Issued Instruction 37.760 cycle 0.237 5
Warp Cycles Per Executed Instruction 37.862 cycle 0.237 5
Avg. Active Threads Per Warp 32.000 0.000 5
Avg. Not Predicated Off Threads Per Warp 29.370 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 10.000 block 0.000 5
Block Limit Shared Mem 32.000 block 0.000 5
Block Limit Warps 8.000 block 0.000 5
Theoretical Active Warps per SM 64.000 warp 0.000 5
Theoretical Occupancy 100.000 % 0.000 5
Achieved Occupancy 80.638 % 0.029 5
Achieved Active Warps Per SM 51.608 warp 0.012 5
Analysis Rules
Rule Description
INF HighPipeUtilization ALU is the highest-utilized pipeline (20.8%) based on active cycles, taking into account the rates of its different instructions. It executes integer and logic operations. It is well-utilized, but should not be a bottleneck.
INF CPIStall Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason.
WRN Occupancy This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (80.3%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy.
Operation / Metric Value Unit
aten::to
CPU Time 473288.72 μs
Device Time 7231.42 μs
Self CPU Time 47.16 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_to_copy
CPU Time 473241.56 μs
Device Time 7231.42 μs
Self CPU Time 114.22 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::empty_strided
CPU Time 465631.61 μs
Device Time 0.00 μs
Self CPU Time 97.57 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaDeviceGetStreamPriorityRange
CPU Time 465167.36 μs
Device Time 0.00 μs
Self CPU Time 465167.36 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 690512.52 μs
Device Time 17762.41 μs
Self CPU Time 690512.52 μs
Self Device Time 17762.41 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
unified_diag_matmul_kernel(float const*, float const*, float*, long, long, bool)
CPU Time 0.00 μs
Device Time 321922.47 μs
Self CPU Time 0.00 μs
Self Device Time 321922.47 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaEventRecord
CPU Time 18557.60 μs
Device Time 35286.99 μs
Self CPU Time 18557.60 μs
Self Device Time 35286.99 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 280597.46 μs
Device Time 530774.87 μs
Self CPU Time 11239.04 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::fill_
CPU Time 269360.23 μs
Device Time 530774.87 μs
Self CPU Time 13561.14 μs
Self Device Time 530774.87 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>)
CPU Time 0.00 μs
Device Time 530774.87 μs
Self CPU Time 0.00 μs
Self Device Time 530774.87 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45287 warnings generated when compiling for host.
Suppressed 45322 warnings (45275 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b8_s3_unified_diag_matmul/base/base.cu:7:5 bugprone-easily-swappable-parameters
7 | const float* __restrict__ A,
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~
8 | const float* __restrict__ B,
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b8_s3_unified_diag_matmul/base/base.cu:7:31: note: the first parameter in the range is 'A'
7 | const float* __restrict__ A,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b8_s3_unified_diag_matmul/base/base.cu:8:31: note: the last parameter in the range is 'B'
8 | const float* __restrict__ B,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b8_s3_unified_diag_matmul/base/base.cu:16:19: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
16 | int idx = blockIdx.x * blockDim.x + threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b8_s3_unified_diag_matmul/base/base.cu:17:22: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
17 | int stride = blockDim.x * gridDim.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b8_s3_unified_diag_matmul/base/base.cu:29:23: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
29 | int row = base_idx / M; // Determine the row based on the flat index
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b8_s3_unified_diag_matmul/base/base.cu:45:24: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
45 | for (int row = blockIdx.x; row < N; row += gridDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b8_s3_unified_diag_matmul/base/base.cu:45:52: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
45 | for (int row = blockIdx.x; row < N; row += gridDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b8_s3_unified_diag_matmul/base/base.cu:47:30: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
47 | int row_offset = row * M;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b8_s3_unified_diag_matmul/base/base.cu:48:28: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
48 | for (int col = threadIdx.x; col < M; col += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b8_s3_unified_diag_matmul/base/base.cu:48:57: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
48 | for (int col = threadIdx.x; col < M; col += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b8_s3_unified_diag_matmul/base/base.cu:77:22: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
77 | int blocks = (vec_total + threads - 1) / threads;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b8_s3_unified_diag_matmul/base/base.cu:85:35: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
85 | int threads = (M < 256) ? (((M + 31) / 32) * 32) : 256;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_12/b8_s3_unified_diag_matmul/base/base.cu:86:34: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
86 | int blocks = (N < 256) ? N : 256;
| ^