Generative AI

Customizing Neural Machine Translation Models with NVIDIA NeMo, Part 2

Decorative image of a globe surrounded by people speaking and texting in different languages, with the text Part 2.

In the first post, we walked through the prerequisites for a neural machine translation example from English to Chinese, running the pretrained model with NeMo, and evaluating its performance. In this post, we walk you through curating a custom dataset and fine-tuning the model on that dataset. 

Custom data collection

Custom data collection is crucial in model fine-tuning because it enables a model to adapt to the specific requirement of a particular task or domain. 

For example, if the task is to translate computer science-related technical articles or blog posts from English to Chinese, collecting previously manually translated or reviewed blog post pairs as fine-tuning data is vital. Such articles contain concepts and terminologies that are commonly used in the computer science field but are rare in the pretraining dataset. 

We recommend collecting at least a few thousand high-quality samples. After fine-tuning with these tailored data, the model can perform much better in technical blog translation tasks.

Data preprocessing pipeline for fine-tuning

You need data preprocessing to filter out invalid and redundant data. The NVIDIA NeMo framework contains NVIDIA NeMo Curator for processing corpora used in LLM pretraining. However, the NMT parallel datasets are different from the corpora as they have source and target texts, which require special filtering methods. Fortunately, NeMo offers most of the out-of-the-box functions and scripts for them.

You can introduce a simple data preprocessing pipeline to clean English-Chinese parallel translations:

  • Language filtering
  • Length filtering
  • Deduplication
  • Tokenization and normalization (NeMo model only)
  • Converting to JSONL format (ALMA model only)
  • Splitting the datasets

You can also use additional preprocessing approaches to filter out invalid data. For example, using an existing translator to remove potential incorrect translations. For more information about other data preprocessing methods, see Machine Translation Models.

Original data format

We collected English-Chinese translation pairs in two text files: 

  • en_zh.en stores the English sentences separated by line.
  • en_zh.zh stores the corresponding Chinese translation in each line.
en_zh.enen_zh.zh
Line 1Accelerate Data Preparation加快完成数据准备
Line 2An end to end, cloud native, suite of AI and data analytics software, optimized, certified and supported by NVIDIA.一款经 NVIDIA 优化、认证和支持的端到端云原生 AI 和数据分析软件套件。
Table 1. Example correspondence between files

In this section, the data files are retained in the same format after each processing step. 

Language filtering

NeMo provides language ID filtering, which enables filtering out the training dataset data that isn’t in the correct language by using a pretrained language ID classifier from fastText

Download the lid.176.bin language classifier model from from the fastText website:

wget https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin -O lid.176.bin

Use the following code for language ID filtering:

python /opt/NeMo/scripts/neural_machine_translation/filter_langs_nmt.py \
--input-src en_zh.en \
--input-tgt en_zh.zh \
--output-src en_zh_preprocessed1.en \
--output-tgt en_zh_preprocessed1.zh \
--removed-src en_zh_garbage1.en \
--removed-tgt en_zh_garbage1.zh \
--source-lang en \
--target-lang zh \
--fasttext-model \
lid.176.bin

en_zh_preprocessed1.en and en_zh_preprocessed1.zh are the valid data pairs retained for the next step. Pairs in en_zh_garbage1.en and en_zh_garbage1.zh are discarded. 

Here are examples of filtered data in en_zh_garbage1.zh not in Chinese:

NVIDIA OVX
Tensor Core
PROP_2.USD
ConnectX-6 Dx、ConnectX-6

Length filtering

Length filtering removes sentences that are less than a minimum length or longer than a maximum length in a parallel corpus, as too short or too long translation may be noisy data. It also filters based on the length ratio between target and source sentences.

Before running length filtering, you can compute the length ratio in a certain percentile with the following script, so that you can have insight into the ratio threshold to be filtered:

def compute_length_ratio(source_txt, target_txt, percentile):
    len_ratios = list()

    with open(source_txt, "r") as src, \
        open(target_txt, "r") as tgt:
        for src_line, tgt_line in zip(src, tgt):
            len_ratios.append(len(src_line.strip()) / len(tgt_line.strip()))

    len_ratios.sort()

    # compute percentile
    ratio = len_ratios[int(len(len_ratios) * percentile)]
    print(f"Length ratio @{percentile} percentile is {ratio}")
    return ratio

compute_length_ratio("en_zh_preprocessed1.en", "en_zh_preprocessed1.zh", 0.95)

The length ratio varies on different datasets and the source and target languages as well.

Run the following script to perform length filtering. In this case --ratio 4.6 is used for the maximum length ratio. 

python /opt/NeMo/scripts/neural_machine_translation/length_ratio_filter.py \
    --input-src en_zh_preprocessed1.en \
    --input-tgt en_zh_preprocessed1.zh \
    --output-src en_zh_preprocessed2.en \
    --output-tgt en_zh_preprocessed2.zh \
    --removed-src en_zh_garbage2.en \
    --removed-tgt en_zh_garbage2.zh \
    --min-length 10 \
    --max-length 512 \
    --ratio 4.6

Similarly, en_zh_preprocessed2.en and en_zh_preprocessed2.zh are the pairs sent to the next step.

Deduplication

In this step, you remove any duplicated translation pairs by using the xxhash library.

pip install xxhash

Use the following Python script for deduplication:

import xxhash

def dedup_file(input_file_lang_1, input_file_lang_2, output_file_lang_1, output_file_lang_2):
    hashes = set()
    with open(input_file_lang_1, 'r') as f_lang1, \
        open(input_file_lang_2, 'r')  as f_lang2, \
        open(output_file_lang_1, 'w') as f_out_lang1, \
        open(output_file_lang_2, 'w') as f_out_lang2:
        for line_1, line_2 in zip(f_lang1, f_lang2):
            parallel_hash = xxhash.xxh64((line_1.strip()).encode('utf-8')).hexdigest()
            if parallel_hash not in hashes:
                hashes.add(parallel_hash)
                f_out_lang1.write(line_1.strip() + '\n')
                f_out_lang2.write(line_2.strip() + '\n')

dedup_file(
    'en_zh_preprocessed2.en',
    'en_zh_preprocessed2.zh',
    'en_zh_preprocessed3.en',
    'en_zh_preprocessed3.zh'
)

Again, en_zh_preprocessed3.en and en_zh_preprocessed3.zh are output files in this step.

Tokenization and normalization for the NeMo model

For fine-tuning the NeMo model, additional tokenization and normalization is required:

  • Normalization: Standardizes the punctuation in the sentence, such as quotes written in different ways. 
  • Tokenization: Splits the punctuation from its neighboring word by adding a space to avoid the punctuation attached to the word, which is a recommended step in NeMo training.
python /opt/NeMo/scripts/neural_machine_translation/preprocess_tokenization_normalization.py \
    --input-src en_zh_preprocessed3.en \
    --input-tgt en_zh_preprocessed3.zh \
    --output-src en_zh_final_nemo.en \
    --output-tgt en_zh_final_nemo.zh \
    --source-lang en \
    --target-lang zh

The script uses different libraries to process different languages. For example, it uses sacremoses for English and Jieba and OpenCC for simplified Chinese.

en_zh_final_nemo.en and en_zh_final_nemo.zh are the dataset files for NeMo fine-tuning.

Converting to JSONL format for the ALMA model

ALMA training requires JSON Lines (JSONL) as raw input data, where each line in the file is a single JSON structure:

{"translation": {"en": "Accelerate Data Preparation", "zh": "加快完成数据准备"}}

You must convert the en_zh_preprocessed3.en and en_zh_preprocessed3.zh parallel translation text files to the JSONL format:

import json

def txt2jsonl(source_txt, target_txt, source, target, output_jsonl):
  with open(source_txt, "r") as f:
      source_lines = f.read().splitlines()
 
  with open(target_txt, "r") as f:
      target_lines = f.read().splitlines()

  json_list = list()
  for source_text, target_text in zip(source_lines, target_lines):
      json_item = {
          "translation": {source: source_text, target: target_text}
      }
      json_list.append(json_item)

  with open(output_jsonl, "w") as f:
      for json_item in json_list:
          f.write(json.dumps(json_item, ensure_ascii=False) + "\n")

source_txt = "en_zh_preprocessed3.en"
target_txt = "en_zh_preprocessed3.zh"
source = "en"
target = "zh"
output_jsonl = "en_zh_final_alma.jsonl"
txt2jsonl(source_txt, target_txt, source, target, output_jsonl)

en_zh_final_alma.jsonl is the dataset file for ALMA training. 

Splitting datasets

The final step is to split the dataset into training, validation, and test sets. You could use the train_test_split method from the sklearn.model_selection package to do this. 

After splitting, you have the following files for NeMo fine-tuning, from the original en_zh_final_nemo.en and en_zh_final_nemo.zh:

  • en_zh_final_nemo_train.en
  • en_zh_final_nemo_train.zh
  • en_zh_final_nemo_val.en
  • en_zh_final_nemo_val.zh
  • en_zh_final_nemo_test.en
  • en_zh_final_nemo_test.zh

You also have the following files for ALMA fine-tuning from the original en_zh_final_alma.jsonl:

  • train.zh-en.json
  • valid.zh-en.json
  • test.zh-en.json

The output files are renamed to follow the ALMA fine-tuning convention.

Model fine-tuning

In this section, we demonstrate how to fine-tune NeMo and ALMA models separately.

Fine-tuning the NeMo NMT model

Before fine-tuning, ensure that you followed the instructions in the NMT Evaluation section and downloaded the NeMo pretrained model to /model/pretrained_ckpt/en_zh_24x6.nemo. You can finally fine-tune the NeMo EN-ZH model with your custom dataset. The batch size depends on the size of GPU memory.

python /opt/NeMo/examples/nlp/machine_translation/enc_dec_nmt_finetune.py \
  model_path=model/pretrained_ckpt/en_zh_24x6.nemo \
  trainer.devices=1 \
  trainer.max_epochs=10 \
  +trainer.val_check_interval=600 \
  model.train_ds.tgt_file_name=data/en_zh_final_nemo_train.zh \
  model.train_ds.src_file_name=data/en_zh_final_nemo_train.en \
  model.train_ds.tokens_in_batch=3000 \
  model.validation_ds.tgt_file_name=data/en_zh_final_nemo_val.zh \
  model.validation_ds.src_file_name=data/en_zh_final_nemo_val.en \
  model.validation_ds.tokens_in_batch=1000 \
  model.test_ds.tgt_file_name=data/en_zh_final_nemo_test.zh \
  model.test_ds.src_file_name=data/en_zh_final_nemo_test.en \
  +exp_manager.exp_dir=output/ \
  +exp_manager.create_checkpoint_callback=True \
  +exp_manager.checkpoint_callback_params.monitor=val_sacreBLEU \
  +exp_manager.checkpoint_callback_params.mode=max \
  +exp_manager.checkpoint_callback_params.save_best_model=true \
  +exp_manager.checkpoint_callback_params.always_save_nemo=true \
  +exp_manager.checkpoint_callback_params.save_top_k=10

After the training is completed, the results and checkpoints are saved to ./output/AAYNBaseFineTune path. You can use TensorBoard to visualize the loss curve or review the log files. 

Fine-tuning the ALMA NMT model

To train the ALMA Model, clone its repo from GitHub and install additional dependencies at the first step in the NeMo framework container.

git clone https://github.com/fe1ixxu/ALMA.git
cd ALMA
bash install_alma.sh
pip install --upgrade pytest

Data

You can place the fine-tuning datasets (train.zh-en.json, valid.zh-en.json, test.zh-en.json) in the /human_written_data/zhen data directory, where the /zhen subdirectory is used in this case for English/Chinese parallel datasets.

Configs

The next step is to modify the parameters in the /runs/parallel_ft_lora.sh and /configs/deepspeed_train_config.yaml files. 

Typical fields to be modified in /runs/parallel_ft_lora.sh:

  • per_device_train_batch_size: Training batch size.
  • gradient_accumulation_steps: Number of accumulated steps before gradient update.
  • learning_rate: Adjust according to the batch size and the number of accumulation steps.

Fields to modify in /configs/deepspeed_train_config.yaml:

  • gradient_accumulation_steps: Should match the value in /runs/parallel_ft_lora.sh.
  • num_processes: Number of GPU devices.

Fine-tuning command

To tune ALMA with LoRA for both English-to-Chinese and Chinese-to-English, run the following command in the ALMA repo’s root directory:

bash runs/parallel_ft_lora.sh output zh-en,en-zh

The results are stored in the /output directory:

  • adapter_config.json: LoRA config.
  • adapter_model.bin: LoRA weight.
  • trainer_state.json: Training losses.

Fine-tuned model evaluation

In the previous section, you evaluated the performance of the NeMo and ALMA pre-trained models without any modification. At this point, you can benchmark their fine-tuned models by running with the same test dataset again.

Fine-tuned NeMo model evaluation

Use the following script to run the same test dataset:

python /opt/NeMo/examples/nlp/machine_translation/nmt_transformer_infer.py \
  --model $fine_tuned_nemo_path \
  --srctext input_en.txt \
  --tgtout nemo_ft_out_zh.txt \
  --source_lang en \
  --target_lang zh \
  --batch_size 200 \
  --max_delta_length 20
sacrebleu reference.txt -i nemo_ft_out_zh.txt -m bleu -b -w 4
  • $fine_tuned_nemo_path: Fine-tuned NeMo model path.
  • input_en.txt: English text file.
  • nemo_ft_out_zh.txt: Output translated text file. 
  • reference.txt: Reference translation. 

The BLEU score is computed at the end.

Fine-tuned ALMA model evaluation 

Evaluating the ALMA model must run the inference manually on the same evaluation dataset. The following script is the inference code of the fine-tuned ALMA model, where the model-loading part is slightly different from the one discussed earlier, as it reads the fine-tuned PEFT model and config locally.

import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM
from transformers import LlamaTokenizer


# Load base model and LoRA weights
peft_config = PeftConfig.from_pretrained("./output")
model = AutoModelForCausalLM.from_pretrained(peft_config.base_model_name_or_path, torch_dtype=torch.float16, device_map="auto")
model = PeftModel.from_pretrained(model, "./output")
model = model.eval()
tokenizer = LlamaTokenizer.from_pretrained("haoranxu/ALMA-7B-Pretrain", padding_side='left')

# Add the source sentence into the prompt template
prompt_template = "Translate this from English to Chinese:\nEnglish: {}\nChinese:"
prompt = prompt_template.format("AI is powering change in every industry")


# Tokenize
input_ids = tokenizer(prompt, return_tensors="pt", padding=True, max_length=40, truncation=True).input_ids.cuda()
# Inference
with torch.no_grad():
    generated_ids = model.generate(input_ids=input_ids, num_beams=5, max_new_tokens=256, do_sample=True, temperature=0.6, top_p=0.9)

outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
output = outputs[0].replace(prompt, "").strip()
print(output)

You can modify the sample inference script to generate translations for the custom English text file and benchmark its BLEU score as described earlier.

Conclusion

The NeMo framework container provides a convenient environment for various inference and customization tasks:

In this series, we introduced the NeMo NMT model and the ALMA NMT model fine-tuning recipes from scratch in the container, covering pretrained model inference and evaluation, data collection and preprocessing, model fine-tuning, and final evaluation. 

For more information about building and deploying fully customizable multilingual conversational AI pipelines, see NVIDIA Riva. You can also learn more about real-time bilingual and multilingual speech-to-speech and speech-to-text translation APIs.

To get started with more LLM and distributed training tasks in the NeMo framework, explore the playbooks and developer documentation.

For more information about tackling data curation and evaluation tasks, see the recently released NeMo Curator and NVIDIA NeMo Evaluator microservices and apply for NeMo Microservices early access.

Discuss (0)

Tags