← Back to Leaderboard

The AI CUDA Engineer 👷

35_LTSM35_lstm_workload_balanced_base

Level 3 • Task 35

Kernel Information

Related Kernels (Level 3, Task 35 • 35_LTSM)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 35_lstm_grid_stride_base_base 72.97 0.44 0.83
🥈 35_lstm_modular_device_edit_1 75.07 0.43 0.81
🥉 35_lstm_shared_memory_base 86.54 0.37 0.70
4 35_lstm_atomic_reduction_base_base 86.99 0.37 0.69
5 35_lstm_workload_balanced_base 87.77 0.36 0.69
6 35_lstm_aligned_base 88.03 0.36 0.69
7 35_lstm_tiled_unroll_edit_1 88.19 0.36 0.69
8 35_lstm_load_balancing_base 88.28 0.36 0.68
9 fused_tiled_base 88.40 0.36 0.68
10 35_lstm_ldg_aligned_v2_base 88.50 0.36 0.68
11 35_lstm_load_balancing_edit_1 88.68 0.36 0.68
12 35_LTSM 88.90 0.36 0.68
13 35_lstm_memory_coalescing_edit_1 89.05 0.36 0.68
14 modular_35_ltsm_base 89.17 0.36 0.68
15 35_lstm_shared_memory_edit_1 89.34 0.36 0.68
16 fused_tiled_edit_1 89.35 0.36 0.68
17 35_lstm_unrolled_base 89.58 0.36 0.67
18 35_lstm_memory_coalescing_base 89.77 0.36 0.67
19 35_lstm_warp_reduce_base 89.78 0.36 0.67
20 35_lstm_warp_aligned_base 89.81 0.36 0.67
#include <torch/extension.h>
#include <vector>
#include <cmath>

// Optimized device functions
__device__ __forceinline__ float sigmoid_fast(float x) {
    return 1.0f / (1.0f + __expf(-x));
}

// Optimized LSTM kernel with better workload distribution
__global__ void lstm_elementwise_forward(
    const float* __restrict__ gates,
    const float* __restrict__ prev_c,
    float* __restrict__ h,
    float* __restrict__ c,
    const int batch_size,
    const int hidden_size
) {
    const int tid = threadIdx.x;
    const int bid = blockIdx.x;
    const int num_threads = blockDim.x * gridDim.x;
    const int total_elements = batch_size * hidden_size;
    
    // Each thread processes multiple elements in a strided fashion
    for (int idx = bid * blockDim.x + tid; idx < total_elements; idx += num_threads) {
        const int b = idx / hidden_size;
        const int n = idx % hidden_size;
        const int gate_offset = b * hidden_size * 4 + n;
        
        // Coalesced memory access for gates
        const float i_gate = sigmoid_fast(gates[gate_offset]);
        const float f_gate = sigmoid_fast(gates[gate_offset + hidden_size]);
        const float g_gate = tanhf(gates[gate_offset + 2 * hidden_size]);
        const float o_gate = sigmoid_fast(gates[gate_offset + 3 * hidden_size]);
        
        const float c_prev = prev_c[idx];
        const float c_new = f_gate * c_prev + i_gate * g_gate;
        const float h_new = o_gate * tanhf(c_new);
        
        c[idx] = c_new;
        h[idx] = h_new;
    }
}

// Optimized linear kernel with balanced workload
__global__ void linear_forward_balanced(
    const float* __restrict__ input,
    const float* __restrict__ weight,
    const float* __restrict__ bias,
    float* __restrict__ output,
    const int batch_size,
    const int in_features,
    const int out_features
) {
    extern __shared__ float shmem[];
    
    const int tid = threadIdx.x;
    const int wid = tid / 32;  // warp ID
    const int lane = tid % 32;  // lane within warp
    const int num_warps = blockDim.x / 32;
    
    for (int out_idx = blockIdx.x; out_idx < batch_size * out_features; out_idx += gridDim.x) {
        const int batch = out_idx / out_features;
        const int feat = out_idx % out_features;
        
        float sum = 0.0f;
        const float* in_row = input + batch * in_features;
        const float* w_row = weight + feat * in_features;
        
        // Each warp processes a chunk of the input features
        for (int k = lane; k < in_features; k += 32) {
            sum += in_row[k] * w_row[k];
        }
        
        // Warp reduction
        #pragma unroll
        for (int offset = 16; offset > 0; offset /= 2) {
            sum += __shfl_down_sync(0xffffffff, sum, offset);
        }
        
        // First thread in warp writes result
        if (lane == 0) {
            float final_sum = sum;
            if (bias != nullptr) {
                final_sum += bias[feat];
            }
            output[out_idx] = final_sum;
        }
    }
}

torch::Tensor lstm_forward_cuda(
    torch::Tensor input,
    torch::Tensor w_ih,
    torch::Tensor w_hh,
    torch::Tensor b_ih,
    torch::Tensor b_hh,
    torch::Tensor h0,
    torch::Tensor c0
) {
    const int batch_size = input.size(0);
    const int seq_len = input.size(1);
    const int hidden_size = h0.size(1);
    
    auto h = h0.clone();
    auto c = c0.clone();
    std::vector<torch::Tensor> outputs;
    
    // Optimize thread block size for H100
    const int threads_per_block = 256;
    const int num_sms = 132;  // H100 has 132 SMs
    const int blocks_per_sm = 16;
    const int total_blocks = num_sms * blocks_per_sm;
    
    for (int t = 0; t < seq_len; t++) {
        auto xt = input.select(1, t);
        auto gates = torch::addmm(b_ih, xt, w_ih.t());
        gates = torch::addmm(gates, h, w_hh.t());
        gates += b_hh;
        
        lstm_elementwise_forward<<<total_blocks, threads_per_block>>>(
            gates.data_ptr<float>(),
            c.data_ptr<float>(),
            h.data_ptr<float>(),
            c.data_ptr<float>(),
            batch_size,
            hidden_size
        );
        
        outputs.push_back(h.unsqueeze(1));
    }
    
    return torch::cat(outputs, 1);
}

torch::Tensor linear_forward_cuda(
    torch::Tensor input,
    torch::Tensor weight,
    torch::Tensor bias
) {
    const int batch_size = input.size(0);
    const int in_features = input.size(1);
    const int out_features = weight.size(0);
    
    auto output = torch::empty({batch_size, out_features}, input.options());
    
    const int threads_per_block = 128;
    const int num_blocks = std::min(65535, (batch_size * out_features + threads_per_block - 1) / threads_per_block);
    
    linear_forward_balanced<<<num_blocks, threads_per_block>>>(
        input.data_ptr<float>(),
        weight.data_ptr<float>(),
        bias.defined() ? bias.data_ptr<float>() : nullptr,
        output.data_ptr<float>(),
        batch_size,
        in_features,
        out_features
    );
    
    return output;
}

torch::Tensor forward(
    torch::Tensor x,
    std::vector<torch::Tensor> lstm_weights_ih,
    std::vector<torch::Tensor> lstm_weights_hh,
    std::vector<torch::Tensor> lstm_biases_ih,
    std::vector<torch::Tensor> lstm_biases_hh,
    torch::Tensor fc_weight,
    torch::Tensor fc_bias,
    torch::Tensor h0,
    torch::Tensor c0,
    bool is_training
) {
    h0 = h0.to(x.device());
    c0 = c0.to(x.device());
    
    torch::Tensor out = x;
    const int num_layers = lstm_weights_ih.size();
    
    for (int i = 0; i < num_layers; i++) {
        auto w_ih = lstm_weights_ih[i].to(x.device());
        auto w_hh = lstm_weights_hh[i].to(x.device());
        auto b_ih = lstm_biases_ih[i].to(x.device());
        auto b_hh = lstm_biases_hh[i].to(x.device());
        
        auto h_i = h0.narrow(0, i, 1).squeeze(0);
        auto c_i = c0.narrow(0, i, 1).squeeze(0);
        
        out = lstm_forward_cuda(out, w_ih, w_hh, b_ih, b_hh, h_i, c_i);
    }
    
    out = out.select(1, -1);
    out = linear_forward_cuda(out, fc_weight, fc_bias);
    
    return out;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "LSTM forward with balanced workload distribution");
}