Generative AI

Hymba ハイブリッド ヘッド アーキテクチャが小規模言語モデルのパフォーマンスを向上

Reading Time: 4 minutes

Transformer は、その Attention ベースのアーキテクチャによる、強力なパフォーマンス、並列化能力、および KV (Key-Value) キャッシュを通じた長期記憶のおかげで、言語モデル (LM) の主流となっています。しかし、二次計算コストと高いメモリ要求により、効率性に課題が生じています。これに対し、Mamba や Mamba-2 のような状態空間モデル (SSMs) は、複雑さを一定にして効率的なハードウェア最適化を提供しますが、メモリ想起タスクが苦手でそれは一般的なベンチマークでのパフォーマンスに影響を与えています。

NVIDIA の研究者は最近、効率性とパフォーマンスの両方を向上させるために、Transformer の Attention メカニズムを SSM と統合したハイブリッド ヘッド並列アーキテクチャを特徴とする小規模言語モデル (SLM) ファミリである Hymba を提案しました。Hymba では、Attention ヘッドが高解像度の記憶能力を提供し、SSM ヘッドが効率的なコンテキストの要約を可能にします。

Hymba の新たなアーキテクチャは、いくつかの洞察を明らかにしています。

  1. Attention のオーバーヘッド: Attention 計算の 50% 以上を、より安価な SSM 計算に置き換えることができます。
  2. ローカル Attention の優位性: SSM ヘッドにより要約されたグローバル情報のおかげで、一般的なタスクやメモリ想起に集中するタスクのパフォーマンスを犠牲にすることなく、ほとんどのグローバル Attention をローカル Attention に置き換えることができます。
  3. KV キャッシュ冗長性: Key-value キャッシュは、ヘッド間とレイヤー間で高い相関性があるため、ヘッド間 (GQA: Group Query Attention) およびレイヤー間 (Cross-layer KV キャッシュ共有) で共有できます。
  4. Softmax の Attention の制限: Attention メカニズムは、合計が 1 になるように制限されており、疎性と柔軟性に制限があります。NVIDIA は、プロンプトの先頭に学習可能なメタトークンを導入し、重要な情報を格納し、Attention メカニズムに関連する「強制的に Attention を行う」負担を軽減します。

この記事では、Hymba 1.5B が同様の規模である最先端のオープンソース モデル、Llama 3.2 1B、OpenELM 1B、Phi 1.5、SmolLM2 1.7B、Danube2 1.8B、Qwen2.5 1.5B などと比較して、良好なパフォーマンスを発揮することが示されています。同等のサイズの Transformer モデルと比較すると、Hymba はより高いスループットを発揮し、キャッシュを保存するために必要なメモリが 10 分の 1 で済みます。

Hymba 1.5B は Hugging Face コレクションと GitHub で公開されています。

Hymba 1.5B のパフォーマンス

図 1 は、Hymba 1.5B と 2B 未満のモデル (Llama 3.2 1B、OpenELM 1B、Phi 1.5、SmolLM2 1.7B、Danube2 1.8B、Qwen2.5 1.5B) を、平均タスク精度、シーケンス長に対するキャッシュ サイズ (MB)、スループット (tok/sec) で比較したものです。

図 1. Hymba 1.5B Base と 2B 未満のモデルのパフォーマンス比較

この一連の実験には、MMLU、ARC-C、ARC-E、PIQA、Hellaswag、Winogrande、SQuAD-C などのタスクが含まれています。スループットは、シーケンス長 8K、バッチ サイズ 128 で PyTorch を使用して NVIDIA A100 GPU で測定します。スループット測定中にメモリ不足 (OOM: Out of Memory) 問題が発生したモデルでは、OOM が解決されるまでバッチ サイズを半分にして、OOM なしで達成可能な最大スループットを測定しました。

Hymba モデルのデザイン

Mamba のような SSM は、Transformer の二次的な複雑性と推論時の KV キャッシュが大きい問題に対処するために導入されました。しかし、メモリ解像度が低いために、SSM は記憶想起とパフォーマンスの点で苦戦しています。これらの制限を克服するために、表 1 で効率的で高性能な小規模言語モデルを開発するためのロードマップを提案します。

構成常識推論 (%) ↑リコール (%) ↑スループット (token/sec) ↑キャッシュ サイズ (MB) ↓設計理由
300M モデル サイズと 100B トレーニング トークンのアブレーション
Transformer (Llama)44.0839.98721.1414.7非効率的ながら正確な記憶
状態空間モデル (Mamba)42.9819.234720.81.9効率的だが不正確な記憶
A. + Attention ヘッド (連続)44.0745.16776.3156.3記憶能力を強化
B. + 複数ヘッド (並列)45.1949.90876.7148.22 つのモジュールのバランスの改善
C. + ローカル / グローバル Attention44.5648.792399.741.2演算 / キャッシュの効率を向上
D. + KV キャッシュ共有45.1648.042756.539.4キャッシュ効率化
E. + メタトークン45.5951.792695.840.0学習した記憶の初期化
1.5B モデル サイズと 1.5T トレーニング トークンへのスケーリング
F. + サイズ / データ60.5664.15664.178.6タスク パフォーマンスのさらなる向上
G. + コンテキスト長の拡張 (2K→8K)60.6468.79664.178.6マルチショットとリコール タスクの改善
表 1. Hymba モデルのデザイン ロードマップ

融合型ハイブリッド モジュール

アブレーション研究によると、ハイブリッド ヘッド モジュール内で Attention と SSM ヘッドを並列にして融合するほうが、シーケンシャルにスタッキングするより優れていることが分かっています。Hymba は、ハイブリッド ヘッド モジュール内で Attention と SSM ヘッドを並列に融合させ、両ヘッドが同時に同じ情報を処理できるようにします。このアーキテクチャは、推論と記憶の正確さを高めます。

図 2. Hymba のハイブリッド ヘッド モジュール

効率性と KV キャッシュの最適化

Attention ヘッドはタスクのパフォーマンスを向上させますが、KV キャッシュの要求を増大させ、スループットを低下させます。これを緩和するために、Hymba はローカルおよびグローバルの Attention を組み合わせ、 Cross-layer KV キャッシュ共有を採用することで、ハイブリッド ヘッド モジュールを最適化します。これにより、パフォーマンスを犠牲にすることなくスループットが 3 倍向上し、キャッシュがほぼ 4 分の 1 に削減されます。

図 3. Hymba モデルのアーキテクチャ

メタトークン

入力の先頭に置かれる 128 の事前学習済みの埋め込みのセットであり、学習済みキャッシュの初期化として機能し、関連情報への注意を強化します。このようなトークンには 2 つの目的があります。

  • バックストップ トークンとして機能し、Attention を効果的に再分配することで Attention の流出を軽減する
  • 圧縮された世界知識をカプセル化する
図 4. メモリの側面から見た Hymba の解釈

モデル解析

このセクションでは、同一のトレーニング設定における異なるアーキテクチャを比較する方法を紹介します。それから、SSM と Attention の Attention マップを異なる学習済みモデルで可視化し、最後に、剪定 (pruning) を通じて Hymba のヘッド重要度分析を行います。このセクションのすべての分析は、Hymba のデザインにおける選択の仕組みと、それが効果的な理由を説明するのに役立ちます。

同一条件での比較

Hymba、純粋な Mamba2、Mamba2 と FFN、Llama3 スタイル、Samba スタイル (Mamba-FFN-Attn-FFN) のアーキテクチャを同一条件で比較しました。すべてのモデルが 10 億のパラメーターで、まったく同じトレーニング レシピで SmolLM-Corpus から 1,000 億トークンをゼロから学習しています。すべての結果は、Hugging Face モデルでゼロショット設定を使用して lm-evaluation-harness を通じて取得されています。Hymba は、常識推論だけでなく、質問応答タスクや記憶想起タスクでも最高のパフォーマンスを発揮します。

表 2 は、言語モデリングタスクと記憶想起タスクおよび常識推論タスクに関するさまざまなモデル アーキテクチャを比較しており、Hymba はすべての評価基準で卓越したパフォーマンスを達成しています。Hymba は、言語モデリングタスクで最も低い Perplexity を示し (Wiki で 18.62、LMB で 10.38)、特に SWDE (54.29) と SQuAD-C (44.71) の記憶想起タスクにおいて堅実な結果を示し、このカテゴリで最高の平均スコア (49.50) を達成しました。

モデル言語モデリング (PPL) ↓記憶想起型 (%) ↑常識推論 (%) ↑
Mamba215.8843.3452.52
Mamba2 と FFN17.4328.9251.14
Llama316.1947.3352.82
Samba16.2836.1752.83
Hymba14.549.554.57
表 2. 同じ設定で 1,000 億トークンで学習されたアーキテクチャの比較

常識推論と質問応答において、Hymba は平均スコア 54.57 で、 SIQA (31.76) や TruthfulQA (31.64) などのほとんどのタスクで、Llama3 や Mamba2 をやや上回っています。全体的に、Hymba はバランスの取れたモデルとして際立っており、多様なカテゴリで効率性とタスク パフォーマンスの両方で優れています。

Attention マップの可視化

さらに、Attention マップの要素を 4 つのタイプに分類しました。

  1. Meta: すべての実トークンからメタトークンへの Attention スコア。このカテゴリは、モデルがメタトークンに Attention を向ける傾向を反映するものです。Attention マップでは、通常、モデルにメタトークンがある場合、最初の数列 (例えば Hymba の場合は 128) に位置しています。
  2. BOS: すべての実トークンからセンテンスの開始トークンまでの Attention スコア。Attention マップでは、通常、メタトークンの直後の最初の列に位置します。
  3. Self: すべての実トークンからそれ自身への Attention スコア。Attention マップでは、通常、対角線上に位置しています。
  4. Cross: すべての実トークンから他の実トークンへの Attention スコア。Attention マップでは、通常、対角線外の領域に位置しています。

Hymba の Attention パターンは、vanilla (加工されていない) Transformer のそれとは大きく異なります。vanilla Transformer の Attention スコアは BOS に集中しており、Attention Sink の結果と一致しています。さらに、vanilla Transformer は、Self-Attention スコアの比率も高くなっています。Hymba では、メタトークン、Attention ヘッド、SSM ヘッドが互いに補完し合うように機能し、異なるタイプのトークン間で、よりバランスの取れた Attention スコアの分布を実現しています。

具体的には、メタトークンが BOS からの Attention スコアをオフロードすることで、モデルがより実際のトークンに集中できるようになります。SSM ヘッドはグローバルなコンテキストを要約し、現在のトークン (Self-Attention スコア) により重点を置きます。一方、Attention ヘッドは、Self と BOS トークンに対する注意が低く、他のトークン (すなわち、Cross Attention スコア) への注意が高くなります。これは、Hymba のハイブリッド ヘッド デザインが、異なるタイプのトークン間の Attention 分布のバランスを効果的に取ることができ、パフォーマンスの向上につながる可能性があることを示唆しています。

図 5. メタトークン、Sliding Window Attention、Mamba 貢献の組み合わせによる Hymba の Attention マップの概略図
図 6. Llama 3.2 3B と Hymba 1.5B の異なるカテゴリからの Attention スコアの合計

ヘッド重要度分析

各レイヤーのAttention と SSM ヘッドの相対的な重要性を分析するために、それぞれを削除して最終的な精度を記録しました。分析の結果、以下のことが明らかになりました。

  • 同じレイヤーの  Attention / SSM ヘッドの相対的な重要性は入力適応であり、タスクによって異なります。これは、さまざまな入力の処理において、異なる役割を果たす可能性があることを示唆しています。
  • 最初のレイヤーの SSM ヘッドは言語モデリングタスクに不可欠で、これを削除すると、ランダム推測レベルにまで大幅に精度が低下します。
  • 一般的に、Attention / SSM ヘッドを 1 つ削除すると、Hellaswag ではそれぞれ平均 0.24%/1.1% 精度が低下します。
図 7. Hellaswag の 1K サンプルを使用して測定した、各レイヤーの Attention または SSM ヘッドを削除した後の達成精度

モデル アーキテクチャと学習のベスト プラクティス

このセクションでは、Hymba 1.5B Base と Hymba 1.5B Instruct の主要アーキテクチャ上の決定事項と学習方法の概要について説明します。

モデル アーキテクチャ

  • ハイブリッド アーキテクチャ: Mamba は要約に優れ、通常は現在のトークンにより重点を置きます。Attention はより正確でスナップショット メモリとして機能します。標準的なシーケンシャル融合ではなく、並列に組み合わせることで利点を統合することができます。SSM と Attention ヘッド間のパラメーター比は 5:1 を選択しました。
  • Sliding Window Attention: 完全な Attention ヘッドは 3 つのレイヤー (最初、最後、中間) に維持され、残りの 90% のレイヤーで Sliding Window Attention ヘッドが使用されます。
  • Cross-layer KV キャッシュ共有: 連続する 2 つの Attention レイヤー間に実装されます。これは、ヘッド間の GQA KV キャッシュ共有に加えて行われます。
  • メタトークン: これらの 128 トークンは教師なし学習が可能であり、大規模言語モデル (LLM) におけるエントロピー崩壊の問題を回避し、Attention Sink 現象を緩和するのに役立ちます。さらに、モデルはこれらのトークンに一般的な知識を格納します。

学習のベスト プラクティス

  • 事前学習: 2 段階のベースモデル学習を選択しました。ステージ 1 では、一定の高い学習率を維持し、フィルタリングされていない大規模なコーパス データの使用しました。続いて、高品質のデータを用いて 1e-5 まで継続的に学習率を減衰させました。このアプローチにより、ステージ 1 の継続的な学習と再開が可能になります。
  • 指示ファインチューニング: 指示モデルの調整は 3 つの段階で行われます。まず、SFT-1 は、コード、数学、関数呼び出し、ロール プレイ、その他のタスク固有のデータで学習を実施し、強力な推論能力をモデルに付与します。次に、SFT-2 はモデルに人間の指示に従うことを教えます。最後に、DPO を活用して、モデルを人間の好みに合わせ、モデルの安全性を高めます。
図 8. Hymba モデル ファミリに適応した学習パイプライン

パフォーマンスと効率性の評価

1.5T の事前学習トークンだけで、Hymba 1.5B モデルはすべての小規模言語モデルの中で最高の性能を発揮し、Transformer ベースの LM よりも優れたスループットとキャッシュ効率を実現します。

例えば、13 倍以上のトークン数で事前学習された最も強力なベースラインである Qwen2.5 に対してベンチマークした場合、Hymba 1.5B は平均精度が 1.55%、スループットが 1.41 倍、キャッシュ効率が 2.90 倍に向上します。2T 未満のトークンで学習された最も強力な小規模言語モデル、すなわち h2o-danube2 と比較すると、この方法は平均精度が 5.41%、スループットが 2.45 倍、キャッシュ効率が 6.23 倍に向上しています。

モデルパラメーター数学習トークントークン(1 秒あたり)キャッシュ(MB)MMLU 5-shotARC-E 0-shotARC-C 0-shotPIQA 0-shotWino. 0-shotHella. 0-shotSQuAD -C1-shot平均
OpenELM-11.1B1.5T24634627.0662.3719.5474.7661.848.3745.3848.57
Renev0.11.3B1.5T80011332.9467.0531.0676.4962.7551.1648.3652.83
Phi1.51.3B0.15T241157342.5676.1844.7176.5672.854830.0955.85
SmolLM1.7B1T238157327.0676.4743.4375.7960.9349.5845.8154.15
Cosmo1.8B.2T244157326.162.4232.9471.7655.842.938.5147.2
h20dan-ube21.8B2T27149240.0570.6633.1976.0166.9353.749.0355.65
Llama 3.2 1B1.2B9T53526232.1265.5331.3974.4360.6947.7240.1850.29
Qwen2.51.5B18T46922960.9275.5141.2175.7963.3850.249.5359.51
AMDOLMo1.2B1.3T387104926.9365.9131.5774.9261.6447.333.7148.85
SmolLM21.7B11T238157350.2977.7844.7177.0966.3853.5550.560.04
Llama3.2 3B3.0B9T19191856.0374.5442.3276.6669.8555.2943.4659.74
Hymba1.5B1.5T6647951.1976.9445.977.3166.6153.5555.9361.06
表 2. Hymba 1.5B ベース モデルの結果

指示モデル

Hymba 1.5B Instruct モデルは、全タスク平均で最高のパフォーマンスを達成し、直近の最高性能モデルである Qwen 2.5 Instruct を約 2% 上回りました。特に、Hymba 1.5B は GSM8K/GPQA/BFCLv2 で、それぞれ 58.76/31.03/46.40 のスコアで他のすべてのモデルを上回っています。これらの結果は、特に複雑な推論能力を必要とする分野において、Hymba 1.5B の優位性を示しています。

モデルパラメーター数MMLU ↑IFEval ↑GSM8K ↑GPQA ↑BFCLv2 ↑平均↑
SmolLM1.7B27.8025.161.3625.67-*20.00
OpenELM1.1B25.656.2556.0321.62-*27.39
Llama 3.21.2B44.4158.9242.9924.1120.2738.14
Qwen2.51.5B59.7346.7856.0330.1343.8547.30
SmolLM21.7B49.1155.0647.6829.2422.8340.78
Hymba 1.5B1.5B52.7957.1458.7631.0346.4049.22
表 3. Hymba 1.5B Instruct モデルの結果

まとめ

新しい Hymba ファミリの小規模言語モデルは、ハイブリッド ヘッド アーキテクチャを採用し、Attention ヘッドの高解像な記憶能力と SSM ヘッドの効率的なコンテキストの要約を組み合わせています。Hymba のパフォーマンスをさらに最適化するために、学習可能なメタトークンが導入され、Attention ヘッドと SSM ヘッドの両方で学習済みキャッシュとして機能し、顕著な情報に注目するモデルの精度を強化しました。Hymba のロードマップ、包括的な評価、アブレーション研究を通じて、Hymba は幅広いタスクにわたって新たな最先端のパフォーマンスを確立し、正確さと効率性の両面で優れた結果を達成しました。さらに、この研究は、ハイブリッド ヘッド アーキテクチャの利点に関する貴重な洞察をもたらし、効率的な言語モデルの今後の研究に有望な方向性を示しています。

Hybma 1.5B BaseHymba 1.5B Instruct の詳細はこちらをご覧ください。

謝辞

この成果は、Wonmin Byeon、Zijia Chen、Ameya Sunil Mahabaleshwarkar、Shih-Yang Liu、Matthijs Van Keirsbilck、Min-Hung Chen、Yoshi Suhara、Nikolaus Binder、Hanah Zhang、Maksim Khadkevich、Yingyan Celine Lin、Jan Kautz、Pavlo Molchanov、Nathan Horrocks など、NVIDIA の多くのメンバーの貢献なくしては実現しませんでした。

関連情報

Tags