数据中心/云端

使用 JAX 和 XLA 优化推理工作负载的低延迟通信

在生产环境中使用大语言模型 (LLM) 进行推理需要满足严格的延迟限制。此过程的关键阶段是 LLM 解码,下一步 token 的时间成为优化的关键指标。为了最大限度地减少运行时延迟,常见的做法是在多个 GPU 之间划分推理,通常是通过对 Transformer 块中的多层感知器 (MLP) 和投影 GEMM 层应用张量并行

在解码阶段的每个步骤中,消息大小和计算需求相对较小。由于核函数调用或通信设置,任何静态用度往往会主导通信和计算时间。在这篇博客文章中,我们将分享我们开发的一些技术,以更大限度地减少导致整体解码延迟的此类用度。

我们在单节点上使用 8 路张量并行 Gemma2 LLM 模型运行推理解码阶段,该节点上有 8 个 NVIDIA H100 Tensor Core GPU 通过 NVIDIA NVLink 连接。我们观察到,张量并行层中的全归约集合成为端到端解码延迟的重大瓶颈。以大约 30 KB/s 的消息大小运行的 all-reduce 集合约占端到端解码延迟的 23%。由于计算和通信内核之间的数据依赖性,这些集合无法与之前或之后的计算重叠。

传统的 all-reduce 方法使用环形算法,该算法以环形方式围绕 N 个 GPU 进行多达 2N-2 阶段的数据交换。虽然环形算法在中大型消息大小 (大于 10 mb) 下的带宽最佳,但在通信小型消息时,环形交换中涉及的数据交换和重复的 GPU 间同步障碍会导致显著的延迟 (高达 2 倍) 。

与环形算法不同,我们实施了自定义的单步全局归约算法,其中每个秩 (设备) 聚合来自对等方的数据,并在单个阶段中执行归约。这相当于先进行全局收集,然后对全局收集的缓冲区进行局部归约。

虽然这增加了数据交换的数量,从而增加了总带宽,但由于双向 NVLink 通信,交换同时发生,整体通信延迟降低。

我们还通过在自定义的 all-reduce 内核中使用 cudaDeviceEnablePeerAccess 来直接访问在对等 GPU 上注册的缓冲区,从而避免任何额外的内存复制用度。这种实现对于在单个进程上运行的单节点多 GPU 设置特别有用,其中共享的 CUDA 上下文使跨对等 GPU 访问设备内存指针变得更加容易。

// Fused One shot All Reduce + Root Mean Square Normalization kernel 

// peer_comm_buffer: 
//        -> thread_offsets
// 	|  [00 01 02 03,...]
// 	v  [10 11 12 13,....]
//   ranks
//
// Outputs : [00+10.., 01+11.., 02+12..., 03+13..]
__global__ void OneShotARNormKernel(std::vector<T*> peer_comm_buffer_ptrs, T* sum_vec, T* weight_buffer, float eps, int hidden_size)
{

for (int ii = 0; ii < NUM_RANKS; ++ii)
{
    // One-shot All Reduce sum
    sum_vec = add(sum_vec, peer_comm_buffer_ptrs[ii][thread_offset]);
}

......

// All Reduce and Norm fusion
// Compute x^2
squares = compute_square(sum_vec);
// Compute sum(x^2) across grid
summed_squares = block_reduce_sum(squares);
// Compute RMS denominator
float denom = __fsqrt_rn(__fdividef(summed_squares, hidden_size) + eps);
// Load per-element affine param if necessary
if learnable_affine
{
    weight_vec = weight_buffer + thread_offset;
}
// rms_norm = (sum_vec / (denom)) * weight_vec
sum_vec = rms_norm(denom, sum_vec, weight_vec);
}



// Define the C++ custom call to be invoked from JAX FFI

#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"

// Global ptr vector shared amongst all ranks or threads
std::vector<void*> peer_mem_ptr(MAX_RANKS);

void AllReduceDispatcher(...) {
 // Each thread/rank in the node populates its input buffer to global ptr map
 peer_mem_ptr[rank_id] = input;
 // Synchronizing all devices so the pointer vector is fully populated
 Barrier();
 
 // Launch kernel with peer_mem_ptr
 OneShotARNormKernel<<<...>>>(peer_mem_ptr, ...)
}

// Define fusion kernel launcher func
ffi::Error customAllReduce(cudaStream_t stream,
                       ffi::AnyBuffer input,
                       ffi::AnyBuffer weight_buffer,
                       ffi::Result<ffi::AnyBuffer> sum_vec,
                       int hidden_size,
                       float eps,
                       int rank_id)
{
    // Launch dispatcher...
    AllReduceDispatcher(...)
}


// Create symbol ArNorm with C linkage that can be loaded using Python ctypes
XLA_FFI_DEFINE_HANDLER_SYMBOL(
    ArNorm, customAllReduce,
    ffi::Ffi::Bind()
        .Ctx<ffi::PlatformStream<cudaStream_t>>()  // stream
        .Arg<ffi::AnyBuffer>()    // input
        .Arg<ffi::AnyBuffer>()    // weight_buffer
        .Ret<ffi::AnyBuffer>()    // sum_vec
        .Attr<int>("hidden_size")
        .Attr<float>("eps")
        .Attr<int>("rank_id")); 
# Invoke custom call in JAX application using JAX FFI

from jax.lib import xla_client
import ctypes
# Load the library built from the c++ funtion
SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "libcustom_ar_kernel.so")

library = ctypes.cdll.LoadLibrary(SHARED_LIBRARY)
# Register the FFI function in jax
XLA_CUSTOM_CALL_TARGET_AR_NORM = "ar-norm"
xla_client.register_custom_call_target(name=XLA_CUSTOM_CALL_TARGET_AR_NORM,
                                       fn=ffi.pycapsule(library.ArNorm),
                                       platform=XLA_PLATFORM,
                                       api_version=XLA_CUSTOM_CALL_API_VERSION)

# Invoke the custom call
output = ffi.ffi_call(
    XLA_CUSTOM_CALL_TARGET_AR_NORM,
    jax.ShapeDtypeStruct(input.shape, input.dtype),  # output type
    input,                                           # input buffer
    weight_buffer,                                   # weight buffer
    hidden_size=hidden_size,
    eps=eps,
    rank_id=rank_id)

为使自定义调用与 XLA 的原生 CUDA Graph 兼容,我们在注册自定义调用处理程序时指定了 xla::ffi::Traits::kCmdBufferCompatible 特质。

// Creates symbol ArNorm with C linkage that can be loaded using Python ctypes
XLA_FFI_DEFINE_HANDLER_SYMBOL(
    ArNorm, customAllReduce,
    ffi::Ffi::Bind()
        .Ctx<ffi::PlatformStream<cudaStream_t>>()  // stream
        .Arg<ffi::AnyBuffer>()   			    // input
        .Arg<ffi::AnyBuffer>()                     // weight_buffer
        .Ret<ffi::AnyBuffer>()                     // sum_vec
        .Attr<int>("hidden_size")
        .Attr<float>("eps")
        .Attr<int>("rank_id"),
        {xla::ffi::Traits::kCmdBufferCompatible});

此单步全归约内核与相邻层归一化和逐点加法运算进一步融合到 CUDA C++ 中实现的单融合设备内核中。通过将这些计算运算与 one-shot all-reduce 融合在一起,我们可以最大限度地减少内核启动用度以及进出设备 HBM 显存的数据移动。

使用 JAX 外部函数接口,将内核作为自定义调用集成到模型实现中。与独立的 all-reduce 内核相比,融合的自定义 all – reduce 内核使我们的内核时间加快了约 3 倍,解码阶段的端到端延迟降低了约 27%。自定义内核与模型中的其他计算内核一起分组并作为单个 CUDA Graph 启动,从而最大限度地减少内核启动用度,并将解码延迟额外降低 5%。

Figure 1 shows the list of compute and All-reduce communication layers in Gemma2 decode model that get fused into a custom kernel. The model has two such sets of fused layers.
图 1。融合到自定义内核中的 Gemma2 解码模型中的计算和通信层

针对低延迟推理的其他优化

在推理的解码阶段,减少小消息大小的通信用度非常重要,特别是在数据依赖项阻止任何计算 – 通信重叠的情况下。在推理解码工作负载中,针对更高的吞吐量和更大的消息大小而优化的集合算法无法很好地扩展到更小的通信负载。

这些算法可以通过自定义实现内核进行调整,从而实现通信块与计算的融合或交错。JAX 外部函数接口允许编写此类自定义内核并插入高级模型,同时仍使用 CUDA Graphs 等 XLA/ GPU 优化。

即将推出多项功能,用于解决多 GPU 集群中运行的推理工作负载中的通信延迟问题。在 NCCL 2.27 和后续版本中迁移到对称内存模型将改善通信用度,从而为较小负载提供高达 4 倍的通信内核速度。此外,还可以使用 NVIDIA OpenSHMEM 库中提供的 GPU 发起的设备端通信 API 交错计算通信代码块,以隐藏通信延迟。

最近,Mosaic-GPU DSL 引入了表达这种交错计算通信融合模式的能力,这些模式使用 NVSHMEM 进行 GPU 发起的通信。交错计算通信块使我们能够表达高效的分布式融合核函数,例如用于张量并行 GEMM 或混合专家范式中使用的专家并行分组 GEMM 核函数。

 

标签