开发与优化

CUTLASS 3.x:用于 GEMM 内核设计的正交、可重用和组合抽象

GPU 上的 GEMM 优化是一个模块化问题。高性能实现需要指定超参数,例如图块形状、数学和复制指令以及线程束专用方案。这些超参数在很大程度上彼此独立;此外,最佳选择可能会因硬件、问题形状或其他用户需求而有显著差异。

通过重新设计 3.x,CUTLASS 旨在通过可组合、正交构建块的分层系统最大限度地覆盖 GEMM 实现空间,同时提高代码可读性,并将支持扩展到后续的 NVIDIA 架构 (如 Hopper 和 Blackwell) 。由于这种设计理念与 GPU 的分层硬件设计相关联,因此对于其他 GPU 应用程序也是一个不错的选择,例如,FlashAttention-3 在其设计中使用熟悉的 CUTLASS 抽象概念。

在 CUTLASS 博客系列的第二篇博文中,我们将探讨 CUTLASS 3.x 中 GEMM 分层系统背后的设计原则,并解压 CUTLASS 如何从第 1 部分中介绍的低级 CuTe 抽象中构建 GEMM 内核。

CUTLASS 3.x 中的新概念 GEMM 层次结构

CUTLASS 3.x 开发了一个独立于特定硬件功能的概念 GEMM 层次结构。它分为五个层:

A diagram of green shaded semi-circles that are nested within each other to depict the GEMM hierarchy concept; from Atom to Device
图 1。独立于硬件的 CUTLASS GEMM 层次结构概念图
  • 原子层:特定于架构的指令和相关的元信息 cute::Mma_Atom<>cute::Copy_Atom<>
  • Tiled MMA/ Copy:空间微核,支持架构特定原子的任意交错和平铺 cute::TiledMma<>cute::TiledCopy<>
  • 集合层:时间微核函数,使用架构特定的同步来编排执行一个或多个空间微核函数,以计算单个输出图块 cutlass::gemm::collective::CollectiveMma<>cutlass::epilogue::collective::CollectiveEpilogue<>
  • 内核层:用于在线程块/ 集群网格上执行内核的设备代码 cutlass::gemm::kernel::GemmUniversal<>
  • 设备层:主机侧设置和接口 cutlass::gemm::device::GemmUniversalAdapter<>

每个层都用作前一层抽象的合成点,可以使用模板参数进行高度定制。用户可以坚持使用最高层,信任 CUTLASS 的编译时逻辑来提供高性能 GEMM 实现,也可以选择使用较低级别的层次结构所带来的高级修改。Atom 和 Tiled MMA/ Copy 层提供的空间微核是 CuTe 的领域,我们将在第 1 部分讨论这些微核。本文的其余部分将介绍高层中提供的 GEMM 的时间级和内核级组织。

以下是如何在 CUTLASS 3.x 中定义 GEMM 内核的基本示例:

// Step 1: Generate the required collective layer mainloop specialization
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
    ArchTag, OperatorClass,
    ElementA, LayoutA, AlignmentA,
    ElementB, LayoutB, AlignmentB,
    ElementAccumulator,
    TilesShape, ClusterShape,
    cutlass::gemm::collective::StageCountAuto,
    cutlass::gemm::collective::KernelScheduleAuto
  >::CollectiveOp;

// Step 2: Specify the collective layer epilogue type
using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue<
    cutlass::gemm::TagToStrideC_t<LayoutC>,
    cutlass::gemm::TagToStrideC_t<LayoutC>,
    cutlass::epilogue::thread::LinearCombination<ElementC, 1, ElementAccumulator, ElementAccumulator>>;

// Step 3: Compose the mainloop and epilogue together at the kernel layer
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
    cute::Shape<int,int,int,int>, // ProblemShape [M,N,K,L]
    CollectiveMainloop,
    CollectiveEpilogue
>;

// Step 4: Wrap up the kernel::GemmUniversal kernel class
// with the device adapter to obtain a host-side handle to the kernel
using GemmHandle = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

集合层:Mainloop

集合是一组线程,它们相互协作以执行工作,并且可以并行重复以形成整个内核。通常,这是线程块或集群。TiledMMA 和 TiledCopy 对象用于描述并行工作进程对计算和复制工作的空间分配 (例如,线程束、线程组,甚至 Blackwell MMA 的线程块) ,而集合层则负责以时间方式组织此工作,方法是设置工作流和线程束专用方案,以及使用硬件加速同步基元来管理工作流和集合主回路的定义如下:

using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
  DispatchPolicy,
  TileShape,
  ElementA, // dtype, e.g. float
  StrideA,  // e.g. Stride<_1, int> for M-major
  ElementB, StrideB,
  TiledMma,
  GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, TransformA,
  GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, TransformB
>;

集合主回路是低层抽象的合成点:TiledMma、每个操作数的 GMEM 到 SMEM 加载的 TiledCopy,以及用于 SMEM 到 RMEM 加载的可选复制原子,用于寄存器来源的 MMA。这些抽象在很大程度上是正交的,允许将不同的 MMA 操作与不同的复制操作结合起来,同时更大限度地重复使用代码。

可以说,最重要的部分是调度策略,该策略定义了特定算法或 GPU 架构的主循环专用化。例如,调度策略 MainloopSm90TmaGmmaWarpSpecialized 将 CollectiveMma 专门用于 Hopper TMA 线程束专用实现。它本身就是一个模板,可以针对工作流阶段、集群形状和内核调度选择 (例如针对 Hopper GEMM 内核的 pingpong 或协同调度) 进行参数化。

您可以在 GEMM 集合文件夹中找到专门的集合主循环实现示例。

集合构建器

CollectiveMma 具有各种调优旋钮,允许用户根据 TiledCopy 和 TiledMma 对象精确指定 GEMM 主回路,但伴随这种灵活性,复杂性也随之增加。通常,用户希望从有关流水线、硬件功能和资源可用性的高阶考虑因素中推断出这些对象和相关的 SMEM 布局。CUTLASS 还可以使用 CollectiveBuilder 接口执行此推理。使用 CollectiveBuilder 的主循环声明如下所示:

using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
  ArchTag,       // e.g. cute::arch::Sm90 for Hopper
  OpClass,       // e.g. cute::arch::OpClassTensorOp for Tensor Cores
  ElementA, LayoutA, AlignmentA,
  ElementB, LayoutB, AlignmentB,
  ElementAccumulator,
  TileShape, ClusterShape,
  StageCount,    // e.g. cutlass::gemm::collective::StageCountAuto
  KernelSchedule // e.g. cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;

模板参数从用户友好型标准中选择,并使用它们将较低级别的参数推导至 CollectiveMma 模板:

  • 架构专业领域:GPU 架构和 MMA 运算符类型 (例如 SIMT 或 Tensor Core) 。
  • 操作数和累加器信息:操作数和累加器的数据类型,以及全局内存中操作数 (例如,行或列主) 的对齐和编译时布局信息。
  • 图块形状:用于推理 TiledMma 和 TiledCopy 对象以及 SMEM 布局。
  • 调度信息:集群形状、工作流阶段计数和内核调度均由调度算法使用。对于阶段计数和内核调度参数,有默认的“Auto” (自动) 选项,这些选项指示 CUTLASS 尝试为给定的架构和参数自动选择最佳选项。

集合层:Epilogue

集合 epilogue 是集合 API 的另一端。它负责在每次主循环迭代后对工作图的后处理和输出存储进行时间编排。与主循环一样,这意味着集合结语是复制运算 (输出存储) 和一些数学运算 (通常是元素级运算,但可能也包括归约) 的合成点。与主循环不同,这些数学运算本身通过结语访客树 (EVT) 形式高度可组合。这对于 AI 工作负载尤其有用,因为这些工作负载通常需要在 GEMM 之后立即计算激活函数。CUTLASS 的集合结语负责将此激活函数融合到内核中,从而消除不必要的数据移动。

CUTLASS GitHub 上定义了几个结语。模板参数在不同实现之间存在显著差异,但通常包括以下信息:

  • 矩阵 C 和 D 的数据类型和编译时布局信息。
  • 指定任何其他后处理的融合运算。
  • GMEM 商店和任何 SMEM 暂存的平铺复制操作。
  • 与集合主循环一样,调度策略包含有关集群大小、TMA 使用、线程束专门化等的信息。

适用于结语的 CollectiveBuilder 提供了一个更统一、更高级的界面:

using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
  ArchTag,
  OpClass,
  TileShape,
  ClusterShape,
  EpilogueTileType,
  ElementAccumulator,
  ElementCompute,
  ElementC, GmemLayoutTagC, AlignmentC,
  ElementD, GmemLayoutTagD, AlignmentD,
  EpilogueScheduleType,
  FusionOpOrCallbacks
>::CollectiveOp;

其中许多参数在主循环构建器中很常见,但也有一些是新参数:

  • 结语可以将 CTA 图块划分为更小的图块,以实现更好的数学拷贝重叠。
  • 累加器 (主回路的输出) 现在是结语的输入。结语计算可在不同的中间数据类型 (由 ElementCompute 给出) 中进行。
  • CUTLASS 提供多种常见的融合运算,例如 D = activation(alpha * AB + beta * C)。用户还可以使用 Epilogue Visitor Trees 构建定制的融合操作。有关结语访客树的更多信息,请参阅此 Colfax 教程
  • 结语调度类型定义了 TMA 和线程束专用化的用法。默认的 EpilogueScheduleAuto 指示 CUTLASS 尝试推断出最佳选项。

要了解这两个集合构建器的实际应用,我们参考了用于 Hopper 的 CUTLASS 示例 49 和用于 Blackwell 的示例 71

内核层

集合层完全定义了核函数执行期间集合所完成的计算。内核层的作用是将集合扩展到涵盖整个动态大小问题空间的线程块或集群网格上。内核层通过将加载、存储、MMA 等的基本程序拼接在一起,将集合主回路和集合结语组合到设备内核中。

内核层的入口点 API 是 cutlass::gemm::kernel::GemmUniversal 类,这是一种无状态通用设备内核,可将 GEMM 实现为集合主回路和集合结语的合成。无状态意味着调用者通过向内核传入参数来管理内核的状态。通用意味着 GemmUniversal 是 2.x 和 3.x GEMM 内核的入口点。对于 3.x API,GemmUniversal 的基本用法如下所示:

using GemmKernel = cutlass::gemm::kernel::GemmUniversal&lt;
    ProblemShape, // e.g. Shape&lt;int, int, int> for a fully generic GEMM
    CollectiveMainloop,
    CollectiveEpilogue
>;

TiledMma 和 tg_ 21 一样,tg_ 22 和 tg_ 23 是通过 tg_ 24 合成的正交抽象。第一个模板参数,即问题形状,主要用于在普通 GEMM (具有 rank-3 问题形状) 和批量 GEMM (具有 rank-4 问题形状) 之间进行选择,但如果需要,也可以静态地限制某些问题维度。

GemmUniversal 的实例化可以在 tg_ 26 形式的文件中找到,其中 tg_ 27 主要基于集合主循环的 tg_ 28 参数进行调度。所有实例化均提供一致的接口:

  • 用于向内核传递参数的接口,包括问题形状、硬件信息、张量的指针和布局,以及结语参数。
  • 静态初始化功能用于获取网格和块维度,检查内核是否可在硬件上实现,并为结语或图块调度程序所需的任何归约操作或全局屏障设置全局内存工作空间。
  • 最重要的是,它们将核函数逻辑实现为 operator()。这是一个设备函数,虽然内核层包含内核执行的所有逻辑,但尚未显示从主机启动的方法。

例如,此处定义了 Blackwell 的 TMA 线程束专用内核。

图块调度

内核层也是用于指定图块调度程序的合成点。正如内核调度程序定义集合内工作的时间安排一样,图块调度程序定义集合内工作的顺序和分布。对于最基本的图块调度程序,每个输出图块分配一个 CTA。CUTLASS 3.x 为 Hopper 实现了两个额外的图块调度程序:一个是持久性调度程序,可为每个 SM 启动一个 CTA,并让每个 CTA (可能) 在其生命周期内计算多个输出图块;另一个是 Stream-K 调度程序,它也是持久性的,但会沿 K 模式额外划分一些输出图块工作,以实现更好的负载平衡。在 Blackwell 架构中,则使用具有集群启动控制的调度程序。有关图块调度的更多深入信息,请参阅此 Colfax 教程

我们可以使用以下命令扩展上述核函数以使用 Stream-K 图块调度程序:

using GemmKernel = cutlass::gemm::kernel::GemmUniversal&lt;
    cute::Shape&lt;int,int,int,int>,
    CollectiveMainloop,
    CollectiveEpilogue,
    cutlass::gemm::StreamKScheduler
>;

CUTLASS 示例 74 是使用 Stream-K 调度程序的更详细示例。

设备层

用于核函数启动 (包括使用集群支持或在不同设备或 CUDA 流上启动) 的主机端逻辑在设备层中实施。设备层的主要入口点是 cutlass::gemm::device::GemmUniversalAdapter,它将 tg_ 32 核函数封装在一个有状态、可重复使用的句柄中。有状态意味着句柄实例包含核函数需要运行的状态 (即,它管理核函数参数本身) 。可重用意味着同一句柄实例可用于多次使用不同参数调用核函数。

GemmUniversalAdapter 是在 GitHub 上实现的。此示例展示了如何使用 GemmUniversalAdapter 启动核函数:

using GemmHandle = cutlass::gemm::kernel::GemmUniversalAdapter<GemmKernel>;
using Arguments = typename GemmHandle::Arguments;    // surfaced from GemmKernel
Arguments args {
    cutlass::Gemm::kBatched,                   // mode (here batched GEMM)
    cute::make_shape(M, N, K, L),              // problem shape
    {A, stride_A, B, stride_B},                // mainloop args
    {{alpha, beta}, C, stride_C, D, stride_D}, // epilogue args
    make_kernel_hardware_info(device_id),      // hardware info
    {}                                         // scheduler args (here default)
};
GemmHandle gemm;

// Check that problem can run with given shape and hardware
cutlass::Status status;
status = GemmHandle::can_implement(args);
if (status != cutlass::Status::kSuccess) {
  std::cerr << "Problem not supported\n";
  exit(EXIT_FAILURE);
}

// Set up global memory workspace
size_t workspace_size = GemmHandle::get_workspace_size(args);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

// Initialize GEMM handle state from arguments
status = gemm.initialize(args, workspace.get());
if (status != cutlass::Status::kSuccess) {
  std::cerr << "Failed to initialize GEMM kernel\n";
  exit(EXIT_FAILURE);
}

// Launch kernel
status = gemm.run();  // can supply CUDA stream and CUDA host adaptor here
if (status != cutlass::Status::kSuccess) {
  std::cerr << "Failed to launch GEMM kernel\n";
  exit(EXIT_FAILURE);
}

总结

在本文中,我们讨论了如何将 CUTLASS 库从概念上组织为层次结构,其中每层的对象由来自下层的正交对象组成。这种设计可实现高度可定制的 GEMM 实现,并实现高级别的代码重用。在该系列的下一篇也是最后一篇文章中,我们将介绍 CUTLASS 4.0 中引入的更改,尤其是 CuTe Python DSL。

有关更多信息,您可以在 GitHub 上下载软件,阅读我们的文档,或加入我们的开发者论坛进行更深入的讨论。

致谢

感谢 Cris Cecka、Jack Kosaian、Mark Hoemmen、Haicheng Wu 和 Matt Nicely 为本文做出的贡献。

 

标签