← Back to Leaderboard

The AI CUDA Engineer 👷

22_Matmul_Scale_ResidualAdd_Clamp_LogSumExp_Mishunroll_optimization_base_base

Level 2 • Task 22
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(
    x: torch.Tensor,
    scale_factor: float,
    clamp_min: float,
    clamp_max: float,
    weight: torch.Tensor,
    bias: torch.Tensor,
) -> torch.Tensor:
    """
    Applies matrix multiplication, scaling, residual connection, clamping, LogSumExp and Mish activation.

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, input_size)
        scale_factor (float): Factor to scale the output by
        clamp_min (float): Minimum value for clamping
        clamp_max (float): Maximum value for clamping
        weight (torch.Tensor): Weight matrix of shape (hidden_size, input_size)
        bias (torch.Tensor): Bias vector of shape (hidden_size)

    Returns:
        torch.Tensor: Output tensor of shape (batch_size, hidden_size)
    """
    x = F.linear(x, weight, bias)
    x = x * scale_factor
    x = x + x
    x = torch.clamp(x, clamp_min, clamp_max)
    x = torch.logsumexp(x, dim=1, keepdim=True)
    x = x * F.mish(x)
    return x


class Model(nn.Module):
    """
    Model that performs a matrix multiplication, scales the result, adds a residual connection, clamps the output,
    applies LogSumExp, and finally applies the Mish activation function.
    """

    def __init__(self, input_size, hidden_size, scale_factor, clamp_min, clamp_max):
        super(Model, self).__init__()
        matmul = nn.Linear(input_size, hidden_size)
        self.weight = matmul.weight
        self.bias = nn.Parameter(
            matmul.bias + torch.ones_like(matmul.bias) * 0.02
        )  # make sure its nonzero

    def forward(self, x, scale_factor, clamp_min, clamp_max, fn=module_fn):
        return fn(x, scale_factor, clamp_min, clamp_max, self.weight, self.bias)


batch_size = 128
input_size = 512
hidden_size = 1024
scale_factor = 2.0
clamp_min = -10.0
clamp_max = 10.0


def get_inputs():
    return [torch.randn(batch_size, input_size), scale_factor, clamp_min, clamp_max]


def get_init_inputs():
    return [input_size, hidden_size, scale_factor, clamp_min, clamp_max]
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Model that performs a matrix multiplication, scales the result, adds a residual connection, clamps the output,
    applies LogSumExp, and finally applies the Mish activation function.
    """
    def __init__(self, input_size, hidden_size, scale_factor, clamp_min, clamp_max):
        super(Model, self).__init__()
        self.matmul = nn.Linear(input_size, hidden_size)
        self.matmul.bias = nn.Parameter(self.matmul.bias + torch.ones_like(self.matmul.bias) * 0.02)
        self.scale_factor = scale_factor
        self.clamp_min = clamp_min
        self.clamp_max = clamp_max

    def forward(self, x):
        """
        Args:
            x: Input tensor of shape (batch_size, input_size).

        Returns:
            Output tensor of shape (batch_size, hidden_size).
        """
        x = self.matmul(x)
        x = x * self.scale_factor
        x = x + x
        x = torch.clamp(x, self.clamp_min, self.clamp_max)
        x = torch.logsumexp(x, dim=1, keepdim=True)
        x = x * torch.nn.functional.mish(x)  # Mish activation
        return x

batch_size = 128
input_size = 512
hidden_size = 1024
scale_factor = 2.0
clamp_min = -10.0
clamp_max = 10.0

def get_inputs():
    return [torch.randn(batch_size, input_size)]

def get_init_inputs():
    return [input_size, hidden_size, scale_factor, clamp_min, clamp_max]

Kernel Information

Related Kernels (Level 2, Task 22 • 22_Matmul_Scale_ResidualAdd_Clamp_LogSumExp_Mish)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 22_matmul_scale_residualadd_clamp_logsumexp_mish_syncthreads_optimized_edit_1 0.03 2.19 1.49
🥇 22_matmul_scale_residualadd_clamp_logsumexp_mish_syncthreads_optimized_base 0.03 2.19 1.49
🥇 22_matmul_scale_residualadd_clamp_logsumexp_mish_even_workload_edit_1 0.03 2.19 1.49
🥇 22_matmul_scale_residualadd_clamp_logsumexp_mish_ldg_aligned_edit_1 0.03 2.19 1.49
🥇 22_matmul_scale_residualadd_clamp_logsumexp_mish_ldg_aligned_base 0.03 2.19 1.49
6 22_matmul_scale_residualadd_clamp_logsumexp_mish_warp_optimized_base 0.03 2.12 1.44
6 22_matmul_scale_residualadd_clamp_logsumexp_mish_shared_memory_warp_optimized_base 0.03 2.12 1.44
6 22_matmul_scale_residualadd_clamp_logsumexp_mish_shared_memory_warp_optimized_edit_1 0.03 2.12 1.44
6 22_matmul_scale_residualadd_clamp_logsumexp_mish_stride_loop_base 0.03 2.12 1.44
6 22_matmul_scale_residualadd_clamp_logsumexp_mish_even_workload_base 0.03 2.12 1.44
6 22_matmul_scale_residualadd_clamp_logsumexp_mish_2dindex_optimized_base 0.03 2.12 1.44
12 coalesced_fused_kernel_base_base 0.04 1.76 1.20
12 block_size_optimization_base 0.04 1.76 1.20
12 atomic_optimized_fused_kernel_base 0.04 1.76 1.20
15 reduced_sync_fused_kernel_base 0.04 1.67 1.14
16 22_matmul_scale_residualadd_clamp_logsumexp_mish_shared_memory_optimized_base 0.04 1.63 1.11
16 fused_kernel_base 0.04 1.63 1.11
16 22_matmul_scale_residualadd_clamp_logsumexp_mish_shared_memory_optimized_edit_1 0.04 1.63 1.11
19 aligned_ldg_fused_kernel_base 0.04 1.44 0.98
20 unroll_optimization_base_base 0.04 1.41 0.96
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cfloat>
#include <cmath>

// Define tile dimension
#define TILE_DIM 16

// Kernel for matrix multiplication, scaling, residual addition, and clamping.
// Computes: out = clamp(2 * scale_factor * (A * W^T + bias), clamp_min, clamp_max)
// A: [batch_size, input_size]
// W: [hidden_size, input_size] used as transposed
// bias: [hidden_size]
__global__ void unroll_matmul_scale_resAdd_clamp_kernel(
    const float* __restrict__ A,
    const float* __restrict__ W,
    const float* __restrict__ bias,
    float* __restrict__ intermediate,
    int batch_size,
    int input_size,
    int hidden_size,
    float scale_factor,
    float clamp_min,
    float clamp_max
) {
    // Allocate shared memory with padding to avoid bank conflicts
    __shared__ float sA[TILE_DIM][TILE_DIM+1];
    __shared__ float sB[TILE_DIM][TILE_DIM+1];

    int row = blockIdx.y * TILE_DIM + threadIdx.y;
    int col = blockIdx.x * TILE_DIM + threadIdx.x;
    float acc = 0.0f;

    int numTiles = (input_size + TILE_DIM - 1) / TILE_DIM;
    for (int m = 0; m < numTiles; m++) {
        int a_col = m * TILE_DIM + threadIdx.x;
        if (row < batch_size && a_col < input_size) {
            sA[threadIdx.y][threadIdx.x] = A[row * input_size + a_col];
        } else {
            sA[threadIdx.y][threadIdx.x] = 0.0f;
        }

        int b_row = m * TILE_DIM + threadIdx.y;
        if (col < hidden_size && b_row < input_size) {
            sB[threadIdx.y][threadIdx.x] = W[col * input_size + b_row];
        } else {
            sB[threadIdx.y][threadIdx.x] = 0.0f;
        }

        __syncthreads();

        // Unroll the inner multiplication loop
        #pragma unroll
        for (int k = 0; k < TILE_DIM; k++) {
            acc += sA[threadIdx.y][k] * sB[k][threadIdx.x];
        }
        __syncthreads();
    }

    if (row < batch_size && col < hidden_size) {
        acc += __ldg(&bias[col]);
        acc = acc * scale_factor;
        acc = acc + acc; // equivalent to a residual addition (doubling the value)
        acc = fminf(fmaxf(acc, clamp_min), clamp_max);
        intermediate[row * hidden_size + col] = acc;
    }
}

// Kernel for row-wise reduction to compute logsumexp and then apply Mish activation.
// For each row, it computes:
//   lse = max_row + log(sum(exp(val - max_row)))
//   mish = lse * tanhf(softplus(lse))
//   output = lse * mish
__global__ void unroll_row_reduce_logsumexp_mish_kernel(
    const float* __restrict__ intermediate,
    float* __restrict__ output,
    int batch_size,
    int hidden_size
) {
    int row = blockIdx.x;
    if (row >= batch_size) return;

    extern __shared__ float sdata[];

    float thread_max = -FLT_MAX;
    for (int col = threadIdx.x; col < hidden_size; col += blockDim.x) {
        float val = intermediate[row * hidden_size + col];
        thread_max = fmaxf(thread_max, val);
    }
    sdata[threadIdx.x] = thread_max;
    __syncthreads();

    // Unrolled reduction to compute the maximum value in the row
    #pragma unroll
    for (unsigned int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
        if (threadIdx.x < stride) {
            sdata[threadIdx.x] = fmaxf(sdata[threadIdx.x], sdata[threadIdx.x + stride]);
        }
        __syncthreads();
    }
    float row_max = sdata[0];

    float thread_sum = 0.0f;
    for (int col = threadIdx.x; col < hidden_size; col += blockDim.x) {
        float val = intermediate[row * hidden_size + col];
        thread_sum += expf(val - row_max);
    }
    sdata[threadIdx.x] = thread_sum;
    __syncthreads();

    // Unrolled reduction to sum the exponentials
    #pragma unroll
    for (unsigned int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
        if (threadIdx.x < stride) {
            sdata[threadIdx.x] += sdata[threadIdx.x + stride];
        }
        __syncthreads();
    }

    float sum_exp = sdata[0];
    float logsumexp = row_max + logf(sum_exp);
    float softplus = logf(1.0f + expf(logsumexp));
    float mish = logsumexp * tanhf(softplus);
    float final_val = logsumexp * mish;

    if (threadIdx.x == 0) {
        output[row] = final_val;
    }
}

// Host wrapper that launches the CUDA kernels

torch::Tensor module_fn_forward(
    torch::Tensor x,
    float scale_factor,
    float clamp_min,
    float clamp_max,
    torch::Tensor weight,
    torch::Tensor bias
) {
    TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor");
    TORCH_CHECK(weight.is_cuda(), "weight must be a CUDA tensor");
    TORCH_CHECK(bias.is_cuda(), "bias must be a CUDA tensor");

    int batch_size = x.size(0);
    int input_size = x.size(1);
    int hidden_size = weight.size(0);

    auto options = torch::TensorOptions().dtype(x.dtype()).device(x.device());
    auto intermediate = torch::empty({batch_size, hidden_size}, options);
    auto output = torch::empty({batch_size, 1}, options);

    dim3 blockDim(TILE_DIM, TILE_DIM);
    dim3 gridDim((hidden_size + TILE_DIM - 1) / TILE_DIM, (batch_size + TILE_DIM - 1) / TILE_DIM);

    unroll_matmul_scale_resAdd_clamp_kernel<<<gridDim, blockDim>>>(
        x.data_ptr<float>(),
        weight.data_ptr<float>(),
        bias.data_ptr<float>(),
        intermediate.data_ptr<float>(),
        batch_size,
        input_size,
        hidden_size,
        scale_factor,
        clamp_min,
        clamp_max
    );

    int threads = 256;
    dim3 gridDim2(batch_size);
    int sharedMemSize = threads * sizeof(float);
    unroll_row_reduce_logsumexp_mish_kernel<<<gridDim2, threads, sharedMemSize>>>(
        intermediate.data_ptr<float>(),
        output.data_ptr<float>(),
        batch_size,
        hidden_size
    );

    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &module_fn_forward, "Forward pass for module_fn with unrolled loops (CUDA)");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 0.404 inst/cycle 0.000 5
Executed Ipc Elapsed 0.228 inst/cycle 0.000 5
Issue Slots Busy 10.174 % 0.001 5
Issued Ipc Active 0.410 inst/cycle 0.000 5
SM Busy 10.174 % 0.001 5
Memory Throughput 86468764764.216 byte/second 678630938957026432.000 5
Mem Busy 5.906 % 0.007 5
Max Bandwidth 4.826 % 0.002 5
L1/TEX Hit Rate 49.810 % 0.000 5
L2 Hit Rate 68.124 % 0.051 5
Mem Pipes Busy 4.826 % 0.002 5
Warp Cycles Per Issued Instruction 19.442 cycle 0.003 5
Warp Cycles Per Executed Instruction 19.512 cycle 0.003 5
Avg. Active Threads Per Warp 31.920 0.000 5
Avg. Not Predicated Off Threads Per Warp 21.680 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 10.000 block 0.000 5
Block Limit Shared Mem 16.000 block 0.000 5
Block Limit Warps 8.000 block 0.000 5
Theoretical Active Warps per SM 64.000 warp 0.000 5
Theoretical Occupancy 100.000 % 0.000 5
Achieved Occupancy 12.370 % 0.000 5
Achieved Active Warps Per SM 7.920 warp 0.000 5
Analysis Rules
Rule Description
WRN HighPipeUtilization All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details.
INF CPIStall Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason.
WRN ThreadDivergence Instructions are executed in warps, which are groups of 32 threads. Optimal instruction throughput is achieved if all 32 threads of a warp execute the same instruction. The chosen launch configuration, early thread completion, and divergent flow control can significantly lower the number of active threads in a warp per cycle. This kernel achieves an average of 31.9 threads being active per cycle. This is further reduced to 21.7 threads per warp due to predication. The compiler may use predication to avoid an actual branch. Instead, all instructions are scheduled, but a per-thread condition code or predicate controls which threads execute the instructions. Try to avoid different execution paths within a warp when possible. In addition, ensure your kernel makes use of Independent Thread Scheduling, which allows a warp to reconverge after a data-dependent conditional block by explicitly calling __syncwarp().
WRN Occupancy This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (12.4%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy.
Operation / Metric Value Unit
aten::to
CPU Time 358842.78 μs
Device Time 153.57 μs
Self CPU Time 61.27 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_to_copy
CPU Time 358781.51 μs
Device Time 153.57 μs
Self CPU Time 112.68 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::empty_strided
CPU Time 358242.51 μs
Device Time 0.00 μs
Self CPU Time 117.74 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaDeviceGetStreamPriorityRange
CPU Time 357208.73 μs
Device Time 0.00 μs
Self CPU Time 357208.73 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::fill_
CPU Time 62631.56 μs
Device Time 581370.23 μs
Self CPU Time 17679.31 μs
Self Device Time 581370.23 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 518331.37 μs
Device Time 27026.31 μs
Self CPU Time 518331.37 μs
Self Device Time 27026.31 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
unroll_matmul_scale_resAdd_clamp_kernel(float const*, float const*, float const*, float*, int, int, int, float, float, float)
CPU Time 0.00 μs
Device Time 273698.26 μs
Self CPU Time 0.00 μs
Self Device Time 273698.26 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 75697.24 μs
Device Time 581370.23 μs
Self CPU Time 13078.33 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>)
CPU Time 0.00 μs
Device Time 581448.47 μs
Self CPU Time 0.00 μs
Self Device Time 581448.47 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45313 warnings generated when compiling for host.
Suppressed 45341 warnings (45294 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_22/b5_s3_unroll_optimization_base/base/base.cu:17:5 bugprone-easily-swappable-parameters
17 | const float* __restrict__ A,
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~
18 | const float* __restrict__ W,
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
19 | const float* __restrict__ bias,
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_22/b5_s3_unroll_optimization_base/base/base.cu:17:31: note: the first parameter in the range is 'A'
17 | const float* __restrict__ A,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_22/b5_s3_unroll_optimization_base/base/base.cu:19:31: note: the last parameter in the range is 'bias'
19 | const float* __restrict__ bias,
| ^~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_22/b5_s3_unroll_optimization_base/base/base.cu:23:5: warning: 3 adjacent parameters of 'unroll_matmul_scale_resAdd_clamp_kernel' of convertible types are easily swapped by mistake [bugprone-easily-swappable-parameters]
23 | int hidden_size,
| ^~~~~~~~~~~~~~~~
24 | float scale_factor,
| ~~~~~~~~~~~~~~~~~~~
25 | float clamp_min,
| ~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_22/b5_s3_unroll_optimization_base/base/base.cu:23:9: note: the first parameter in the range is 'hidden_size'
23 | int hidden_size,
| ^~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_22/b5_s3_unroll_optimization_base/base/base.cu:25:11: note: the last parameter in the range is 'clamp_min'
25 | float clamp_min,
| ^~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_22/b5_s3_unroll_optimization_base/base/base.cu:24:5: note: 'int' and 'float' may be implicitly converted
24 | float scale_factor,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_22/b5_s3_unroll_optimization_base/base/base.cu:32:15: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
32 | int row = blockIdx.y * TILE_DIM + threadIdx.y;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_22/b5_s3_unroll_optimization_base/base/base.cu:33:15: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
33 | int col = blockIdx.x * TILE_DIM + threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_22/b5_s3_unroll_optimization_base/base/base.cu:38:21: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
38 | int a_col = m * TILE_DIM + threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_22/b5_s3_unroll_optimization_base/base/base.cu:45:21: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
45 | int b_row = m * TILE_DIM + threadIdx.y;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_22/b5_s3_unroll_optimization_base/base/base.cu:79:5: warning: 2 adjacent parameters of 'unroll_row_reduce_logsumexp_mish_kernel' of similar type ('int') are easily swapped by mistake [bugprone-easily-swappable-parameters]
79 | int batch_size,
| ^~~~~~~~~~~~~~~
80 | int hidden_size
| ~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_22/b5_s3_unroll_optimization_base/base/base.cu:79:9: note: the first parameter in the range is 'batch_size'
79 | int batch_size,
| ^~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_22/b5_s3_unroll_optimization_base/base/base.cu:80:9: note: the last parameter in the range is 'hidden_size'
80 | int hidden_size
| ^~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_22/b5_s3_unroll_optimization_base/base/base.cu:82:15: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
82 | int row = blockIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_22/b5_s3_unroll_optimization_base/base/base.cu:88:20: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
88 | for (int col = threadIdx.x; col < hidden_size; col += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_22/b5_s3_unroll_optimization_base/base/base.cu:88:59: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
88 | for (int col = threadIdx.x; col < hidden_size; col += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_22/b5_s3_unroll_optimization_base/base/base.cu:106:20: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
106 | for (int col = threadIdx.x; col < hidden_size; col += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_22/b5_s3_unroll_optimization_base/base/base.cu:106:59: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
106 | for (int col = threadIdx.x; col < hidden_size; col += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_22/b5_s3_unroll_optimization_base/base/base.cu:136:19: warning: the parameter 'x' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
136 | torch::Tensor x,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_22/b5_s3_unroll_optimization_base/base/base.cu:140:19: warning: the parameter 'weight' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
140 | torch::Tensor weight,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_22/b5_s3_unroll_optimization_base/base/base.cu:141:19: warning: the parameter 'bias' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
141 | torch::Tensor bias
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_22/b5_s3_unroll_optimization_base/base/base.cu:147:22: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
147 | int batch_size = x.size(0);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_22/b5_s3_unroll_optimization_base/base/base.cu:148:22: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
148 | int input_size = x.size(1);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_22/b5_s3_unroll_optimization_base/base/base.cu:149:23: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
149 | int hidden_size = weight.size(0);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_22/b5_s3_unroll_optimization_base/base/base.cu:173:25: warning: narrowing conversion from 'unsigned long' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
173 | int sharedMemSize = threads * sizeof(float);
| ^