← Back to Leaderboard

The AI CUDA Engineer 👷

97_CosineSimilarityLossblock_tuned_cosine_loss_base_base

Level 1 • Task 97
import torch
import torch.nn as nn
import torch.nn.functional as F


def module_fn(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    """
    Computes the Cosine Similarity Loss for comparing vectors.

    Args:
        predictions (torch.Tensor): Predicted values.
        targets (torch.Tensor): Target values.

    Returns:
        torch.Tensor: Cosine Similarity Loss.
    """
    cosine_sim = F.cosine_similarity(predictions, targets, dim=1)
    return torch.mean(1 - cosine_sim)


class Model(nn.Module):
    """
    A model that computes Cosine Similarity Loss for comparing vectors.

    Parameters:
        None
    """

    def __init__(self):
        super(Model, self).__init__()

    def forward(self, predictions, targets, fn=module_fn):
        return fn(predictions, targets)


batch_size = 128
input_shape = (4096,)
dim = 1


def get_inputs():
    return [
        torch.randn(batch_size, *input_shape),
        torch.randn(batch_size, *input_shape),
    ]


def get_init_inputs():
    return []
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    A model that computes Cosine Similarity Loss for comparing vectors.

    Parameters:
        None
    """
    def __init__(self):
        super(Model, self).__init__()

    def forward(self, predictions, targets):
        cosine_sim = torch.nn.functional.cosine_similarity(predictions, targets, dim=1)
        return torch.mean(1 - cosine_sim)

batch_size = 128
input_shape = (4096, )
dim = 1

def get_inputs():
    return [torch.randn(batch_size, *input_shape), torch.randn(batch_size, *input_shape)]

def get_init_inputs():
    return []

Kernel Information

Related Kernels (Level 1, Task 97 • 97_CosineSimilarityLoss)

#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

template<int BLOCK_SIZE>
__global__ void block_tuned_cosine_similarity_loss_kernel(const float* __restrict__ predictions,
                                                          const float* __restrict__ targets,
                                                          float* output,
                                                          const int N,
                                                          const int D) {
    constexpr int WARPS_PER_BLOCK = BLOCK_SIZE / 32;
    const int row = blockIdx.x;
    const int tid = threadIdx.x;
    const int warp_id = tid / warpSize;
    const int lane_id = tid % warpSize;
    
    const float* pred_row = predictions + row * D;
    const float* target_row = targets + row * D;
    
    // Calculate number of elements per thread
    const int items_per_thread = (D + BLOCK_SIZE - 1) / BLOCK_SIZE;
    
    float sum_dot = 0.0f;
    float sum_pred_sq = 0.0f;
    float sum_target_sq = 0.0f;

    // Process elements with stride equal to block size
    #pragma unroll 4
    for (int i = 0; i < items_per_thread; i++) {
        const int idx = tid + i * BLOCK_SIZE;
        if (idx < D) {
            const float pred = pred_row[idx];
            const float target = target_row[idx];
            sum_dot += pred * target;
            sum_pred_sq += pred * pred;
            sum_target_sq += target * target;
        }
    }

    // Warp-level reduction using shuffle operations
    #pragma unroll
    for (int offset = 16; offset > 0; offset /= 2) {
        sum_dot += __shfl_down_sync(0xffffffff, sum_dot, offset);
        sum_pred_sq += __shfl_down_sync(0xffffffff, sum_pred_sq, offset);
        sum_target_sq += __shfl_down_sync(0xffffffff, sum_target_sq, offset);
    }

    __shared__ float s_dot[WARPS_PER_BLOCK];
    __shared__ float s_pred_sq[WARPS_PER_BLOCK];
    __shared__ float s_target_sq[WARPS_PER_BLOCK];

    // Store warp results to shared memory
    if (lane_id == 0) {
        s_dot[warp_id] = sum_dot;
        s_pred_sq[warp_id] = sum_pred_sq;
        s_target_sq[warp_id] = sum_target_sq;
    }
    __syncthreads();

    // Final reduction by first warp
    if (warp_id == 0 && lane_id < WARPS_PER_BLOCK) {
        sum_dot = s_dot[lane_id];
        sum_pred_sq = s_pred_sq[lane_id];
        sum_target_sq = s_target_sq[lane_id];

        // Warp-level reduction for final results
        #pragma unroll
        for (int offset = WARPS_PER_BLOCK/2; offset > 0; offset /= 2) {
            sum_dot += __shfl_down_sync(0xffffffff, sum_dot, offset);
            sum_pred_sq += __shfl_down_sync(0xffffffff, sum_pred_sq, offset);
            sum_target_sq += __shfl_down_sync(0xffffffff, sum_target_sq, offset);
        }

        if (lane_id == 0) {
            const float eps = 1e-8f;
            float norm_pred = sqrtf(sum_pred_sq);
            float norm_target = sqrtf(sum_target_sq);
            float denominator = norm_pred * norm_target;
            denominator = fmaxf(denominator, eps);
            float cos_sim = sum_dot / denominator;
            atomicAdd(output, (1.0f - cos_sim) / N);
        }
    }
}

torch::Tensor block_tuned_cosine_similarity_loss_forward(torch::Tensor predictions, torch::Tensor targets) {
    TORCH_CHECK(predictions.dim() == 2, "predictions must be 2D");
    TORCH_CHECK(targets.dim() == 2, "targets must be 2D");
    TORCH_CHECK(predictions.sizes() == targets.sizes(), "Input tensors must have the same shape");
    TORCH_CHECK(predictions.scalar_type() == torch::kFloat32, "predictions must be float32");
    TORCH_CHECK(targets.scalar_type() == torch::kFloat32, "targets must be float32");

    int N = predictions.size(0);
    int D = predictions.size(1);
    auto output = torch::zeros({1}, predictions.options());

    // Choose block size based on D dimension
    if (D <= 256) {
        block_tuned_cosine_similarity_loss_kernel<128><<<N, 128>>>(
            predictions.data_ptr<float>(),
            targets.data_ptr<float>(),
            output.data_ptr<float>(),
            N, D);
    } else if (D <= 512) {
        block_tuned_cosine_similarity_loss_kernel<256><<<N, 256>>>(
            predictions.data_ptr<float>(),
            targets.data_ptr<float>(),
            output.data_ptr<float>(),
            N, D);
    } else if (D <= 1024) {
        block_tuned_cosine_similarity_loss_kernel<384><<<N, 384>>>(
            predictions.data_ptr<float>(),
            targets.data_ptr<float>(),
            output.data_ptr<float>(),
            N, D);
    } else {
        block_tuned_cosine_similarity_loss_kernel<512><<<N, 512>>>(
            predictions.data_ptr<float>(),
            targets.data_ptr<float>(),
            output.data_ptr<float>(),
            N, D);
    }

    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &block_tuned_cosine_similarity_loss_forward, "Block Tuned Cosine Similarity Loss Forward (CUDA)");
}
Performance Metrics
Metric Value Unit Variance Samples
Executed Ipc Active 0.628 inst/cycle 0.000 5
Executed Ipc Elapsed 0.346 inst/cycle 0.000 5
Issue Slots Busy 16.022 % 0.005 5
Issued Ipc Active 0.640 inst/cycle 0.000 5
SM Busy 16.022 % 0.005 5
Memory Throughput 802288065874.968 byte/second 15284305688874254336.000 5
Mem Busy 13.906 % 0.007 5
Max Bandwidth 24.090 % 0.021 5
L1/TEX Hit Rate 0.000 % 0.000 5
L2 Hit Rate 18.466 % 0.000 5
Mem Pipes Busy 6.766 % 0.001 5
Warp Cycles Per Issued Instruction 23.818 cycle 1.095 5
Warp Cycles Per Executed Instruction 24.338 cycle 1.148 5
Avg. Active Threads Per Warp 29.270 0.000 5
Avg. Not Predicated Off Threads Per Warp 28.050 0.000 5
Max Active Clusters 0.000 cluster 0.000 5
Max Cluster Size 8.000 block 0.000 5
Overall GPU Occupancy 0.000 % 0.000 5
Cluster Occupancy 0.000 % 0.000 5
Block Limit SM 32.000 block 0.000 5
Block Limit Registers 4.000 block 0.000 5
Block Limit Shared Mem 12.000 block 0.000 5
Block Limit Warps 4.000 block 0.000 5
Theoretical Active Warps per SM 64.000 warp 0.000 5
Theoretical Occupancy 100.000 % 0.000 5
Achieved Occupancy 23.428 % 0.000 5
Achieved Active Warps Per SM 14.994 warp 0.000 5
Analysis Rules
Rule Description
WRN HighPipeUtilization All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details.
INF CPIStall Check the Warp Stall Sampling (All Cycles) table for the top stall locations in your source based on sampling data. The Kernel Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details on each stall reason.
WRN Occupancy This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (23.4%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy.
Operation / Metric Value Unit
aten::to
CPU Time 509941.75 μs
Device Time 293.85 μs
Self CPU Time 45.99 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zeros
CPU Time 5891938.15 μs
Device Time 224245.15 μs
Self CPU Time 149734.82 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::zero_
CPU Time 6213956.97 μs
Device Time 7563679.29 μs
Self CPU Time 315783.32 μs
Self Device Time 0.00 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
aten::fill_
CPU Time 5898176.19 μs
Device Time 7563679.29 μs
Self CPU Time 398647.99 μs
Self Device Time 7563679.29 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaLaunchKernel
CPU Time 5869405.95 μs
Device Time 2921.88 μs
Self CPU Time 5869405.95 μs
Self Device Time 2921.88 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void block_tuned_cosine_similarity_loss_kernel<512>(float const*, float const*, float*, int, int)
CPU Time 0.00 μs
Device Time 461056.96 μs
Self CPU Time 0.00 μs
Self Device Time 461056.96 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
cudaEventRecord
CPU Time 252923.70 μs
Device Time 1217113.43 μs
Self CPU Time 252923.70 μs
Self Device Time 1217113.43 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>)
CPU Time 0.00 μs
Device Time 7339434.14 μs
Self CPU Time 0.00 μs
Self Device Time 7339434.14 μs
CPU Memory Usage 0 B
Device Memory Usage 0 B
Self CPU Memory Usage 0 B
Self Device Memory Usage 0 B
Status: Completed
45286 warnings generated when compiling for host.
Suppressed 45322 warnings (45275 in non-user code, 47 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_97/b6_s3_block_tuned_cosine_loss_base/base/base.cu:6:59 bugprone-easily-swappable-parameters
6 | __global__ void block_tuned_cosine_similarity_loss_kernel(const float* __restrict__ predictions,
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
7 | const float* __restrict__ targets,
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_97/b6_s3_block_tuned_cosine_loss_base/base/base.cu:6:85: note: the first parameter in the range is 'predictions'
6 | __global__ void block_tuned_cosine_similarity_loss_kernel(const float* __restrict__ predictions,
| ^~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_97/b6_s3_block_tuned_cosine_loss_base/base/base.cu:7:85: note: the last parameter in the range is 'targets'
7 | const float* __restrict__ targets,
| ^~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_97/b6_s3_block_tuned_cosine_loss_base/base/base.cu:9:59: warning: 2 adjacent parameters of 'block_tuned_cosine_similarity_loss_kernel' of similar type ('const int') are easily swapped by mistake [bugprone-easily-swappable-parameters]
9 | const int N,
| ^~~~~~~~~~~~
10 | const int D) {
| ~~~~~~~~~~~
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_97/b6_s3_block_tuned_cosine_loss_base/base/base.cu:9:69: note: the first parameter in the range is 'N'
9 | const int N,
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_97/b6_s3_block_tuned_cosine_loss_base/base/base.cu:10:69: note: the last parameter in the range is 'D'
10 | const int D) {
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_97/b6_s3_block_tuned_cosine_loss_base/base/base.cu:12:21: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
12 | const int row = blockIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_97/b6_s3_block_tuned_cosine_loss_base/base/base.cu:13:21: warning: narrowing conversion from 'unsigned int' to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
13 | const int tid = threadIdx.x;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_97/b6_s3_block_tuned_cosine_loss_base/base/base.cu:17:29: warning: result of multiplication in type 'int' is used as a pointer offset after an implicit widening conversion to type 'ptrdiff_t' [bugprone-implicit-widening-of-multiplication-result]
17 | const float* pred_row = predictions + row * D;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_97/b6_s3_block_tuned_cosine_loss_base/base/base.cu:17:43: note: make conversion explicit to silence this warning
4 |
5 | template<int BLOCK_SIZE>
6 | __global__ void block_tuned_cosine_similarity_loss_kernel(const float* __restrict__ predictions,
7 | const float* __restrict__ targets,
8 | float* output,
9 | const int N,
10 | const int D) {
11 | constexpr int WARPS_PER_BLOCK = BLOCK_SIZE / 32;
12 | const int row = blockIdx.x;
13 | const int tid = threadIdx.x;
14 | const int warp_id = tid / warpSize;
15 | const int lane_id = tid % warpSize;
16 |
17 | const float* pred_row = predictions + row * D;
| ^~~~~~~
| static_cast<ptrdiff_t>( )
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_97/b6_s3_block_tuned_cosine_loss_base/base/base.cu:17:43: note: perform multiplication in a wider type
17 | const float* pred_row = predictions + row * D;
| ^~~
| static_cast<ptrdiff_t>( )
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_97/b6_s3_block_tuned_cosine_loss_base/base/base.cu:18:31: warning: result of multiplication in type 'int' is used as a pointer offset after an implicit widening conversion to type 'ptrdiff_t' [bugprone-implicit-widening-of-multiplication-result]
18 | const float* target_row = targets + row * D;
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_97/b6_s3_block_tuned_cosine_loss_base/base/base.cu:18:41: note: make conversion explicit to silence this warning
18 | const float* target_row = targets + row * D;
| ^~~~~~~
| static_cast<ptrdiff_t>( )
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_97/b6_s3_block_tuned_cosine_loss_base/base/base.cu:18:41: note: perform multiplication in a wider type
18 | const float* target_row = targets + row * D;
| ^~~
| static_cast<ptrdiff_t>( )
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_97/b6_s3_block_tuned_cosine_loss_base/base/base.cu:81:50: warning: narrowing conversion from 'int' to 'float' [bugprone-narrowing-conversions]
81 | atomicAdd(output, (1.0f - cos_sim) / N);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_97/b6_s3_block_tuned_cosine_loss_base/base/base.cu:86:72: warning: the parameter 'predictions' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
86 | torch::Tensor block_tuned_cosine_similarity_loss_forward(torch::Tensor predictions, torch::Tensor targets) {
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_97/b6_s3_block_tuned_cosine_loss_base/base/base.cu:86:99: warning: the parameter 'targets' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param]
86 | torch::Tensor block_tuned_cosine_similarity_loss_forward(torch::Tensor predictions, torch::Tensor targets) {
| ^
| const &
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_97/b6_s3_block_tuned_cosine_loss_base/base/base.cu:93:13: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
93 | int N = predictions.size(0);
| ^
/home/robert_sakana_ai/llm_cuda/experiments/20250202_optimize_b10_s4_e0_sweep/level_1/task_97/b6_s3_block_tuned_cosine_loss_base/base/base.cu:94:13: warning: narrowing conversion from 'int64_t' (aka 'long') to signed type 'int' is implementation-defined [bugprone-narrowing-conversions]
94 | int D = predictions.size(1);
| ^