对话式人工智能

Hymba 混合头架构提高小型语言模型性能

Transformer 及其基于注意力的架构,凭借强大的性能、并行化功能以及通过键值 (KV) 缓存进行的长期召回,已成为语言模型 (LM) 的主要选择。然而,其二次计算成本和高内存需求带来了效率挑战。相比之下,Mamba 和 Mamba-2 等状态空间模型 (SSM) 可提供恒定的复杂性和高效的硬件优化,但难以处理记忆回收任务,从而影响其在常规基准测试中的性能。

NVIDIA 研究人员最近提出了 Hymba ,这是一系列小语言模型 (SLMs),采用混合 head 并行架构,将 Transformer Attention 机制与 SSMs 集成,以提高效率和性能。在 Hymba 中,attention heads 可实现高分辨率召回,而 SSM heads 可实现高效的上下文摘要。

Hymba 的新型架构揭示了以下几点见解:

  1. 注意力开销: 超过 50% 的注意力计算可以被更便宜的 SSM 计算所取代。
  2. 本地注意力优势: 大多数全球注意力可以被本地注意力取代,而不会影响一般任务和召回密集型任务的性能,得益于 SSM heads 汇总的全局信息。
  3. KV 缓存冗余: 键值缓存在 heads 和层之间高度相关,因此可以在 heads (组查询注意力)和层(跨层 KV 缓存共享)之间共享。
  4. Softmax 注意力限制: 注意力机制的总和被限制为 1,从而限制了稀疏性和灵活性。我们引入了可学习的元令牌,这些元令牌在提示之前提供,用于存储关键信息,并减轻与注意力机制相关的“强制关注”的负担。

本文展示了 Hymba 1.5B 与类似大小的先进开源模型 (包括 Llama 3.2 1B、OpenELM 1B、Phi 1.5、SmolLM2 1.7B、Danube2 1.8B 和 Qwen2.5 1.5B) 相比,表现良好。与类似大小的 Transformer 模型相比,Hymba 还实现了更高的吞吐量,存储缓存所需的内存减少了 10 倍。

Hymba 1.5B 已发布至 Hugging Face 集合和 GitHub

Hymba 15 亿性能 

图 1 比较了 Hymba 1.5B 与次 2B 模型(Llama 3.2 1B、OpenELM 1B、Phi 1.5、SmolLM2 1.7B、Danube2 1.8B、Qwen2.5 1.5B)在平均任务准确性、相对于序列长度的缓存大小(MB)和吞吐量(tok/秒)方面的表现。

A figure showing three performance metrics comparing seven different AI language models in terms of average accuracy, cache size (MB) relative to sequence length, and throughput (tok/sec).
图 1、Hymba 1.5B 基准与低于 2B 模型的性能比较

在这组实验中,任务包括 MMLU、ARC-C、ARC-E、PIQA、Hellaswag、Winogrande 和 SQuAD-C。使用 PyTorch 在序列长度为 8K、批量大小为 128 的 NVIDIA A100 GPU 上测量吞吐量。对于在吞吐量测量期间遇到内存不足(OOM)问题的模型,批量大小减半,直到 OOM 得到解决,以测量在不使用 OOM 时可实现的最大吞吐量。

Hymba 模型设计 

引入 Mamba 等 SSM 是为了解决 Transformer 的二次复杂性和推理时间较大的 KV 缓存问题。然而,由于其低分辨率内存,SSM 在内存召回和性能方面遇到困难。为了克服这些限制,我们在表 1 中提出了开发高效、高性能小型语言模型的路线图。

配置 常识推理 (*) 召回 (%)* 吞吐量 (令牌/秒)* 缓存大小 (MB) 设计理由
Ablations 在 300M 模型大小和 100B 训练令牌上的消融
Transformer (Llama) 44.08 39.98 721.1 414.7 准确召回,低效
状态空间模型 (Mamba) 42.98 19.23 4720.8 1.9 高效且召回不准确
A. + Attention heads (顺序) 44.07 45.16 776.3 156.3 增强召回功能
B. + Multi-heads (并行) 45.19 49.90 876.7 148.2 更好地平衡两个模块
C. 本地/全球 attention 44.56% 48.79 2399.7 41.2 提高计算/缓存效率
D. + KV 缓存共享 45.16 48.04 2756.5 39.4 缓存效率
E. + Meta-tokens  45.59% 51.79 2695.8 40.0 学习内存初始化
扩展至 1.5 亿模型大小和 1.5 万亿训练令牌
F. 规模/数据 60.56% 64.15 664.1 78.6% 进一步提高任务性能
G. 扩展上下文长度 (2K→8K) 60.64% 68.79 664.1 78.6% 改进 multishot 和召回任务
表 1、Hymba 模型的设计路线图

融合混合模组 

根据消融研究,在混合 head 模块中并行融合 attention heads 和 SSM heads 的表现优于顺序堆叠。Hymba 在混合 head 模块中并行融合 attention heads 和 SSM heads,使两个 heads 能够同时处理相同的信息。此架构可提高推理和召回准确性。

A diagram showing the architecture of a dual-path attention mechanism. The flow starts with an Input Projection, leading to Latent Feature extraction which splits into two parallel paths. The upper path (in blue) contains SSM Feature processing through SSM Heads and Gate Normalization. The lower path (in red) processes Attention Features through Attention Heads and Gate Normalization. Both paths converge at a Mean operation before final Output Projection. Arrows indicate the flow of data through the system.
图 2、Hymba 中的混合 head模块

效率和 KV 缓存优化 

Attention heads 可提高任务性能,但会增加 KV 缓存需求并降低吞吐量。为缓解此问题,Hymba 通过结合本地和全局 attention 并采用跨层 KV 缓存共享来优化混合 head 模块,从而将吞吐量提高了 3 倍,并在不牺牲性能的情况下将缓存减少了近 4 倍。

A diagram showing the architecture of a neural network model with Hymba Blocks. The model flows from left to right, starting with an Embedding layer, followed by alternating Hymba Blocks with Full Attention (in red) and SWA (in blue). The blocks are connected with KV sharing every 2 layers, shown in dotted green boxes labeled 'Repeat (N-3)/2'. Below the main flow, there's a detailed view of a module containing Layer norm, Hybrid-head module, another Layer norm, and FFN components. The diagram ends with an LM Head layer on the right.
图 3、Hymba 模型架构

Meta-tokens 

一组包含 128 个预训练嵌入的输入,可用作学习缓存初始化,以增强对相关信息的关注。这些 token 具有双重用途:

  • 充当后盾令牌,有效地重新分配 attention,从而减轻 attention 流失。
  • 封装压缩世界知识
A diagram illustrating the Fading Memory architecture from SSM (State Space Model). The image shows three layers: At the top is a blue rectangular box labeled 'Fading Memory (From SSM)'. Below it are seven gray input tokens arranged horizontally. At the bottom are two sets of memory blocks: on the left are two green blocks labeled 'Meta Memory (Meta Tokens)', and on the right are three red blocks labeled 'Snapshot Memory (From Attn)'. Green arrows connect the Meta Memory to the input tokens, while red arrows connect the Snapshot Memory to the rightmost input tokens. A blue arrow loops back from the Fading Memory box to itself.
图 4、从内存方面解读 Hymba

模型分析 

本节介绍了在相同训练设置下跨不同架构的苹果对比。然后,我们在不同的预训练模型中可视化 SSM 和 Attention 的 attention 图。最后,我们通过剪枝对 Hymba 执行头部重要性分析。本节中的所有分析有助于说明 Hymba 的设计选择如何有效以及为何有效。

苹果与苹果对比 

我们对 Hymba、纯 Mamba2、Mamba2 with FFN、Llama3 风格和 Samba 风格(Mamba-FFN-Attn-FFN)架构进行了苹果到苹果的比较。所有模型都有 1 亿个参数,并使用完全相同的训练方法从 SmolLM-Corpus 中针对 100 亿个令牌从头开始进行训练。所有结果均通过使用 Hugging Face 模型上的零样本设置的 lm-evaluation-harness 获得。Hymba 在常识推理以及问答和召回密集型任务方面表现出色。

表 2 比较了用于语言建模以及召回密集型和常识推理任务的各种模型架构,其中 Hymba 实现了跨指标的强大性能。Hymba 在语言任务中的困惑度最低(Wiki 为 18.62,LMB 为 10.38),并且在召回密集型任务中表现出色,尤其是在 SWDE(54.29)和 SQuAD-C(44.71)中,从而在此类别中获得最高平均分(49.50)。

模型 语言 (PPL) 召回密集型 (%) 常识推理 (*)
Mamba2 15.88 43.34 52.52%
Mamba2 w/ FFN 17.43 28.92 51.14
Llama3 16.19 47.33 52.82%
Samba 16.28 3617 52.83
Hymba 14.5 49.5% 54.57
表 2、在相同设置下使用 100 亿个令牌进行训练的架构对比

在常识推理和问答方面,Hymba 在大多数任务中的表现优于其他模型,例如 SIQA (31.76) 和 TruthfulQA (31.64),平均分为 54.57,略高于 Llama3 和 Mamba2。总的来说,Hymba 是一款出色的平衡模型,在效率和任务性能方面表现出色,适用于各种类别。

Attention 贴图可视化 

我们将 attention 贴图中的元素进一步分为四种类型:

  1. 元: 从所有真实令牌到元令牌的 attention 分数。此类别反映了模型对元令牌的偏好。在注意力图中,如果模型具有元令牌,它们通常位于前几列(例如,Hymba 的 128 列)。
  2. BOS: 从所有真实令牌到序列开始令牌的 attention 分数。在 attention 图中,它们通常位于元令牌之后的第一列中。
  3. Self: 从所有真实令牌到自身的 attention 数。在 attention 映射中,它们通常位于对角线上。
  4. 交叉: 从所有真实令牌到其他真实令牌的 attention 数。在 attention 地图中,它们通常位于对角线外区域。

Hymba 的 attention 模式与 Vanilla Transformer 明显不同。在 Vanilla Transformer 中,attention 得分更集中在 BOS 上,这与 Attention Sink 中的结果一致。此外,Vanilla Transformer 的 Self Attention 得分比例也较高。在 Hymba 中,meta-tokens、attention heads 和 SSM heads 相辅相成,从而在不同类型的 tokens 之间更平衡地分配注意力得分。

具体来说,meta-tokens 可分流 BOS 的 attention 数,使模型能够更专注于真实标记。SSM heads 对全局上下文进行总结,更侧重于当前令牌(Self attention scores)。另一方面,attention heads 对 Self 和 BOS 令牌的关注度较低,而对其他令牌(即 Cross attention scores)的关注度较高。这表明 Hymba 的混合 head 设计可以有效平衡不同类型令牌的 attention 分布,从而有可能带来更好的性能。

A diagram showing the composition of the Hymba attention mechanism. It consists of three components that are added together: Meta Tokens (shown as a vertical green stripe on the left), Sliding Window Attention (displayed as a diagonal green band), and SSM (Mamba) (represented as a triangular green gradient). These three patterns combine to form the final Hymba pattern on the right, which shows a triangular area filled with green squares of varying intensity. Each component is displayed in a square grid format, and the combination is shown using plus signs between the components and an equals sign before the final result.
图 5、Hymba 的 attention图示意图 (元令牌、滑动窗口attention和 Mamba 贡献的组合)
A comparative visualization showing attention patterns across different language models. The image consists of three main parts: 1) Three attention heatmaps for Llama 3.2 3B and Hymba 1.5B models, showing diagonal patterns in purple, yellow, and blue colors. 2) A grid diagram showing BOS (Beginning of Sequence) token connections with Meta and Cross sections marked. 3) Three horizontal stacked bar charts comparing percentage distributions of Meta, BOS, Cross, and Self attention patterns across Llama 3.2 3B and two variants of Hymba models, with percentages clearly labeled in different colors.
图 6、Llama 3.2 3B 和 Hymba 1.5B 中不同类别的 attention 得分总和。

主管重要性分析 

我们通过移除 attention 和 SSM heads 并记录最终精度来分析每层中的相对重要性。我们的分析揭示了以下内容:

  • 同一层中的 attention/SSM heads的相对重要性会根据输入进行自适应,并且会因任务而异,这表明它们在处理各种输入时可以发挥不同的作用。
  • 第一层中的 SSM heads 对于语言建模至关重要,移除它会导致准确度大幅下降到随机猜测的水平。
  • 通常,移除一个attention/SSM heads 会导致 Hellaswag 的平均准确率分别下降 0.24%/1.1%。
A line graph comparing the Hellswag Accuracy (y-axis ranging from 0.45 to 0.50) across 32 different layers (x-axis). The graph shows three elements: a horizontal dashed line labeled Orig Model at approximately 0.493, and two sets of bars in blue and orange representing Remove Attn and Remove SSM, respectively. The bars fluctuate slightly above and below the original model line, with most values falling between 0.47 and 0.495. The graph compares the impact of removing attention mechanisms versus SSM components at different layers of the model.
图 7、移除每层的 Attention 或 SSM 头后,使用 Hellaswag 的 1K 个样本测量得出的准确率。

模型架构和训练最佳实践 

本节概述 Hymba 1.5B Base 和 Hymba 1.5B Instruct 的关键架构决策和训练方法。

模型架构 

  • 混合架构: Mamba 擅长总结,通常更专注于当前 token,而 attention 更精确,可用作快照内存。并行组合可以合并这些优势,但标准顺序融合则不然。我们在 SSM 和 attention heads 之间选择了 5:1 的参数比。
  • 滑窗法 attention heads:全 attention heads 被保留在三个层级(第一层、最后一层和中间层),其余 90%的层级使用滑窗法 attention heads。
  • 跨层 KV 缓存共享 :在每两个连续的 attention 层之间实现。除了在 heads 之间共享 GQA KV 缓存之外,还完成了这一过程。
  • 元令牌: 这些 128 个令牌无需监督即可学习,有助于避免大语言模型 (LLMs) 中的熵崩溃问题,并缓解 attention 汇集现象。此外,模型会将一般知识存储在这些令牌中。

训练最佳实践 

  • 预训练: 我们选择了两个阶段的基础模型训练。第 1 阶段保持恒定的高学习率,并使用较少的过滤大型语料库数据。然后,使用高质量数据将连续学习率衰减至 1e-5。这种方法支持持续训练和恢复第 1 阶段。
  • 指令微调: 指令模型调优分三个阶段执行。首先,SFT-1 通过对代码、数学、函数调用、角色扮演和其他特定任务数据进行训练,为模型提供强大的推理能力。其次,SFT-2 教会模型遵循人类指令。最后,利用 DPO 使模型与人类偏好保持一致,并提高模型的安全性。
Training pipeline for the Hymba model family divided into five sections that read (left to right) General pretraining, LR annealing, SFT-1, SFT-2, and DPO.
图 8、适用于 Hymba 模型系列的训练管线。

性能和效率评估 

Hymba 1.5B 模型仅使用 1.5T 预训练令牌,在所有小型语言模型中表现最佳,并实现比所有基于 Transformer 的语言模型更高的吞吐量和缓存效率。

例如,在与最强基准 Qwen2.5(使用 13 倍以上的 tokens 进行预训练)进行基准测试时,Hymba 1.5B 实现了 1.55%的平均准确性提升、1.41 倍的吞吐量和 2.90 倍的缓存效率。与使用少于 2T 的 tokens 训练的最强小型 LM(即 h2o-danube2)相比,我们的方法实现了 5.41%的平均准确性提升、2.45 倍的吞吐量和 6.23 倍的缓存效率。

模型 #参数 训练令牌 令牌/秒 缓存 (MB) MMLU 5-shot ARC-E 0-shot ARC-C 0-shot PIQA 0-shot Wino0-shot Hella0-shot SQuAD-C 1-shot 平均
开放
ELM-1
11 亿 1.5 T 249 346 27.06 62.37 19.54 74.76 61.8 48.37 45.38 48.57
Rene
v0.1
13 亿 1.5 T 800 113 32.94 67.05 31.06 76.49 62.75 51.16 48.36 52.83
Phi
1.5
13 亿 0.15  241 1573 42.56 76.18 44.71 76.56 72.85 48 30.09 55.85
Smol
LM
17 亿 1T 238 1573 27.06 76.47 43.43 75.79 60.93 49.58 45.81 54.15
Cosmo 18 亿 .2T 244 1573 26.1 62.42 32.94 71.76 55.8 42.9 38.51 47.2
h20
danube2
18 亿 2T 271 492 40.05 70.66 33.19 76.01 66.93 53.7 49.03 55.65
Llama 3.2 1B 12 亿 9T 535 262 32.12 65.53 31.39 74.43 60.69 47.72 40.18 50.29
Qwen
2.5
15 亿 18T 469 229 60.92 75.51 41.21 75.79 63.38 50.2 49.53 59.51
AMD
OLMo
12 亿 1.3 T 387 1049 26.93 65.91 31.57 74.92 61.64 47.3 33.71 48.85
Smol
LM2
17 亿 11T 238 1573 50.29 77.78 44.71 77.09 66.38 53.55 50.5 60.04
Llama
32 3B
30 亿 9T 191 918 56.03 74.54 42.32 76.66 69.85 55.29 43.46 59.74
                         
Hymba 15 亿 1.5 T 664 79 51.19 76.94 45.9 77.31 66.61 53.55 55.93 61.06
表 2、Hymba 1.5 B 基础模型结果

指令模型 

在所有任务中,Hymba 1.5B Instruct 模型的平均性能最高,比之前的先进模型 Qwen 2.5 Instruct 约高出 2%。具体来说,Hymba 1.5B 模型在 GSM8K、GPQA 和 BFCLv2 中的得分分别为 58.76、31.03 和 46.40,优于所有其他模型。这些结果表明 Hymba 1.5B 模型在复杂推理能力方面具有优势,特别是在需要复杂推理能力的领域。

模型 #参数 MMLU ↑ IFEval ↑ GSM8K ↑ GPQA ↑ BFCLv2 ↑ 平均 ↑
SmolLM 17 亿 27.80 25.16 1.36 25.67 20.00
OpenELM 11 亿 25.65 6.25 56.03 21.62 27.39
Llama 3.2 12 亿 44.41 58.92 42.99 24.11 20.27 38.14
Qwen2.5 15 亿 59.73 46.78 56.03 30.13 43.85 47.30
SmolLM2 17 亿 49.11 55.06 47.68 29.24 22.83 40.78
Hymba 15 亿 15 亿 52.79 57.14 58.76 31.03 46.40 49.22
表 3、Hymba 1.5 B 指令模型结果

结束语 

新的 Hymba 系列小型 LM 采用混合 head 架构,将 attention heads 的高分辨率召回功能与 SSM heads 的高效上下文摘要相结合。为进一步优化 Hymba 的性能,我们引入了可学习的元令牌,用作 attention 和 SSM heads 的学习缓存,从而增强模型对显著信息的关注。通过 Hymba 的路线图、全面评估和消融研究,Hymba 在各种任务中设定了新的 state-of-the-art 性能,在准确性和效率方面实现了出色的结果。此外,这项工作还对混合 head 架构的优势提供了宝贵见解,为高效 LM 的未来研究提供了前景光明的方向。

详细了解 Hybma 1.5B Base Hymba 1.5B Instruct

致谢 

这项工作如果没有 NVIDIA 许多人的贡献是不可能完成的 ,包括 Wonmin Byeon、Zijia Chen、Ameya Sunil Mahabaleshwarkar、Shih-Yang Liu、Matthijs Van Keirsbilck、Min-Hung Chen、Yoshi Suhara、Nikolaus Binder、Hanah Zhang、Maksim Khadkevich、Yingyan Celine Lin、Jan Kautz、Pavlo Molchanov 和 Nathan Horrocks。

标签