本記事では、Nemotron-Personas-Japan を Seed とした高品質な合成データを用い、 NeMo RL による Supervised Fine-Tuning (SFT) を行うことで、 小規模モデルの日本語常識推論性能を改善する方法を解説します。本記事では Supervised Fine-Tuning を便宜上 SFT の略称で記載します。
合成データ生成方法に関しては「Nemotron-Personas-Japan を用いた NVIDIA NeMo Data Designer による合成データ生成」を参照してください
大規模言語モデル (LLM) は高性能で汎用性が高い一方、その運用には膨大な計算コストやインフラが必要で、現実的な本番運用には課題があります。一方で、Small Language Models (SLM) は軽量で高速に動作する利点があるものの、複雑な推論や多様な文脈理解を要するビジネス ユース ケースでは性能が十分でない場合もあります。
このように、「性能」と「現実的な運用性」のトレードオフが存在する中で、多くの企業は次のような課題に直面しています。
- 事前学習済みモデルは存在するが、特定タスクへの適応 (ドメイン適合) が不十分
- 合成データや独自評価指標を活用して継続的にモデルを改善したいが、開発基盤が分断されがち
- SFT から DPO、RL に至るまで、開発パイプラインが統合されておらず、再現性や効率に課題がある
NeMo RL とは何か
NeMo RL は、NVIDIA が提供する NVIDIA NeMo Framework におけるポストトレーニング (post-training) ライブラリです。
NeMo RL は、以下を前提とした設計になっています。
- モデル規模: 数 B パラメーターから 100B 超まで
- 実行環境: 単一 GPU、マルチ GPU、クラスター環境
- ローカル環境での小規模検証
- 本番環境での大規模学習
- 様々なモデル サイズとハードウェア構成に対応するためにトレーニング バックエンドには DTensor、Megatron、生成バックエンドには vLLM、Megatron を採用
- SFT → DPO → RL への拡張が容易
- 大規模モデルや分散学習まで見据えた設計
- NVIDIA NeMo エコシステムとの親和性が高い
- YAML ベースの設定ファイルで学習条件を管理
- JSON / JSONL 形式のデータを直接利用可能
NeMo RL は、以下のようなポストトレーニングをサポートします。
- Supervised Fine-Tuning (SFT): 教師付きデータ (人手 or 合成データ) によるファインチューニング
- Direct Preference Optimization (DPO): 人間の好みや比較データを用いた最適化
- Reinforcement learning (RL): GRPO, GSPO, DAPO などのポリシー最適化手法
- Reward Modeling: 報酬モデルの学習
本記事では、NeMo RL が提供する機能のうち、SFT に焦点を当て、以下の役割で使用します。
- NeMo Data Designer で生成した高品質な合成データを投入
- nvidia/nvidia-nemotron-nano-9b-v2 を SFT で改善
検証環境
本記事で検証に使用した環境は下記です。
ハードウェア動作環境
- DGX-A100
ソフトウェア環境
- 2026.01.06 時点の main ブランチのコードを使用
NVIDIA-NeMo/RL: Scalable toolkit for efficient model reinforcement
下記の方法で必要なコードを取得します。
git clone https://github.com/NVIDIA-NeMo/RL.git nemo-rl --recursive
cd nemo-rl
git submodule update --init --recursive
nvidia/nvidia-nemotron-nano-9b-v2 を SFT するため、下記で必要なライブラリをインストールします。
uv sync --extra automodel
合成データを用いた NeMo RL による SFT
合成データの前処理
SFT 実行時に LLM がインデックス分布の偏りを利用して予測するケースがあるため、答えのインデックスを下記コードでリバランスして、学習データと検証データに分離します。
nvidia/nvidia-nemotron-nano-9b-v2 は思考トークン (Chain-of-Thought) 対応していますが、本記事の SFT では 推論時に思考トークン (Chain-of-Thought) を使用しない想定でデータを準備します。
import json
import random
import collections
import codecs
from pathlib import Path
def load_jsonl(file_path):
data = []
error_count = 0
decoder = json.JSONDecoder()
text = ""
idx = 0
line_num = 1
# 'replace' でデコードエラーを潰し、末尾の不完全なUTF-8でも落とさない
inc = codecs.getincrementaldecoder("utf-8")(errors="replace")
def _parse_available_text(is_final: bool = False) -> None:
nonlocal text, idx, line_num, error_count
while idx < len(text):
# 空白をスキップ
while idx < len(text) and text[idx].isspace():
if text[idx] == "\n":
line_num += 1
idx += 1
if idx >= len(text):
break
try:
obj, end_idx = decoder.raw_decode(text, idx)
data.append(obj)
idx = end_idx
continue
except json.JSONDecodeError as e:
# チャンク境界でJSONが未完成の可能性があるので、改行が無ければ次チャンクを待つ
next_newline = text.find("\n", idx)
if next_newline == -1:
if is_final:
error_count += 1
print(
f"Warning: Trailing/incomplete JSON around line {line_num}, "
f"position {idx}: {e}"
)
break
# 壊れた/不完全な行をスキップして続行
error_count += 1
print(
f"Warning: Error parsing around line {line_num}, "
f"position {idx}: {e}"
)
idx = next_newline + 1
line_num += 1
if error_count > 100:
print(" Too many errors, stopping...")
idx = len(text)
break
# バッファが肥大化しないように定期的に詰める
if idx > 1_000_000:
text = text[idx:]
idx = 0
chunk_size = 8 * 1024 * 1024
with open(file_path, "rb") as f:
while True:
b = f.read(chunk_size)
if not b:
break
text += inc.decode(b)
_parse_available_text(is_final=False)
# finalize decoder + parse remainder
text += inc.decode(b"", final=True)
_parse_available_text(is_final=True)
if error_count > 0:
print(f" Total errors: {error_count} objects/lines skipped")
return data
def save_jsonl(data, file_path):
"""データをJSONLファイルに保存する"""
with open(file_path, 'w', encoding='utf-8') as f:
for item in data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
def split_train_valid(data, train_ratio=0.9, seed=42):
"""
データをtrainとvalidに分割する
"""
random.seed(seed)
shuffled_data = data.copy()
random.shuffle(shuffled_data)
split_idx = int(len(shuffled_data) * train_ratio)
train_data = shuffled_data[:split_idx]
valid_data = shuffled_data[split_idx:]
return train_data, valid_data
def _count_jcqa_labels_raw(data):
"""元の data(load_jsonl の直後)から jcqa の answer_index 分布を数える"""
counts = collections.Counter()
total = 0
for item in data:
jcqa = item.get("jcqa_data")
if jcqa is None:
continue
idx = jcqa.get("answer_index")
if idx is None:
continue
counts[str(idx)] += 1
total += 1
return counts, total
def _rebalance_one_jcqa(jcqa, target_label):
choices = []
for i in range(5):
choices.append(jcqa.get(f"choice{i}", ""))
orig_idx = jcqa.get("answer_index")
if orig_idx is None:
return
correct_text = choices[orig_idx]
other_texts = [c for i, c in enumerate(choices) if i != orig_idx]
random.shuffle(other_texts)
new_choices = [None] * 5
new_choices[target_label] = correct_text
it = iter(other_texts)
for i in range(5):
if new_choices[i] is None:
new_choices[i] = next(it)
# jcqa_data を上書き
for i in range(5):
jcqa[f"choice{i}"] = new_choices[i]
jcqa["answer_index"] = target_label
def rebalance_jcqa_labels_in_place(data, seed=42):
random.seed(seed)
# before: 元の分布を出力
before_counts, n = _count_jcqa_labels_raw(data)
print(f"[jcqa] total jcqa_data examples: {n}")
print(f"[jcqa] label counts BEFORE rebalance: {dict(before_counts)}")
jcqa_indices = [i for i, item in enumerate(data) if "jcqa_data" in item]
n_jcqa = len(jcqa_indices)
if n_jcqa == 0:
print("[jcqa] No jcqa_data found, skip rebalance.")
return
# 目標件数:ほぼ N/5 ずつ
base = n_jcqa // 5
rem = n_jcqa % 5 # 余りは先頭から +1 して配る
target_counts = {str(l): base for l in range(5)}
for l in range(rem):
target_counts[str(l)] += 1
print(f"[jcqa] target label counts: {target_counts}")
# シャッフルした順番で、各サンプルにターゲットラベルを割り当てる
random.shuffle(jcqa_indices)
current_counts = {str(l): 0 for l in range(5)}
for idx in jcqa_indices:
item = data[idx]
jcqa = item["jcqa_data"]
# まだ目標に達していないラベルの中から選ぶ
available_labels = [
l for l in range(5)
if current_counts[str(l)] < target_counts[str(l)]
]
if not available_labels:
# すべて埋まっていたら、何か適当に
available_labels = list(range(5))
target_label = random.choice(available_labels)
current_counts[str(target_label)] += 1
# 1サンプルの choice 並びと answer_index を target_label に揃える
_rebalance_one_jcqa(jcqa, target_label)
print(f"[jcqa] label counts AFTER rebalance: {current_counts}")
def convert_jcommonsenseqa(item):
"""jcommonsenseqaデータをmessagesフォーマットに変換"""
jcqa = item.get('jcqa_data', {})
question = jcqa.get('question', '')
choices_lines = []
for i in range(5):
choice_key = f'choice{i}'
if choice_key in jcqa:
choices_lines.append(f" {i}. {jcqa[choice_key]}")
choices_text = '\n'.join(choices_lines)
user_content = (
"以下の質問に答えてください。\n\n"
f"質問:\n{question}\n\n"
"選択肢:\n"
f"{choices_text}\n\n"
"最も適切な選択肢の番号だけを、半角数字で1つ出力してください。"
)
ground_truth = jcqa.get('answer_index')
ground_truth_str = str(ground_truth) if ground_truth is not None else ""
assistant_content = f"<think></think>\n\n{ground_truth_str}"
return {
"messages": [
{"role": "user", "content": user_content},
{"role": "assistant", "content": assistant_content}
],
"extra_env_info": {
"dataset_type": "jcommonsenseqa",
"ground_truth": ground_truth_str
}
}
def count_labels_in_converted(data, name):
counts = collections.Counter()
total = 0
for ex in data:
gt = ex.get("extra_env_info", {}).get("ground_truth")
if gt is None:
continue
counts[gt] += 1
total += 1
print(f"[{name}] total examples: {total}")
print(f"[{name}] label counts: {dict(counts)}")
from collections import defaultdict
def convert_data(input_file_path, output_dir):
print(f"Loading {input_file_path}...")
data = load_jsonl(input_file_path)
print(f" Loaded {len(data)} examples")
# ★ jcommonsenseqa のラベル分布をリバランス(choice 並び+answer_index を書き換える)
print("\nRebalancing jcommonsenseqa label indices (0-4)...")
rebalance_jcqa_labels_in_place(data, seed=42)
# 変換
jcommonsenseqa_data = []
for idx, item in enumerate(data):
if 'jcqa_data' in item:
converted = convert_jcommonsenseqa(item)
jcommonsenseqa_data.append(converted)
print(f"\nJCommonsenseQA (converted): {len(jcommonsenseqa_data)} examples")
# 変換後のラベル分布も表示
count_labels_in_converted(jcommonsenseqa_data, "jcommonsenseqa_converted")
# 出力ディレクトリを作成
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# train/valid に分割
print("\nSplitting data into train/valid...")
jcommonsenseqa_train, jcommonsenseqa_valid = split_train_valid(jcommonsenseqa_data)
print(f" JCommonsenseQA: train={len(jcommonsenseqa_train)}, valid={len(jcommonsenseqa_valid)}")
# train/valid ごとのラベル分布も確認
count_labels_in_converted(jcommonsenseqa_train, "jcommonsenseqa_train")
count_labels_in_converted(jcommonsenseqa_valid, "jcommonsenseqa_valid")
# ファイルを保存
# JCommonsenseQA
jcommonsenseqa_train_file = output_path / "jcommonsenseqa_train.jsonl"
jcommonsenseqa_valid_file = output_path / "jcommonsenseqa_valid.jsonl"
print(f"\nSaving JCommonsenseQA train data to {jcommonsenseqa_train_file}...")
save_jsonl(jcommonsenseqa_train, jcommonsenseqa_train_file)
print(f"Saving JCommonsenseQA valid data to {jcommonsenseqa_valid_file}...")
save_jsonl(jcommonsenseqa_valid, jcommonsenseqa_valid_file)
print("\n✓ Done!")
return str(jcommonsenseqa_train_file)
if __name__ == "__main__":
base_dir = Path(__file__).parent
# 元のファイル名は環境に合わせて変更してください
input_file = base_dir / "with_seed_data.jsonl"
output_dir = base_dir / "simple_format_arrange"
print("=" * 60)
print("Converting to simple format with balanced jcommonsenseqa labels...")
print("=" * 60)
jcommonsenseqa_file = convert_data(
str(input_file),
str(output_dir)
)
print(f"\nOutput files:")
print(f" JCommonsenseQA: {jcommonsenseqa_file}")
SFT 用の SFT yaml 設定
標準で提供されている sft.yaml を変更して使用します。
SFT 関連のパラメーターとデータ パス、データ フォーマットを変更しました。データ フォーマットは OpenAI Format 形式で変換しているので、これを設定してます。 OpenAI Format とは、messages 配列で user / assistant の対話形式を表現するデータフォーマットです。主な変更部分をピックアップします。以下の checkpoint_dir、train_data_path、val_data_path は適宜変更してください。
# SFT Algorithm Configuration
sft:
## total number of steps to train will equal
## min((max_num_epochs * len(train_dataloader)), max_num_steps)
max_num_epochs: 2
max_num_steps: 55
val_period: 10
val_batches: 4
val_global_batch_size: 32
val_micro_batch_size: 1
val_at_start: true
seed: 42
checkpointing:
enabled: true
checkpoint_dir: {save model path}
metric_name: "val:val_loss" # one of "val:" or "train:" followed by the metric name
higher_is_better: false
keep_top_k: 3
save_period: 55
checkpoint_must_save_by: null
policy:
model_name: "nvidia/NVIDIA-Nemotron-Nano-9B-v2"
tokenizer:
name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default
# chat_template can be a Jinja template string or path to a .jinja file
chat_template: "{% for message in messages %}{%- if message['role'] == 'system' %}{{'Context: ' + message['content'].strip()}}{%- elif message['role'] == 'user' %}{{' Question: ' + message['content'].strip() + ' Answer:'}}{%- elif message['role'] == 'assistant' %}{{' ' + message['content'].strip()}}{%- endif %}{% endfor %}"
chat_template_kwargs: null # can be used to pass kwargs to the chat template, e.g., enable_thinking=true
train_global_batch_size: 256
train_micro_batch_size: 1
max_total_sequence_length: 2048
precision: "bfloat16"
offload_optimizer_for_logprob: false
dtensor_cfg:
enabled: true
env_vars: {}
cpu_offload: False
sequence_parallel: false
activation_checkpointing: false
tensor_parallel_size: 1
context_parallel_size: 1
custom_parallel_plan: null
dynamic_batching:
enabled: false
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
sequence_length_round: 64
sequence_packing:
enabled: False
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
algorithm: "modified_first_fit_decreasing"
sequence_length_round: 64
# makes the training sequence length divisible by the tensor parallel size
# this is useful for sequence parallel training
make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size}
max_grad_norm: 1.0
optimizer:
name: "torch.optim.AdamW"
kwargs:
lr: 3.0e-6
weight_decay: 0.05
betas: [0.9, 0.98]
eps: 1e-5
# when using Dtensor, we need to set foreach
# and fused to False
foreach: False
fused: False
## ignored since enabled=false, but needed for testing purposes
megatron_cfg:
enabled: false
env_vars: {}
empty_unused_memory_level: 1
activation_checkpointing: false
tensor_model_parallel_size: 1
expert_tensor_parallel_size: 1
expert_model_parallel_size: 1
pipeline_model_parallel_size: 1
context_parallel_size: 1
pipeline_dtype: ${policy.precision}
num_layers_in_first_pipeline_stage: null
num_layers_in_last_pipeline_stage: null
sequence_parallel: false
freeze_moe_router: false
moe_router_dtype: null
moe_router_load_balancing_type: "aux_loss"
moe_router_bias_update_rate: 1e-3
moe_permute_fusion: false
#gives ~20% training perf speedup with sequence packing
apply_rope_fusion: True
# gives ~25% training perf speedup with sequence packing and apply_rope_fusion
bias_activation_fusion: True
defer_fp32_logits: False
optimizer:
optimizer: "adam"
lr: 1.0e-6
min_lr: 4.9999e-6
weight_decay: 0.1
bf16: false
fp16: false
params_dtype: "float32"
#adam
adam_beta1: 0.9
adam_beta2: 0.98
adam_eps: 1e-5
#sgd
sgd_momentum: 0.9
#distributed optimizer
use_distributed_optimizer: true
use_precision_aware_optimizer: true
clip_grad: ${policy.max_grad_norm}
# optimizer cpu offload
optimizer_cpu_offload: false
optimizer_offload_fraction: 0.0
scheduler:
start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
weight_decay_incr_style: "constant"
lr_decay_style: "constant"
lr_decay_iters: 1000
lr_warmup_iters: 50
lr_warmup_init: 4.9999e-6
distributed_data_parallel_config:
grad_reduce_in_fp32: false
overlap_grad_reduce: true
overlap_param_gather: true
data_parallel_sharding_strategy: "optim_grads_params"
use_custom_fsdp: false
data:
max_input_seq_length: ${policy.max_total_sequence_length}
add_bos: true
add_eos: true
add_generation_prompt: false
shuffle: true
num_workers: 1
dataset_name: "openai_format"
train_data_path: {train synthesis data path}
val_data_path: {valid synthesis data path}
chat_key: "messages" # Key for messages in the data
system_key: null # Key for system message (optional)
system_prompt: null # Default system prompt (optional)
tool_key: "tools" # Key for tools in the data
use_preserving_dataset: false # If true, uses PreservingDataset to preserve heterogeneous schemas (e.g., tool calls with varying argument structures)
# You can use custom response datasets for training and validation. For example:
# data:
# dataset_name: ResponseDataset
# train_data_path: <PathToTrainingDataset> # e.g., /path/to/local/dataset.jsonl or hf_org/hf_dataset_name (HuggingFace)
# val_data_path: <PathToValidationDataset>
# input_key: <QuestionKey>, default is "input"
# output_key: <AnswerKey>, default is "output"
# train_split: <TrainSplit>, default is None # used for HuggingFace datasets
# val_split: <ValSplit>, default is None # used for HuggingFace datasets
# See https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/sft.md#datasets for more details.
## unused with openai_format dataset
prompt_file: null
split: null
output_key: null
seed: null
logger:
log_dir: "logs" # Base directory for all logs
wandb_enabled: true # Make sure you do a ``wandb login [Your API key]'' before running
tensorboard_enabled: true
mlflow_enabled: false
swanlab_enabled: false # Disable SwanLab logging
monitor_gpus: true # If true, will monitor GPU usage and log to wandb and/or tensorboard
wandb:
project: "sft-nemotron-nano-v2-9b"
name: "sft-nemotron-nano-v2-9b-${data.dataset_name}_with_seed_adjust_filter_2000_8k"
tensorboard:
log_dir: "tb_logs-sft-dev-${data.dataset_name}"
mlflow:
experiment_name: "sft-dev"
run_name: "sft-dev-${data.dataset_name}"
gpu_monitoring:
collection_interval: 10 # How often to collect GPU usage metrics (in seconds)
flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds)
cluster:
gpus_per_node: 8
num_nodes: 1
下記コマンドで実行します。本検証では、約 8,000 件の合成データを用い、2 エポックで SFT を行いました。
uv run python examples/run_sft.py --config {config yaml}
NVIDIA NeMo RL を用いた合成データの SFT 結果


学習データと検証データでも loss が下がっており、 SFT が上手く動作していることが確認できます。学習データと検証データの loss には著しい差がないことが見られます。
まとめ
本ブログでは、Nemotron-Personas-Japan を Seed とした高品質な合成データから、 NeMo RL による SFT を行いました。
次のステップとしては JCommonsenseQA の精度が合成データを使った SFT によって改善するか「マルチ LLM 対応の NVIDIA NIM による合成データ SFT (Seed あり / なし) の効果分析」で検証します。