← Back to Leaderboard

The AI CUDA Engineer 👷

67_Conv2d_GELU_GlobalAvgPoolunrolled_fused_conv_gelu_pool_base

Level 2 • Task 67
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(
    x: torch.Tensor,
    conv_weight: torch.Tensor,
    conv_bias: torch.Tensor,
) -> torch.Tensor:
    """
    Applies convolution, GELU activation, and global average pooling.

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width)
        conv_weight (torch.Tensor): Convolution weight tensor of shape
            (out_channels, in_channels, kernel_size, kernel_size)
        conv_bias (torch.Tensor): Convolution bias tensor of shape (out_channels)

    Returns:
        torch.Tensor: Output tensor of shape (batch_size, out_channels)
    """
    x = F.conv2d(x, conv_weight, bias=conv_bias)
    x = F.gelu(x)
    x = F.adaptive_avg_pool2d(x, 1)
    x = x.squeeze(-1).squeeze(-1)
    return x


class Model(nn.Module):
    """
    Simple model that performs a convolution, applies GELU, and then performs global average pooling.
    """

    def __init__(self, in_channels, out_channels, kernel_size):
        super(Model, self).__init__()
        conv = nn.Conv2d(in_channels, out_channels, kernel_size)
        self.conv_weight = nn.Parameter(conv.weight)
        self.conv_bias = nn.Parameter(conv.bias)

    def forward(self, x, fn=module_fn):
        return fn(x, self.conv_weight, self.conv_bias)


batch_size = 128
in_channels = 3
out_channels = 16
height, width = 32, 32
kernel_size = 3


def get_inputs():
    return [torch.randn(batch_size, in_channels, height, width)]


def get_init_inputs():
    return [in_channels, out_channels, kernel_size]
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Simple model that performs a convolution, applies GELU, and then performs global average pooling.
    """
    def __init__(self, in_channels, out_channels, kernel_size):
        super(Model, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size)

    def forward(self, x):
        """
        Args:
            x: Input tensor of shape (batch_size, in_channels, height, width)
        Returns:
            Output tensor of shape (batch_size, out_channels)
        """
        x = self.conv(x)
        x = torch.nn.functional.gelu(x)
        x = torch.nn.functional.adaptive_avg_pool2d(x, 1)
        x = x.squeeze(-1).squeeze(-1)
        return x

batch_size = 128
in_channels = 3
out_channels = 16
height, width = 32, 32
kernel_size = 3

def get_inputs():
    return [torch.randn(batch_size, in_channels, height, width)]

def get_init_inputs():
    return [in_channels, out_channels, kernel_size]

Kernel Information

Related Kernels (Level 2, Task 67 • 67_Conv2d_GELU_GlobalAvgPool)

#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cmath>

#define KERNEL_SIZE 3
#define BLOCK_SIZE 256
#define WARP_SIZE 32
#define NUM_WARPS (BLOCK_SIZE/WARP_SIZE)

__device__ __forceinline__ float gelu_activate(float x) {
    return 0.5f * x * (1.f + erff(x / 1.41421356f));
}

__device__ __forceinline__ void warp_reduce(volatile float* sdata, int tid) {
    sdata[tid] += sdata[tid + 16];
    sdata[tid] += sdata[tid + 8];
    sdata[tid] += sdata[tid + 4];
    sdata[tid] += sdata[tid + 2];
    sdata[tid] += sdata[tid + 1];
}

__global__ void unrolled_fused_conv_gelu_pool_kernel(
    const float* __restrict__ input,
    const float* __restrict__ weight,
    const float* __restrict__ bias,
    float* __restrict__ output,
    const int N,
    const int in_channels,
    const int in_h,
    const int in_w,
    const int out_channels,
    const int out_h,
    const int out_w
) {
    extern __shared__ float shared_mem[];
    
    // Partition shared memory
    float* conv_weights = shared_mem;
    float* partial_sums = &shared_mem[in_channels * KERNEL_SIZE * KERNEL_SIZE];
    
    const int tid = threadIdx.x;
    const int n = blockIdx.y;
    const int c_out = blockIdx.x;
    
    // Load convolution weights into shared memory
    const int weight_size = in_channels * KERNEL_SIZE * KERNEL_SIZE;
    for (int i = tid; i < weight_size; i += BLOCK_SIZE) {
        conv_weights[i] = weight[c_out * weight_size + i];
    }
    __syncthreads();
    
    // Initialize partial sum
    float thread_sum = 0.0f;
    
    // Calculate number of pixels per thread
    const int total_pixels = out_h * out_w;
    const int pixels_per_thread = (total_pixels + BLOCK_SIZE - 1) / BLOCK_SIZE;
    
    // Process pixels
    #pragma unroll
    for (int p = 0; p < pixels_per_thread; p++) {
        const int pixel_idx = tid + p * BLOCK_SIZE;
        if (pixel_idx < total_pixels) {
            const int out_row = pixel_idx / out_w;
            const int out_col = pixel_idx % out_w;
            
            float conv_result = 0.0f;
            
            #pragma unroll
            for (int ic = 0; ic < in_channels; ic++) {
                const float* in_ptr = &input[((n * in_channels + ic) * in_h + out_row) * in_w + out_col];
                const float* w_ptr = &conv_weights[ic * KERNEL_SIZE * KERNEL_SIZE];
                
                #pragma unroll
                for (int kh = 0; kh < KERNEL_SIZE; kh++) {
                    #pragma unroll
                    for (int kw = 0; kw < KERNEL_SIZE; kw++) {
                        conv_result += in_ptr[kh * in_w + kw] * w_ptr[kh * KERNEL_SIZE + kw];
                    }
                }
            }
            
            // Add bias and apply GELU
            conv_result = gelu_activate(conv_result + bias[c_out]);
            thread_sum += conv_result;
        }
    }
    
    // Store partial sum in shared memory
    partial_sums[tid] = thread_sum;
    __syncthreads();
    
    // Reduce within block using shared memory
    for (int s = BLOCK_SIZE/2; s > 32; s >>= 1) {
        if (tid < s) {
            partial_sums[tid] += partial_sums[tid + s];
        }
        __syncthreads();
    }
    
    // Final reduction within first warp
    if (tid < 32) {
        volatile float* smem = partial_sums;
        if (BLOCK_SIZE >= 64) smem[tid] += smem[tid + 32];
        warp_reduce(smem, tid);
        
        // Write result
        if (tid == 0) {
            output[n * out_channels + c_out] = smem[0] / float(total_pixels);
        }
    }
}

torch::Tensor forward(
    torch::Tensor input,
    torch::Tensor conv_weight,
    torch::Tensor conv_bias
) {
    TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor");
    TORCH_CHECK(conv_weight.is_cuda(), "conv_weight must be a CUDA tensor");
    TORCH_CHECK(conv_bias.is_cuda(), "conv_bias must be a CUDA tensor");
    
    const int N = input.size(0);
    const int in_channels = input.size(1);
    const int in_h = input.size(2);
    const int in_w = input.size(3);
    const int out_channels = conv_weight.size(0);
    const int out_h = in_h - 2;
    const int out_w = in_w - 2;
    
    auto options = torch::TensorOptions().dtype(input.dtype()).device(input.device());
    auto output = torch::empty({N, out_channels}, options);
    
    dim3 grid(out_channels, N);
    
    // Calculate shared memory size: space for weights + space for partial sums
    const size_t shared_mem_size = 
        (in_channels * KERNEL_SIZE * KERNEL_SIZE + BLOCK_SIZE) * sizeof(float);
    
    unrolled_fused_conv_gelu_pool_kernel<<<grid, BLOCK_SIZE, shared_mem_size>>>(
        input.data_ptr<float>(),
        conv_weight.data_ptr<float>(),
        conv_bias.data_ptr<float>(),
        output.data_ptr<float>(),
        N, in_channels, in_h, in_w,
        out_channels, out_h, out_w
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Unrolled Fused Conv2d + GELU + GlobalAvgPool");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 2.296 inst/cycle 0.000 5
Executed Ipc Elapsed 2.050 inst/cycle 0.000 5
Issue Slots Busy 57.554 % 0.033 5
Issued Ipc Active 2.304 inst/cycle 0.000 5
SM Busy 57.554 % 0.033 5
Memory Throughput 56477827154.732 byte/second 129497362984474624.000 5
Mem Busy 73.132 % 0.225 5
Max Bandwidth 48.182 % 0.099 5
L1/TEX Hit Rate 90.154 % 0.000 5
L2 Hit Rate 89.296 % 0.394 5
Mem Pipes Busy 46.036 % 0.089 5
Warp Cycles Per Issued Instruction 23.228 cycle 0.000 5
Warp Cycles Per Executed Instruction 23.264 cycle 0.000 5
Avg. Active Threads Per Warp 31.060 0.000 5
Avg. Not Predicated Off Threads Per Warp 28.890 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 8.000 block 0.000 5
Block Limit Shared Mem 30.000 block 0.000 5
Block Limit Warps 8.000 block 0.000 5
Theoretical Active Warps per SM 64.000 warp 0.000 5
Theoretical Occupancy 100.000 % 0.000 5
Achieved Occupancy 83.676 % 0.065 5
Achieved Active Warps Per SM 53.554 warp 0.026 5
Analysis Rules
Rule Description
INF HighPipeUtilization ALU is the highest-utilized pipeline (27.3%) based on active cycles, taking into account the rates of its different instructions. It executes integer and logic operations. It is well-utilized, but should not be a bottleneck.
INF CPIStall Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason.
WRN Occupancy This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (84.0%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy.
Operation / Metric Value Unit
aten::to
CPU Time 533516.48 μs
Device Time 82.21 μs
Self CPU Time 57.12 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_to_copy
CPU Time 533459.36 μs
Device Time 82.21 μs
Self CPU Time 111.84 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::empty_strided
CPU Time 532973.45 μs
Device Time 0.00 μs
Self CPU Time 129.80 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaDeviceGetStreamPriorityRange
CPU Time 532393.85 μs
Device Time 0.00 μs
Self CPU Time 532393.85 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 640337.09 μs
Device Time 21318.28 μs
Self CPU Time 640337.09 μs
Self Device Time 21318.28 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
unrolled_fused_conv_gelu_pool_kernel(float const*, float const*, float const*, float*, int, int, int, int, int, int, int)
CPU Time 0.00 μs
Device Time 188615.66 μs
Self CPU Time 0.00 μs
Self Device Time 188615.66 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaEventRecord
CPU Time 20078.79 μs
Device Time 42571.74 μs
Self CPU Time 20078.79 μs
Self Device Time 42571.74 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 154167.83 μs
Device Time 634921.60 μs
Self CPU Time 15801.86 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::fill_
CPU Time 138366.97 μs
Device Time 634921.60 μs
Self CPU Time 16396.97 μs
Self Device Time 634921.60 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>)
CPU Time 0.00 μs
Device Time 634921.60 μs
Self CPU Time 0.00 μs
Self Device Time 634921.60 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45292 warnings generated when compiling for host.
Suppressed 45323 warnings (45276 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_67/b10_s1_unrolled_fused_conv_gelu_pool/base/base.cu:24:5 bugprone-easily-swappable-parameters
24 | const float* __restrict__ input,
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
25 | const float* __restrict__ weight,
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
26 | const float* __restrict__ bias,
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_67/b10_s1_unrolled_fused_conv_gelu_pool/base/base.cu:24:31: note: the first parameter in the range is 'input'
24 | const float* __restrict__ input,
| ^~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_67/b10_s1_unrolled_fused_conv_gelu_pool/base/base.cu:26:31: note: the last parameter in the range is 'bias'
26 | const float* __restrict__ bias,
| ^~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_67/b10_s1_unrolled_fused_conv_gelu_pool/base/base.cu:28:5: warning: 2 adjacent parameters of 'unrolled_fused_conv_gelu_pool_kernel' of similar type ('const int') are easily swapped by mistake [bugprone-easily-swappable-parameters]
28 | const int N,
| ^~~~~~~~~~~~
29 | const int in_channels,
| ~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_67/b10_s1_unrolled_fused_conv_gelu_pool/base/base.cu:28:15: note: the first parameter in the range is 'N'
28 | const int N,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_67/b10_s1_unrolled_fused_conv_gelu_pool/base/base.cu:29:15: note: the last parameter in the range is 'in_channels'
29 | const int in_channels,
| ^~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_67/b10_s1_unrolled_fused_conv_gelu_pool/base/base.cu:31:5: warning: 3 adjacent parameters of 'unrolled_fused_conv_gelu_pool_kernel' of similar type ('const int') are easily swapped by mistake [bugprone-easily-swappable-parameters]
31 | const int in_w,
| ^~~~~~~~~~~~~~~
32 | const int out_channels,
| ~~~~~~~~~~~~~~~~~~~~~~~
33 | const int out_h,
| ~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_67/b10_s1_unrolled_fused_conv_gelu_pool/base/base.cu:31:15: note: the first parameter in the range is 'in_w'
31 | const int in_w,
| ^~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_67/b10_s1_unrolled_fused_conv_gelu_pool/base/base.cu:33:15: note: the last parameter in the range is 'out_h'
33 | const int out_h,
| ^~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_67/b10_s1_unrolled_fused_conv_gelu_pool/base/base.cu:40:28: warning: result of multiplication in type 'int' is used as a pointer offset after an implicit widening conversion to type 'ptrdiff_t' [bugprone-implicit-widening-of-multiplication-result]
40 | float* partial_sums = &shared_mem[in_channels * KERNEL_SIZE * KERNEL_SIZE];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_67/b10_s1_unrolled_fused_conv_gelu_pool/base/base.cu:40:39: note: make conversion explicit to silence this warning
5 | float* partial_sums = &shared_mem[in_channels * KERNEL_SIZE * KERNEL_SIZE];
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
| static_cast<ptrdiff_t>( )
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_67/b10_s1_unrolled_fused_conv_gelu_pool/base/base.cu:40:39: note: perform multiplication in a wider type
40 | float* partial_sums = &shared_mem[in_channels * KERNEL_SIZE * KERNEL_SIZE];
| ^~~~~~~~~~~~~~~~~~~~~~~~~
| static_cast<ptrdiff_t>( )
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_67/b10_s1_unrolled_fused_conv_gelu_pool/base/base.cu:42:21: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
42 | const int tid = threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_67/b10_s1_unrolled_fused_conv_gelu_pool/base/base.cu:43:19: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
43 | const int n = blockIdx.y;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_67/b10_s1_unrolled_fused_conv_gelu_pool/base/base.cu:44:23: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
44 | const int c_out = blockIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_67/b10_s1_unrolled_fused_conv_gelu_pool/base/base.cu:73:39: warning: result of multiplication in type 'int' is used as a pointer offset after an implicit widening conversion to type 'ptrdiff_t' [bugprone-implicit-widening-of-multiplication-result]
73 | const float* w_ptr = &conv_weights[ic * KERNEL_SIZE * KERNEL_SIZE];
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_67/b10_s1_unrolled_fused_conv_gelu_pool/base/base.cu:73:52: note: make conversion explicit to silence this warning
73 | const float* w_ptr = &conv_weights[ic * KERNEL_SIZE * KERNEL_SIZE];
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
| static_cast<ptrdiff_t>( )
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_67/b10_s1_unrolled_fused_conv_gelu_pool/base/base.cu:73:52: note: perform multiplication in a wider type
73 | const float* w_ptr = &conv_weights[ic * KERNEL_SIZE * KERNEL_SIZE];
| ^~~~~~~~~~~~~~~~
| static_cast<ptrdiff_t>( )
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_67/b10_s1_unrolled_fused_conv_gelu_pool/base/base.cu:116:19: warning: the parameter 'input' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
116 | torch::Tensor input,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_67/b10_s1_unrolled_fused_conv_gelu_pool/base/base.cu:117:19: warning: the parameter 'conv_weight' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
117 | torch::Tensor conv_weight,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_67/b10_s1_unrolled_fused_conv_gelu_pool/base/base.cu:118:19: warning: the parameter 'conv_bias' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
118 | torch::Tensor conv_bias
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_67/b10_s1_unrolled_fused_conv_gelu_pool/base/base.cu:124:19: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
124 | const int N = input.size(0);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_67/b10_s1_unrolled_fused_conv_gelu_pool/base/base.cu:125:29: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
125 | const int in_channels = input.size(1);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_67/b10_s1_unrolled_fused_conv_gelu_pool/base/base.cu:126:22: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
126 | const int in_h = input.size(2);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_67/b10_s1_unrolled_fused_conv_gelu_pool/base/base.cu:127:22: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
127 | const int in_w = input.size(3);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_67/b10_s1_unrolled_fused_conv_gelu_pool/base/base.cu:128:30: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
128 | const int out_channels = conv_weight.size(0);
| ^