← Back to Leaderboard

The AI CUDA Engineer 👷

33_VanillaRNNoptimized_rnn_reduction_base

Level 3 • Task 33

Kernel Information

Related Kernels (Level 3, Task 33 • 33_VanillaRNN)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 fused_rnn_i2h_warp_base 0.02 1.21 2.67
🥈 warp_optimized_rnn_base 0.02 1.15 2.56
🥈 optimized_rnn_reduction_base 0.02 1.15 2.56
4 atomic_rnn_optimized_edit_1 0.02 1.11 2.45
4 modular_warp_rnn_base 0.02 1.11 2.45
6 balanced_load_rnn_base_base 0.03 0.83 1.84
6 optimized_concat_kernel_base 0.03 0.83 1.84
6 optimized_unroll_concat_base 0.03 0.83 1.84
6 shared_memory_optimized_edit_1 0.03 0.83 1.84
6 stride_loops_rnn_base 0.03 0.83 1.84
6 optimal_blocksize_rnn_edit_1 0.03 0.83 1.84
6 modular_vanillarnn_edit_1 0.03 0.83 1.84
6 unroll_optimized_rnn_base_base 0.03 0.83 1.84
6 optimized_concat_base 0.03 0.83 1.84
15 unrolled_rnn_base_base 0.03 0.80 1.78
15 efficient_concat_base 0.03 0.80 1.78
15 sync_optimized_rnn_base_base 0.03 0.80 1.78
15 atomic_optimized_rnn_base 0.03 0.80 1.78
15 warp_aligned_rnn_base 0.03 0.80 1.78
15 optimized_concat_kernel_edit_1 0.03 0.80 1.78
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>

// Kernel to concatenate x and hidden into a combined tensor
__global__ void concat_kernel(
    const float* __restrict__ x,
    const float* __restrict__ hidden,
    float* __restrict__ combined,
    int batch_size,
    int x_size,
    int hidden_size,
    int total_elements
) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int combined_width = x_size + hidden_size;
    for (; idx < total_elements; idx += blockDim.x * gridDim.x) {
        int row = idx / combined_width;
        int col = idx % combined_width;
        if (col < x_size) {
            combined[idx] = x[row * x_size + col];
        } else {
            combined[idx] = hidden[row * hidden_size + (col - x_size)];
        }
    }
}

// Kernel for computing the linear transformation with tanh activation
// This kernel computes: out[row, col] = tanh( bias[col] + dot( A[row, :], weight[col, :] ) )
// where A is the combined tensor of shape [B, K] and weight is i2h_weight of shape [M, K] (row-major).
// Each warp (32 threads) cooperatively computes one output element using warp-level reduction (__shfl_down_sync).
__global__ void linear_tanh_kernel(
    const float* __restrict__ A,       // Combined tensor, shape [B, K]
    const float* __restrict__ weight,  // i2h_weight, shape [M, K] (row-major)
    const float* __restrict__ bias,    // i2h_bias, shape [M]
    float* __restrict__ out,           // Output tensor, shape [B, M]
    int B, int K, int M                // Dimensions: batch, input features, output neurons
) {
    int global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;
    int warp_id = global_thread_id / 32; // each warp computes one output element
    int lane_id = global_thread_id % 32;
    int row = warp_id / M;              // batch index
    int col = warp_id % M;              // neuron index

    if (row >= B) return;

    float sum = 0.0f;
    const float* a_row = A + row * K;   // Pointer to the beginning of the row in combined
    const float* w_row = weight + col * K; // weight is stored row-major; row 'col' of weight

    // Each thread in the warp processes a strided portion of the K dimension
    for (int k = lane_id; k < K; k += 32) {
        sum += a_row[k] * w_row[k];
    }

    // Warp-level reduction using shuffle operations
    for (int offset = 16; offset > 0; offset /= 2) {
        sum += __shfl_down_sync(0xffffffff, sum, offset);
    }

    // The first lane writes the result after adding bias and applying tanh
    if (lane_id == 0) {
        float val = sum + bias[col];
        out[row * M + col] = tanhf(val);
    }
}

// Main function which launches the kernels
torch::Tensor module_fn_cuda(
    torch::Tensor x,
    torch::Tensor i2h_weight,
    torch::Tensor i2h_bias,
    torch::Tensor h2o_weight,
    torch::Tensor h2o_bias,
    torch::Tensor hidden
) {
    // Ensure all tensors are contiguous and on CUDA
    x = x.contiguous().cuda();
    i2h_weight = i2h_weight.contiguous().cuda();
    i2h_bias = i2h_bias.contiguous().cuda();
    h2o_weight = h2o_weight.contiguous().cuda();
    h2o_bias = h2o_bias.contiguous().cuda();
    hidden = hidden.contiguous().cuda();

    int batch_size = x.size(0);
    int x_size = x.size(1);
    int hidden_input_size = hidden.size(1);
    int combined_width = x_size + hidden_input_size;

    // Allocate the combined tensor
    auto options = torch::TensorOptions().dtype(x.dtype()).device(x.device());
    torch::Tensor combined = torch::empty({batch_size, combined_width}, options);

    int total_elements = batch_size * combined_width;
    int threads = 256;
    int blocks = (total_elements + threads - 1) / threads;
    concat_kernel<<<blocks, threads>>>(
        x.data_ptr<float>(),
        hidden.data_ptr<float>(),
        combined.data_ptr<float>(),
        batch_size,
        x_size,
        hidden_input_size,
        total_elements
    );

    // Compute the linear transformation with tanh activation for the i2h layer
    // i2h_weight has shape [M, combined_width] and i2h_bias has shape [M]
    int M = i2h_weight.size(0); // output neurons
    int K = combined_width;     // input dimensionality for the transformation

    // Allocate the hidden state tensor after transformation
    torch::Tensor hidden_new = torch::empty({batch_size, M}, options);

    // Each warp (32 threads) computes one output element, so total warps = batch_size * M
    int total_warps = batch_size * M;
    int total_threads = total_warps * 32; // 32 threads per warp
    int threads_per_block = 256;
    int grid = (total_threads + threads_per_block - 1) / threads_per_block;

    linear_tanh_kernel<<<grid, threads_per_block>>>(
        combined.data_ptr<float>(),
        i2h_weight.data_ptr<float>(),
        i2h_bias.data_ptr<float>(),
        hidden_new.data_ptr<float>(),
        batch_size,
        K,
        M
    );

    // Final output: compute the h2o layer: output = h2o_bias + hidden_new * h2o_weight.t()
    torch::Tensor output = torch::addmm(h2o_bias, hidden_new, h2o_weight.t());
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &module_fn_cuda, "Optimized RNN forward with reduction (CUDA)");
}