← Back to Leaderboard

The AI CUDA Engineer 👷

98_KLDivLosskl_div_balanced_workload_base

Level 1 • Task 98
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    """
    Computes the Kullback-Leibler Divergence for comparing two distributions.

    Args:
        predictions (torch.Tensor): Predicted values.
        targets (torch.Tensor): Target values.

    Returns:
        torch.Tensor: Kullback-Leibler Divergence.
    """
    return F.kl_div(torch.log(predictions), targets, reduction="batchmean")


class Model(nn.Module):
    """
    A model that computes Kullback-Leibler Divergence for comparing two distributions.

    Parameters:
        None
    """

    def __init__(self):
        super(Model, self).__init__()

    def forward(self, predictions, targets, fn=module_fn):
        return fn(predictions, targets)


batch_size = 128
input_shape = (4096,)
dim = 1


def get_inputs():
    return [
        torch.randn(batch_size, *input_shape).softmax(dim=-1),
        torch.randn(batch_size, *input_shape).softmax(dim=-1),
    ]


def get_init_inputs():
    return []
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    A model that computes Kullback-Leibler Divergence for comparing two distributions.

    Parameters:
        None
    """
    def __init__(self):
        super(Model, self).__init__()

    def forward(self, predictions, targets):
        return torch.nn.functional.kl_div(torch.log(predictions), targets, reduction='batchmean')

batch_size = 128
input_shape = (4096, )
dim = 1

def get_inputs():
    return [torch.randn(batch_size, *input_shape).softmax(dim=-1), torch.randn(batch_size, *input_shape).softmax(dim=-1)]

def get_init_inputs():
    return []

Kernel Information

Related Kernels (Level 1, Task 98 • 98_KLDivLoss)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 optimized_kl_div_cuda_base 0.01 2.83 3.20
🥈 kl_div_sync_optimized_base 0.01 2.59 2.93
🥈 optimized_kl_div_kernel_base 0.01 2.59 2.93
🥈 kl_div_balanced_workload_base 0.01 2.59 2.93
🥈 kl_div_warp_reduce_base_base 0.01 2.59 2.93
🥈 optimized_kl_div_base 0.01 2.59 2.93
🥈 kl_div_modular_reduce_base_base 0.01 2.59 2.93
🥈 kldiv_optimized_stride_base_base_base 0.01 2.59 2.93
🥈 vectorized_aligned_kl_base 0.01 2.59 2.93
🥈 98_KLDivLoss_optimal_reduce_edit_1 0.01 2.59 2.93
🥈 strided_warp_kl_base_base 0.01 2.59 2.93
🥈 fast_strided_kl_base 0.01 2.59 2.93
🥈 coalesced_chunked_kl_base 0.01 2.59 2.93
🥈 kldiv_modular_per_thread_base_base 0.01 2.59 2.93
🥈 kldiv_unrolled_reduction_base_base 0.01 2.59 2.93
🥈 kl_div_unrolled_reduce_base_base 0.01 2.59 2.93
🥈 warp_block_vec4_opt_base 0.01 2.59 2.93
🥈 vectorized_kldiv_base_base 0.01 2.59 2.93
🥈 kl_div_even_workload_distribution_base 0.01 2.59 2.93
🥈 adaptive_kl_div_cuda_base 0.01 2.59 2.93
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

#define WARP_SIZE 32
#define BLOCK_SIZE 256
#define ELEMENTS_PER_THREAD 4

__device__ __forceinline__ float warp_reduce(float val) {
    #pragma unroll
    for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {
        val += __shfl_down_sync(0xffffffff, val, offset);
    }
    return val;
}

__global__ void kl_div_kernel_balanced(
    const float* __restrict__ log_predictions,
    const float* __restrict__ targets,
    float* __restrict__ output,
    const int n,
    const int elements_per_thread) {
    
    const int tid = threadIdx.x;
    const int wid = tid / WARP_SIZE;
    const int lane = tid % WARP_SIZE;
    const int global_thread_id = blockIdx.x * blockDim.x + tid;
    
    extern __shared__ float warp_results[];
    
    float thread_sum = 0.0f;
    
    // Each thread processes multiple elements
    const int start_idx = global_thread_id * elements_per_thread;
    #pragma unroll
    for (int i = 0; i < elements_per_thread; i++) {
        const int idx = start_idx + i;
        if (idx < n) {
            const float log_pred = log_predictions[idx];
            const float target = targets[idx];
            thread_sum += expf(log_pred) - target * log_pred;
        }
    }
    
    // First level reduction within warps
    thread_sum = warp_reduce(thread_sum);
    
    // Store warp results
    if (lane == 0) {
        warp_results[wid] = thread_sum;
    }
    __syncthreads();
    
    // Final reduction across warps
    if (wid == 0) {
        float warp_sum = (lane < (BLOCK_SIZE / WARP_SIZE)) ? warp_results[lane] : 0.0f;
        warp_sum = warp_reduce(warp_sum);
        
        if (lane == 0) {
            atomicAdd(output, warp_sum);
        }
    }
}

torch::Tensor kl_div_cuda_forward(
    torch::Tensor log_predictions,
    torch::Tensor targets) {
    
    const int n = log_predictions.numel();
    auto output = torch::zeros({1}, log_predictions.options());
    
    // Calculate optimal grid dimensions
    const int elements_per_thread = ELEMENTS_PER_THREAD;
    const int total_threads_needed = (n + elements_per_thread - 1) / elements_per_thread;
    const int blocks = (total_threads_needed + BLOCK_SIZE - 1) / BLOCK_SIZE;
    
    // Shared memory for warp results
    const int warps_per_block = BLOCK_SIZE / WARP_SIZE;
    const int shared_mem = warps_per_block * sizeof(float);
    
    kl_div_kernel_balanced<<<blocks, BLOCK_SIZE, shared_mem>>>(
        log_predictions.data_ptr<float>(),
        targets.data_ptr<float>(),
        output.data_ptr<float>(),
        n,
        elements_per_thread
    );
    
    return output / static_cast<float>(n);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &kl_div_cuda_forward, "KL divergence forward (CUDA balanced)");
}