生成式人工智能/大语言模型

探索在大模型训练中使用 Megatron-Core 训练框架提高显存使用效率

前言

在大模型训练中,显存(GPU Memory)始终是最稀缺的资源之一。随着模型规模迈入百亿、千亿甚至万亿参数级别,如何在有限显存中“塞下”训练任务,对研发和运维团队都是巨大挑战。在强化学习中训练部分更是会面临超长序列、可变长度序列、训练和 Rollout 显存占用相互干扰等更严峻的挑战。NVIDIA Megatron-Core 作为流行的大模型训练框架,提供了灵活高效的并行化策略;理解这些策略对显存的影响,才能更好地规划训练超参数,在不 OOM(out of memory)的情况下尽可能提升硬件使用效率。

本文将带您了解:

1. 系统梳理影响显存的关键超参数及其作用机理;

2. 拆解静态显存与动态显存的组成,并给出量化公式;

3. 介绍一款与 Megatron-Core 原生流程深度耦合的显存估计器 Megatron Memory Estimator,通过 Qwen3-235B、DeepSeek V3-671B 等真实配置的对比校准,验证估计精度;

5. 通过实际案例展示通过显存估计器优化训练超参数设置和优化训练框架的过程。

图1: 显存估计器用户界面截图 ,以 Qwen3-30B-A3B 为例

显存的组成与衡量方法

通过torch的显存可视化工具捕捉一个典型的模型训练中的显存占用如图2所示。图2跟踪了两个micro batch step的forward和backward。可以简单将显存分为静态显存和动态显存。其中静态部分是forward和backward过程中不变的部分,对应图2中32.0G以下的部分,动态显存对应之上的部分,在forward过程中逐渐增加,并在backward过程中逐渐降低。

图2: 模型在训练过程中的显存,通过torch的显存可视化工具捕捉

静态显存

静态显存主要组成部分包括模型参数、梯度和优化器的所占用的空间,及一些其他的系统开销。其他系统开销包括CUDA context、NCCL 通信缓存、CUDA graph等开销,较难建模,本文将忽略。本文只考虑bf16训练的情况,暂不考虑fp8训练。

对应每张卡上的参数,将包含bf16类型的参数、fp32类型的梯度、Adam优化器(fp32的主参数、fp32一阶动量和fp32二阶动量),共18字节,其中优化器的12字节通常会通过开启分布式优化器选项来降低,全局中重复的参数只会有一份优化器参数。因此每个参数占用的空间是6+12/R,其中R为参数重复的次数。对于MoE模型来说,由于Megatron支持parallel folding,模型的模型会分为稠密部分和MoE部分,其中稠密部分的R为DP*CP,MoE部分的R为EDP=n_GPU/PP/EP/ETP。(下面将介绍各种P及其含义)

动态显存

模型前向传播过程中暂存的中间结果,用于反向传播时计算梯度,通常被称为激活(Activation),绝大部分为bf16数据类型。其数量正比于batch_size*s*d 。(批大小*序列长度*隐藏层维度)

对显存影响的关键超参数

Megatron-Core 支持以下并行、重算维度,组合后可覆盖当下主流大模型训练需求。

并行设置

名称缩写作用典型取值
张量并行TP切分Hidden Dimension `d/TP`1,2,4,8
流水线并行PP切分Layers1,2,4,8,16
虚拟流水线并行VPP细分PP内的Layer Chunks,减少训练中的空泡(bubble)1,2,3,4
上下文并行CP切分Sequence Dimension `s/CP`1,2,4,8,16
专家并行EP切分 experts 1,2,4,8,16
专家张量并行ETPMoE 层的张量并行1
数据并行DP切分输入数据 batch框架自动推导

 ⚠️ 约束关系:`n_GPU / PP = TP×CP×DP = EP×ETP×EDP`,其中 `EDP` 为专家数据并行度。

重算策略

除了完全不重算的情况之外,为了降低动态显存,Megatron-Core 0.14 提供两档重算:

1. 完全重算(full):仅保留层间输入,丢弃层内全部激活;降低显存效果最好,代价是增加约30%的计算量

2. 细粒度重算(selective):支持 `core_attn`, `moe_act`, `layernorm`, `mla_up_proj`, `mlp`, `moe` 六个子模块按需组合重算。

显存估计器的设计

设计原则

  • 基于 Megatron-Core 模块化设计
  • 复用模型构建的代码和模型forward代码
  • 细粒度展示各层级显存占用

设计思路

当前Megatron基于torch实现,所有模块均派生自torch.nn.Module, 构成训练GPT类模型的模块如图3 所示

图3 GPT模型需要的模块结构,图片由DeepWiki生成

我们通过实现一个基类MemEstimator并基于此基类派生出所有需要的模块类,根据每个模块的显存占用特点分别计算其中的参数量和激活量。然后复用Megatron中本身构建模型的代码,实现一个Megatron模拟器,并可以展示出个层次的模块数据量。

这里展示一个例子,以 Qwen3-30B-A3B 在32卡上开启tp=4 ep=32 etp=1 序列长度4k为例,其中n_params和n_act均为个数,例如整体参数量1228.48M(百万),激活量3389.12M(百万)

GPTModel /* n_params=1228.48M n_act=3389.12M */ (

  (embedding): LanguageModelEmbedding /* n_params=74.19M n_act=8.00M */ (

    (word_embeddings): VocabParallelEmbedding /* n_params=74.19M n_act=8.00M */ ()

    (embedding_dropout): Dropout /* n_params=0.00M n_act=0.00M */ ()

  )

  (decoder): TransformerBlock /* n_params=1080.11M n_act=2936.00M */ (

    (layers): ModuleList /* n_params=1080.11M n_act=2928.00M */ (

      (0-47): 48 x TransformerLayer /* n_params=22.50M n_act=61.00M */ (

        (input_layernorm): IdentityOp /* n_params=0.00M n_act=0.00M */ ()

        (self_attention): SelfAttention /* n_params=4.50M n_act=17.00M */ (

          (core_attention): TEDotProductAttention /* n_params=0.00M n_act=1.00M */ ()

          (linear_qkv): ColumnParallelLinear /* n_params=2.50M n_act=0.00M */ ()

          (q_layernorm): RMSNorm /* n_params=0.00M n_act=0.00M */ ()

          (k_layernorm): RMSNorm /* n_params=0.00M n_act=0.00M */ ()

          (linear_proj): RowParallelLinear /* n_params=2.00M n_act=4.00M */ ()

        )

        (self_attn_bda): GetBiasDropoutAdd /* n_params=0.00M n_act=2.00M */ ()

        (pre_cross_attn_layernorm): IdentityOp /* n_params=0.00M n_act=0.00M */ ()

        (cross_attention): IdentityOp /* n_params=0.00M n_act=0.00M */ ()

        (cross_attn_bda): IdentityOp /* n_params=0.00M n_act=0.00M */ ()

        (pre_mlp_layernorm): RMSNorm /* n_params=0.00M n_act=2.00M */ ()

        (mlp): MoELayer /* n_params=18.00M n_act=38.00M */ (

          (router): TopKRouter /* n_params=0.00M n_act=4.00M */ ()

          (experts): TEGroupedMLP /* n_params=18.00M n_act=18.00M */ (

            (linear_fc1): TEColumnParallelGroupedLinear /* n_params=12.00M n_act=12.00M */ ()

            (linear_fc2): TERowParallelGroupedLinear /* n_params=6.00M n_act=16.00M */ ()

          )

        )

        (mlp_bda): GetBiasDropoutAdd /* n_params=0.00M n_act=2.00M */ ()

      )

    )

    (final_layernorm): RMSNorm /* n_params=0.00M n_act=8.00M */ ()

  )

  (output_layer): ColumnParallelLinear /* n_params=74.19M n_act=148.38M */ ()

)

根据每张卡上每个模型的参数量和激活量,分别乘上每个参数和每个激活对应的字节数(6+12/R和2)即可获得每张卡上的显存占用。

结果分析

与流行模型配置的对比和校准

选取两个时下流行的大模型,使用流行的配置开启训练,并对比显存估计的结果与真实的显存占用。对QWen3 235B和DeepSeek v3模型均使用256张显卡,开启完全重计算,流水线并行PP=8,不同之处在于专家并行EP分别为8和32,虚拟流水线并行VPP分别为2和4. 注意测试时会强制保证专家路由均衡,以排除专家负载不均衡的影响,实际训练中激活显存峰值将会随不均衡程度浮动。两个模型的实际峰值与估计峰值相差均小于2GB。表格中实际静态显存和实际峰值显存均通过torch的显存分析工具采集。

Qwen3-235B-A22B pp8vpp2ep8 256GPU full_recompute
PP_rank估计静态估计峰值实际静态实际峰值
036.24136.742.6
135.540.33642.1
235.5403641.7
335.539.63641.4
435.539.435.841.4
535.539.235.841
635.538.835.840.6
736.242.636.642.7
max42.642.7
DeepSeek V3 671B pp8vpp4ep32 256GPU full_recompute
PP_rank估计静态估计峰值实际静态实际峰值
052.659.25461.8
150.657.75260.3
258.265.758.667.1
358.265.558.966.9
458.265.458.966.8
558.265.358.966.7
658.26558.966.6
749.25849.960.1
max65.767.1

动态显存分析

  • TP: 对于流程中所有的embedding b*s*d, 其中的d会被切分成d/tp_size份,所以每张卡的embedding的显存占用是b*s*d/tp_size。
  • CP: 对于流程中所有的embedding b*s*d, 其中的s会被切分成s/cp_size份,所以每张卡的activation的显存占用是b*s*d/cp_size。
  • EP: EP 无法降低动态显存,但由于不同expert的负载不均衡,越大的EP越有可能将token路由到某些热点expert,导致这些expert的显存峰值放大。
  • ETP:ETP无法降低动态显存,在专家数量多的时候通常设置为1,专家数量少的时候也可以进一步设置以进一步降低静态显存。
  • PP: PP的情况较为复杂,笼统得说PP无法降低动态显存。因为当开启PP时,会通过1f1b的调度策略有一个预热的过程,对于n个pprank,虽然每个卡上的模型只有1/n份,对应一个batch也只有1/n的激活量,但是峰值显存上不同的 PP_rank 分别会有n, n-1, … , 1个batch的激活量,对于PPrnk0来说,其峰值显存是 b*s*d*(1/n)*n =b*s*d 未发生变化
图4: 流水线并行1F1B调度示意图
  • VPP情况相比PP更复杂一些,对于pp_size=n,vpp_size=m的情况,模型被分为了n*m个chunk,每个pprank上会有m个chunk,显存峰值出现在PP_rank0,其上有m个chunk,对于vpp_chunk0,有2*n-1个microbatch,对于其他vpp_chunk,有n个micro_batch,假设各chunk的的参数量均匀则一共有 2*n-1 + n*(m-1) = m*n+n-1个microbatch,对于pprank0来说,其峰值显存是 b*s*d*(1/m/n)*(m*n+n-1),会略大于单纯开pp的情况
图5: 虚拟流水线并行 Interleave 1F1B调度示意图

总结:只有TP和CP能降低激活量, EP和ETP只会改变集群内激活值的分布,无法降低激活量,PP和VPP由于1f1b的流水线预热机制,无法有效降低峰值激活量。

注:为什么EP和ETP无法降低动态显存?ETP和EP只能改变activation在整个训练集群中的分布,但是整个集群中的activation总量是不变的。考虑专家路由均匀时,EP让每张卡上的专家数变为1/EP倍,但是每个专家路由到的token数量变为EP倍;ETP让每张卡上的hidden dim变为1/ETP倍,但每张卡上路由到的token数量也会变为ETP倍。

激活重算/激活卸载

对每一部分激活量,可以通过卸载到cpu或者重算的方式来降低显存。Megatron-Core0.13当前对卸载的支持还在开发中,但重算已经支持。包括完全重算和细粒度重算两种,完全重算是将每层transformer的activation丢弃掉仅保留层与层之间的activation,细粒度重算将特定的activation进行重算,图示参考图6。

图6 细粒度重算模块,以DeepSeek V3为例

这里基于Qwen3-30B-A3B EP8 32卡和Moonligt-16B-A3B EP8 8卡的配置,序列长度4k,micro_batch_size=1,来分析不同重算下的激活显存占用。

Qwen3 30B-A3B EP8 32GPUmoonlight 8GPU EP8
no_recompute26.319
full4.24.2
core_attn2617.7
moe_act19.212.7
layernorm25.216.9
moe11.77.73
mlp/17.4
mla_up_proj/17.7

通过分析不同重算方式,完全重算对激活量的释放最彻底,在超大模型上经常使用。

  • core_attn由于当前实现中已经使用了融合算子,实际可降低的空间很有限,实践中通常不会启用。
  • moe_act会降低可观的显存,且不需要在重算过程中重复MoE的all2all通信,实践中推荐使用。
  • layernorm虽然降低的显存不多,但由于计算量很小,实践中推荐使用
  • moe模块将整个MoE进行重算,和moe_act重算互斥,降低的显存也较为可观,但重算过程中将重新执行MoE的all2all通信,当序列较长,attention计算占比较高的时候推荐打开。
  • mlp属于模型的稠密部分,通常不推荐开启重算
  • mla_up_proj 在MLA中显存降低能力近似于core_attn,但重算代价小很多,推荐打开。

静态显存分析

静态显存分为2个部分,稠密部分(dense)和MoE部分,

n_param_dense * (6+12/R_dense) + n_param_moe * (6+12/R_moe)

其中R_dense= DP*CP =n_gpu/PP/TP, R_moe=EDP = n_gpu/PP/EP/ETP。

  • 其中TP/ETP 对模型的影响是均匀的,TP=n时,n_param_dense=n_param_dense/n
  • PP切分由于首尾会有embedding层,所以不会完全均匀,但大体上是均匀的,这里假设均匀,TP、EP、ETP的切分都是均匀的

所以实际显存是

param_dense/TP/PP * (6+12/(n_gpu/PP/TP)) + param_moe/PP/EP/ETP * (6+12/(n_gpu/PP/EP/ETP))

MoE部分往往占据显存的大头,因此从显存角度分析推荐EP设置大一些。

优化器卸载

Megatron-Core0.13 现已支持通过cpu分担optimizer的显存占用,并可以通过超参数设置卸载到cpu的比例,每个参数的6字节(bf16参数,fp32梯度)无法卸载,其余可以卸载。

实践中veRL强化学习框架通过CPU optimizer可以实现32卡训练qwen3 235B或96卡训练deepseek 671B。参见https://verl.readthedocs.io/en/latest/perf/dpsk.html

用例分析

可以通过https://huggingface.co/spaces/ISEEKYAN/megatron_memory_estimator访问显存估计器的web界面,或者基于源码通过命令行工具使用https://github.com/ISEEKYAN/mbridge/tree/main/memory_estimator

Megatron用户:逐步压缩显存的实践

用户目标在32张80GB显存的GPU上实现Qwen3-30B-A3B的强化学习训练,序列长度是10k,用户使用显存估计器做对并行配置进行摸底

编号配置静态显存/GB动态显存/GB总峰值/GB
1无并行181.265.96257.16
2TP453.2930.6193.91
3PP454.3857.34121.72
4EP839.4565.96115.41
5EP3224.2665.96100.22
6TP2 EP835.1833.0278.2
7CP2 EP839.4533.4582.9
8EP8 full_recompute39.4510.6160.06
9TP2 EP8 selective_recompute(moe_act+layernorm)35.1823.1868.35

注:1. 开启EP时ETP均设置为1

2. 显存估计假设专家路由均匀,实际专家不均匀会导致动态显存有一定浮动

讨论:

  1. 编号1是完全不开启并行和重算时的配置,静态显存高达181GB,在此基础上2、3、4、5分别开启了TP、PP、机内EP、机间EP,观察静态显存发现均有可观降幅,但只有2号因为开启TP4带来了动态显存的有效降低。
  2. 从直觉上分析,如果显存允许的话,并行程度越低,通信开销越小,模型训练效率越高,这个直觉大部分情况下与实际情况相符,此处不展开讨论。但以此为前提,虽然仍然可以继续进行更高并行度的切分,但更高并行度预期将降低训练效率。
  3. EP8与EP32的权衡,已知当前主流的硬件为单节点8块GPU,节点内GPU通过nvlink互联,节点间通过IB/以太网互联。EP的实现原理是通过all2all对token进行专家分发,计算后再通过all2all分发回去,跨机EP速率将大幅机内EP的。在未启用all2all通信计算掩盖技术(此处不展开讨论)的情况下应该优先使用EP8。
  4. 6、7号设置通过TP/CP/EP的组合实现了接近可运行的配置,但实际上由于专家不均衡、其他开销等原因在实际训练过程中将会遭遇CUDA out of memory。
  5. 8号设置实验通过完全重计算实现动态显存的降低,9号设置则是通过细粒度重算在6号设置的基础上进一步降低动态显存。

通过Megatron预训练实验模拟强化学习的训练性能,在32*Hopper架构 80GB GPU 的实验中,8号实验的TFLOPS为160,9号实验的TFLOPS为167。在笔者的大量实验中,此模型此硬件条件情况下8号与9号取得的性能为最佳。

Megatron框架开发者:定位显存热点

Megatron开发者可以通过显存分析工具的breakdown视角,详细察看每个模块的激活量,通过权衡激活量和计算量寻找性价比高(激活量/计算量)的模块的激活为其开发进行重算或卸载功能。

总结

本文分析了基于 Megatron Core 0.13 的大模型训练过程中显存情况,展开介绍了影响显存占用的各种并行与重算对应的作用原理,并展示了一个为 Megatron Core 定制的显存估计器,方便用户模拟在各种超参数组合情况下的显存占用情况。希望本文可以帮助您在大模型训练过程中找到性能与显存的最佳平衡点。

标签