← Back to Leaderboard

The AI CUDA Engineer 👷

97_Matmul_BatchNorm_BiasAdd_Divide_Swishfused_bn_swish_opt_base

Level 2 • Task 97
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(
    x: torch.Tensor,
    bn_eps: float,
    bn_momentum: float,
    divide_value: float,
    weight: torch.Tensor,
    bias: torch.Tensor,
    bn_weight: torch.Tensor,
    bn_bias: torch.Tensor,
    bn_running_mean: torch.Tensor,
    bn_running_var: torch.Tensor,
    add_bias: torch.Tensor,
) -> torch.Tensor:
    """
    Applies matrix multiplication, batch normalization, bias addition, division and Swish activation.

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, in_features)
        bn_eps (float): Small constant for numerical stability in batch norm
        bn_momentum (float): Momentum for batch norm running stats
        divide_value (float): Value to divide by
        weight (torch.Tensor): Weight matrix of shape (out_features, in_features)
        bias (torch.Tensor): Bias vector of shape (out_features)
        bn_weight (torch.Tensor): Batch norm weight of shape (out_features)
        bn_bias (torch.Tensor): Batch norm bias of shape (out_features)
        bn_running_mean (torch.Tensor): Batch norm running mean of shape (out_features)
        bn_running_var (torch.Tensor): Batch norm running variance of shape (out_features)
        add_bias (torch.Tensor): Additional bias term of shape (1,)

    Returns:
        torch.Tensor: Output tensor of shape (batch_size, out_features)
    """
    x = F.linear(x, weight, bias)
    x = F.batch_norm(
        x,
        bn_running_mean,
        bn_running_var,
        bn_weight,
        bn_bias,
        training=True,
        momentum=bn_momentum,
        eps=bn_eps,
    )
    x = x + add_bias
    x = x / divide_value
    x = x * torch.sigmoid(x)
    return x


class Model(nn.Module):
    """
    Model that performs a matrix multiplication, batch normalization, bias addition, division and Swish activation.
    """

    def __init__(
        self, in_features, out_features, bn_eps, bn_momentum, bias_shape, divide_value
    ):
        super(Model, self).__init__()
        gemm = nn.Linear(in_features, out_features)
        bn = nn.BatchNorm1d(out_features, eps=bn_eps, momentum=bn_momentum)
        self.weight = gemm.weight
        self.bias = gemm.bias
        self.bn_weight = bn.weight
        self.bn_bias = bn.bias
        self.bn_running_mean = nn.Parameter(bn.running_mean, requires_grad=False)
        self.bn_running_var = nn.Parameter(bn.running_var, requires_grad=False)
        self.add_bias = nn.Parameter(torch.randn(bias_shape) * 0.02)

    def forward(self, x, bn_eps, bn_momentum, divide_value, fn=module_fn):
        return fn(
            x,
            bn_eps,
            bn_momentum,
            divide_value,
            self.weight,
            self.bias,
            self.bn_weight,
            self.bn_bias,
            self.bn_running_mean,
            self.bn_running_var,
            self.add_bias,
        )


batch_size = 128
in_features = 1024
out_features = 512
bn_eps = 1e-5
bn_momentum = 0.1
bias_shape = (1,)
divide_value = 1.0


def get_inputs():
    return [torch.randn(batch_size, in_features), bn_eps, bn_momentum, divide_value]


def get_init_inputs():
    return [in_features, out_features, bn_eps, bn_momentum, bias_shape, divide_value]
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Model that performs a matrix multiplication, batch normalization, bias addition, division, and Swish activation.
    """
    def __init__(self, in_features, out_features, bn_eps=1e-5, bn_momentum=0.1, bias_shape=(1,), divide_value=1.0):
        super(Model, self).__init__()
        self.matmul = nn.Linear(in_features, out_features)
        self.bn = nn.BatchNorm1d(out_features, eps=bn_eps, momentum=bn_momentum)
        self.bias = nn.Parameter(torch.randn(bias_shape) * 0.02)
        self.divide_value = divide_value

    def forward(self, x):
        x = self.matmul(x)
        x = self.bn(x)
        x = x + self.bias
        x = x / self.divide_value
        x = x * torch.sigmoid(x)
        return x

batch_size = 128
in_features = 1024
out_features = 512
bn_eps = 1e-5
bn_momentum = 0.1
bias_shape = (1,)
divide_value = 1.0

def get_inputs():
    return [torch.randn(batch_size, in_features)]

def get_init_inputs():
    return [in_features, out_features, bn_eps, bn_momentum, bias_shape, divide_value]

Kernel Information

Related Kernels (Level 2, Task 97 • 97_Matmul_BatchNorm_BiasAdd_Divide_Swish)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 block_tuned_fused_bn_swish_base 0.03 2.28 1.94
🥇 optimized_thread_block_mapping_base 0.03 2.28 1.94
🥇 block_experiment_fused_bn_swish_base 0.03 2.28 1.94
🥇 blocksize_optimized_fused_bn_swish_base 0.03 2.28 1.94
🥇 fused_bn_swish_opt_base 0.03 2.28 1.94
🥇 fused_bn_swish_combined_base 0.03 2.28 1.94
🥇 fused_bn_swish_ldg_base_base 0.03 2.28 1.94
8 fused_bn_swish_atomic_opt_base_base 0.03 2.19 1.87
8 optimized_fused_bn_swish_base 0.03 2.19 1.87
8 sync_optimized_fused_bn_swish_base 0.03 2.19 1.87
8 stride_loop_optimized_fused_bn_swish_base_base 0.03 2.19 1.87
8 fused_bn_swish_warp_base 0.03 2.19 1.87
8 fused_bn_swish_atomic_opt_base 0.03 2.19 1.87
8 warp_divergence_optimized_fused_bn_swish_base 0.03 2.19 1.87
8 fused_bn_swish_base 0.03 2.19 1.87
8 shared_param_fused_bn_swish_base 0.03 2.19 1.87
8 tuned_block_size_bn_swish_base 0.03 2.19 1.87
18 adaptive_block_fused_bn_swish_base_base 0.03 2.11 1.80
19 atomic_optimized_matmul_bn_base 0.04 1.44 1.23
20 stream_optimized_fused_bn_swish_base 0.04 1.35 1.14
/*
  This kernel fuses batch normalization, bias addition, division, and Swish activation.
  Each block processes one feature column. It uses vectorized loads (processing 4 elements per iteration)
  and warp-level reduction via __shfl_down_sync to compute the feature-wise sum and sum of squares efficiently.
  The computed mean and variance are then used to update running statistics and normalize the inputs.
*/

#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include <math.h>

// Kernel: fused_bn_swish_opt_kernel
// Each block handles one feature (column) from the x_linear tensor of shape [batch_size, out_features].
// We use vectorized loads and warp-level reduction to compute the feature-wise sum and sumsq.

template <typename scalar_t>
__global__ void fused_bn_swish_opt_kernel(
    const scalar_t* __restrict__ x_linear,       // Input: [batch_size, out_features]
    scalar_t* __restrict__ output,               // Output: same shape as x_linear
    const scalar_t* __restrict__ bn_weight,        // BatchNorm scale, shape: [out_features]
    const scalar_t* __restrict__ bn_bias,          // BatchNorm bias, shape: [out_features]
    scalar_t* __restrict__ bn_running_mean,        // Running mean, shape: [out_features]
    scalar_t* __restrict__ bn_running_var,         // Running variance, shape: [out_features]
    const scalar_t* __restrict__ add_bias,         // Additional bias (1-element tensor)
    const float bn_eps,
    const float bn_momentum,
    const float divide_value,
    const int batch_size,
    const int out_features) {

    // Each block processes one output feature
    const int f = blockIdx.x;
    if (f >= out_features) return;
    const int tid = threadIdx.x;

    // Use vectorized processing with a vector size of 4
    const int vec_size = 4;
    const int vec_limit = (batch_size / vec_size) * vec_size;

    // Each thread accumulates partial sums and sums-of-squares
    float thread_sum = 0.0f;
    float thread_sumsq = 0.0f;

    // Vectorized loop: process multiples of 4
    for (int i = tid * vec_size; i < vec_limit; i += blockDim.x * vec_size) {
        // Manually load 4 elements. Note: Elements are accessed with a stride of out_features.
        float v0 = static_cast<float>(x_linear[(i + 0) * out_features + f]);
        float v1 = static_cast<float>(x_linear[(i + 1) * out_features + f]);
        float v2 = static_cast<float>(x_linear[(i + 2) * out_features + f]);
        float v3 = static_cast<float>(x_linear[(i + 3) * out_features + f]);
        thread_sum   += v0 + v1 + v2 + v3;
        thread_sumsq += v0 * v0 + v1 * v1 + v2 * v2 + v3 * v3;
    }

    // Remainder loop: process remaining elements
    for (int i = vec_limit + tid; i < batch_size; i += blockDim.x) {
        float v = static_cast<float>(x_linear[i * out_features + f]);
        thread_sum   += v;
        thread_sumsq += v * v;
    }

    // Warp-level reduction using shuffle
    unsigned int mask = 0xffffffff;
    for (int offset = warpSize / 2; offset > 0; offset /= 2) {
        thread_sum   += __shfl_down_sync(mask, thread_sum, offset);
        thread_sumsq += __shfl_down_sync(mask, thread_sumsq, offset);
    }

    // Use shared memory to accumulate results from each warp
    extern __shared__ float shared[]; // shared memory: first part for sums, second for sumsq
    const int warpId = tid >> 5;         // tid / 32
    const int lane   = tid & 31;           // tid % 32
    const int numWarps = blockDim.x / 32;

    if (lane == 0) {
        shared[warpId] = thread_sum;
        shared[warpId + numWarps] = thread_sumsq;
    }
    __syncthreads();

    // Let the first warp reduce the per-warp partials
    float total_sum = 0.0f;
    float total_sumsq = 0.0f;
    if (tid < numWarps) {
        total_sum = shared[tid];
        total_sumsq = shared[tid + numWarps];
    }
    // Only first warp threads participate
    if (tid < 32) {
        for (int offset = numWarps / 2; offset > 0; offset /= 2) {
            if (tid < offset) {
                total_sum   += shared[tid + offset];
                total_sumsq += shared[tid + offset + numWarps];
            }
        }
    }

    // Store computed mean and variance in shared memory for use in second pass
    if (tid == 0) {
        float mean = total_sum / batch_size;
        float var = total_sumsq / batch_size - mean * mean;
        // Update running statistics (no atomics needed as one block per feature)
        bn_running_mean[f] = bn_running_mean[f] * (1.0f - bn_momentum) + mean * bn_momentum;
        bn_running_var[f]  = bn_running_var[f]  * (1.0f - bn_momentum) + var * bn_momentum;
        shared[0] = mean;
        shared[1] = var;
    }
    __syncthreads();

    // Load the computed mean and variance
    float mean = shared[0];
    float var  = shared[1];
    float inv_std = rsqrtf(var + bn_eps);
    float gamma = static_cast<float>(bn_weight[f]);
    float beta  = static_cast<float>(bn_bias[f]);
    float extra_bias = static_cast<float>(add_bias[0]);

    // Second pass: apply BN, add extra bias, divide and apply Swish activation
    // Process vectorized loop
    for (int i = tid * vec_size; i < vec_limit; i += blockDim.x * vec_size) {
        #pragma unroll
        for (int j = 0; j < vec_size; j++) {
            const int idx = (i + j) * out_features + f;
            float val = static_cast<float>(x_linear[idx]);
            float normalized = (val - mean) * inv_std;
            // Apply batch norm transform with bias and extra bias
            float transformed = fmaf(normalized, gamma, beta) + extra_bias;
            float divided = transformed / divide_value;
            float swish = divided / (1.f + expf(-divided));
            output[idx] = static_cast<scalar_t>(swish);
        }
    }
    
    // Process remaining elements
    for (int i = vec_limit + tid; i < batch_size; i += blockDim.x) {
        const int idx = i * out_features + f;
        float val = static_cast<float>(x_linear[idx]);
        float normalized = (val - mean) * inv_std;
        float transformed = fmaf(normalized, gamma, beta) + extra_bias;
        float divided = transformed / divide_value;
        float swish = divided / (1.f + expf(-divided));
        output[idx] = static_cast<scalar_t>(swish);
    }
}


// Host function to launch the kernel
torch::Tensor module_fn_cuda(
    torch::Tensor x,
    float bn_eps,
    float bn_momentum,
    float divide_value,
    torch::Tensor weight,
    torch::Tensor bias,
    torch::Tensor bn_weight,
    torch::Tensor bn_bias,
    torch::Tensor bn_running_mean,
    torch::Tensor bn_running_var,
    torch::Tensor add_bias) {

    const auto batch_size = x.size(0);
    const auto out_features = weight.size(0);

    // Ensure input tensors are contiguous
    x = x.contiguous();
    weight = weight.contiguous();
    bias = bias.contiguous();

    // Perform linear transformation: x_linear = x @ weight.T + bias
    auto x_linear = torch::addmm(bias, x, weight.t());
    auto output = torch::empty_like(x_linear);

    // Launch configuration: one block per feature column
    const int threads = 128;
    const int blocks = out_features;
    // Shared memory: allocate 2 floats per warp (numWarps = threads/32)
    const size_t shared_mem_size = 2 * (threads / 32) * sizeof(float);

    AT_DISPATCH_FLOATING_TYPES(x_linear.scalar_type(), "fused_bn_swish_opt_kernel", ([&] {
        fused_bn_swish_opt_kernel<scalar_t><<<blocks, threads, shared_mem_size>>>(
            x_linear.data_ptr<scalar_t>(),
            output.data_ptr<scalar_t>(),
            bn_weight.data_ptr<scalar_t>(),
            bn_bias.data_ptr<scalar_t>(),
            bn_running_mean.data_ptr<scalar_t>(),
            bn_running_var.data_ptr<scalar_t>(),
            add_bias.data_ptr<scalar_t>(),
            bn_eps,
            bn_momentum,
            divide_value,
            batch_size,
            out_features);
    }));

    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &module_fn_cuda, "Fused BN, bias add, division and Swish forward (CUDA optimized)");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 0.420 inst/cycle 0.000 5
Executed Ipc Elapsed 0.252 inst/cycle 0.000 5
Issue Slots Busy 10.944 % 0.058 5
Issued Ipc Active 0.440 inst/cycle 0.000 5
SM Busy 10.944 % 0.058 5
Memory Throughput 44783780022.960 byte/second 387569521833071936.000 5
Mem Busy 24.568 % 0.146 5
Max Bandwidth 22.874 % 0.107 5
L1/TEX Hit Rate 34.940 % 0.000 5
L2 Hit Rate 96.220 % 0.251 5
Mem Pipes Busy 5.998 % 0.007 5
Warp Cycles Per Issued Instruction 27.552 cycle 1.207 5
Warp Cycles Per Executed Instruction 28.588 cycle 1.294 5
Avg. Active Threads Per Warp 30.780 0.000 5
Avg. Not Predicated Off Threads Per Warp 26.450 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 16.000 block 0.000 5
Block Limit Shared Mem 28.000 block 0.000 5
Block Limit Warps 16.000 block 0.000 5
Theoretical Active Warps per SM 64.000 warp 0.000 5
Theoretical Occupancy 100.000 % 0.000 5
Achieved Occupancy 18.790 % 0.012 5
Achieved Active Warps Per SM 12.024 warp 0.005 5
Analysis Rules
Rule Description
WRN HighPipeUtilization All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details.
INF CPIStall Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason.
WRN Occupancy This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (19.0%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy.
Operation / Metric Value Unit
aten::to
CPU Time 390324.30 μs
Device Time 236.86 μs
Self CPU Time 58.18 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_to_copy
CPU Time 390266.12 μs
Device Time 236.86 μs
Self CPU Time 109.66 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::empty_strided
CPU Time 407611.20 μs
Device Time 0.00 μs
Self CPU Time 18222.14 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaDeviceGetStreamPriorityRange
CPU Time 388808.66 μs
Device Time 0.00 μs
Self CPU Time 388808.66 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 62358.67 μs
Device Time 511470.32 μs
Self CPU Time 11083.29 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::fill_
CPU Time 51289.54 μs
Device Time 511470.32 μs
Self CPU Time 14948.24 μs
Self Device Time 511470.32 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::addmm
CPU Time 447259.25 μs
Device Time 109739.85 μs
Self CPU Time 157017.71 μs
Self Device Time 109739.85 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
sm80_xmma_gemm_f32f32_f32f32_f32_tn_n_tilesize32x32x8_stage3_warpsize1x2x1_ffma_aligna4_alignc4_execute_kernel__51_cublas
CPU Time 0.00 μs
Device Time 98823.74 μs
Self CPU Time 0.00 μs
Self Device Time 98823.74 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>)
CPU Time 0.00 μs
Device Time 511470.32 μs
Self CPU Time 0.00 μs
Self Device Time 511470.32 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45302 warnings generated when compiling for host.
Suppressed 45331 warnings (45284 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_97/b4_s1_fused_bn_swish_opt/base/base.cu:22:5 bugprone-easily-swappable-parameters
22 | const scalar_t* __restrict__ bn_weight, // BatchNorm scale, shape: [out_features]
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
23 | const scalar_t* __restrict__ bn_bias, // BatchNorm bias, shape: [out_features]
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_97/b4_s1_fused_bn_swish_opt/base/base.cu:22:34: note: the first parameter in the range is 'bn_weight'
22 | const scalar_t* __restrict__ bn_weight, // BatchNorm scale, shape: [out_features]
| ^~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_97/b4_s1_fused_bn_swish_opt/base/base.cu:23:34: note: the last parameter in the range is 'bn_bias'
23 | const scalar_t* __restrict__ bn_bias, // BatchNorm bias, shape: [out_features]
| ^~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_97/b4_s1_fused_bn_swish_opt/base/base.cu:24:5: warning: 2 adjacent parameters of 'fused_bn_swish_opt_kernel' of similar type ('scalar_t *__restrict') are easily swapped by mistake [bugprone-easily-swappable-parameters]
24 | scalar_t* __restrict__ bn_running_mean, // Running mean, shape: [out_features]
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
25 | scalar_t* __restrict__ bn_running_var, // Running variance, shape: [out_features]
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_97/b4_s1_fused_bn_swish_opt/base/base.cu:24:28: note: the first parameter in the range is 'bn_running_mean'
24 | scalar_t* __restrict__ bn_running_mean, // Running mean, shape: [out_features]
| ^~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_97/b4_s1_fused_bn_swish_opt/base/base.cu:25:28: note: the last parameter in the range is 'bn_running_var'
25 | scalar_t* __restrict__ bn_running_var, // Running variance, shape: [out_features]
| ^~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_97/b4_s1_fused_bn_swish_opt/base/base.cu:27:5: warning: 5 adjacent parameters of 'fused_bn_swish_opt_kernel' of convertible types are easily swapped by mistake [bugprone-easily-swappable-parameters]
27 | const float bn_eps,
| ^~~~~~~~~~~~~~~~~~~
28 | const float bn_momentum,
| ~~~~~~~~~~~~~~~~~~~~~~~~
29 | const float divide_value,
| ~~~~~~~~~~~~~~~~~~~~~~~~~
30 | const int batch_size,
| ~~~~~~~~~~~~~~~~~~~~~
31 | const int out_features) {
| ~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_97/b4_s1_fused_bn_swish_opt/base/base.cu:27:17: note: the first parameter in the range is 'bn_eps'
27 | const float bn_eps,
| ^~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_97/b4_s1_fused_bn_swish_opt/base/base.cu:31:15: note: the last parameter in the range is 'out_features'
31 | const int out_features) {
| ^~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_97/b4_s1_fused_bn_swish_opt/base/base.cu:30:5: note: 'const float' and 'const int' may be implicitly converted: 'const float' (as 'float') -> 'const int' (as 'int'), 'const int' (as 'int') -> 'const float' (as 'float')
30 | const int batch_size,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_97/b4_s1_fused_bn_swish_opt/base/base.cu:34:19: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
34 | const int f = blockIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_97/b4_s1_fused_bn_swish_opt/base/base.cu:36:21: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
36 | const int tid = threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_97/b4_s1_fused_bn_swish_opt/base/base.cu:47:54: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
47 | for (int i = tid * vec_size; i < vec_limit; i += blockDim.x * vec_size) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_97/b4_s1_fused_bn_swish_opt/base/base.cu:58:56: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
58 | for (int i = vec_limit + tid; i < batch_size; i += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_97/b4_s1_fused_bn_swish_opt/base/base.cu:75:26: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
75 | const int numWarps = blockDim.x / 32;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_97/b4_s1_fused_bn_swish_opt/base/base.cu:102:34: warning: narrowing conversion from 'int' to 'float' [bugprone-narrowing-conversions]
102 | float mean = total_sum / batch_size;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_97/b4_s1_fused_bn_swish_opt/base/base.cu:103:35: warning: narrowing conversion from 'int' to 'float' [bugprone-narrowing-conversions]
103 | float var = total_sumsq / batch_size - mean * mean;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_97/b4_s1_fused_bn_swish_opt/base/base.cu:122:54: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
122 | for (int i = tid * vec_size; i < vec_limit; i += blockDim.x * vec_size) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_97/b4_s1_fused_bn_swish_opt/base/base.cu:137:56: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
137 | for (int i = vec_limit + tid; i < batch_size; i += blockDim.x) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_97/b4_s1_fused_bn_swish_opt/base/base.cu:156:5: warning: 2 adjacent parameters of 'module_fn_cuda' of similar type ('torch::Tensor') are easily swapped by mistake [bugprone-easily-swappable-parameters]
156 | torch::Tensor bias,
| ^~~~~~~~~~~~~~~~~~~
157 | torch::Tensor bn_weight,
| ~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_97/b4_s1_fused_bn_swish_opt/base/base.cu:156:19: note: the first parameter in the range is 'bias'
156 | torch::Tensor bias,
| ^~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_97/b4_s1_fused_bn_swish_opt/base/base.cu:157:19: note: the last parameter in the range is 'bn_weight'
157 | torch::Tensor bn_weight,
| ^~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_97/b4_s1_fused_bn_swish_opt/base/base.cu:177:24: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
177 | const int blocks = out_features;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_97/b4_s1_fused_bn_swish_opt/base/base.cu:179:36: warning: performing an implicit widening conversion to type 'unsigned long' of a multiplication performed in type 'int' [bugprone-implicit-widening-of-multiplication-result]
179 | const size_t shared_mem_size = 2 * (threads / 32) * sizeof(float);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_97/b4_s1_fused_bn_swish_opt/base/base.cu:179:36: note: make conversion explicit to silence this warning
11 | const size_t shared_mem_size = 2 * (threads / 32) * sizeof(float);
| ^~~~~~~~~~~~~~~~~~
| static_cast<unsigned long>( )
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_97/b4_s1_fused_bn_swish_opt/base/base.cu:179:36: note: perform multiplication in a wider type
179 | const size_t shared_mem_size = 2 * (threads / 32) * sizeof(float);
| ^
| static_cast<long>( )
/home/robert_sakana_ai/llm_cuda/experiments/20250203_optimize_b10_s4_e0_sweep/level_2/task_97/b4_s1_fused_bn_swish_opt/base/base.cu:181:5: warning: inside a lambda, '__func__' expands to the name of the function call operator; consider capturing the name of the enclosing function explicitly [bugprone-lambda-function-name]
181 | AT_DISPATCH_FLOATING_TYPES(x_linear.scalar_type(), "fused_bn_swish_opt_kernel", ([&] {
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:237:34: note: expanded from macro 'AT_DISPATCH_FLOATING_TYPES'
237 | AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:233:3: note: expanded from macro 'AT_DISPATCH_CASE_FLOATING_TYPES'
233 | AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:74:3: note: expanded from macro 'AT_DISPATCH_CASE'
74 | AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
| ^
note: (skipping 1 expansions in backtrace; use -fmacro-backtrace-limit=0 to see all)
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h:58:7: note: expanded from macro 'AT_PRIVATE_CHECK_SELECTIVE_BUILD'
58 | AT_ERROR( \
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/c10/util/Exception.h:711:32: note: expanded from macro 'AT_ERROR'
711 | C10_EXPAND_MSVC_WORKAROUND(TORCH_CHECK(false, ::c10::str(__VA_ARGS__))); \
| ^
/home/robert_sakana_ai/miniconda3/envs/llm2cuda/lib/python3.11/site-packages/torch/include/c10/util/Exception.h:536:9: note: expanded from macro 'TORCH_CHECK'
536 | __func__, \
| ^