大语言模型 (LLM) 的演变标志着其处理和生成文本的能力有了显著提升。在这些发展中,上下文长度的概念 (模型可以处理的单个输入样本中的 token 数量) 已成为定义这些模型在不同应用中可以实现的目标的关键因素。
例如,这些模型允许处理视频输入、总结冗长的文档、在多轮对话中保持一致性、通过思维链进行推理,以及使用大量示例执行详细的上下文学习。在视频生成和理解、法律文档分析、低资源语言翻译等必须保留和利用全面上下文的场景中,以及在使用 AI 助手时,这种扩展功能尤为重要。
在本文中,我们将探讨长上下文 LLM 的技术基础,以及如何有效训练它们的技巧。我们规划了需求和挑战,以及如何使用 NVIDIA NeMo 框架通过各种优化技术来解决这些问题,从而提供高吞吐量的高效训练。
对扩展上下文长度的需求和相关挑战
随着越来越多的多模态用例涌现,处理长视频内容需要模型同时处理数千帧,同时保持时间一致性。具有扩展上下文长度的模型 (例如支持多达 100 万个 token 的模型) 可以在大量视觉输入帧中保留详细的时间信息。
针对复杂推理优化的模型 (例如 DeepSeek-R1 和 Llama Nemotron) 依靠扩展上下文来通过链式推理来解决多步问题。如果没有足够的上下文窗口大小,这些模型将截断关键逻辑路径,从而导致错误。DeepSeek-R1 的上下文长度超过 128K,而 Llama 4 则将上下文长度的界限扩大到超过 1000 万个 token。
使用扩展上下文长度训练 LLM 会带来重大的技术障碍,尤其是在内存管理方面。基于 Transformer 的 LLM 会随着序列长度的增加 (如果使用闪光注意力,则时间复杂度为 O (n^ 2)) 进行计算扩展。这使得使用超长上下文进行训练的成本非常高昂。
借助 NVIDIA NeMo 实现长上下文训练
作为开发者,您可以在长上下文训练期间通过以下方式改进内存管理:
- 激活重新计算
- 上下文并行
- 激活卸载
NeMo 框架通过先进的实施实现了这些功能,并为热门社区模型提供长上下文方法。
激活重新计算
在训练期间存储中间激活函数所需的显存会随着序列长度和模型深度的增加而增加,甚至会迅速超过最大 GPU 的容量。
NeMo 框架支持激活重新计算,这是一种可解决此瓶颈的内存节省技术。训练过程没有存储反向传播所需的所有中间激活函数,而是选择性地仅检查子集 (例如每个 Transformer 层的输入) 。在向后传递期间计算梯度时,系统会通过重新执行部分向前传递来即时重新计算所需的激活值。
通过仅存储一小部分激活函数并重新计算其余部分,激活函数重新计算可显著减少内存占用。这对于在有限的 GPU 显存中拟合超长序列和大批量大小至关重要。随着上下文长度的增加,激活内存甚至可能超过模型权重和优化器状态所需的内存。重新计算允许将训练扩展到更长的环境,同时保持成本效益。

上下文并行
虽然激活重新计算通过在向后传递期间丢弃和重新计算激活函数来有效减少内存使用量,但这种方法会引入大量的重新计算用度 (通常每个训练步骤高达 30%) ,从而减缓训练过程。
上下文并行 (CP) 提供了一种更高效的替代方案。CP 在 NeMo 框架中实现,并在针对近无限上下文的块状 Transformer 的 Ring Attention 中引入,它将序列维度分割到多个 GPU。每个 GPU 仅处理和存储序列的一个块,从而能够在不超过显存限制的情况下训练具有更长输入序列的模型。
CP 与序列并行 (SP) 的不同之处在于,SP 仅分割几个选定层 (例如 LayerNorm 和 Dropout) 的序列,而 CP 分割所有层的序列,通信成本通常与计算重叠。这使 CP 能够克服单 GPU 显存容量的限制,同时避免重新计算用度。这种方法为在长序列上训练大型模型提供了可扩展且高计算效率的解决方案,使其成为大规模深度学习时代的强大工具。
上下文并行的工作原理
总的来说,CP 允许标准模组 (如 Linear、LayerNorm 和其他逐点运算) 在不进行修改的情况下运行。这些层不需要令牌间通信,因此自然支持分割序列布局。对于注意力机制,每个令牌的查询 (Q) 必须关注同一序列中所有令牌的键 (K) 和值 (V) 。
在前向传递期间,CP 在每个 GPU 上为其本地序列块存储 KV,在后向传递期间根据需要再次收集 KV 张量,从而更高效地利用内存。所涉及的通信集合 ( all-gather 和 reduce-scatter) 作为环形拓扑内的优化点对点通信来实现。交换 KV 还可以利用 MQA/ GQA 来减少通信量,因为它们只有一个或几个 KV 注意力头。
例如,在图 2 中,GPU0 和 GPU1 形成一个张量并行组,GPU0 和 GPU2 形成一个上下文并行组,相互交换 KV 对。GPU1 和 GPU3 之间也会执行相同的操作。CP 通过以下方式进一步提高性能:
- 利用最新的开源软件 (OSS) 和 NVIDIA cuDNN 闪存注意力内核,实现更快、更节省内存的注意力计算。
- 消除由低三角形因果遮罩引起的不必要计算,并在 GPU 之间实现最佳负载平衡。

AG/ RS:全聚在正向,并减少向后散射。RS/ AG:归约散射在前向,全聚在后向,/ AG:无操作在前向,全聚在后向。
CP 基准测试
图 3 显示了在 Llama 3 8B 上,上下文并行的序列长度在 16K 到 100 万个序列之间的效率。从 32K 及以上的序列长度开始,可以看到使用 CP 会产生更高的 teraFLOPS。序列长度为 100 万时,必须使用 CP 才能运行模型。请注意,尽管序列长度增加,teraFLOPS 仍开始趋于平稳,这表明 CP 实现是以最小的开销高效完成的。

激活卸载
除 CPU 卸载外,高效管理 GPU 显存的另一项技术是 CPU 卸载。CPU 卸载的工作原理是将中间激活函数和非活动权重卸载到 CPU 内存,从而减少 GPU 显存占用峰值。NeMo 框架支持 Transformer 层卸载,允许用户配置使用此策略的层数。在前向传递期间,NeMo Framework 会在最佳时间卸载激活函数,而在后向传递期间,它会根据需要重新加载这些函数。
这种动态卸载机制有助于进一步扩展每个 GPU 的内存容量,尤其是在训练非常深度的模型时,使其成为大型模型训练中上下文并行的重要补充。
总结
虽然您可以实施各种技术来改进模型长上下文长度,但在进行优化时,最好考虑模型架构和硬件选择。
NVIDIA NeMo 框架是适用于 LLM、语音模型和多模态模型的 GPU 加速训练框架,可提供经过测试的长上下文模型训练方法。这些 recipe 可在 NeMo Framework LLM recipe 目录中找到。现有方案包括 Llama 3 8B 和 70B、Mixtral 8x7B 以及 Nemotron 4 15B 和 22B,序列长度分别为 16K、64K 和 128K。
您还可以通过预训练检查点扩展上下文窗口。有关更多信息,请参阅长上下文 recipe 文档。