对话式人工智能

使用 NVIDIA NeMo 训练本地化多语种 LLM,第 2 部分

第一部分 中,我们讨论了如何训练单语分词器,并将其与预训练 LLM 的分词器合并,以形成多语言分词器。在本文中,我们将向您展示如何将自定义分词器集成到预训练 LLM,以及如何在 NVIDIA NeMo 中实现这一目标。

Diagram shows the process starting with the TH Wikipedia dataset and a GPT2 BPE pretrained tokenizer, merging with Megatron GPT-1.3B.nemo.
图 1.训练本地化多语种 LLM 的工作流程

准备工作

开始之前,请先导入以下库:

import torch 
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel 
from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder 
from omegaconf import OmegaConf

模型修改

合并后,组合分词器的词汇量大于 GPT-megatron-1.3 B 模型预训练分词器的词汇量。这意味着您必须扩展 GPT – megatron – 1.3 B 模型的嵌入层,以适应组合分词器 (图 2)。

Diagram shows the process to modify the model’s original embedding layer to accommodate a tokenizer with a larger vocabulary size.
图 2.扩展新分词器的模型嵌入层

关键步骤包括以下内容:

  • 使用所需增加的词汇量创建新的嵌入层。
  • 通过从原始嵌入层复制现有权重来初始化它。
  • 将新词表条目设置为零权重。

然后,此扩展嵌入层会替换预训练模型中的原始层,使其能够以新语言处理其他标记,同时保留在初始预训练过程中学习的知识。

加载并提取嵌入层

运行以下代码以加载 GPT-megatron-1.3 B.nemo 模型:

#Initialization
trainer_config = OmegaConf.load('/opt/NeMo/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml')
trainer_config.trainer.accelerator='gpu' if torch.cuda.is_available() else 'cpu'
trainer = MegatronTrainerBuilder(trainer_config).create_trainer()
#load gpt-megatron-1.3b.nemo and its config
nemo_model = MegatronGPTModel.restore_from('./path_to_1.3B_nemo_model',trainer=trainer)
nemo_config = OmegaConf.load('./path_to_1.3B_nemo_model_config.yaml')

加载模型后,您可以从模型中提取嵌入层的权重,即 state_dict 参数。该嵌入层用于生成新的嵌入层。

#Extract original embedding layer
embed_weight = nemo_model.state_dict()[f'model.language_model.embedding.word_embeddings.weight']
print(f"Shape of original embedding layer: {embed_weight.shape}")

生成新的嵌入层

现在,您必须计算新嵌入层和原始嵌入层之间的维度差异。根据此差异创建张量,并将其连接到原始嵌入层,以形成新的嵌入层。

差值基于以下内容计算得出:

  • 组合分词器词汇量
  • 原始嵌入层长度
  • 一个参数:model_config.yamlmodel.make_vocab_size_divisible_by

以下是差值的方程:

Diff = [\frac{N}{X}] \times X-O

  • 组合分词器词汇量 = N
  • 原始 NeMo 嵌入层长度 = O
  • model.make_vocab_size_divisible_by = X

机制是,为了更大限度地提高计算效率,您需要将嵌入层填充到可被 8 的倍数整除的数字上。给定 tokenizer 词汇量,模型预计会有一个带填充的嵌入层。

如果您从头开始训练,此过程应该是自动的,但在这种情况下,您必须手动填充新的嵌入层。

tokenizer = AutoTokenizer.from_pretrained('./path_to_new_merged_tokenizer')
if len(tokenizer)% nemo_config.make_vocab_size_divisible_by != 0:
  tokenizer_diff = (int(len(tokenizer)/nemo_config.make_vocab_size_divisible_by)+1) * nemo_config.make_vocab_size_divisible_by - embed_weight.shape[0]
else:
  tokenizer_diff = tokenizer.vocab_size - embed_weight.shape[0]

现在,您可以生成额外的张量作为新标记的初始权重。然后,此张量连接到之前提取的原始嵌入层,以形成新的嵌入层。

hidden_size = embed_weight.shape[1]
random_embed = torch.zeros((tokenizer_diff, hidden_size)).to('cuda')
new_embed_weight = torch.cat((embed_weight, random_embed), dim=0)

修改并输出新模型

在此步骤中,您将修改模型配置中与 tokenizer 相关的设置,以与新词表保持一致。回想一下,新嵌入的形状不同于原始嵌入层。如果直接替换原始模型中的嵌入层,您将遇到层大小不匹配错误。

加载包含更新的 tokenizer 配置的空模型实例,并将预训练模型的值分配给其 state_dict,同时添加新的嵌入层。

最后,以 .nemo 格式保存此修改后的模型,以便对扩展的词汇表进行持续预训练。

state_dict = nemo_model.state_dict()
state_dict[f'model.language_model.embedding.word_embeddings.weight'] = new_embed_weight

NEW_TOKENIZER_PATH = './path_to_new_merged_tokenizer'
nemo_config['tokenizer']['vocab_file'] = f"{NEW_TOKENIZER_PATH}/vocab.json"
nemo_config['tokenizer']['merge_file'] = f"{NEW_TOKENIZER_PATH}/merges.txt"
nemo_config['vocab_file'] = f"{NEW_TOKENIZER_PATH}/vocab.json"
nemo_config['merges_file'] = f"{NEW_TOKENIZER_PATH}/merges.txt"

new_nemo_model = MegatronGPTModel(nemo_config,trainer)
new_nemo_model.load_state_dict(state_dict)
new_nemo_model.save_to('./path_to_modified_nemo_model')

运行以下代码,检查新模型在英语提示下是否表现良好:

python /opt/NeMo/examples/nlp/language_modeling/megatron_gpt_eval.py \ gpt_model_file='./path_to_modified_nemo_model' \
prompts='ENTER YOUR PROMPT' \
inference.greedy=True \
inference.add_BOS=True \
trainer.devices=1 \
trainer.num_nodes=1 \
tensor_model_parallel_size=-1 \
pipeline_model_parallel_size=-1

数据预处理

为了确保数据的一致性,请重复运行数据预处理脚本以处理训练、验证和测试数据集。有关更多信息,请参阅 第 3 步:将数据拆分为训练、验证和测试

替换 --json_key 参数,用于指定数据集中文档文本的键值:

python /opt/NeMo/scripts/nlp_language_modeling/preprocess_data_for_megatron.py \ --input='./path_to_train/val/test_dataset' \
--json-keys=text \
--tokenizer-library=megatron \
--vocab './path_to_merged_tokenizer_vocab_file'\
--dataset-impl mmap \
--tokenizer-type GPT2BPETokenizer \
--merge-file './path_to_merged_tokenizer_merge_file' \
--append-eod \
--output-prefix='./path_to_output_preprocessed_dataset'

持续预训练

与您的模型相比,用于持续预训练的默认配置文件可能具有不同的模型配置。运行以下代码以覆盖这些配置。并相应地更新 tokenizer 和 data prefix 参数。

ori_conf = OmegaConf.load('./path_to_original_GPT-1.3B_model/model_config.yaml')
conf = OmegaConf.load('/opt/NeMo/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml')
for key in ori_conf.keys():
  conf['model'][key] = ori_conf[key]
# Set global_batch_size based on micro_batch_size
conf['model']["global_batch_size"] = conf['model']["micro_batch_size"] * conf.get('data_model_parallel_size',1) * conf.get('gradient_accumulation_steps',1)
# Reset data_prefix (dataset path)
conf['model']['data']['data_prefix'] = '???'
# Reset tokenizer config 

NEW_TOKENIZER_PATH = "./path_to_new_merged_tokenizer"
conf['model']['tokenizer']['vocab_file'] = f"{NEW_TOKENIZER_PATH}/vocab.json"
conf['model']['tokenizer']['merge_file'] = f"{NEW_TOKENIZER_PATH}/merges.txt"
conf['model']['vocab_file'] = f"{NEW_TOKENIZER_PATH}/vocab.json"
conf['model']['merges_file'] = f"{NEW_TOKENIZER_PATH}/merges.txt"
OmegaConf.save(config=conf,f='/opt/NeMo/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml')

运行以下代码以开始持续预训练。应根据您的特定硬件和设置修改以下参数:

  • nproc_per_node:每个节点上的 GPU 数量。
  • model.data.data_prefix:训练、验证和测试数据集的路径,请参阅代码示例以了解相关格式。
  • exp_manager.name:输出文件夹名称,该名称将用于保存中间检查点,在./nemo_experiments/文件夹中。
  • trainer.devices:每个节点的 GPU 设备数量。
  • trainer.num_nodes:表示模型训练的节点数。
  • trainer.val_check_interval:在训练过程中执行验证检查的频率(以步数为单位)。
  • trainer.max_steps:指定训练过程的最大步长。
  • model.tensor_model_parallel_size:对于 13B 模型,请继续使用1。对于更大的模型,请使用更大的尺寸。
  • model.pipeline_model_parallel_size:对于 13B 模型,建议保持为 1。对于更大的模型,建议使用更大的尺寸。
  • model.micro_batch_size:根据 GPU 的视频随机存取存储器(vRAM)大小进行调整。
  • model.global_batch_size:其值取决于micro_batch_size。欲了解更多信息,请参阅 批处理(Batching)
DATA = '{train:[1.0,training_data_indexed/train_text_document], validation:[training_data_indexed/val_text_document], test:[training_data_indexed/test_text_document]}'

!torchrun --nproc_per_node=1 \ /opt/NeMo/examples/nlp/language_modeling/megatron_gpt_continue_training.py \ 
"model.data.data_prefix={DATA}"\ 
name=megatron_gpt_ \ 
exp_manager.name=megatron_gpt_1 \ 
restore_from_path='./path_to_modified_nemo_model' \ 
trainer.devices=1 \ 
trainer.num_nodes=1 \ 
trainer.precision=16 \ 
trainer.val_check_interval=300 \ 
trainer.max_steps=1200 \ 
model.megatron_amp_O2=False \ 
model.tensor_model_parallel_size=1 \ 
model.pipeline_model_parallel_size=1 \ 
model.micro_batch_size=1 \ 
model.global_batch_size=1 \ 
++model.use_flash_attention=False \
++model.seq_len_interpolation_factor=null

模型推理

在训练期间,生成的中间文件将存储在./nemo_experiments 文件夹中。在这里,您应该可以找到所需的模型 Checkpoint 文件和 hparams.yaml 文件。

使用以下代码使用 Checkpoint 文件进行推理:

python /opt/NeMo/examples/nlp/language_modeling/megatron_gpt_eval.py \
'checkpoint_dir="./path_to_checkpoint_folder"' \
'checkpoint_name="name of checkpoint file in .ckpt format"' \
'hparams_file="./path_to_hparams_file"' \
prompts='ENTER YOUR PROMPT' \
inference.greedy=True \
inference.add_BOS=True \
trainer.devices=1 \
trainer.num_nodes=1 \
tensor_model_parallel_size=-1 \
pipeline_model_parallel_size=-1

表 1 比较了原始 GPT-megatron-1.3 B 模型和使用泰文维基百科数据训练的 GPT-megatron-1.3 B 模型生成的句子输出。在本文中,我们截断了一些重复输出标记。

提示 GPT-megatron-1.3 B.nemo 的输出 输出经过训练的 GPT-megatron-1.3 B-TH.nemo
泰国的省会城市是 泰国的省会城市为曼谷。泰国的省会城市为曼谷。n nHistory n n 泰国的省会城市原名 Chiang Mai,意思是“太阳之城”。泰国的省会城市原名 Chiang Mai,意思是“太阳之城”。  
     
表 1.句子输出对比

训练后,模型提高了对泰语的理解,尽管其英语性能降低。这是由于使用单语数据集持续预训练导致模型遗忘。为避免这种情况,我们建议使用包含英语和目标语言的语料库进行训练。

结束语

通过遵循此工作流程,您可以有效地扩展基础 LLM 的语言支持,使其能够理解并生成多种语言的内容。此方法使用在初始预训练期间学习的现有知识和表征,同时使模型能够通过持续学习适应和获得新的语言技能。

此过程的成功确实取决于用于分词器训练和持续预训练的目标语言数据的质量和数量。为了确保最佳性能,尤其是减轻灾难性的遗忘,仔细的训练课程和训练策略也是必要的。

要开始使用,请下载 NeMo 框架容器 或下载并设置 NVIDIA/NeMo 开源库 在 GitHub 上。您可以在低资源语言上使用自己的精选数据集,并按照本文中的步骤在基础语言模型(LLM)上添加所需的新语言支持。

作为 NeMo 微服务抢先体验,您还可以请求访问 NVIDIA NeMo 策展人NVIDIA NeMo 定制器 微服务。这些微服务共同简化了大语言模型(LLM)的数据管护和自定义,使您能够更快地将解决方案推向市场。

 

Tags