← Back to Leaderboard

The AI CUDA Engineer 👷

76_Gemm_Add_ReLUwarp_tile_ldg_opt_base

Level 2 • Task 76
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(
    x: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
) -> torch.Tensor:
    """
    Performs matrix multiplication, adds bias, and applies ReLU activation.

    Args:
        x (torch.Tensor): Input tensor with shape (batch_size, in_features)
        weight (torch.Tensor): Weight matrix with shape (out_features, in_features)
        bias (torch.Tensor): Bias tensor with shape (out_features,)

    Returns:
        torch.Tensor: Output tensor with shape (batch_size, out_features)
    """
    x = F.linear(x, weight)
    x = x + bias
    x = F.relu(x)
    return x


class Model(nn.Module):
    """
    Simple model that performs a matrix multiplication, adds a bias term, and applies ReLU.
    """

    def __init__(self, in_features, out_features, bias_shape):
        super(Model, self).__init__()
        gemm = nn.Linear(in_features, out_features, bias=False)
        self.weight = nn.Parameter(gemm.weight)
        self.bias = nn.Parameter(torch.randn(bias_shape) * 0.02)

    def forward(self, x, fn=module_fn):
        return fn(x, self.weight, self.bias)


batch_size = 128
in_features = 1024
out_features = 512
bias_shape = (out_features,)


def get_inputs():
    return [torch.randn(batch_size, in_features)]


def get_init_inputs():
    return [in_features, out_features, bias_shape]
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Simple model that performs a matrix multiplication, adds a bias term, and applies ReLU.
    """
    def __init__(self, in_features, out_features, bias_shape):
        super(Model, self).__init__()
        self.gemm = nn.Linear(in_features, out_features, bias=False)
        self.bias = nn.Parameter(torch.randn(bias_shape)*0.02)

    def forward(self, x):   
        """
        Args:
            x (torch.Tensor): Input tensor with shape (batch_size, in_features).
        Returns:
            torch.Tensor: Output tensor with shape (batch_size, out_features).
        """
        x = self.gemm(x)
        x = x + self.bias
        x = torch.relu(x)
        return x

batch_size = 128
in_features = 1024
out_features = 512
bias_shape = (out_features,)

def get_inputs():
    return [torch.randn(batch_size, in_features)]

def get_init_inputs():
    return [in_features, out_features, bias_shape]

Kernel Information

Related Kernels (Level 2, Task 76 • 76_Gemm_Add_ReLU)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 shared_warp_tile_kernel_base 0.03 0.93 1.54
🥇 combined_warp_tile_base 0.03 0.93 1.54
🥉 optimized_block_size_kernel_base 0.03 0.89 1.49
4 warp_tile_ldg_base 0.03 0.87 1.44
4 even_workload_dist_base_base 0.03 0.87 1.44
4 hybrid_warp_tile_kernel_base 0.03 0.87 1.44
4 warp_tile_hybrid_base 0.03 0.87 1.44
8 warp_tile_ldg_opt_base 0.03 0.81 1.36
8 warp_reduction_optimized_base_base 0.03 0.81 1.36
10 optimized_shared_memory_base_base 0.03 0.79 1.32
10 warp_tile_base_base 0.03 0.79 1.32
12 hybrid_optimized_kernel_base 0.04 0.77 1.28
13 warp_reduction_gemm_base 0.04 0.71 1.18
13 warp_tile_aligned_base_base 0.04 0.71 1.18
15 vectorized_warp_unroll_base_base 0.04 0.69 1.15
15 vectorized_warp_unroll_base_edit_1 0.04 0.69 1.15
15 warp_reduction_unrolled_gemm_edit_1 0.04 0.69 1.15
18 unrolled_warp_gemm_edit_1 0.04 0.67 1.12
18 unrolled_warp_gemm_base 0.04 0.67 1.12
18 vectorized_warp_reduction_base 0.04 0.67 1.12
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <c10/cuda/CUDAGuard.h>

// Define warp size and tile size: each warp will process TILE_SIZE output features simultaneously.
#define WARP_SIZE 32
#define TILE_SIZE 4

// Kernel to compute GEMM with bias and ReLU activation using optimized global memory loads.
// Global memory loads from x, weight, and bias use __ldg() for read-only accesses and 128-bit aligned loads through float4.

__global__ void warp_tile_ldg_opt_kernel(const float* __restrict__ x,
                                          const float* __restrict__ weight,
                                          const float* __restrict__ bias,
                                          float* __restrict__ out,
                                          int in_features,
                                          int out_features) {
    // Each block corresponds to one batch row
    int batch_idx = blockIdx.x;

    // Determine warp-related indices
    int warps_per_block = blockDim.x / WARP_SIZE;  // Number of warps per block
    int warp_id_in_block = threadIdx.x / WARP_SIZE;  
    int lane = threadIdx.x % WARP_SIZE;

    // Each warp processes TILE_SIZE output features concurrently.
    int base_out = (blockIdx.y * warps_per_block + warp_id_in_block) * TILE_SIZE;

    // Early exit if the base output feature index is out of range
    if (base_out >= out_features) return;

    // Array to hold partial sums for each of the TILE_SIZE output features
    float sums[TILE_SIZE] = {0.0f, 0.0f, 0.0f, 0.0f};

    // Pointer to the current input row for the batch
    const float* x_row = x + batch_idx * in_features;
    // Use vectorized pointer for 128-bit aligned loads
    const float4* x_vec = reinterpret_cast<const float4*>(x_row);
    int vec_count = in_features / 4;          // Number of full float4 elements
    int remainder = in_features - vec_count * 4; // Number of remaining floats

    // Process each output feature in the current tile
    #pragma unroll
    for (int t = 0; t < TILE_SIZE; t++) {
        int out_idx = base_out + t;
        if (out_idx >= out_features) break;
        
        // Pointer to the corresponding weight row, use __restrict__ and assume 128-bit alignment
        const float* w_row = weight + out_idx * in_features;
        const float4* w_vec = reinterpret_cast<const float4*>(w_row);
        
        float local_sum = 0.0f;
        
        // Process main vectorized part. Each thread processes elements in a strided loop over the float4 blocks.
        for (int i = lane; i < vec_count; i += WARP_SIZE) {
            float4 x_val = __ldg(x_vec + i);
            float4 w_val = __ldg(w_vec + i);
            local_sum += x_val.x * w_val.x + x_val.y * w_val.y + x_val.z * w_val.z + x_val.w * w_val.w;
        }

        // Process any remaining scalar elements
        int rem_start = vec_count * 4;
        for (int i = rem_start + lane; i < in_features; i += WARP_SIZE) {
            local_sum += __ldg(x_row + i) * __ldg(w_row + i);
        }

        sums[t] = local_sum;
    }

    // Warp-level reduction using shuffle operations for each tile output
    #pragma unroll
    for (int t = 0; t < TILE_SIZE; t++) {
        for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
            sums[t] += __shfl_down_sync(0xffffffff, sums[t], offset);
        }
    }

    // The first thread in the warp writes the final results
    if (lane == 0) {
        #pragma unroll
        for (int t = 0; t < TILE_SIZE; t++) {
            int out_idx = base_out + t;
            if (out_idx < out_features) {
                float result = sums[t] + __ldg(bias + out_idx);
                // Apply ReLU activation
                out[batch_idx * out_features + out_idx] = (result > 0.0f) ? result : 0.0f;
            }
        }
    }
}


// Host function to launch the kernel
torch::Tensor warp_tile_ldg_opt_forward(torch::Tensor x,
                                          torch::Tensor weight,
                                          torch::Tensor bias) {
    TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor");
    TORCH_CHECK(weight.is_cuda(), "weight must be a CUDA tensor");
    TORCH_CHECK(bias.is_cuda(), "bias must be a CUDA tensor");

    int batch_size = x.size(0);
    int in_features = x.size(1);
    int out_features = weight.size(0);

    auto out = torch::empty({batch_size, out_features}, x.options());

    // Configure grid and block dimensions.
    // Each block handles one batch row. The grid's y-dimension covers all output features in chunks of (warps_per_block * TILE_SIZE).
    int warps_per_block = 8;  // Tunable parameter
    int threads_per_block = warps_per_block * WARP_SIZE;
    int grid_y = (out_features + (warps_per_block * TILE_SIZE) - 1) / (warps_per_block * TILE_SIZE);

    dim3 grid(batch_size, grid_y);
    dim3 block(threads_per_block);

    cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
    warp_tile_ldg_opt_kernel<<<grid, block, 0, stream>>>(
        x.data_ptr<float>(),
        weight.data_ptr<float>(),
        bias.data_ptr<float>(),
        out.data_ptr<float>(),
        in_features,
        out_features
    );

    return out;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &warp_tile_ldg_opt_forward, "GEMM with bias and ReLU using __ldg and 128-bit aligned loads (CUDA)");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 1.486 inst/cycle 0.000 5
Executed Ipc Elapsed 1.336 inst/cycle 0.000 5
Issue Slots Busy 37.308 % 0.018 5
Issued Ipc Active 1.492 inst/cycle 0.000 5
SM Busy 37.308 % 0.018 5
Memory Throughput 79715062133.274 byte/second 296012034291023488.000 5
Mem Busy 67.772 % 0.218 5
Max Bandwidth 66.662 % 0.206 5
L1/TEX Hit Rate 59.126 % 0.497 5
L2 Hit Rate 95.614 % 46.744 5
Mem Pipes Busy 22.062 % 0.024 5
Warp Cycles Per Issued Instruction 25.236 cycle 0.007 5
Warp Cycles Per Executed Instruction 25.334 cycle 0.008 5
Avg. Active Threads Per Warp 30.240 0.000 5
Avg. Not Predicated Off Threads Per Warp 29.330 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 5.000 block 0.000 5
Block Limit Shared Mem 32.000 block 0.000 5
Block Limit Warps 8.000 block 0.000 5
Theoretical Active Warps per SM 40.000 warp 0.000 5
Theoretical Occupancy 62.500 % 0.000 5
Achieved Occupancy 59.024 % 0.021 5
Achieved Active Warps Per SM 37.776 warp 0.009 5
Analysis Rules
Rule Description
INF HighPipeUtilization ALU is the highest-utilized pipeline (21.5%) based on active cycles, taking into account the rates of its different instructions. It executes integer and logic operations. It is well-utilized, but should not be a bottleneck.
INF CPIStall Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason.
WRN Occupancy This kernel's theoretical occupancy (62.5%) is limited by the number of required registers. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy.
Operation / Metric Value Unit
aten::to
CPU Time 538029.63 μs
Device Time 189.82 μs
Self CPU Time 71.58 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_to_copy
CPU Time 537958.05 μs
Device Time 189.82 μs
Self CPU Time 135.37 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::empty_strided
CPU Time 537261.85 μs
Device Time 0.00 μs
Self CPU Time 138.72 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaDeviceGetStreamPriorityRange
CPU Time 534040.54 μs
Device Time 0.00 μs
Self CPU Time 534040.54 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 433015.31 μs
Device Time 546.81 μs
Self CPU Time 433015.31 μs
Self Device Time 546.81 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
warp_tile_ldg_opt_kernel(float const*, float const*, float const*, float*, int, int)
CPU Time 0.00 μs
Device Time 153830.00 μs
Self CPU Time 0.00 μs
Self Device Time 153830.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 129968.87 μs
Device Time 407784.39 μs
Self CPU Time 8927.71 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::fill_
CPU Time 121043.14 μs
Device Time 407784.39 μs
Self CPU Time 10421.21 μs
Self Device Time 407784.39 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>)
CPU Time 0.00 μs
Device Time 407784.39 μs
Self CPU Time 0.00 μs
Self Device Time 407784.39 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45316 warnings generated when compiling for host.
Suppressed 45347 warnings (45300 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_76/b7_s1_warp_tile_ldg_opt/base/base.cu:13:42 bugprone-easily-swappable-parameters
13 | __global__ void warp_tile_ldg_opt_kernel(const float* __restrict__ x,
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~
14 | const float* __restrict__ weight,
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
15 | const float* __restrict__ bias,
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_76/b7_s1_warp_tile_ldg_opt/base/base.cu:13:68: note: the first parameter in the range is 'x'
13 | __global__ void warp_tile_ldg_opt_kernel(const float* __restrict__ x,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_76/b7_s1_warp_tile_ldg_opt/base/base.cu:15:69: note: the last parameter in the range is 'bias'
15 | const float* __restrict__ bias,
| ^~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_76/b7_s1_warp_tile_ldg_opt/base/base.cu:17:43: warning: 2 adjacent parameters of 'warp_tile_ldg_opt_kernel' of similar type ('int') are easily swapped by mistake [bugprone-easily-swappable-parameters]
17 | int in_features,
| ^~~~~~~~~~~~~~~~
18 | int out_features) {
| ~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_76/b7_s1_warp_tile_ldg_opt/base/base.cu:17:47: note: the first parameter in the range is 'in_features'
17 | int in_features,
| ^~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_76/b7_s1_warp_tile_ldg_opt/base/base.cu:18:47: note: the last parameter in the range is 'out_features'
18 | int out_features) {
| ^~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_76/b7_s1_warp_tile_ldg_opt/base/base.cu:20:21: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
20 | int batch_idx = blockIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_76/b7_s1_warp_tile_ldg_opt/base/base.cu:23:27: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
23 | int warps_per_block = blockDim.x / WARP_SIZE; // Number of warps per block
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_76/b7_s1_warp_tile_ldg_opt/base/base.cu:24:28: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
24 | int warp_id_in_block = threadIdx.x / WARP_SIZE;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_76/b7_s1_warp_tile_ldg_opt/base/base.cu:25:16: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
25 | int lane = threadIdx.x % WARP_SIZE;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_76/b7_s1_warp_tile_ldg_opt/base/base.cu:28:20: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
28 | int base_out = (blockIdx.y * warps_per_block + warp_id_in_block) * TILE_SIZE;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_76/b7_s1_warp_tile_ldg_opt/base/base.cu:37:26: warning: result of multiplication in type 'int' is used as a pointer offset after an implicit widening conversion to type 'ptrdiff_t' [bugprone-implicit-widening-of-multiplication-result]
37 | const float* x_row = x + batch_idx * in_features;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_76/b7_s1_warp_tile_ldg_opt/base/base.cu:37:30: note: make conversion explicit to silence this warning
5 | const float* x_row = x + batch_idx * in_features;
| ^~~~~~~~~~~~~~~~~~~~~~~
| static_cast<ptrdiff_t>()
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_76/b7_s1_warp_tile_ldg_opt/base/base.cu:37:30: note: perform multiplication in a wider type
37 | const float* x_row = x + batch_idx * in_features;
| ^~~~~~~~~
| static_cast<ptrdiff_t>()
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_76/b7_s1_warp_tile_ldg_opt/base/base.cu:41:9: warning: Value stored to 'remainder' during its initialization is never read [clang-analyzer-deadcode.DeadStores]
41 | int remainder = in_features - vec_count * 4; // Number of remaining floats
| ^~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_76/b7_s1_warp_tile_ldg_opt/base/base.cu:41:9: note: Value stored to 'remainder' during its initialization is never read
41 | int remainder = in_features - vec_count * 4; // Number of remaining floats
| ^~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_76/b7_s1_warp_tile_ldg_opt/base/base.cu:50:30: warning: result of multiplication in type 'int' is used as a pointer offset after an implicit widening conversion to type 'ptrdiff_t' [bugprone-implicit-widening-of-multiplication-result]
50 | const float* w_row = weight + out_idx * in_features;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_76/b7_s1_warp_tile_ldg_opt/base/base.cu:50:39: note: make conversion explicit to silence this warning
50 | const float* w_row = weight + out_idx * in_features;
| ^~~~~~~~~~~~~~~~~~~~~
| static_cast<ptrdiff_t>( )
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_76/b7_s1_warp_tile_ldg_opt/base/base.cu:50:39: note: perform multiplication in a wider type
50 | const float* w_row = weight + out_idx * in_features;
| ^~~~~~~
| static_cast<ptrdiff_t>( )
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_76/b7_s1_warp_tile_ldg_opt/base/base.cu:95:55: warning: the parameter 'x' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
95 | torch::Tensor warp_tile_ldg_opt_forward(torch::Tensor x,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_76/b7_s1_warp_tile_ldg_opt/base/base.cu:96:57: warning: the parameter 'weight' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
96 | torch::Tensor weight,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_76/b7_s1_warp_tile_ldg_opt/base/base.cu:97:57: warning: the parameter 'bias' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
97 | torch::Tensor bias) {
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_76/b7_s1_warp_tile_ldg_opt/base/base.cu:102:22: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
102 | int batch_size = x.size(0);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_76/b7_s1_warp_tile_ldg_opt/base/base.cu:103:23: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
103 | int in_features = x.size(1);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_76/b7_s1_warp_tile_ldg_opt/base/base.cu:104:24: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
104 | int out_features = weight.size(0);
| ^