Generative AI

NVIDIA NeMo RL を用いた合成データによる Supervised Fine-Tuning (SFT)

Reading Time: 7 minutes

本記事では、Nemotron-Personas-Japan を Seed とした高品質な合成データを用い、  NeMo RL による Supervised Fine-Tuning (SFT) を行うことで、  小規模モデルの日本語常識推論性能を改善する方法を解説します。本記事では Supervised Fine-Tuning を便宜上 SFT の略称で記載します。

合成データ生成方法に関しては「Nemotron-Personas-Japan を用いた NVIDIA NeMo Data Designer による合成データ生成」を参照してください

大規模言語モデル (LLM) は高性能で汎用性が高い一方、その運用には膨大な計算コストやインフラが必要で、現実的な本番運用には課題があります。一方で、Small Language Models (SLM) は軽量で高速に動作する利点があるものの、複雑な推論や多様な文脈理解を要するビジネス ユース ケースでは性能が十分でない場合もあります。

このように、「性能」と「現実的な運用性」のトレードオフが存在する中で、多くの企業は次のような課題に直面しています。

  • 事前学習済みモデルは存在するが、特定タスクへの適応 (ドメイン適合) が不十分
  • 合成データや独自評価指標を活用して継続的にモデルを改善したいが、開発基盤が分断されがち
  • SFT から DPO、RL に至るまで、開発パイプラインが統合されておらず、再現性や効率に課題がある

NeMo RL とは何か

NeMo RL は、NVIDIA が提供する NVIDIA NeMo Framework におけるポストトレーニング (post-training) ライブラリです。  

NeMo RL は、以下を前提とした設計になっています。

  • モデル規模: 数 B パラメーターから 100B 超まで
  • 実行環境: 単一 GPU、マルチ GPU、クラスター環境
    • ローカル環境での小規模検証
    • 本番環境での大規模学習
  • 様々なモデル サイズとハードウェア構成に対応するためにトレーニング バックエンドには DTensor、Megatron、生成バックエンドには vLLM、Megatron を採用
  • SFT → DPO → RL への拡張が容易
  • 大規模モデルや分散学習まで見据えた設計
  • NVIDIA NeMo エコシステムとの親和性が高い
  • YAML ベースの設定ファイルで学習条件を管理
  • JSON / JSONL 形式のデータを直接利用可能

NeMo RL は、以下のようなポストトレーニングをサポートします。

  • Supervised Fine-Tuning (SFT): 教師付きデータ (人手 or 合成データ) によるファインチューニング
  • Direct Preference Optimization (DPO): 人間の好みや比較データを用いた最適化
  • Reinforcement learning (RL): GRPO, GSPO, DAPO などのポリシー最適化手法
  • Reward Modeling: 報酬モデルの学習

本記事では、NeMo RL が提供する機能のうち、SFT に焦点を当て、以下の役割で使用します。

検証環境

本記事で検証に使用した環境は下記です。

ハードウェア動作環境

  • DGX-A100

ソフトウェア環境

下記の方法で必要なコードを取得します。

git clone https://github.com/NVIDIA-NeMo/RL.git nemo-rl --recursive
cd nemo-rl

git submodule update --init --recursive

nvidia/nvidia-nemotron-nano-9b-v2 を SFT するため、下記で必要なライブラリをインストールします。

uv sync --extra automodel

合成データを用いた NeMo RL による SFT

合成データの前処理

SFT 実行時に LLM がインデックス分布の偏りを利用して予測するケースがあるため、答えのインデックスを下記コードでリバランスして、学習データと検証データに分離します。

nvidia/nvidia-nemotron-nano-9b-v2 は思考トークン (Chain-of-Thought) 対応していますが、本記事の SFT では 推論時に思考トークン (Chain-of-Thought) を使用しない想定でデータを準備します。

import json
import random
import collections
import codecs
from pathlib import Path


def load_jsonl(file_path):
    data = []
    error_count = 0

    decoder = json.JSONDecoder()
    text = ""
    idx = 0
    line_num = 1

    # 'replace' でデコードエラーを潰し、末尾の不完全なUTF-8でも落とさない
    inc = codecs.getincrementaldecoder("utf-8")(errors="replace")

    def _parse_available_text(is_final: bool = False) -> None:
        nonlocal text, idx, line_num, error_count

        while idx < len(text):
            # 空白をスキップ
            while idx < len(text) and text[idx].isspace():
                if text[idx] == "\n":
                    line_num += 1
                idx += 1
            if idx >= len(text):
                break

            try:
                obj, end_idx = decoder.raw_decode(text, idx)
                data.append(obj)
                idx = end_idx
                continue
            except json.JSONDecodeError as e:
                # チャンク境界でJSONが未完成の可能性があるので、改行が無ければ次チャンクを待つ
                next_newline = text.find("\n", idx)
                if next_newline == -1:
                    if is_final:
                        error_count += 1
                        print(
                            f"Warning: Trailing/incomplete JSON around line {line_num}, "
                            f"position {idx}: {e}"
                        )
                    break

                # 壊れた/不完全な行をスキップして続行
                error_count += 1
                print(
                    f"Warning: Error parsing around line {line_num}, "
                    f"position {idx}: {e}"
                )
                idx = next_newline + 1
                line_num += 1
                if error_count > 100:
                    print("  Too many errors, stopping...")
                    idx = len(text)
                    break

        # バッファが肥大化しないように定期的に詰める
        if idx > 1_000_000:
            text = text[idx:]
            idx = 0

    chunk_size = 8 * 1024 * 1024
    with open(file_path, "rb") as f:
        while True:
            b = f.read(chunk_size)
            if not b:
                break
            text += inc.decode(b)
            _parse_available_text(is_final=False)

    # finalize decoder + parse remainder
    text += inc.decode(b"", final=True)
    _parse_available_text(is_final=True)

    if error_count > 0:
        print(f"  Total errors: {error_count} objects/lines skipped")
    return data


def save_jsonl(data, file_path):
    """データをJSONLファイルに保存する"""
    with open(file_path, 'w', encoding='utf-8') as f:
        for item in data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')


def split_train_valid(data, train_ratio=0.9, seed=42):
    """
    データをtrainとvalidに分割する
    """
    random.seed(seed)
    shuffled_data = data.copy()
    random.shuffle(shuffled_data)
    split_idx = int(len(shuffled_data) * train_ratio)
    train_data = shuffled_data[:split_idx]
    valid_data = shuffled_data[split_idx:]
    return train_data, valid_data

def _count_jcqa_labels_raw(data):
    """元の data(load_jsonl の直後)から jcqa の answer_index 分布を数える"""
    counts = collections.Counter()
    total = 0
    for item in data:
        jcqa = item.get("jcqa_data")
        if jcqa is None:
            continue
        idx = jcqa.get("answer_index")
        if idx is None:
            continue
        counts[str(idx)] += 1
        total += 1
    return counts, total


def _rebalance_one_jcqa(jcqa, target_label):
    choices = []
    for i in range(5):
        choices.append(jcqa.get(f"choice{i}", ""))

    orig_idx = jcqa.get("answer_index")
    if orig_idx is None:
        return

    correct_text = choices[orig_idx]
    other_texts = [c for i, c in enumerate(choices) if i != orig_idx]

    random.shuffle(other_texts)

    new_choices = [None] * 5
    new_choices[target_label] = correct_text

    it = iter(other_texts)
    for i in range(5):
        if new_choices[i] is None:
            new_choices[i] = next(it)

    # jcqa_data を上書き
    for i in range(5):
        jcqa[f"choice{i}"] = new_choices[i]
    jcqa["answer_index"] = target_label


def rebalance_jcqa_labels_in_place(data, seed=42):
    random.seed(seed)

    # before: 元の分布を出力
    before_counts, n = _count_jcqa_labels_raw(data)
    print(f"[jcqa] total jcqa_data examples: {n}")
    print(f"[jcqa] label counts BEFORE rebalance: {dict(before_counts)}")

    jcqa_indices = [i for i, item in enumerate(data) if "jcqa_data" in item]
    n_jcqa = len(jcqa_indices)
    if n_jcqa == 0:
        print("[jcqa] No jcqa_data found, skip rebalance.")
        return

    # 目標件数:ほぼ N/5 ずつ
    base = n_jcqa // 5
    rem = n_jcqa % 5  # 余りは先頭から +1 して配る
    target_counts = {str(l): base for l in range(5)}
    for l in range(rem):
        target_counts[str(l)] += 1

    print(f"[jcqa] target label counts: {target_counts}")

    # シャッフルした順番で、各サンプルにターゲットラベルを割り当てる
    random.shuffle(jcqa_indices)
    current_counts = {str(l): 0 for l in range(5)}

    for idx in jcqa_indices:
        item = data[idx]
        jcqa = item["jcqa_data"]

        # まだ目標に達していないラベルの中から選ぶ
        available_labels = [
            l for l in range(5)
            if current_counts[str(l)] < target_counts[str(l)]
        ]
        if not available_labels:
            # すべて埋まっていたら、何か適当に
            available_labels = list(range(5))

        target_label = random.choice(available_labels)
        current_counts[str(target_label)] += 1

        # 1サンプルの choice 並びと answer_index を target_label に揃える
        _rebalance_one_jcqa(jcqa, target_label)

    print(f"[jcqa] label counts AFTER  rebalance: {current_counts}")

def convert_jcommonsenseqa(item):
    """jcommonsenseqaデータをmessagesフォーマットに変換"""
    jcqa = item.get('jcqa_data', {})

    question = jcqa.get('question', '')
    choices_lines = []
    for i in range(5):
        choice_key = f'choice{i}'
        if choice_key in jcqa:
            choices_lines.append(f"    {i}. {jcqa[choice_key]}")
    choices_text = '\n'.join(choices_lines)

    user_content = (
        "以下の質問に答えてください。\n\n"
        f"質問:\n{question}\n\n"
        "選択肢:\n"
        f"{choices_text}\n\n"
        "最も適切な選択肢の番号だけを、半角数字で1つ出力してください。"
    )

    ground_truth = jcqa.get('answer_index')
    ground_truth_str = str(ground_truth) if ground_truth is not None else ""
    assistant_content = f"<think></think>\n\n{ground_truth_str}"

    return {
        "messages": [
            {"role": "user", "content": user_content},
            {"role": "assistant", "content": assistant_content}
        ],
        "extra_env_info": {
            "dataset_type": "jcommonsenseqa",
            "ground_truth": ground_truth_str
        }
    }

def count_labels_in_converted(data, name):
    counts = collections.Counter()
    total = 0
    for ex in data:
        gt = ex.get("extra_env_info", {}).get("ground_truth")
        if gt is None:
            continue
        counts[gt] += 1
        total += 1
    print(f"[{name}] total examples: {total}")
    print(f"[{name}] label counts: {dict(counts)}")

from collections import defaultdict

def convert_data(input_file_path, output_dir):
    print(f"Loading {input_file_path}...")
    data = load_jsonl(input_file_path)
    print(f"  Loaded {len(data)} examples")

    # ★ jcommonsenseqa のラベル分布をリバランス(choice 並び+answer_index を書き換える)
    print("\nRebalancing jcommonsenseqa label indices (0-4)...")
    rebalance_jcqa_labels_in_place(data, seed=42)

    # 変換
    jcommonsenseqa_data = []

    for idx, item in enumerate(data):
        if 'jcqa_data' in item:
            converted = convert_jcommonsenseqa(item)
            jcommonsenseqa_data.append(converted)

    print(f"\nJCommonsenseQA (converted): {len(jcommonsenseqa_data)} examples")

    # 変換後のラベル分布も表示
    count_labels_in_converted(jcommonsenseqa_data, "jcommonsenseqa_converted")

    # 出力ディレクトリを作成
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)

    # train/valid に分割
    print("\nSplitting data into train/valid...")
    jcommonsenseqa_train, jcommonsenseqa_valid = split_train_valid(jcommonsenseqa_data)

    print(f"  JCommonsenseQA:     train={len(jcommonsenseqa_train)},     valid={len(jcommonsenseqa_valid)}")

    # train/valid ごとのラベル分布も確認
    count_labels_in_converted(jcommonsenseqa_train, "jcommonsenseqa_train")
    count_labels_in_converted(jcommonsenseqa_valid, "jcommonsenseqa_valid")

    # ファイルを保存
    # JCommonsenseQA
    jcommonsenseqa_train_file = output_path / "jcommonsenseqa_train.jsonl"
    jcommonsenseqa_valid_file = output_path / "jcommonsenseqa_valid.jsonl"
    print(f"\nSaving JCommonsenseQA train data to {jcommonsenseqa_train_file}...")
    save_jsonl(jcommonsenseqa_train, jcommonsenseqa_train_file)
    print(f"Saving JCommonsenseQA valid data to {jcommonsenseqa_valid_file}...")
    save_jsonl(jcommonsenseqa_valid, jcommonsenseqa_valid_file)

    print("\n✓ Done!")
    return str(jcommonsenseqa_train_file)


if __name__ == "__main__":
    base_dir = Path(__file__).parent

    # 元のファイル名は環境に合わせて変更してください
    input_file = base_dir / "with_seed_data.jsonl"
    output_dir = base_dir / "simple_format_arrange"

    print("=" * 60)
    print("Converting to simple format with balanced jcommonsenseqa labels...")
    print("=" * 60)

    jcommonsenseqa_file = convert_data(
        str(input_file),
        str(output_dir)
    )

    print(f"\nOutput files:")
    print(f"  JCommonsenseQA:     {jcommonsenseqa_file}")

SFT 用の SFT yaml 設定

標準で提供されている sft.yaml を変更して使用します。

SFT 関連のパラメーターとデータ パス、データ フォーマットを変更しました。データ フォーマットは OpenAI Format 形式で変換しているので、これを設定してます。 OpenAI Format とは、messages 配列で user / assistant の対話形式を表現するデータフォーマットです。主な変更部分をピックアップします。以下の checkpoint_dirtrain_data_pathval_data_path は適宜変更してください。

# SFT Algorithm Configuration
sft:
  ## total number of steps to train will equal
  ## min((max_num_epochs * len(train_dataloader)), max_num_steps)
  max_num_epochs: 2
  max_num_steps: 55 

  val_period: 10
  val_batches: 4 
  val_global_batch_size: 32
  val_micro_batch_size: 1
  val_at_start: true
  seed: 42

checkpointing:
  enabled: true
  checkpoint_dir: {save model path}
  metric_name: "val:val_loss" # one of "val:" or "train:" followed by the metric name
  higher_is_better: false
  keep_top_k: 3
  save_period: 55 
  checkpoint_must_save_by: null

policy:
  model_name: "nvidia/NVIDIA-Nemotron-Nano-9B-v2"
  tokenizer:
    name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default
    # chat_template can be a Jinja template string or path to a .jinja file 
    chat_template: "{% for message in messages %}{%- if message['role'] == 'system'  %}{{'Context: ' + message['content'].strip()}}{%- elif message['role'] == 'user'  %}{{' Question: ' + message['content'].strip() + ' Answer:'}}{%- elif message['role'] == 'assistant'  %}{{' ' + message['content'].strip()}}{%- endif %}{% endfor %}"
    chat_template_kwargs: null # can be used to pass kwargs to the chat template, e.g., enable_thinking=true
  train_global_batch_size: 256 
  train_micro_batch_size: 1
  max_total_sequence_length: 2048 
  precision: "bfloat16"

  offload_optimizer_for_logprob: false

  dtensor_cfg:
    enabled: true 
    env_vars: {}
    cpu_offload: False
    sequence_parallel: false
    activation_checkpointing: false
    tensor_parallel_size: 1
    context_parallel_size: 1
    custom_parallel_plan: null

  dynamic_batching:
    enabled: false
    train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
    sequence_length_round: 64

  sequence_packing:
    enabled: False
    train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
    algorithm: "modified_first_fit_decreasing"
    sequence_length_round: 64

  # makes the training sequence length divisible by the tensor parallel size
  # this is useful for sequence parallel training
  make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size}
  max_grad_norm: 1.0

  optimizer:
    name: "torch.optim.AdamW"
    kwargs:
      lr: 3.0e-6
      weight_decay: 0.05
      betas: [0.9, 0.98]
      eps: 1e-5
      # when using Dtensor, we need to set foreach
      # and fused to False
      foreach: False
      fused: False
    
  ## ignored since enabled=false, but needed for testing purposes
  megatron_cfg:
    enabled: false
    env_vars: {}
    empty_unused_memory_level: 1
    activation_checkpointing: false
    tensor_model_parallel_size: 1
    expert_tensor_parallel_size: 1
    expert_model_parallel_size: 1
    pipeline_model_parallel_size: 1
    context_parallel_size: 1
    pipeline_dtype: ${policy.precision}
    num_layers_in_first_pipeline_stage: null
    num_layers_in_last_pipeline_stage: null
    sequence_parallel: false
    freeze_moe_router: false
    moe_router_dtype: null
    moe_router_load_balancing_type: "aux_loss"
    moe_router_bias_update_rate: 1e-3
    moe_permute_fusion: false
    #gives ~20% training perf speedup with sequence packing 
    apply_rope_fusion: True
    # gives ~25% training perf speedup with sequence packing and apply_rope_fusion
    bias_activation_fusion: True
    defer_fp32_logits: False

    optimizer:
      optimizer: "adam"
      lr: 1.0e-6
      min_lr: 4.9999e-6
      weight_decay: 0.1
      bf16: false
      fp16: false
      params_dtype: "float32"

      #adam
      adam_beta1: 0.9
      adam_beta2: 0.98
      adam_eps: 1e-5

      #sgd
      sgd_momentum: 0.9

      #distributed optimizer
      use_distributed_optimizer: true
      use_precision_aware_optimizer: true

      clip_grad: ${policy.max_grad_norm}

      # optimizer cpu offload
      optimizer_cpu_offload: false
      optimizer_offload_fraction: 0.0

    scheduler:
      start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
      end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
      weight_decay_incr_style: "constant"
      lr_decay_style: "constant"
      lr_decay_iters: 1000
      lr_warmup_iters: 50
      lr_warmup_init: 4.9999e-6

    distributed_data_parallel_config:
      grad_reduce_in_fp32: false
      overlap_grad_reduce: true
      overlap_param_gather: true
      data_parallel_sharding_strategy: "optim_grads_params"
      use_custom_fsdp: false

data:
  max_input_seq_length: ${policy.max_total_sequence_length}
  add_bos: true
  add_eos: true
  add_generation_prompt: false
  shuffle: true
  num_workers: 1

  dataset_name: "openai_format"
  train_data_path: {train synthesis data path}
  val_data_path: {valid synthesis data path}
  chat_key: "messages"                     # Key for messages in the data
  system_key: null                         # Key for system message (optional)
  system_prompt: null                      # Default system prompt (optional)
  tool_key: "tools"                        # Key for tools in the data
  use_preserving_dataset: false            # If true, uses PreservingDataset to preserve heterogeneous schemas (e.g., tool calls with varying argument structures)
  
  # You can use custom response datasets for training and validation. For example:
  #   data:
  #     dataset_name: ResponseDataset
  #     train_data_path: <PathToTrainingDataset>  # e.g., /path/to/local/dataset.jsonl or hf_org/hf_dataset_name (HuggingFace)
  #     val_data_path: <PathToValidationDataset>
  #     input_key: <QuestionKey>, default is "input"
  #     output_key: <AnswerKey>, default is "output"
  #     train_split: <TrainSplit>, default is None  # used for HuggingFace datasets
  #     val_split: <ValSplit>, default is None  # used for HuggingFace datasets
  # See https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/sft.md#datasets for more details.

  ## unused with openai_format dataset
  prompt_file: null
  split: null
  output_key: null
  seed: null

logger:
  log_dir: "logs"  # Base directory for all logs
  wandb_enabled: true # Make sure you do a ``wandb login [Your API key]'' before running
  tensorboard_enabled: true
  mlflow_enabled: false
  swanlab_enabled: false # Disable SwanLab logging
  monitor_gpus: true  # If true, will monitor GPU usage and log to wandb and/or tensorboard
  wandb:
    project: "sft-nemotron-nano-v2-9b"
    name: "sft-nemotron-nano-v2-9b-${data.dataset_name}_with_seed_adjust_filter_2000_8k"
  tensorboard:
    log_dir: "tb_logs-sft-dev-${data.dataset_name}"
  mlflow:
    experiment_name: "sft-dev"
    run_name: "sft-dev-${data.dataset_name}"
  gpu_monitoring:
    collection_interval: 10  # How often to collect GPU usage metrics (in seconds)
    flush_interval: 10  # How often to flush GPU usage metrics to the loggers (in seconds)

cluster:
  gpus_per_node: 8 
  num_nodes: 1

下記コマンドで実行します。本検証では、約 8,000 件の合成データを用い、2 エポックで SFT を行いました。

uv run python examples/run_sft.py --config {config yaml}

NVIDIA NeMo RL を用いた合成データの SFT 結果

図 1. 学習データの loss
図 2. 検証データの loss

学習データと検証データでも loss が下がっており、 SFT が上手く動作していることが確認できます。学習データと検証データの loss には著しい差がないことが見られます。

まとめ

本ブログでは、Nemotron-Personas-Japan を Seed とした高品質な合成データから、  NeMo RL による SFT を行いました。

次のステップとしては JCommonsenseQA の精度が合成データを使った SFT によって改善するか「マルチ LLM 対応の NVIDIA NIM による合成データ SFT (Seed あり / なし) の効果分析」で検証します。

関連情報

Tags