import torch
import torch.nn as nn
import torch.nn.functional as F
def module_fn(x: torch.Tensor, dim: int) -> torch.Tensor:
"""
Applies LogSoftmax activation to the input tensor.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, dim)
dim (int): Dimension along which to apply LogSoftmax
Returns:
torch.Tensor: Output tensor with LogSoftmax applied, same shape as input
"""
return F.log_softmax(x, dim=dim)
class Model(nn.Module):
"""
Simple model that performs a LogSoftmax activation.
"""
def __init__(self, dim):
super(Model, self).__init__()
self.dim = dim
def forward(self, x: torch.Tensor, fn=module_fn) -> torch.Tensor:
return fn(x, self.dim)
batch_size = 16
dim = 16384
sm_dim = 1
def get_inputs():
x = torch.randn(batch_size, dim)
return [x]
def get_init_inputs():
return [sm_dim]
import torch
import torch.nn as nn
class Model(nn.Module):
"""
Simple model that performs a LogSoftmax activation.
"""
def __init__(self, dim: int = 1):
super(Model, self).__init__()
self.dim = dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Applies LogSoftmax activation to the input tensor.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, dim).
Returns:
torch.Tensor: Output tensor with LogSoftmax applied, same shape as input.
"""
return torch.log_softmax(x, dim=self.dim)
batch_size = 16
dim = 16384
sm_dim = 1
def get_inputs():
x = torch.randn(batch_size, dim)
return [x]
def get_init_inputs():
return [sm_dim]
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include <limits>
#include <cmath>
// Combined kernel: uses compile-time block size tuning and warp-level reductions
// to efficiently compute the LogSoftmax over the last dimension of the input tensor.
template <typename scalar_t, int BLOCK_SIZE>
__global__ void combined_logsoftmax_kernel(
const scalar_t* __restrict__ input,
scalar_t* __restrict__ output,
int dim_size) {
// Each block processes one row (batch element)
int row = blockIdx.x;
const scalar_t* input_row = input + row * dim_size;
scalar_t* output_row = output + row * dim_size;
// Phase 1: Compute the maximum value using warp-level reduction
scalar_t thread_max = -std::numeric_limits<scalar_t>::infinity();
// Each thread processes multiple elements
for (int idx = threadIdx.x; idx < dim_size; idx += BLOCK_SIZE) {
scalar_t val = input_row[idx];
thread_max = (val > thread_max) ? val : thread_max;
}
// Warp-level reduction for maximum using shuffle intrinsics
unsigned int mask = 0xffffffff;
for (int offset = warpSize/2; offset > 0; offset /= 2) {
scalar_t other = __shfl_down_sync(mask, thread_max, offset);
thread_max = (other > thread_max) ? other : thread_max;
}
// Shared memory to gather per-warp maximums
__shared__ scalar_t warp_max[32]; // Supports up to 32 warps per block
int warp_id = threadIdx.x / warpSize;
int lane = threadIdx.x % warpSize;
if (lane == 0) {
warp_max[warp_id] = thread_max;
}
__syncthreads();
// Thread 0 computes the block-wide maximum from warp results
scalar_t max_val = warp_max[0];
if (threadIdx.x == 0) {
int num_warps = (BLOCK_SIZE + warpSize - 1) / warpSize;
for (int i = 1; i < num_warps; i++) {
max_val = (warp_max[i] > max_val) ? warp_max[i] : max_val;
}
// Store global max in warp_max[0] for broadcast
warp_max[0] = max_val;
}
__syncthreads();
max_val = warp_max[0];
// Phase 2: Compute the sum of exponentials (with numerical stability)
scalar_t thread_sum = 0;
for (int idx = threadIdx.x; idx < dim_size; idx += BLOCK_SIZE) {
thread_sum += exp(input_row[idx] - max_val);
}
// Warp-level reduction for sum
for (int offset = warpSize/2; offset > 0; offset /= 2) {
thread_sum += __shfl_down_sync(mask, thread_sum, offset);
}
// Use shared memory to gather per-warp sums
__shared__ scalar_t warp_sum[32];
if (lane == 0) {
warp_sum[warp_id] = thread_sum;
}
__syncthreads();
// Thread 0 sums the warp results to get the total sum
scalar_t total_sum = 0;
if (threadIdx.x == 0) {
int num_warps = (BLOCK_SIZE + warpSize - 1) / warpSize;
for (int i = 0; i < num_warps; i++) {
total_sum += warp_sum[i];
}
warp_sum[0] = total_sum; // broadcast the total sum
}
__syncthreads();
total_sum = warp_sum[0];
scalar_t log_sum = log(total_sum);
// Phase 3: Compute the final LogSoftmax values and write back
for (int idx = threadIdx.x; idx < dim_size; idx += BLOCK_SIZE) {
output_row[idx] = (input_row[idx] - max_val) - log_sum;
}
}
// Host function: Permutes input tensor so that the specified dimension is last,
// selects an optimal block size based on the dimension size, launches the kernel,
// and then inversely permutes the output to the original layout.
torch::Tensor combined_logsoftmax_cuda_forward(torch::Tensor input, int64_t dim) {
TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor");
TORCH_CHECK(input.scalar_type() == torch::kFloat32 || input.scalar_type() == torch::kFloat64,
"input must be float32 or float64");
int64_t ndim = input.dim();
TORCH_CHECK(dim >= -ndim && dim < ndim, "dim out of range");
dim = (dim >= 0) ? dim : dim + ndim;
// Permute input so that the target dimension is the last dimension
std::vector<int64_t> permute_dims;
for (int64_t i = 0; i < ndim; ++i) {
if (i != dim) {
permute_dims.push_back(i);
}
}
permute_dims.push_back(dim);
input = input.permute(permute_dims).contiguous();
int64_t batch_size = input.numel() / input.size(-1);
int64_t dim_size = input.size(-1);
auto output = torch::empty_like(input);
// Select an optimal block size from {32, 64, 128, 256, 512} based on dim_size
int optimal_block_size = 256; // default value
if (dim_size <= 32) {
optimal_block_size = 32;
} else if (dim_size <= 64) {
optimal_block_size = 64;
} else if (dim_size <= 128) {
optimal_block_size = 128;
} else if (dim_size <= 256) {
optimal_block_size = 256;
} else if (dim_size <= 512) {
optimal_block_size = 512;
} else {
optimal_block_size = 512; // for larger dims, cap at 512 threads per block
}
int blocks = batch_size;
dim3 grid(blocks);
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "combined_logsoftmax_cuda_forward", ([&] {
if (optimal_block_size == 32) {
combined_logsoftmax_kernel<scalar_t, 32><<<grid, 32>>>(
input.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(),
dim_size);
} else if (optimal_block_size == 64) {
combined_logsoftmax_kernel<scalar_t, 64><<<grid, 64>>>(
input.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(),
dim_size);
} else if (optimal_block_size == 128) {
combined_logsoftmax_kernel<scalar_t, 128><<<grid, 128>>>(
input.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(),
dim_size);
} else if (optimal_block_size == 256) {
combined_logsoftmax_kernel<scalar_t, 256><<<grid, 256>>>(
input.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(),
dim_size);
} else if (optimal_block_size == 512) {
combined_logsoftmax_kernel<scalar_t, 512><<<grid, 512>>>(
input.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(),
dim_size);
}
}));
// Inverse permutation to restore the original tensor layout
std::vector<int64_t> inverse_permute_dims(ndim);
for (size_t i = 0; i < permute_dims.size(); ++i) {
inverse_permute_dims[permute_dims[i]] = i;
}
output = output.permute(inverse_permute_dims);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &combined_logsoftmax_cuda_forward, "Combined LogSoftmax forward (CUDA)");
}
Metric | Value | Unit | Variance | Samples |
---|---|---|---|---|
Executed Ipc Active | 1.112 | inst/cycle | 0.000 | 5 |
Executed Ipc Elapsed | 0.100 | inst/cycle | 0.000 | 5 |
Issue Slots Busy | 27.850 | % | 0.035 | 5 |
Issued Ipc Active | 1.114 | inst/cycle | 0.000 | 5 |
SM Busy | 27.850 | % | 0.035 | 5 |
Memory Throughput | 133058544578.312 | byte/second | 337237587015650368.000 | 5 |
Mem Busy | 6.216 | % | 0.002 | 5 |
Max Bandwidth | 5.872 | % | 0.002 | 5 |
L1/TEX Hit Rate | 50.000 | % | 0.000 | 5 |
L2 Hit Rate | 68.484 | % | 0.038 | 5 |
Mem Pipes Busy | 2.212 | % | 0.000 | 5 |
Warp Cycles Per Issued Instruction | 13.968 | cycle | 0.060 | 5 |
Warp Cycles Per Executed Instruction | 14.012 | cycle | 0.062 | 5 |
Avg. Active Threads Per Warp | 31.520 | 0.000 | 5 | |
Avg. Not Predicated Off Threads Per Warp | 30.300 | 0.000 | 5 | |
Max Active Clusters | 0.000 | cluster | 0.000 | 5 |
Max Cluster Size | 8.000 | block | 0.000 | 5 |
Overall GPU Occupancy | 0.000 | % | 0.000 | 5 |
Cluster Occupancy | 0.000 | % | 0.000 | 5 |
Block Limit SM | 32.000 | block | 0.000 | 5 |
Block Limit Registers | 4.000 | block | 0.000 | 5 |
Block Limit Shared Mem | 12.000 | block | 0.000 | 5 |
Block Limit Warps | 4.000 | block | 0.000 | 5 |
Theoretical Active Warps per SM | 64.000 | warp | 0.000 | 5 |
Theoretical Occupancy | 100.000 | % | 0.000 | 5 |
Achieved Occupancy | 24.240 | % | 0.000 | 5 |
Achieved Active Warps Per SM | 15.516 | warp | 0.000 | 5 |
Rule | Description |
---|---|
WRN HighPipeUtilization | All compute pipelines are under-utilized. Either this kernel is very small or it doesn't issue enough warps per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details. |
WRN Occupancy | This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated theoretical (100.0%) and measured achieved occupancy (24.2%) can be the result of warp scheduling overheads or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block as well as across blocks of the same kernel. See the CUDA Best Practices Guide (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on optimizing occupancy. |
Operation / Metric | Value | Unit |
---|---|---|
aten::to | ||
CPU Time | 440618.22 | μs |
Device Time | 39.91 | μs |
Self CPU Time | 40.31 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::_to_copy | ||
CPU Time | 440577.91 | μs |
Device Time | 39.91 | μs |
Self CPU Time | 97.39 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::empty_strided | ||
CPU Time | 459307.26 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 19188.00 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaDeviceGetStreamPriorityRange | ||
CPU Time | 439919.93 | μs |
Device Time | 0.00 | μs |
Self CPU Time | 439919.93 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaLaunchKernel | ||
CPU Time | 435390.76 | μs |
Device Time | 20415.57 | μs |
Self CPU Time | 435390.76 | μs |
Self Device Time | 20415.57 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void combined_logsoftmax_kernel<float, 512>(float const*, float*, int) | ||
CPU Time | 0.00 | μs |
Device Time | 48564.16 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 48564.16 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
cudaEventRecord | ||
CPU Time | 21137.32 | μs |
Device Time | 39924.59 | μs |
Self CPU Time | 21137.32 | μs |
Self Device Time | 39924.59 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::zero_ | ||
CPU Time | 71799.98 | μs |
Device Time | 585338.78 | μs |
Self CPU Time | 11049.32 | μs |
Self Device Time | 0.00 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
aten::fill_ | ||
CPU Time | 60755.15 | μs |
Device Time | 585338.78 | μs |
Self CPU Time | 15161.61 | μs |
Self Device Time | 585338.78 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, at::detail::Array<char*, 1> >(int, at::native::FillFunctor<int>, at::detail::Array<char*, 1>) | ||
CPU Time | 0.00 | μs |
Device Time | 585417.31 | μs |
Self CPU Time | 0.00 | μs |
Self Device Time | 585417.31 | μs |
CPU Memory Usage | 0 | B |
Device Memory Usage | 0 | B |
Self CPU Memory Usage | 0 | B |
Self Device Memory Usage | 0 | B |
45287 warnings generated when compiling for host. Suppressed 45322 warnings (45275 in non-user code, 47 NOLINT). Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.