← Back to Leaderboard

The AI CUDA Engineer 👷

51_Gemm_Subtract_GlobalAvgPool_LogSumExp_GELU_ResidualAddfused_forward_coalesced_base

Level 2 • Task 51
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(
    x: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
    subtract: torch.Tensor,
) -> torch.Tensor:
    """
    Performs a series of operations: Gemm, Subtract, GlobalAvgPool, LogSumExp, GELU, and ResidualAdd.

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, in_features)
        weight (torch.Tensor): Weight matrix for linear layer of shape (out_features, in_features)
        bias (torch.Tensor): Bias vector for linear layer of shape (out_features)
        subtract (torch.Tensor): Vector to subtract of shape (out_features)

    Returns:
        torch.Tensor: Output tensor after applying all operations
    """
    original_x = x.clone().detach()

    # Gemm
    x = F.linear(x, weight, bias)

    # Subtract
    x = x - subtract

    # GlobalAvgPool
    x = torch.mean(x, dim=1, keepdim=True)

    # LogSumExp
    x = torch.logsumexp(x, dim=1, keepdim=True)

    # GELU
    x = F.gelu(x)

    # ResidualAdd
    x = x + original_x

    return x


class Model(nn.Module):
    """
    Model that performs a series of operations: Gemm, Subtract, GlobalAvgPool, LogSumExp, GELU, and ResidualAdd.
    """

    def __init__(self, in_features, out_features):
        super(Model, self).__init__()
        gemm = nn.Linear(in_features, out_features)
        self.weight = nn.Parameter(gemm.weight)
        self.bias = nn.Parameter(gemm.bias)
        self.subtract = nn.Parameter(torch.randn(out_features) * 0.02)

    def forward(self, x, fn=module_fn):
        return fn(x, self.weight, self.bias, self.subtract)


batch_size = 128
in_features = 1024
out_features = 512


def get_inputs():
    return [torch.randn(batch_size, in_features)]


def get_init_inputs():
    return [in_features, out_features]
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Model that performs a series of operations: Gemm, Subtract, GlobalAvgPool, LogSumExp, GELU, and ResidualAdd.
    """
    def __init__(self, in_features, out_features, bias=True):
        super(Model, self).__init__()
        self.gemm = nn.Linear(in_features, out_features, bias=bias)
        self.subtract = nn.Parameter(torch.randn(out_features) * 0.02)

    def forward(self, x):
        original_x = x.clone().detach()
        # Gemm
        x = self.gemm(x)

        # Subtract
        x = x - self.subtract

        # GlobalAvgPool
        x = torch.mean(x, dim=1, keepdim=True)

        # LogSumExp
        x = torch.logsumexp(x, dim=1, keepdim=True)

        # GELU
        x = torch.nn.functional.gelu(x)

        # ResidualAdd
        x = x + original_x

        return x

batch_size = 128
in_features = 1024
out_features = 512

def get_inputs():
    return [torch.randn(batch_size, in_features)]

def get_init_inputs():
    return [in_features, out_features]

Kernel Information

Related Kernels (Level 2, Task 51 • 51_Gemm_Subtract_GlobalAvgPool_LogSumExp_GELU_ResidualAdd)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 fused_forward_base 0.05 1.62 0.92
🥇 fused_forward_edit_1 0.05 1.62 0.92
🥉 fused_forward_coalesced_base 0.05 1.58 0.90
4 fused_forward_coalesced_edit_1 0.05 1.55 0.89
5 optimized_fused_kernel_base 0.06 1.32 0.76
6 fused_pipeline_base 0.06 1.28 0.73
6 threadblock_mapping_opt_base 0.06 1.28 0.73
8 atomic_optimized_pipeline_base 0.06 1.26 0.72
8 efficient_thread_block_mapping_base 0.06 1.26 0.72
8 aligned_memory_access_base_base 0.06 1.26 0.72
8 fused_pool_gelu_atomic_minimal_base 0.06 1.26 0.72
8 fused_pool_gelu_warp_edit_base 0.06 1.26 0.72
8 aligned_memory_access_base_base 0.06 1.26 0.72
14 constant_memory_optimization_base 0.07 1.24 0.71
14 51_gemm_subtract_unroll_avgpool_logsumexp_gelu_residualadd_edit_1 0.07 1.24 0.71
14 uniform_control_flow_base_base_base 0.07 1.24 0.71
17 modular_device_functions_optimized_base 0.07 1.22 0.70
17 modular_device_functions_base_base 0.07 1.22 0.70
19 experiment_block_sizes_base 0.07 1.19 0.68
19 tiled_gemm_shared_edit_2_base 0.07 1.19 0.68
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cmath>

//------------------------------------------------------------------------------
// GELU approximation function
__device__ float gelu_approx(float x) {
    const float kAlpha = 0.044715f;
    const float kBeta  = 0.7978845608f; // sqrt(2/M_PI)
    float inner = kBeta * (x + kAlpha * x * x * x);
    return x * 0.5f * (1.0f + tanhf(inner));
}

//------------------------------------------------------------------------------
// Fused kernel with improved memory coalescing
// Each block processes one row from the input 'x'.
// To ensure coalesced global memory accesses, we compute row pointers 
// so that threads in a warp access consecutive memory locations when reading
// from 'x' and writing to 'out'.
__global__ void fused_forward_coalesced_kernel(
    const float* __restrict__ x,            // Input: shape [batch_size, in_features]
    const float* __restrict__ weight_sum,     // Precomputed: shape [in_features]
    float constant,                           // Precomputed scalar constant: sum(bias - subtract)
    float* __restrict__ out,                  // Output: shape [batch_size, in_features]
    int batch_size,
    int in_features,
    int out_features                        // Used for normalization
) {
    int row = blockIdx.x;
    if (row >= batch_size) return;

    // Set up pointers for coalesced access, since rows are contiguous
    const float* row_x = x + row * in_features;
    float* row_out = out + row * in_features;

    // Use shared memory for dot product reduction
    extern __shared__ float sdata[]; // Size should be at least blockDim.x
    float sum = 0.0f;

    // Each thread processes a contiguous chunk of the row
    for (int j = threadIdx.x; j < in_features; j += blockDim.x) {
        // Global memory accesses here are coalesced as threads access consecutive elements
        sum += row_x[j] * weight_sum[j];
    }

    sdata[threadIdx.x] = sum;
    __syncthreads();

    // Perform parallel reduction in shared memory
    for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
        if (threadIdx.x < stride) {
            sdata[threadIdx.x] += sdata[threadIdx.x + stride];
        }
        __syncthreads();
    }

    float pool_val = sdata[0];
    if (threadIdx.x == 0) {
        // Normalize using out_features and add the constant
        pool_val = (pool_val + constant) / static_cast<float>(out_features);
        // Apply GELU activation
        pool_val = gelu_approx(pool_val);
        sdata[0] = pool_val; // Broadcast to other threads
    }
    __syncthreads();
    pool_val = sdata[0];

    // Coalesced write: each thread writes consecutively to the output row
    for (int j = threadIdx.x; j < in_features; j += blockDim.x) {
        row_out[j] = row_x[j] + pool_val;
    }
}

//------------------------------------------------------------------------------
// Forward function: Precompute weight_sum and constant, then launch the fused kernel
// using one block per row. Global memory accesses are aligned to ensure coalescing.

torch::Tensor forward_cuda_coalesced(
    const torch::Tensor& x,
    const torch::Tensor& weight,
    const torch::Tensor& bias,
    const torch::Tensor& subtract
) {
    TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor");
    TORCH_CHECK(weight.is_cuda(), "weight must be a CUDA tensor");
    TORCH_CHECK(bias.is_cuda(), "bias must be a CUDA tensor");
    TORCH_CHECK(subtract.is_cuda(), "subtract must be a CUDA tensor");
    TORCH_CHECK(x.dim() == 2, "x must be 2D (batch_size x in_features)");
    TORCH_CHECK(weight.dim() == 2, "weight must be 2D (out_features x in_features)");
    TORCH_CHECK(bias.dim() == 1, "bias must be 1D (out_features)");
    TORCH_CHECK(subtract.dim() == 1, "subtract must be 1D (out_features)");

    int64_t batch_size  = x.size(0);
    int64_t in_features = x.size(1);
    int64_t out_features = weight.size(0);

    TORCH_CHECK(weight.size(1) == in_features, "weight.shape[1] must match x.shape[1]");
    TORCH_CHECK(bias.size(0) == out_features, "bias.shape[0] must match weight.shape[0]");
    TORCH_CHECK(subtract.size(0) == out_features, "subtract.shape[0] must match weight.shape[0]");

    auto x_contig = x.contiguous();
    auto weight_contig = weight.contiguous();
    auto bias_contig = bias.contiguous();
    auto subtract_contig = subtract.contiguous();

    // Precompute weight_sum: sum along the 0-th dimension of weight gives a vector of size [in_features]
    auto weight_sum = torch::sum(weight_contig, 0);

    // Precompute constant = sum(bias - subtract) as a scalar
    auto constant_tensor = torch::sum(bias_contig - subtract_contig);
    float constant = constant_tensor.item<float>();

    // Allocate output tensor with the same shape as x
    auto out = torch::empty({batch_size, in_features}, x.options());

    int threads = 256;          // Number of threads per block
    int blocks = batch_size;      // One block per row
    size_t shared_mem_bytes = threads * sizeof(float);

    fused_forward_coalesced_kernel<<<blocks, threads, shared_mem_bytes>>>(
        x_contig.data_ptr<float>(),
        weight_sum.data_ptr<float>(),
        constant,
        out.data_ptr<float>(),
        batch_size,
        in_features,
        out_features
    );

    return out;
}

//------------------------------------------------------------------------------
// PyBind11 module registration
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward_cuda_coalesced, "Fused Forward Coalesced CUDA Kernel");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 0.362 inst/cycle 0.000 5
Executed Ipc Elapsed 0.200 inst/cycle 0.000 5
Issue Slots Busy 9.216 % 0.009 5
Issued Ipc Active 0.368 inst/cycle 0.000 5
SM Busy 9.216 % 0.009 5
Memory Throughput 94771818915.544 byte/second 171703386115525280.000 5
Mem Busy 7.502 % 0.003 5
Max Bandwidth 5.730 % 0.001 5
L1/TEX Hit Rate 25.000 % 0.000 5
L2 Hit Rate 76.074 % 0.064 5
Mem Pipes Busy 3.774 % 0.000 5
Warp Cycles Per Issued Instruction 21.626 cycle 0.047 5
Warp Cycles Per Executed Instruction 22.152 cycle 0.049 5
Avg. Active Threads Per Warp 31.530 0.000 5
Avg. Not Predicated Off Threads Per Warp 25.530 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 16.000 block 0.000 5
Block Limit Shared Mem 16.000 block 0.000 5
Block Limit Warps 8.000 block 0.000 5
Theoretical Active Warps per SM 64.000 warp 0.000 5
Theoretical Occupancy 100.000 % 0.000 5
Achieved Occupancy 12.440 % 0.000 5
Achieved Active Warps Per SM 7.960 warp 0.000 5
Analysis Rules
Rule Description
WRN HighPipeUtilization All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details.
INF CPIStall Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason.
WRN Occupancy This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (12.4%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy.
Operation / Metric Value Unit
cudaStreamSynchronize
CPU Time 589720.90 μs
Device Time 0.00 μs
Self CPU Time 589720.90 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::sum
CPU Time 279423.22 μs
Device Time 213578.48 μs
Self CPU Time 172789.03 μs
Self Device Time 213578.48 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 226065.05 μs
Device Time 12433.37 μs
Self CPU Time 226065.05 μs
Self Device Time 12433.37 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::reduce_kernel<128, 4, at::native::ReduceOp<float, at::native::func_wrapper_t<float, at::native::sum_functor<float, float, float>::operator()(at::TensorIterator&)::{lambda(float, float)#1}>, unsigned int, float, 4> >(at::native::ReduceOp<float, at::native::func_wrapper_t<float, at::native::sum_functor<float, float, float>::operator()(at::TensorIterator&)::{lambda(float, float)#1}>, unsigned int, float, 4>)
CPU Time 0.00 μs
Device Time 185970.49 μs
Self CPU Time 0.00 μs
Self Device Time 185970.49 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::item
CPU Time 652174.81 μs
Device Time 18837.46 μs
Self CPU Time 8753.84 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_local_scalar_dense
CPU Time 643420.97 μs
Device Time 18837.46 μs
Self CPU Time 26933.70 μs
Self Device Time 18837.46 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 92429.09 μs
Device Time 742895.10 μs
Self CPU Time 21188.75 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::fill_
CPU Time 71241.80 μs
Device Time 742895.10 μs
Self CPU Time 28121.37 μs
Self Device Time 742895.10 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>)
CPU Time 0.00 μs
Device Time 742973.50 μs
Self CPU Time 0.00 μs
Self Device Time 742973.50 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45291 warnings generated when compiling for host.
Suppressed 45324 warnings (45277 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b5_s1_fused_forward_coalesced/base/base.cu:22:5 bugprone-easily-swappable-parameters
22 | const float* __restrict__ x, // Input: shape [batch_size, in_features]
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
23 | const float* __restrict__ weight_sum, // Precomputed: shape [in_features]
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b5_s1_fused_forward_coalesced/base/base.cu:22:31: note: the first parameter in the range is 'x'
22 | const float* __restrict__ x, // Input: shape [batch_size, in_features]
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b5_s1_fused_forward_coalesced/base/base.cu:23:31: note: the last parameter in the range is 'weight_sum'
23 | const float* __restrict__ weight_sum, // Precomputed: shape [in_features]
| ^~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b5_s1_fused_forward_coalesced/base/base.cu:26:5: warning: 3 adjacent parameters of 'fused_forward_coalesced_kernel' of similar type ('int') are easily swapped by mistake [bugprone-easily-swappable-parameters]
26 | int batch_size,
| ^~~~~~~~~~~~~~~
27 | int in_features,
| ~~~~~~~~~~~~~~~~
28 | int out_features // Used for normalization
| ~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b5_s1_fused_forward_coalesced/base/base.cu:26:9: note: the first parameter in the range is 'batch_size'
26 | int batch_size,
| ^~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b5_s1_fused_forward_coalesced/base/base.cu:28:9: note: the last parameter in the range is 'out_features'
28 | int out_features // Used for normalization
| ^~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b5_s1_fused_forward_coalesced/base/base.cu:30:15: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
30 | int row = blockIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b5_s1_fused_forward_coalesced/base/base.cu:34:26: warning: result of multiplication in type 'int' is used as a pointer offset after an implicit widening conversion to type 'ptrdiff_t' [bugprone-implicit-widening-of-multiplication-result]
34 | const float* row_x = x + row * in_features;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b5_s1_fused_forward_coalesced/base/base.cu:34:30: note: make conversion explicit to silence this warning
5 | const float* row_x = x + row * in_features;
| ^~~~~~~~~~~~~~~~~
| static_cast<ptrdiff_t>( )
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b5_s1_fused_forward_coalesced/base/base.cu:34:30: note: perform multiplication in a wider type
34 | const float* row_x = x + row * in_features;
| ^~~
| static_cast<ptrdiff_t>( )
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b5_s1_fused_forward_coalesced/base/base.cu:35:22: warning: result of multiplication in type 'int' is used as a pointer offset after an implicit widening conversion to type 'ptrdiff_t' [bugprone-implicit-widening-of-multiplication-result]
35 | float* row_out = out + row * in_features;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b5_s1_fused_forward_coalesced/base/base.cu:35:28: note: make conversion explicit to silence this warning
35 | float* row_out = out + row * in_features;
| ^~~~~~~~~~~~~~~~~
| static_cast<ptrdiff_t>( )
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b5_s1_fused_forward_coalesced/base/base.cu:35:28: note: perform multiplication in a wider type
35 | float* row_out = out + row * in_features;
| ^~~
| static_cast<ptrdiff_t>( )
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b5_s1_fused_forward_coalesced/base/base.cu:42:18: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
42 | for (int j = threadIdx.x; j < in_features; j += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b5_s1_fused_forward_coalesced/base/base.cu:42:53: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
42 | for (int j = threadIdx.x; j < in_features; j += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b5_s1_fused_forward_coalesced/base/base.cu:51:23: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
51 | for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b5_s1_fused_forward_coalesced/base/base.cu:70:18: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
70 | for (int j = threadIdx.x; j < in_features; j += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b5_s1_fused_forward_coalesced/base/base.cu:70:53: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
70 | for (int j = threadIdx.x; j < in_features; j += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b5_s1_fused_forward_coalesced/base/base.cu:118:18: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
118 | int blocks = batch_size; // One block per row
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b5_s1_fused_forward_coalesced/base/base.cu:126:9: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
126 | batch_size,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b5_s1_fused_forward_coalesced/base/base.cu:127:9: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
127 | in_features,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250208_optimize_b5_s4_e1_sweep/level_2/task_51/b5_s1_fused_forward_coalesced/base/base.cu:128:9: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
128 | out_features
| ^