← Back to Leaderboard

The AI CUDA Engineer 👷

29_Matmul_Mish_Mishoptimized_tiled_matmul_mish_base

Level 2 • Task 29

Kernel Information

Related Kernels (Level 2, Task 29 • 29_Matmul_Mish_Mish)

Rank Kernel Name Runtime (ms) Speedup Native Speedup Compile
🥇 29_Matmul_Mish_Mish 0.01 3.65 9.33
🥇 aligned_ldg_29_matmul_mish_mish_base 0.01 3.65 9.33
🥇 optimized_ldg_matmul_mish_base 0.01 3.65 9.33
🥇 stride_loop_optimized_matmul_mish_base 0.01 3.65 9.33
🥇 optimized_tiled_kernel_base 0.01 3.65 9.33
🥇 uniform_control_flow_optimized_matmul_mish_base 0.01 3.65 9.33
🥇 matmul_mish_coalesced_base 0.01 3.65 9.33
🥇 fast_mish_tiled_base 0.01 3.65 9.33
🥇 unrolled_tiled_matmul_mish_base 0.01 3.65 9.33
🥇 matmul_mish_unroll_edit_1 0.01 3.65 9.33
🥇 matmul_mish_aligned_ldg_base 0.01 3.65 9.33
🥇 matmul_mish_aligned_ldg_edit_1 0.01 3.65 9.33
🥇 matmul_mish_coalesced_edit_1 0.01 3.65 9.33
🥇 modular_matmul_mish_base 0.01 3.65 9.33
🥇 strided_thread_parallel_base 0.01 3.65 9.33
🥇 strided_thread_parallel_edit_1 0.01 3.65 9.33
🥇 modular_strided_thread_parallel_base 0.01 3.65 9.33
🥇 warp_reduce_dot_product_base_base 0.01 3.65 9.33
🥇 warp_reduction_dot_base 0.01 3.65 9.33
🥇 tuned_block_size_128_base 0.01 3.65 9.33
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cmath>

#define TILE_DIM 16

__device__ float softplus_combined(float x) {
    float abs_x = fabsf(x);
    float z = expf(-abs_x);
    return fmaxf(x, 0.0f) + log1pf(z);
}

__device__ float mish_combined(float x) {
    float sp = softplus_combined(x);
    return x * tanhf(sp);
}

__global__ void optimized_forward_kernel(
    const float* __restrict__ x,
    const float* __restrict__ weight,
    const float* __restrict__ bias,
    float* __restrict__ output,
    int M, // batch size
    int K, // in_features
    int N  // out_features
) {
    int row = blockIdx.y * TILE_DIM + threadIdx.y;
    int col = blockIdx.x * TILE_DIM + threadIdx.x;

    float sum = 0.0f;

    __shared__ float As[TILE_DIM][TILE_DIM];
    __shared__ float Bs[TILE_DIM][TILE_DIM];

    int numTiles = (K + TILE_DIM - 1) / TILE_DIM;
    for (int t = 0; t < numTiles; t++) {
        int tiledCol = t * TILE_DIM + threadIdx.x;
        As[threadIdx.y][threadIdx.x] = (row < M && tiledCol < K) ? x[row * K + tiledCol] : 0.0f;

        int tiledRow = t * TILE_DIM + threadIdx.y;
        Bs[threadIdx.y][threadIdx.x] = (col < N && tiledRow < K) ? weight[col * K + tiledRow] : 0.0f;

        __syncthreads();

        for (int i = 0; i < TILE_DIM; i++) {
            sum += As[threadIdx.y][i] * Bs[i][threadIdx.x];
        }
        __syncthreads();
    }

    if (row < M && col < N) {
        float val = sum + bias[col];
        float mish1 = mish_combined(val);
        output[row * N + col] = mish_combined(mish1);
    }
}

torch::Tensor optimized_forward(
    torch::Tensor x,
    torch::Tensor weight,
    torch::Tensor bias
) {
    TORCH_CHECK(x.dim() == 2, "x must be 2D");
    TORCH_CHECK(weight.dim() == 2, "weight must be 2D");
    TORCH_CHECK(bias.dim() == 1, "bias must be 1D");

    int M = x.size(0);
    int K = x.size(1);
    int N = weight.size(0);

    TORCH_CHECK(weight.size(1) == K, "weight shape mismatch");
    TORCH_CHECK(bias.size(0) == N, "bias shape mismatch");

    auto output = torch::empty({M, N}, x.options());

    dim3 blockDim(TILE_DIM, TILE_DIM);
    dim3 gridDim((N + TILE_DIM - 1) / TILE_DIM, (M + TILE_DIM - 1) / TILE_DIM);

    optimized_forward_kernel<<<gridDim, blockDim, 0, at::cuda::getCurrentCUDAStream()>>>(
        x.data_ptr<float>(),
        weight.data_ptr<float>(),
        bias.data_ptr<float>(),
        output.data_ptr<float>(),
        M, K, N
    );

    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &optimized_forward, "Optimized Tiled Matmul Mish Mish forward (CUDA)");
}