← Back to Leaderboard

The AI CUDA Engineer 👷

7_Conv3d_ReLU_LeakyReLU_GELU_Sigmoid_BiasAddcoalesced_memory_activation_kernel_base_base

Level 2 • Task 7
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(
    x: torch.Tensor,
    conv_weight: torch.Tensor,
    conv_bias: torch.Tensor,
    bias: torch.Tensor,
) -> torch.Tensor:
    """
    Applies 3D convolution followed by ReLU, LeakyReLU, GELU, Sigmoid activations and bias addition.

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, in_channels, depth, height, width)
        conv_weight (torch.Tensor): 3D convolution weight tensor of shape
            (out_channels, in_channels, kernel_size, kernel_size, kernel_size)
        conv_bias (torch.Tensor): Bias tensor for 3D convolution of shape (out_channels)
        bias (torch.Tensor): Bias tensor for addition of shape (out_channels, 1, 1, 1)

    Returns:
        torch.Tensor: Output tensor after applying convolution and activations
    """
    x = F.conv3d(x, conv_weight, bias=conv_bias)
    x = F.relu(x)
    x = F.leaky_relu(x, negative_slope=0.01)
    x = F.gelu(x)
    x = torch.sigmoid(x)
    x = x + bias
    return x


class Model(nn.Module):
    """
    Model that performs a 3D convolution, applies ReLU, LeakyReLU, GELU, Sigmoid activations, and bias in sequence.
    """

    def __init__(self, in_channels, out_channels, kernel_size, bias_shape):
        super(Model, self).__init__()
        conv = nn.Conv3d(in_channels, out_channels, kernel_size)
        self.bias = nn.Parameter(torch.randn(bias_shape) * 0.02)

        self.conv_weight = nn.Parameter(conv.weight)
        self.conv_bias = nn.Parameter(conv.bias)
        self.bias = self.bias

    def forward(self, x, fn=module_fn):
        return fn(x, self.conv_weight, self.conv_bias, self.bias)


batch_size = 128
in_channels = 3
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = 3
bias_shape = (out_channels, 1, 1, 1)


def get_inputs():
    return [torch.randn(batch_size, in_channels, depth, height, width)]


def get_init_inputs():
    return [in_channels, out_channels, kernel_size, bias_shape]
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Model that performs a 3D convolution, applies ReLU, LeakyReLU, GELU, Sigmoid activations, and bias in sequence.
    """
    def __init__(self, in_channels, out_channels, kernel_size, bias_shape):
        super(Model, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size)
        self.bias = nn.Parameter(torch.randn(bias_shape) * 0.02) 

    def forward(self, x):
        x = self.conv(x)
        x = torch.relu(x)
        x = torch.nn.functional.leaky_relu(x, negative_slope=0.01)
        x = torch.nn.functional.gelu(x)
        x = torch.sigmoid(x)
        x = x + self.bias
        return x

batch_size = 128
in_channels = 3
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = 3
bias_shape = (out_channels, 1, 1, 1)

def get_inputs():
    return [torch.randn(batch_size, in_channels, depth, height, width)]

def get_init_inputs():
    return [in_channels, out_channels, kernel_size, bias_shape]

Kernel Information

Related Kernels (Level 2, Task 7 • 7_Conv3d_ReLU_LeakyReLU_GELU_Sigmoid_BiasAdd)

#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>

#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

#define WARP_SIZE 32
#define BLOCK_SIZE 256

__device__ __forceinline__ float4 load_float4(const float* addr) {
    float4 val;
    val = *reinterpret_cast<const float4*>(addr);
    return val;
}

__device__ __forceinline__ void store_float4(float* addr, float4 val) {
    *reinterpret_cast<float4*>(addr) = val;
}

__device__ __forceinline__ float process_value(float val, const float* bias, int bias_idx) {
    // ReLU
    val = fmaxf(0.0f, val);
    
    // LeakyReLU
    val = fmaxf(0.01f * val, val);
    
    // GELU
    const float sqrt_2_over_pi = sqrtf(2.0f / M_PI);
    val = 0.5f * val * (1.0f + tanhf(sqrt_2_over_pi * (val + 0.044715f * powf(val, 3.0f))));
    
    // Sigmoid
    val = 1.0f / (1.0f + expf(-val));
    
    // Add bias
    val += __ldg(&bias[bias_idx]);
    
    return val;
}

__global__ void apply_activations_and_bias_kernel(
    float* __restrict__ output, const float* __restrict__ bias,
    int batch_size, int out_channels, int depth, int height, int width
) {
    const int tid = threadIdx.x;
    const int lane_id = tid % WARP_SIZE;
    const int warp_id = tid / WARP_SIZE;
    const int block_offset = blockIdx.x * BLOCK_SIZE;
    
    // Calculate spatial dimensions for coalesced access
    const int spatial_size = depth * height * width;
    const int elements_per_channel = spatial_size;
    
    // Process 4 elements at a time when possible
    const int vector_idx = (block_offset + tid) * 4;
    const int total_elements = batch_size * out_channels * spatial_size;
    
    if (vector_idx < total_elements - 3) {
        // Load 4 consecutive elements
        float4 data = load_float4(&output[vector_idx]);
        
        // Calculate bias index for the current position
        int base_idx = vector_idx / spatial_size;
        int bias_idx = base_idx % out_channels;
        
        // Process each component
        data.x = process_value(data.x, bias, bias_idx);
        data.y = process_value(data.y, bias, bias_idx);
        data.z = process_value(data.z, bias, bias_idx);
        data.w = process_value(data.w, bias, bias_idx);
        
        // Store results back
        store_float4(&output[vector_idx], data);
    }
    // Handle remaining elements
    else if (vector_idx < total_elements) {
        for (int i = 0; i < 4 && vector_idx + i < total_elements; ++i) {
            int curr_idx = vector_idx + i;
            float val = output[curr_idx];
            int bias_idx = (curr_idx / spatial_size) % out_channels;
            output[curr_idx] = process_value(val, bias, bias_idx);
        }
    }
}

torch::Tensor module_fn_cuda(
    torch::Tensor x,
    torch::Tensor conv_weight,
    torch::Tensor conv_bias,
    torch::Tensor bias
) {
    CHECK_INPUT(x);
    CHECK_INPUT(conv_weight);
    CHECK_INPUT(conv_bias);
    CHECK_INPUT(bias);

    auto output = torch::conv3d(x, conv_weight, conv_bias);

    int batch_size = output.size(0);
    int out_channels = output.size(1);
    int depth = output.size(2);
    int height = output.size(3);
    int width = output.size(4);

    int total_vectors = (batch_size * out_channels * depth * height * width + 3) / 4;
    int blocks = (total_vectors + BLOCK_SIZE - 1) / BLOCK_SIZE;

    apply_activations_and_bias_kernel<<<blocks, BLOCK_SIZE>>>(
        output.data_ptr<float>(), bias.data_ptr<float>(),
        batch_size, out_channels, depth, height, width
    );

    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &module_fn_cuda, "CUDA implementation of module_fn");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 2.960 inst/cycle 0.000 5
Executed Ipc Elapsed 2.864 inst/cycle 0.000 5
Issue Slots Busy 73.994 % 0.012 5
Issued Ipc Active 2.962 inst/cycle 0.000 5
SM Busy 80.872 % 0.014 5
Memory Throughput 2519989549747.464 byte/second 173959671717838028800.000 5
Mem Busy 42.522 % 0.050 5
Max Bandwidth 75.214 % 0.165 5
L1/TEX Hit Rate 44.386 % 0.003 5
L2 Hit Rate 50.424 % 0.027 5
Mem Pipes Busy 15.880 % 0.008 5
Warp Cycles Per Issued Instruction 18.560 cycle 0.007 5
Warp Cycles Per Executed Instruction 18.566 cycle 0.007 5
Avg. Active Threads Per Warp 27.500 0.000 5
Avg. Not Predicated Off Threads Per Warp 24.850 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 10.000 block 0.000 5
Block Limit Shared Mem 32.000 block 0.000 5
Block Limit Warps 8.000 block 0.000 5
Theoretical Active Warps per SM 64.000 warp 0.000 5
Theoretical Occupancy 100.000 % 0.000 5
Achieved Occupancy 86.110 % 0.004 5
Achieved Active Warps Per SM 55.110 warp 0.002 5
Analysis Rules
Rule Description
INF HighPipeUtilization ALU is the highest-utilized pipeline (45.2%) based on active cycles, taking into account the rates of its different instructions. It executes integer and logic operations. It is well-utilized, but should not be a bottleneck.
WRN Occupancy This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (86.2%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy.
Operation / Metric Value Unit
aten::conv3d
CPU Time 621582.82 μs
Device Time 4423577.13 μs
Self CPU Time 10646.47 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::convolution
CPU Time 610936.35 μs
Device Time 4423577.13 μs
Self CPU Time 14450.03 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::_convolution
CPU Time 596486.33 μs
Device Time 4423577.13 μs
Self CPU Time 30070.68 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::cudnn_convolution
CPU Time 498369.94 μs
Device Time 3839600.97 μs
Self CPU Time 160296.59 μs
Self Device Time 3839600.97 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
sm80_xmma_fprop_implicit_gemm_indexed_f32f32_f32f32_f32_nchwkcrs_nchw_tilesize32x32x8_stage3_warpsize1x2x1_g1_ffma_aligna4_alignc4_execute_kernel__5x_cudnn
CPU Time 0.00 μs
Device Time 3839599.50 μs
Self CPU Time 0.00 μs
Self Device Time 3839599.50 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 4117846.10 μs
Device Time 83337.44 μs
Self CPU Time 4117846.10 μs
Self Device Time 83337.44 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45293 warnings generated when compiling for host.
Suppressed 45324 warnings (45277 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250213_optimize_b10_s4_e0_cross_no/level_2/task_7/b3_s3_coalesced_memory_activation_kernel_base/base/base.cu:6:35 bugprone-macro-parentheses
6 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
| ^
| ()
/home/robert_sakana_ai/llm_cuda/experiments/20250213_optimize_b10_s4_e0_cross_no/level_2/task_7/b3_s3_coalesced_memory_activation_kernel_base/base/base.cu:7:41: warning: macro argument should be enclosed in parentheses [bugprone-macro-parentheses]
7 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
| ^
| ()
/home/robert_sakana_ai/llm_cuda/experiments/20250213_optimize_b10_s4_e0_cross_no/level_2/task_7/b3_s3_coalesced_memory_activation_kernel_base/base/base.cu:45:21: warning: 2 adjacent parameters of 'apply_activations_and_bias_kernel' of similar type ('int') are easily swapped by mistake [bugprone-easily-swappable-parameters]
45 | int batch_size, int out_channels, int depth, int height, int width
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250213_optimize_b10_s4_e0_cross_no/level_2/task_7/b3_s3_coalesced_memory_activation_kernel_base/base/base.cu:45:25: note: the first parameter in the range is 'out_channels'
45 | int batch_size, int out_channels, int depth, int height, int width
| ^~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250213_optimize_b10_s4_e0_cross_no/level_2/task_7/b3_s3_coalesced_memory_activation_kernel_base/base/base.cu:45:43: note: the last parameter in the range is 'depth'
45 | int batch_size, int out_channels, int depth, int height, int width
| ^~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250213_optimize_b10_s4_e0_cross_no/level_2/task_7/b3_s3_coalesced_memory_activation_kernel_base/base/base.cu:47:21: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
47 | const int tid = threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250213_optimize_b10_s4_e0_cross_no/level_2/task_7/b3_s3_coalesced_memory_activation_kernel_base/base/base.cu:48:15: warning: Value stored to 'lane_id' during its initialization is never read [clang-analyzer-deadcode.DeadStores]
48 | const int lane_id = tid % WARP_SIZE;
| ^~~~~~~ ~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250213_optimize_b10_s4_e0_cross_no/level_2/task_7/b3_s3_coalesced_memory_activation_kernel_base/base/base.cu:48:15: note: Value stored to 'lane_id' during its initialization is never read
48 | const int lane_id = tid % WARP_SIZE;
| ^~~~~~~ ~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250213_optimize_b10_s4_e0_cross_no/level_2/task_7/b3_s3_coalesced_memory_activation_kernel_base/base/base.cu:49:15: warning: Value stored to 'warp_id' during its initialization is never read [clang-analyzer-deadcode.DeadStores]
49 | const int warp_id = tid / WARP_SIZE;
| ^~~~~~~ ~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250213_optimize_b10_s4_e0_cross_no/level_2/task_7/b3_s3_coalesced_memory_activation_kernel_base/base/base.cu:49:15: note: Value stored to 'warp_id' during its initialization is never read
49 | const int warp_id = tid / WARP_SIZE;
| ^~~~~~~ ~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250213_optimize_b10_s4_e0_cross_no/level_2/task_7/b3_s3_coalesced_memory_activation_kernel_base/base/base.cu:50:30: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
50 | const int block_offset = blockIdx.x * BLOCK_SIZE;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250213_optimize_b10_s4_e0_cross_no/level_2/task_7/b3_s3_coalesced_memory_activation_kernel_base/base/base.cu:54:15: warning: Value stored to 'elements_per_channel' during its initialization is never read [clang-analyzer-deadcode.DeadStores]
54 | const int elements_per_channel = spatial_size;
| ^~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250213_optimize_b10_s4_e0_cross_no/level_2/task_7/b3_s3_coalesced_memory_activation_kernel_base/base/base.cu:54:15: note: Value stored to 'elements_per_channel' during its initialization is never read
54 | const int elements_per_channel = spatial_size;
| ^~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250213_optimize_b10_s4_e0_cross_no/level_2/task_7/b3_s3_coalesced_memory_activation_kernel_base/base/base.cu:89:19: warning: the parameter 'x' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
89 | torch::Tensor x,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250213_optimize_b10_s4_e0_cross_no/level_2/task_7/b3_s3_coalesced_memory_activation_kernel_base/base/base.cu:90:19: warning: the parameter 'conv_weight' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
90 | torch::Tensor conv_weight,
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250213_optimize_b10_s4_e0_cross_no/level_2/task_7/b3_s3_coalesced_memory_activation_kernel_base/base/base.cu:92:19: warning: the parameter 'bias' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
92 | torch::Tensor bias
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250213_optimize_b10_s4_e0_cross_no/level_2/task_7/b3_s3_coalesced_memory_activation_kernel_base/base/base.cu:101:22: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
101 | int batch_size = output.size(0);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250213_optimize_b10_s4_e0_cross_no/level_2/task_7/b3_s3_coalesced_memory_activation_kernel_base/base/base.cu:102:24: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
102 | int out_channels = output.size(1);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250213_optimize_b10_s4_e0_cross_no/level_2/task_7/b3_s3_coalesced_memory_activation_kernel_base/base/base.cu:103:17: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
103 | int depth = output.size(2);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250213_optimize_b10_s4_e0_cross_no/level_2/task_7/b3_s3_coalesced_memory_activation_kernel_base/base/base.cu:104:18: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
104 | int height = output.size(3);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250213_optimize_b10_s4_e0_cross_no/level_2/task_7/b3_s3_coalesced_memory_activation_kernel_base/base/base.cu:105:17: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
105 | int width = output.size(4);
| ^