← Back to Leaderboard

The AI CUDA Engineer 👷

22_Matmul_Scale_ResidualAdd_Clamp_LogSumExp_Mishaligned_ldg_fused_kernel_base

Level 2 • Task 22
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(
    x: torch.Tensor,
    scale_factor: float,
    clamp_min: float,
    clamp_max: float,
    weight: torch.Tensor,
    bias: torch.Tensor,
) -> torch.Tensor:
    """
    Applies matrix multiplication, scaling, residual connection, clamping, LogSumExp and Mish activation.

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, input_size)
        scale_factor (float): Factor to scale the output by
        clamp_min (float): Minimum value for clamping
        clamp_max (float): Maximum value for clamping
        weight (torch.Tensor): Weight matrix of shape (hidden_size, input_size)
        bias (torch.Tensor): Bias vector of shape (hidden_size)

    Returns:
        torch.Tensor: Output tensor of shape (batch_size, hidden_size)
    """
    x = F.linear(x, weight, bias)
    x = x * scale_factor
    x = x + x
    x = torch.clamp(x, clamp_min, clamp_max)
    x = torch.logsumexp(x, dim=1, keepdim=True)
    x = x * F.mish(x)
    return x


class Model(nn.Module):
    """
    Model that performs a matrix multiplication, scales the result, adds a residual connection, clamps the output,
    applies LogSumExp, and finally applies the Mish activation function.
    """

    def __init__(self, input_size, hidden_size, scale_factor, clamp_min, clamp_max):
        super(Model, self).__init__()
        matmul = nn.Linear(input_size, hidden_size)
        self.weight = matmul.weight
        self.bias = nn.Parameter(
            matmul.bias + torch.ones_like(matmul.bias) * 0.02
        )  # make sure its nonzero

    def forward(self, x, scale_factor, clamp_min, clamp_max, fn=module_fn):
        return fn(x, scale_factor, clamp_min, clamp_max, self.weight, self.bias)


batch_size = 128
input_size = 512
hidden_size = 1024
scale_factor = 2.0
clamp_min = -10.0
clamp_max = 10.0


def get_inputs():
    return [torch.randn(batch_size, input_size), scale_factor, clamp_min, clamp_max]


def get_init_inputs():
    return [input_size, hidden_size, scale_factor, clamp_min, clamp_max]
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Model that performs a matrix multiplication, scales the result, adds a residual connection, clamps the output,
    applies LogSumExp, and finally applies the Mish activation function.
    """
    def __init__(self, input_size, hidden_size, scale_factor, clamp_min, clamp_max):
        super(Model, self).__init__()
        self.matmul = nn.Linear(input_size, hidden_size)
        self.matmul.bias = nn.Parameter(self.matmul.bias + torch.ones_like(self.matmul.bias) * 0.02)
        self.scale_factor = scale_factor
        self.clamp_min = clamp_min
        self.clamp_max = clamp_max

    def forward(self, x):
        """
        Args:
            x: Input tensor of shape (batch_size, input_size).

        Returns:
            Output tensor of shape (batch_size, hidden_size).
        """
        x = self.matmul(x)
        x = x * self.scale_factor
        x = x + x
        x = torch.clamp(x, self.clamp_min, self.clamp_max)
        x = torch.logsumexp(x, dim=1, keepdim=True)
        x = x * torch.nn.functional.mish(x)  # Mish activation
        return x

batch_size = 128
input_size = 512
hidden_size = 1024
scale_factor = 2.0
clamp_min = -10.0
clamp_max = 10.0

def get_inputs():
    return [torch.randn(batch_size, input_size)]

def get_init_inputs():
    return [input_size, hidden_size, scale_factor, clamp_min, clamp_max]

Kernel Information

Related Kernels (Level 2, Task 22 • 22_Matmul_Scale_ResidualAdd_Clamp_LogSumExp_Mish)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 22_matmul_scale_residualadd_clamp_logsumexp_mish_syncthreads_optimized_edit_1 0.03 2.19 1.49
🥇 22_matmul_scale_residualadd_clamp_logsumexp_mish_syncthreads_optimized_base 0.03 2.19 1.49
🥇 22_matmul_scale_residualadd_clamp_logsumexp_mish_even_workload_edit_1 0.03 2.19 1.49
🥇 22_matmul_scale_residualadd_clamp_logsumexp_mish_ldg_aligned_edit_1 0.03 2.19 1.49
🥇 22_matmul_scale_residualadd_clamp_logsumexp_mish_ldg_aligned_base 0.03 2.19 1.49
6 22_matmul_scale_residualadd_clamp_logsumexp_mish_warp_optimized_base 0.03 2.12 1.44
6 22_matmul_scale_residualadd_clamp_logsumexp_mish_shared_memory_warp_optimized_base 0.03 2.12 1.44
6 22_matmul_scale_residualadd_clamp_logsumexp_mish_shared_memory_warp_optimized_edit_1 0.03 2.12 1.44
6 22_matmul_scale_residualadd_clamp_logsumexp_mish_stride_loop_base 0.03 2.12 1.44
6 22_matmul_scale_residualadd_clamp_logsumexp_mish_even_workload_base 0.03 2.12 1.44
6 22_matmul_scale_residualadd_clamp_logsumexp_mish_2dindex_optimized_base 0.03 2.12 1.44
12 coalesced_fused_kernel_base_base 0.04 1.76 1.20
12 block_size_optimization_base 0.04 1.76 1.20
12 atomic_optimized_fused_kernel_base 0.04 1.76 1.20
15 reduced_sync_fused_kernel_base 0.04 1.67 1.14
16 22_matmul_scale_residualadd_clamp_logsumexp_mish_shared_memory_optimized_base 0.04 1.63 1.11
16 fused_kernel_base 0.04 1.63 1.11
16 22_matmul_scale_residualadd_clamp_logsumexp_mish_shared_memory_optimized_edit_1 0.04 1.63 1.11
19 aligned_ldg_fused_kernel_base 0.04 1.44 0.98
20 unroll_optimization_base_base 0.04 1.41 0.96
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <cuda_runtime.h>
#include <vector>
#include <math.h>

// Define tile and block sizes
#define TILE_DIM 16
#define BLOCK_SIZE 128

// GEMM kernel using __ldg() for read-only global memory loads and assuming data is 128-bit aligned
__global__ void aligned_ldg_gemm_kernel(
    const float* __restrict__ input,
    const float* __restrict__ weight,
    const float* __restrict__ bias,
    float* __restrict__ intermediate,
    int batch_size,
    int input_size,
    int hidden_size,
    float scale_factor,
    float clamp_min,
    float clamp_max
) {
    // Shared memory with padding to avoid bank conflicts
    __shared__ float input_shared[TILE_DIM][TILE_DIM + 1];
    __shared__ float weight_shared[TILE_DIM][TILE_DIM + 1];

    // Calculate starting indices
    int by = blockIdx.y * TILE_DIM;
    int bx = blockIdx.x * TILE_DIM;
    int tx = threadIdx.x % TILE_DIM;
    int ty = threadIdx.x / TILE_DIM;

    float sum = 0.0f;
    
    // Loop over tiles
    int numTiles = (input_size + TILE_DIM - 1) / TILE_DIM;
    for (int tile = 0; tile < numTiles; ++tile) {
        int aRow = by + ty;
        int aCol = tile * TILE_DIM + tx;
        if (aRow < batch_size && aCol < input_size) {
            // Use __ldg() for read-only access; assumes input is 128-bit aligned
            input_shared[ty][tx] = __ldg(&input[aRow * input_size + aCol]);
        } else {
            input_shared[ty][tx] = 0.0f;
        }

        int bRow = tile * TILE_DIM + ty;
        int bCol = bx + tx;
        if (bCol < hidden_size && bRow < input_size) {
            // Note: weight is stored as [hidden_size, input_size] and is accessed in a transposed manner
            weight_shared[ty][tx] = __ldg(&weight[bCol * input_size + bRow]);
        } else {
            weight_shared[ty][tx] = 0.0f;
        }

        __syncthreads();

        // Compute partial dot product
        #pragma unroll
        for (int k = 0; k < TILE_DIM; ++k) {
            sum += input_shared[ty][k] * weight_shared[k][tx];
        }
        __syncthreads();
    }

    int row = by + ty;
    int col = bx + tx;
    if (row < batch_size && col < hidden_size) {
        // Use __ldg() for bias load
        sum += __ldg(&bias[col]);
        sum *= scale_factor;
        sum += sum; // residual addition (doubling the value)
        sum = fminf(fmaxf(sum, clamp_min), clamp_max);
        intermediate[row * hidden_size + col] = sum;
    }
}

// Reduction kernel using __ldg() for reading intermediate results
__global__ void aligned_ldg_reduction_kernel(
    const float* __restrict__ intermediate,
    float* __restrict__ output,
    int hidden_size,
    int batch_size
) {
    __shared__ float sdata[BLOCK_SIZE];
    int tid = threadIdx.x;
    int row = blockIdx.x;

    // Each thread loads multiple elements using a stride of BLOCK_SIZE
    float thread_max = -INFINITY;
    for (int i = tid; i < hidden_size; i += BLOCK_SIZE) {
        float val = __ldg(&intermediate[row * hidden_size + i]);
        thread_max = fmaxf(thread_max, val);
    }
    sdata[tid] = thread_max;
    __syncthreads();

    // Reduction to find maximum value in the row
    for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
        if (tid < s) {
            sdata[tid] = fmaxf(sdata[tid], sdata[tid + s]);
        }
        __syncthreads();
    }

    float row_max = sdata[0];
    __syncthreads();

    // Compute sum of exponentials
    float sum_exp = 0.0f;
    for (int i = tid; i < hidden_size; i += BLOCK_SIZE) {
        float val = __ldg(&intermediate[row * hidden_size + i]);
        sum_exp += expf(val - row_max);
    }
    sdata[tid] = sum_exp;
    __syncthreads();

    // Reduction to sum up the exponential terms
    for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
        if (tid < s) {
            sdata[tid] += sdata[tid + s];
        }
        __syncthreads();
    }

    if (tid == 0) {
        float lse = row_max + logf(sdata[0]);
        float softplus = logf(1.0f + expf(lse));
        float mish_val = lse * tanhf(softplus);
        output[row] = lse * mish_val;
    }
}

// Host function that launches the fused kernels
torch::Tensor aligned_ldg_forward(
    torch::Tensor x,
    float scale_factor,
    float clamp_min,
    float clamp_max,
    torch::Tensor weight,
    torch::Tensor bias
) {
    int batch_size = x.size(0);
    int input_size = x.size(1);
    int hidden_size = weight.size(0);

    auto options = torch::TensorOptions().dtype(x.dtype()).device(x.device());
    auto intermediate = torch::empty({batch_size, hidden_size}, options);

    // Launch GEMM kernel
    dim3 gemm_grid((hidden_size + TILE_DIM - 1) / TILE_DIM,
                   (batch_size + TILE_DIM - 1) / TILE_DIM);
    dim3 gemm_block(TILE_DIM * TILE_DIM);

    aligned_ldg_gemm_kernel<<<gemm_grid, gemm_block>>>(
        x.data_ptr<float>(),
        weight.data_ptr<float>(),
        bias.data_ptr<float>(),
        intermediate.data_ptr<float>(),
        batch_size,
        input_size,
        hidden_size,
        scale_factor,
        clamp_min,
        clamp_max
    );

    // Launch reduction kernel
    auto output = torch::empty({batch_size}, options);
    aligned_ldg_reduction_kernel<<<batch_size, BLOCK_SIZE>>>(
        intermediate.data_ptr<float>(),
        output.data_ptr<float>(),
        hidden_size,
        batch_size
    );

    return output.reshape({batch_size, 1});
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &aligned_ldg_forward, "Fused forward pass with aligned __ldg() optimizations (CUDA)");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 0.300 inst/cycle 0.000 3
Executed Ipc Elapsed 0.130 inst/cycle 0.000 3
Issue Slots Busy 7.693 % 0.002 3
Issued Ipc Active 0.310 inst/cycle 0.000 3
SM Busy 7.693 % 0.002 3
Memory Throughput 122660225442.837 byte/second 1603232021498351616.000 3
Mem Busy 8.377 % 0.010 3
Max Bandwidth 4.510 % 0.001 3
L1/TEX Hit Rate 49.810 % 0.000 3
L2 Hit Rate 66.637 % 0.001 3
Mem Pipes Busy 3.537 % 0.001 3
Warp Cycles Per Issued Instruction 12.080 cycle 0.002 3
Warp Cycles Per Executed Instruction 12.283 cycle 0.002 3
Avg. Active Threads Per Warp 31.140 0.000 3
Avg. Not Predicated Off Threads Per Warp 23.760 0.000 3
Max Active Clusters 0.000 cluster 0.000 3
Max Cluster Size 8.000 block 0.000 3
Overall GPU Occupancy 0.000 % 0.000 3
Cluster Occupancy 0.000 % 0.000 3
Block Limit SM 32.000 block 0.000 3
Block Limit Registers 16.000 block 0.000 3
Block Limit Shared Mem 42.000 block 0.000 3
Block Limit Warps 16.000 block 0.000 3
Theoretical Active Warps per SM 64.000 warp 0.000 3
Theoretical Occupancy 100.000 % 0.000 3
Achieved Occupancy 5.890 % 0.000 3
Achieved Active Warps Per SM 3.770 warp 0.000 3
Analysis Rules
Rule Description
WRN HighPipeUtilization All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details.
WRN ThreadDivergence Instructions are executed in warps, which are groups of 32 threads. Optimal instruction throughput is achieved if all 32 threads of a warp execute the same instruction. The chosen launch configuration, early thread completion, and divergent flow control can significantly lower the number of active threads in a warp per cycle. This kernel achieves an average of 31.1 threads being active per cycle. This is further reduced to 23.8 threads per warp due to predication. The compiler may use predication to avoid an actual branch. Instead, all instructions are scheduled, but a per-thread condition code or predicate controls which threads execute the instructions. Try to avoid different execution paths within a warp when possible. In addition, ensure your kernel makes use of Independent Thread Scheduling, which allows a warp to reconverge after a data-dependent conditional block by explicitly calling __syncwarp().
WRN Occupancy This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (5.9%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy.
Operation / Metric Value Unit
aten::to
CPU Time 487931.45 μs
Device Time 220.61 μs
Self CPU Time 62.51 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_to_copy
CPU Time 487868.94 μs
Device Time 220.61 μs
Self CPU Time 124.33 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::empty_strided
CPU Time 487210.57 μs
Device Time 0.00 μs
Self CPU Time 135.52 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaDeviceGetStreamPriorityRange
CPU Time 478332.88 μs
Device Time 0.00 μs
Self CPU Time 478332.88 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::fill_
CPU Time 46891.49 μs
Device Time 231064.84 μs
Self CPU Time 12873.35 μs
Self Device Time 231064.84 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 148323.06 μs
Device Time 12849.41 μs
Self CPU Time 148323.06 μs
Self Device Time 12849.41 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aligned_ldg_gemm_kernel(float const*, float const*, float const*, float*, int, int, int, float, float, float)
CPU Time 0.00 μs
Device Time 108811.26 μs
Self CPU Time 0.00 μs
Self Device Time 108811.26 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 55930.15 μs
Device Time 231064.84 μs
Self CPU Time 9046.96 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>)
CPU Time 0.00 μs
Device Time 231064.84 μs
Self CPU Time 0.00 μs
Self Device Time 231064.84 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45294 warnings generated when compiling for host.
Suppressed 45326 warnings (45279 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_22/b10_s3_aligned_ldg_fused_kernel/base/base.cu:13:5 bugprone-easily-swappable-parameters
13 | const float* __restrict__ input,
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
14 | const float* __restrict__ weight,
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
15 | const float* __restrict__ bias,
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_22/b10_s3_aligned_ldg_fused_kernel/base/base.cu:13:31: note: the first parameter in the range is 'input'
13 | const float* __restrict__ input,
| ^~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_22/b10_s3_aligned_ldg_fused_kernel/base/base.cu:15:31: note: the last parameter in the range is 'bias'
15 | const float* __restrict__ bias,
| ^~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_22/b10_s3_aligned_ldg_fused_kernel/base/base.cu:19:5: warning: 3 adjacent parameters of 'aligned_ldg_gemm_kernel' of convertible types are easily swapped by mistake [bugprone-easily-swappable-parameters]
19 | int hidden_size,
| ^~~~~~~~~~~~~~~~
20 | float scale_factor,
| ~~~~~~~~~~~~~~~~~~~
21 | float clamp_min,
| ~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_22/b10_s3_aligned_ldg_fused_kernel/base/base.cu:19:9: note: the first parameter in the range is 'hidden_size'
19 | int hidden_size,
| ^~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_22/b10_s3_aligned_ldg_fused_kernel/base/base.cu:21:11: note: the last parameter in the range is 'clamp_min'
21 | float clamp_min,
| ^~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_22/b10_s3_aligned_ldg_fused_kernel/base/base.cu:20:5: note: 'int' and 'float' may be implicitly converted
20 | float scale_factor,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_22/b10_s3_aligned_ldg_fused_kernel/base/base.cu:29:14: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
29 | int by = blockIdx.y * TILE_DIM;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_22/b10_s3_aligned_ldg_fused_kernel/base/base.cu:30:14: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
30 | int bx = blockIdx.x * TILE_DIM;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_22/b10_s3_aligned_ldg_fused_kernel/base/base.cu:31:14: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
31 | int tx = threadIdx.x % TILE_DIM;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_22/b10_s3_aligned_ldg_fused_kernel/base/base.cu:32:14: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
32 | int ty = threadIdx.x / TILE_DIM;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_22/b10_s3_aligned_ldg_fused_kernel/base/base.cu:83:5: warning: 2 adjacent parameters of 'aligned_ldg_reduction_kernel' of similar type ('int') are easily swapped by mistake [bugprone-easily-swappable-parameters]
83 | int hidden_size,
| ^~~~~~~~~~~~~~~~
84 | int batch_size
| ~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_22/b10_s3_aligned_ldg_fused_kernel/base/base.cu:83:9: note: the first parameter in the range is 'hidden_size'
83 | int hidden_size,
| ^~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_22/b10_s3_aligned_ldg_fused_kernel/base/base.cu:84:9: note: the last parameter in the range is 'batch_size'
84 | int batch_size
| ^~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_22/b10_s3_aligned_ldg_fused_kernel/base/base.cu:87:15: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
87 | int tid = threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_22/b10_s3_aligned_ldg_fused_kernel/base/base.cu:88:15: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
88 | int row = blockIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_22/b10_s3_aligned_ldg_fused_kernel/base/base.cu:137:19: warning: the parameter 'x' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
137 | torch::Tensor x,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_22/b10_s3_aligned_ldg_fused_kernel/base/base.cu:141:19: warning: the parameter 'weight' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
141 | torch::Tensor weight,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_22/b10_s3_aligned_ldg_fused_kernel/base/base.cu:142:19: warning: the parameter 'bias' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
142 | torch::Tensor bias
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_22/b10_s3_aligned_ldg_fused_kernel/base/base.cu:144:22: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
144 | int batch_size = x.size(0);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_22/b10_s3_aligned_ldg_fused_kernel/base/base.cu:145:22: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
145 | int input_size = x.size(1);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_22/b10_s3_aligned_ldg_fused_kernel/base/base.cu:146:23: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
146 | int hidden_size = weight.size(0);
| ^