← Back to Leaderboard

The AI CUDA Engineer 👷

45_Gemm_Sigmoid_Sum_LogSumExpfused_gemm_sigmoid_logsumexp_edit_1

Level 2 • Task 45
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(
    x: torch.Tensor,
    linear1_weight: torch.Tensor,
    linear1_bias: torch.Tensor,
) -> torch.Tensor:
    """
    Performs matrix multiplication, applies Sigmoid, sums result, and calculates LogSumExp.

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, input_size)
        linear1_weight (torch.Tensor): Weight matrix for first linear layer of shape (hidden_size, input_size)
        linear1_bias (torch.Tensor): Bias vector for first linear layer of shape (hidden_size)

    Returns:
        torch.Tensor: Scalar output after applying linear layers, sigmoid, sum and logsumexp
    """
    x = F.linear(x, linear1_weight, linear1_bias)
    x = torch.sigmoid(x)
    x = torch.sum(x, dim=1)
    x = torch.logsumexp(x, dim=0)
    return x


class Model(nn.Module):
    """
    Model that performs a matrix multiplication (Gemm), applies Sigmoid, sums the result, and calculates the LogSumExp.
    """

    def __init__(self, input_size, hidden_size, output_size):
        super(Model, self).__init__()
        lin1 = nn.Linear(input_size, hidden_size)
        self.linear1_weight = nn.Parameter(lin1.weight)
        self.linear1_bias = nn.Parameter(
            lin1.bias
            + torch.randn(
                lin1.bias.shape, device=lin1.bias.device, dtype=lin1.bias.dtype
            )
            * 0.02
        )

    def forward(self, x, fn=module_fn):
        return fn(x, self.linear1_weight, self.linear1_bias)


batch_size = 128
input_size = 10
hidden_size = 20
output_size = 5


def get_inputs():
    return [torch.randn(batch_size, input_size)]


def get_init_inputs():
    return [input_size, hidden_size, output_size]
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Model that performs a matrix multiplication (Gemm), applies Sigmoid, sums the result, and calculates the LogSumExp.
    """
    def __init__(self, input_size, hidden_size, output_size):
        super(Model, self).__init__()
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear1.bias = nn.Parameter(self.linear1.bias + torch.randn(self.linear1.bias.shape, device=self.linear1.bias.device, dtype=self.linear1.bias.dtype) * 0.02)

    def forward(self, x):
        x = self.linear1(x)
        x = torch.sigmoid(x)
        x = torch.sum(x, dim=1)
        x = torch.logsumexp(x, dim=0)
        return x

batch_size = 128
input_size = 10
hidden_size = 20
output_size = 5

def get_inputs():
    return [torch.randn(batch_size, input_size)]

def get_init_inputs():
    return [input_size, hidden_size, output_size]

Kernel Information

Related Kernels (Level 2, Task 45 • 45_Gemm_Sigmoid_Sum_LogSumExp)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 fused_gemm_sigmoid_logsumexp_base 0.01 5.98 2.74
🥇 fused_gemm_sigmoid_logsumexp_edit_1 0.01 5.98 2.74
🥉 minimal_sync_shared_memory_base_base 0.01 5.31 2.44
🥉 variable_block_size_tuning_base_base 0.01 5.31 2.44
🥉 aligned_memory_access_base_base 0.01 5.31 2.44
🥉 optimized_thread_block_mapping_base_base 0.01 5.31 2.44
🥉 strided_loop_optimized_base_base 0.01 5.31 2.44
🥉 warp_fused_base 0.01 5.31 2.44
🥉 minimal_sync_fused_base_base 0.01 5.31 2.44
🥉 shared_memory_optimized_base_base 0.01 5.31 2.44
11 fused_linear_reduction_base 0.01 4.78 2.19
11 variable_block_size_base_base 0.01 4.78 2.19
11 warp_divergence_minimization_base 0.01 4.78 2.19
11 fused_gemm_reduction_atomic_opt_edit_1 0.01 4.78 2.19
11 optimized_sync_fusion_kernel_base 0.01 4.78 2.19
11 modular_fused_kernel_opt_base 0.01 4.78 2.19
11 block_size_experimentation_edit_1 0.01 4.78 2.19
11 block_size_experimentation_base 0.01 4.78 2.19
19 fused_gemm_reduction_atomic_opt_base 0.01 4.35 2.00
19 logsumexp_warp_reduce_base 0.01 4.35 2.00
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include <cooperative_groups.h>
namespace cg = cooperative_groups;

// Device function for computing dot product
__device__ __forceinline__ float dot_product(
    const float* __restrict__ vec1,
    const float* __restrict__ vec2,
    const int size
) {
    float result = 0.0f;
    #pragma unroll 4
    for (int i = 0; i < size; ++i) {
        result += vec1[i] * vec2[i];
    }
    return result;
}

// Device function for sigmoid activation
__device__ __forceinline__ float sigmoid(float x) {
    return 1.0f / (1.0f + expf(-x));
}

// Device function for parallel reduction in shared memory
__device__ __forceinline__ float block_reduce_sum(float val, float* shared, const int tid) {
    shared[tid] = val;
    __syncthreads();

    #pragma unroll
    for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
        if (tid < s) {
            shared[tid] += shared[tid + s];
        }
        __syncthreads();
    }
    return shared[0];
}

// Device function for parallel max reduction
__device__ __forceinline__ float block_reduce_max(float val, float* shared, const int tid) {
    shared[tid] = val;
    __syncthreads();

    #pragma unroll
    for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
        if (tid < s) {
            shared[tid] = max(shared[tid], shared[tid + s]);
        }
        __syncthreads();
    }
    return shared[0];
}

__global__ void fused_gemm_sigmoid_logsumexp_kernel(
    const float* __restrict__ input,
    const float* __restrict__ weight,
    const float* __restrict__ bias,
    float* __restrict__ output,
    const int batch_size,
    const int input_size,
    const int hidden_size
) {
    extern __shared__ float shared_mem[];
    const int row = blockIdx.x;
    const int tid = threadIdx.x;
    
    if (row >= batch_size) return;

    float local_sum = 0.0f;
    const float* row_input = &input[row * input_size];

    for (int col = tid; col < hidden_size; col += blockDim.x) {
        const float* col_weight = &weight[col * input_size];
        float dot = dot_product(row_input, col_weight, input_size);
        dot += bias[col];
        local_sum += sigmoid(dot);
    }

    float row_total = block_reduce_sum(local_sum, shared_mem, tid);
    if (tid == 0) {
        output[row] = row_total;
    }

    // Synchronize between steps
    __syncthreads();
    float local_max = -INFINITY;
    for (int i = tid; i < batch_size; i += blockDim.x) {
        local_max = max(local_max, output[i]);
    }
    float max_val = block_reduce_max(local_max, shared_mem, tid);
    __syncthreads();

    // Compute sum of exp(x - max)
    float local_exp_sum = 0.0f;
    for (int i = tid; i < batch_size; i += blockDim.x) {
        local_exp_sum += expf(output[i] - max_val);
    }
    float sum_exp_val = block_reduce_sum(local_exp_sum, shared_mem, tid);

    if (tid == 0) {
        output[0] = logf(sum_exp_val) + max_val;
    }
}

torch::Tensor forward(
    torch::Tensor input,
    torch::Tensor weight,
    torch::Tensor bias
) {
    const int batch_size = input.size(0);
    const int input_size = input.size(1);
    const int hidden_size = weight.size(0);

    auto options = torch::TensorOptions()
        .dtype(input.dtype())
        .device(input.device());

    auto final_output = torch::empty({1}, options);

    const int threads_per_block = 128;
    dim3 grid(batch_size);
    
    fused_gemm_sigmoid_logsumexp_kernel<<<grid, threads_per_block, threads_per_block * sizeof(float)>>> (
        input.data_ptr<float>(),
        weight.data_ptr<float>(),
        bias.data_ptr<float>(),
        final_output.data_ptr<float>(),
        batch_size,
        input_size,
        hidden_size
    );

    return final_output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Forward pass");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 0.186 inst/cycle 0.000 5
Executed Ipc Elapsed 0.114 inst/cycle 0.000 5
Issue Slots Busy 4.646 % 0.050 5
Issued Ipc Active 0.186 inst/cycle 0.000 5
SM Busy 4.646 % 0.050 5
Memory Throughput 4013993955.456 byte/second 9285686217653540.000 5
Mem Busy 5.012 % 0.022 5
Max Bandwidth 2.918 % 0.005 5
L1/TEX Hit Rate 80.970 % 0.000 5
L2 Hit Rate 108.390 % 0.425 5
Mem Pipes Busy 2.918 % 0.005 5
Warp Cycles Per Issued Instruction 20.692 cycle 0.617 5
Warp Cycles Per Executed Instruction 20.758 cycle 0.624 5
Avg. Active Threads Per Warp 30.680 0.000 5
Avg. Not Predicated Off Threads Per Warp 18.110 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 16.000 block 0.000 5
Block Limit Shared Mem 42.000 block 0.000 5
Block Limit Warps 16.000 block 0.000 5
Theoretical Active Warps per SM 64.000 warp 0.000 5
Theoretical Occupancy 100.000 % 0.000 5
Achieved Occupancy 6.192 % 0.000 5
Achieved Active Warps Per SM 3.960 warp 0.000 5
Analysis Rules
Rule Description
WRN HighPipeUtilization All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details.
INF CPIStall Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason.
WRN ThreadDivergence Instructions are executed in warps, which are groups of 32 threads. Optimal instruction throughput is achieved if all 32 threads of a warp execute the same instruction. The chosen launch configuration, early thread completion, and divergent flow control can significantly lower the number of active threads in a warp per cycle. This kernel achieves an average of 30.7 threads being active per cycle. This is further reduced to 18.1 threads per warp due to predication. The compiler may use predication to avoid an actual branch. Instead, all instructions are scheduled, but a per-thread condition code or predicate controls which threads execute the instructions. Try to avoid different execution paths within a warp when possible. In addition, ensure your kernel makes use of Independent Thread Scheduling, which allows a warp to reconverge after a data-dependent conditional block by explicitly calling __syncwarp().
WRN Occupancy This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (6.2%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy.
Operation / Metric Value Unit
aten::to
CPU Time 244125.61 μs
Device Time 2.62 μs
Self CPU Time 63.19 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_to_copy
CPU Time 244062.42 μs
Device Time 2.62 μs
Self CPU Time 119.19 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::empty_strided
CPU Time 243765.90 μs
Device Time 0.00 μs
Self CPU Time 116.65 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaDeviceGetStreamPriorityRange
CPU Time 243465.95 μs
Device Time 0.00 μs
Self CPU Time 243465.95 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 445264.98 μs
Device Time 15660.59 μs
Self CPU Time 445264.98 μs
Self Device Time 15660.59 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
fused_gemm_sigmoid_logsumexp_kernel(float const*, float const*, float const*, float*, int, int, int)
CPU Time 0.00 μs
Device Time 34440.55 μs
Self CPU Time 0.00 μs
Self Device Time 34440.55 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaEventRecord
CPU Time 17221.44 μs
Device Time 30040.63 μs
Self CPU Time 17221.44 μs
Self Device Time 30040.63 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 66590.00 μs
Device Time 564128.75 μs
Self CPU Time 12218.46 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::fill_
CPU Time 54373.24 μs
Device Time 564128.75 μs
Self CPU Time 16223.04 μs
Self Device Time 564128.75 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>)
CPU Time 0.00 μs
Device Time 564128.75 μs
Self CPU Time 0.00 μs
Self Device Time 564128.75 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45346 warnings generated when compiling for host.
Suppressed 45378 warnings (45331 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_45/b4_s0_fused_gemm_sigmoid_logsumexp/edit_1/edit_1.cu:58:5 bugprone-easily-swappable-parameters
58 | const float* __restrict__ input,
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
59 | const float* __restrict__ weight,
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
60 | const float* __restrict__ bias,
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_45/b4_s0_fused_gemm_sigmoid_logsumexp/edit_1/edit_1.cu:58:31: note: the first parameter in the range is 'input'
58 | const float* __restrict__ input,
| ^~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_45/b4_s0_fused_gemm_sigmoid_logsumexp/edit_1/edit_1.cu:60:31: note: the last parameter in the range is 'bias'
60 | const float* __restrict__ bias,
| ^~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_45/b4_s0_fused_gemm_sigmoid_logsumexp/edit_1/edit_1.cu:62:5: warning: 3 adjacent parameters of 'fused_gemm_sigmoid_logsumexp_kernel' of similar type ('const int') are easily swapped by mistake [bugprone-easily-swappable-parameters]
62 | const int batch_size,
| ^~~~~~~~~~~~~~~~~~~~~
63 | const int input_size,
| ~~~~~~~~~~~~~~~~~~~~~
64 | const int hidden_size
| ~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_45/b4_s0_fused_gemm_sigmoid_logsumexp/edit_1/edit_1.cu:62:15: note: the first parameter in the range is 'batch_size'
62 | const int batch_size,
| ^~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_45/b4_s0_fused_gemm_sigmoid_logsumexp/edit_1/edit_1.cu:64:15: note: the last parameter in the range is 'hidden_size'
64 | const int hidden_size
| ^~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_45/b4_s0_fused_gemm_sigmoid_logsumexp/edit_1/edit_1.cu:67:21: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
67 | const int row = blockIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_45/b4_s0_fused_gemm_sigmoid_logsumexp/edit_1/edit_1.cu:68:21: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
68 | const int tid = threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_45/b4_s0_fused_gemm_sigmoid_logsumexp/edit_1/edit_1.cu:73:31: warning: result of multiplication in type 'int' is used as a pointer offset after an implicit widening conversion to type 'ptrdiff_t' [bugprone-implicit-widening-of-multiplication-result]
73 | const float* row_input = &input[row * input_size];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_45/b4_s0_fused_gemm_sigmoid_logsumexp/edit_1/edit_1.cu:73:37: note: make conversion explicit to silence this warning
4 | const float* row_input = &input[row * input_size];
| ^~~~~~~~~~~~~~~~
| static_cast<ptrdiff_t>( )
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_45/b4_s0_fused_gemm_sigmoid_logsumexp/edit_1/edit_1.cu:73:37: note: perform multiplication in a wider type
73 | const float* row_input = &input[row * input_size];
| ^~~
| static_cast<ptrdiff_t>( )
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_45/b4_s0_fused_gemm_sigmoid_logsumexp/edit_1/edit_1.cu:75:51: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
75 | for (int col = tid; col < hidden_size; col += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_45/b4_s0_fused_gemm_sigmoid_logsumexp/edit_1/edit_1.cu:76:36: warning: result of multiplication in type 'int' is used as a pointer offset after an implicit widening conversion to type 'ptrdiff_t' [bugprone-implicit-widening-of-multiplication-result]
76 | const float* col_weight = &weight[col * input_size];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_45/b4_s0_fused_gemm_sigmoid_logsumexp/edit_1/edit_1.cu:76:43: note: make conversion explicit to silence this warning
76 | const float* col_weight = &weight[col * input_size];
| ^~~~~~~~~~~~~~~~
| static_cast<ptrdiff_t>( )
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_45/b4_s0_fused_gemm_sigmoid_logsumexp/edit_1/edit_1.cu:76:43: note: perform multiplication in a wider type
76 | const float* col_weight = &weight[col * input_size];
| ^~~
| static_cast<ptrdiff_t>( )
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_45/b4_s0_fused_gemm_sigmoid_logsumexp/edit_1/edit_1.cu:90:44: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
90 | for (int i = tid; i < batch_size; i += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_45/b4_s0_fused_gemm_sigmoid_logsumexp/edit_1/edit_1.cu:98:44: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
98 | for (int i = tid; i < batch_size; i += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_45/b4_s0_fused_gemm_sigmoid_logsumexp/edit_1/edit_1.cu:109:19: warning: the parameter 'input' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
109 | torch::Tensor input,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_45/b4_s0_fused_gemm_sigmoid_logsumexp/edit_1/edit_1.cu:110:19: warning: the parameter 'weight' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
110 | torch::Tensor weight,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_45/b4_s0_fused_gemm_sigmoid_logsumexp/edit_1/edit_1.cu:111:19: warning: the parameter 'bias' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
111 | torch::Tensor bias
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_45/b4_s0_fused_gemm_sigmoid_logsumexp/edit_1/edit_1.cu:113:28: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
113 | const int batch_size = input.size(0);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_45/b4_s0_fused_gemm_sigmoid_logsumexp/edit_1/edit_1.cu:114:28: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
114 | const int input_size = input.size(1);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250207_optimize_b5_s4_e1_sweep/level_2/task_45/b4_s0_fused_gemm_sigmoid_logsumexp/edit_1/edit_1.cu:115:29: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
115 | const int hidden_size = weight.size(0);
| ^