The AI CUDA Engineer 👷


Level 1 • Task 39
import torch
import torch.nn as nn
import torch.nn.functional as F

def module_fn(x: torch.Tensor) -> torch.Tensor:
    Applies L2 normalization to the input tensor.

        x (torch.Tensor): Input tensor of shape (*, dim, *).

        torch.Tensor: Output tensor with L2 normalization applied, same shape as input.
    return F.normalize(x, p=2, dim=1)

class Model(nn.Module):
    Simple model that performs L2 normalization.

    def __init__(self):
        Initializes the L2Norm layer.
        super(Model, self).__init__()

    def forward(self, x: torch.Tensor, fn=module_fn) -> torch.Tensor:
        Applies L2 normalization to the input tensor.

            x (torch.Tensor): Input tensor of shape (*, dim, *).

            torch.Tensor: Output tensor with L2 normalization applied, same shape as input.
        return fn(x)

batch_size = 16
dim = 16384

def get_inputs():
    x = torch.randn(batch_size, dim)
    return [x]

def get_init_inputs():
    return []
import torch
import torch.nn as nn

class Model(nn.Module):
    Simple model that performs L2 normalization.
    def __init__(self):
        Initializes the L2Norm layer.

            dim (int): Dimension along which to normalize.
        super(Model, self).__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        Applies L2 normalization to the input tensor.

            x (torch.Tensor): Input tensor of shape (*, dim, *).

            torch.Tensor: Output tensor with L2 normalization applied, same shape as input.
        return x / torch.norm(x, p=2, dim=1, keepdim=True)

batch_size = 16
dim = 16384

def get_inputs():
    x = torch.randn(batch_size, dim)
    return [x]

def get_init_inputs():
    return []

#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

template <typename scalar_t>
__global__ void l2norm_strided_kernel(
    const scalar_t* __restrict__ input,
    scalar_t* __restrict__ output,
    const int C,
    const int total_vectors,
    const int stride_C,
    const int outer_stride) {

    const int vector_idx = blockIdx.x;
    if (vector_idx >= total_vectors) return;

    const int base = vector_idx * outer_stride;
    const int tid = threadIdx.x;
    const int stride = blockDim.x;
    // Use shared memory for partial sums
    __shared__ scalar_t shared_mem[256];
    scalar_t thread_sum = 0;

    if (stride_C == 1) {
        // Vectorized load path for contiguous data
        const int vec_size = sizeof(scalar_t) == 4 ? 4 : 2;
        const int aligned_C = (C / vec_size) * vec_size;
        // Process vectorized loads with stride
        if constexpr (sizeof(scalar_t) == 4) {
            const float4* in_vec = reinterpret_cast<const float4*>(input + base);
            const int num_vectors = aligned_C / 4;
            // Each thread processes multiple vectors in strided fashion
            for (int i = tid; i < num_vectors; i += stride) {
                float4 v = in_vec[i];
                thread_sum += v.x * v.x + v.y * v.y + v.z * v.z + v.w * v.w;
        } else {
            const double2* in_vec = reinterpret_cast<const double2*>(input + base);
            const int num_vectors = aligned_C / 2;
            for (int i = tid; i < num_vectors; i += stride) {
                double2 v = in_vec[i];
                thread_sum += v.x * v.x + v.y * v.y;

        // Handle remaining elements
        for (int i = aligned_C + tid; i < C; i += stride) {
            scalar_t val = input[base + i];
            thread_sum += val * val;
    } else {
        // Non-contiguous data handling with stride loops
        for (int i = tid; i < C; i += stride) {
            scalar_t val = input[base + i * stride_C];
            thread_sum += val * val;

    // Store partial sum
    shared_mem[tid] = thread_sum;

    // Reduction within block using stride loops
    for (int s = blockDim.x/2; s > 32; s >>= 1) {
        if (tid < s) {
            shared_mem[tid] += shared_mem[tid + s];

    // Warp-level reduction
    if (tid < 32) {
        // Volatile pointer for warp-synchronous programming
        volatile scalar_t* smem = shared_mem;
        if (blockDim.x > 64) smem[tid] += smem[tid + 32];
        if (blockDim.x > 32) smem[tid] += smem[tid + 16];
        smem[tid] += smem[tid + 8];
        smem[tid] += smem[tid + 4];
        smem[tid] += smem[tid + 2];
        smem[tid] += smem[tid + 1];

    // Compute normalization factor
    if (tid == 0) {
        shared_mem[0] = rsqrt(shared_mem[0] + 1e-12);

    const scalar_t inv_norm = shared_mem[0];

    // Normalize using stride loops
    if (stride_C == 1) {
        // Vectorized store path for contiguous data
        const int vec_size = sizeof(scalar_t) == 4 ? 4 : 2;
        const int aligned_C = (C / vec_size) * vec_size;

        if constexpr (sizeof(scalar_t) == 4) {
            float4* out_vec = reinterpret_cast<float4*>(output + base);
            const float4* in_vec = reinterpret_cast<const float4*>(input + base);
            const int num_vectors = aligned_C / 4;

            for (int i = tid; i < num_vectors; i += stride) {
                float4 v = in_vec[i];
                v.x *= inv_norm;
                v.y *= inv_norm;
                v.z *= inv_norm;
                v.w *= inv_norm;
                out_vec[i] = v;
        } else {
            double2* out_vec = reinterpret_cast<double2*>(output + base);
            const double2* in_vec = reinterpret_cast<const double2*>(input + base);
            const int num_vectors = aligned_C / 2;

            for (int i = tid; i < num_vectors; i += stride) {
                double2 v = in_vec[i];
                v.x *= inv_norm;
                v.y *= inv_norm;
                out_vec[i] = v;

        // Handle remaining elements with stride
        for (int i = aligned_C + tid; i < C; i += stride) {
            output[base + i] = input[base + i] * inv_norm;
    } else {
        // Non-contiguous data handling with stride loops
        for (int i = tid; i < C; i += stride) {
            output[base + i * stride_C] = input[base + i * stride_C] * inv_norm;

torch::Tensor forward(torch::Tensor input) {
    TORCH_CHECK(input.is_cuda(), "Input must be a CUDA tensor");
    TORCH_CHECK(input.dim() >= 2, "Input must be at least 2D");

    const int C = input.size(1);
    const int total_vectors = input.numel() / C;
    const int stride_C = input.stride(1);
    const int outer_stride = input.stride(0);

    auto output = torch::empty_like(input);

    // Choose optimal thread block size based on C
    const int threads = 256;  // Optimal for H100
    const dim3 blocks(total_vectors);

    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "l2norm_strided", ([&] {
        l2norm_strided_kernel<scalar_t><<<blocks, threads>>>(

    return output;

    m.def("forward", &forward, "L2 normalization with stride optimization");
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 0.444 inst/cycle 0.000 5
Executed Ipc Elapsed 0.040 inst/cycle 0.000 5
Issue Slots Busy 11.160 % 0.110 5
Issued Ipc Active 0.446 inst/cycle 0.000 5
SM Busy 11.160 % 0.110 5
Memory Throughput 148176237549.346 byte/second 5409637807862392832.000 5
Mem Busy 6.958 % 0.011 5
Max Bandwidth 6.542 % 0.010 5
L1/TEX Hit Rate 33.330 % 0.000 5
L2 Hit Rate 68.344 % 0.064 5
Mem Pipes Busy 0.572 % 0.000 5
Warp Cycles Per Issued Instruction 17.396 cycle 0.016 5
Warp Cycles Per Executed Instruction 17.478 cycle 0.016 5
Avg. Active Threads Per Warp 31.720 0.000 5
Avg. Not Predicated Off Threads Per Warp 29.970 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 8.000 block 0.000 5
Block Limit Shared Mem 16.000 block 0.000 5
Block Limit Warps 8.000 block 0.000 5
Theoretical Active Warps per SM 64.000 warp 0.000 5
Theoretical Occupancy 100.000 % 0.000 5
Achieved Occupancy 12.310 % 0.000 5
Achieved Active Warps Per SM 7.878 warp 0.000 5
Operation / Metric Value Unit
CPU Time 273402.65 μs
Device Time 40.13 μs
Self CPU Time 30.49 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
CPU Time 273372.16 μs
Device Time 40.13 μs
Self CPU Time 76.21 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
CPU Time 291889.22 μs
Device Time 0.00 μs
Self CPU Time 18947.41 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
CPU Time 272745.89 μs
Device Time 0.00 μs
Self CPU Time 272745.89 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
CPU Time 487054.36 μs
Device Time 21075.07 μs
Self CPU Time 487054.36 μs
Self Device Time 21075.07 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void l2norm_strided_kernel<float>(float const*, float*, int, int, int, int)
CPU Time 0.00 μs
Device Time 52134.50 μs
Self CPU Time 0.00 μs
Self Device Time 52134.50 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
CPU Time 21929.42 μs
Device Time 40645.89 μs
Self CPU Time 21929.42 μs
Self Device Time 40645.89 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
CPU Time 63181.44 μs
Device Time 601920.42 μs
Self CPU Time 11683.94 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
CPU Time 51498.81 μs
Device Time 601920.42 μs
Self CPU Time 16433.94 μs
Self Device Time 601920.42 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>)
CPU Time 0.00 μs
Device Time 601920.42 μs
Self CPU Time 0.00 μs
Self Device Time 601920.42 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
