Generative AI

Curating Custom Datasets for LLM Parameter-Efficient Fine-Tuning with NVIDIA NeMo Curator

A data curator designed for dataset preparation and enhanced LLM performance.

In a recent post, we discussed how to use NVIDIA NeMo Curator to curate custom datasets for pretraining or continuous training use cases of large language models (LLMs) and small language models (SLMs). 

While such training scenarios are an important part of LLM development, many downstream applications involve fine-tuning existing foundation models on domain-specific datasets. This can be achieved using supervised fine-tuning (SFT) or parameter-efficient fine-tuning (PEFT) methods such as LoRA and p-tuning

In these workflows, you typically need to iterate quickly and experiment with various ideas and hyperparameter settings, as well as how the training data are processed and exposed to the model. You must process and curate multiple variants of your datasets to ensure effective learning with the nuances of your domain-specific data. 

Due to the limited amount of data available in such workflows, high-quality data curation using a flexible processing pipeline is crucial.

This post walks you through creating a custom data curation pipeline using NeMo Curator, focusing specifically on SFT and PEFT use cases. For more information about the basic building blocks that NeMo Curator provides, see Curating Custom Datasets for LLM Training with NVIDIA NeMo Curator.

Overview

For demonstration purposes, this post focuses on a toy example involving email classification. The goal is to curate a small text-based dataset, where each record consists of an email (subject and body) along with a predefined classification label for that email. 

We used the Enron emails dataset for this purpose, where each email is labeled into one of eight categories. This dataset is publicly available on HuggingFace and contains ~1,400 records. 

The data curation pipeline involves the following high-level steps:

  1. Define downloader, iterator, and extractor classes to convert the dataset into the JSONL format.
  2. Use existing tools to unify the Unicode representation.
  3. Define custom dataset filters to remove emails that are empty or too long.
  4. Redact all personally identifiable information (PII) from the dataset.
  5. Add instruction prompts to each record.
  6. Put the curation pipeline together.

The execution of this curation pipeline should take less than 5 minutes on consumer-grade hardware. To access the complete code for this tutorial, see the /NVIDIA/NeMo-Curator GitHub repo.

Prerequisites

Before you start, you must install the NeMo Curator framework. Follow the instructions in the NeMo Curator GitHub README file to install the framework. 

Next, run the following commands to verify the installation and install any additional dependencies:

$ python -c "import nemo_curator; print(nemo_curator);"
$ pip3 install requests

Defining custom document builders

The first step of curating a dataset is to implement the document builders that can download and iterate through the dataset.

Downloading the dataset

Implement the DocumentDownloader class, which takes the dataset’s URL and downloads it using the requests library.

import requests
from nemo_curator.download.doc_builder import DocumentDownloader

class EmailsDownloader(DocumentDownloader):
    def __init__(self, download_dir: str):
        super().__init__()

        if not os.path.isdir(download_dir):
            os.makedirs(download_dir)

        self._download_dir = download_dir
        print("Download directory: ", self._download_dir)

    def download(self, url: str) -> str:
        filename = os.path.basename(url)
        output_file = os.path.join(self._download_dir, filename)

        if os.path.exists(output_file):
            print(f"File '{output_file}' already exists, skipping download.")
            return output_file

        print(f"Downloading Enron emails dataset from '{url}'...")
        response = requests.get(url)

        with open(output_file, "wb") as file:
            file.write(response.content)

        return output_file

The downloaded dataset is a text file, and each entry roughly follows the following format:

“<s>[system instruction prompts]

Subject:: [email subject]
Body:: [email body]

[category label] <s>”

This format can be easily broken into its constituent parts using regular expressions. The key thing to remember is that entries are separated by sequences of  “<s> … <s>” and always begin with instruction prompts. Also, the sample delimiter tokens and the system prompt tokens are compatible with the Llama 2 family of tokenizers. 

As you might use this data with other tokenizers or models that don’t support special tokens, it’s best to discard these instructions and tokens during parsing. Later in this post, we show how instruction prompts or special tokens can be added to each entry using the NeMo Curator DocumentModifier utilities.

Parsing and iterating the dataset

Implement the DocumentIterator and DocumentExtractor classes to extract email subject, body, and category (class) labels:

from nemo_curator.download.doc_builder import (
    DocumentExtractor,
    DocumentIterator,
)

class EmailsIterator(DocumentIterator):

    def __init__(self):
        super().__init__()
        self._counter = -1
        self._extractor = EmailsExtractor()
        # The regular expression pattern to extract each email.
        self._pattern = re.compile(r"\"<s>.*?<s>\"", re.DOTALL)

    def iterate(self, file_path):
        self._counter = -1
        file_name = os.path.basename(file_path)

        with open(file_path, "r", encoding="utf-8") as file:
            lines = file.readlines()

        # Ignore the first line which contains the header.
        file_content = "".join(lines[1:])
        # Find all the emails in the file.
        it = self._pattern.finditer(file_content)

        for email in it:
            self._counter += 1
            content = email.group().strip('"').strip()
            meta = {
                "filename": file_name,
                "id": f"email-{self._counter}",
            }
            extracted_content = self._extractor.extract(content)

            # Skip if no content extracted
            if not extracted_content:
                continue

            record = {**meta, **extracted_content}
            yield record


class EmailsExtractor(DocumentExtractor):
    def __init__(self):
        super().__init__()
        # The regular expression pattern to extract subject/body/label into groups.
        self._pattern = re.compile(
            r"Subject:: (.*?)\nBody:: (.*?)\n.*\[/INST\] (.*?) <s>", re.DOTALL
        )

    def extract(self, content: str) -> Dict[str, str]:
        matches = self._pattern.findall(content)

        if not matches:
            return None

        matches = matches[0]

        return {
            "subject": matches[0].strip(),
            "body": matches[1].strip(),
            "category": matches[2].strip(),
        }

The iterator uses the regular expression \"<s>.*?<s>\" to find each sample. It then passes the string to the extractor, which uses the regular expression "Subject:: (.*?)\nBody:: (.*?)\n.*\[/INST\] (.*?) <s>". This expression uses the grouping operator (.*?) to extract the subject, body, and category. 

These extracted parts, along with useful metadata (such as a unique ID for each email) are stored in a dictionary and returned to the caller.

You are now ready to convert this dataset to the JSONL format, which is one of the many formats that NeMo Curator supports

Writing the dataset to the JSONL format

The dataset downloads as a plain text file. Implement the DocumentIterator and DocumentExtractor classes to iterate through records, convert them to the JSONL format, and store every record as a line in a file.

import json

def download_and_convert_to_jsonl() -> str:
    """
    Downloads the emails dataset and converts it to JSONL format.

    Returns:
        str: The path to the JSONL file.
    """

    # Download the dataset in raw format and convert it to JSONL.
    downloader = EmailsDownloader(DATA_DIR)
    output_path = os.path.join(DATA_DIR, "emails.jsonl")
    raw_fp = downloader.download(DATASET_URL)

    iterator = EmailsIterator()

    # Parse the raw data and write it to a JSONL file.
    with open(output_path, "w") as f:
        for record in iterator.iterate(raw_fp):
            json_record = json.dumps(record, ensure_ascii=False)
            f.write(json_record + "\n")

    return output_path

The information from each record in the dataset is written across multiple JSON fields:

  • subject
  • body
  • category 
  • Metadata:
    • id 
    • filename

This is necessary because many data curation operations in NeMo Curator must know which field in each record to operate on. This structure enables the convenient targeting of different dataset information for NeMo Curator operations.

Loading the dataset using the document builders

In NeMo Curator, datasets are represented as objects of type DocumentDataset. This provides helpers to load the datasets from disk in various formats. Use the following code to load the dataset and start working with it:

from nemo_curator.datasets import DocumentDataset
# define `filepath` to be the path to the JSONL file created above.
dataset = DocumentDataset.read_json(filepath, add_filename=True)

You now have everything needed to define a custom dataset curation pipeline and prepare the data.

Using existing tools to unify the Unicode formatting

It is often good practice to fix all Unicode issues in datasets as text scraped from online sources may contain inconsistencies or Unicode errors. 

To modify documents, NeMo Curator provides a DocumentModifier interface along with the Modify helper, which defines how the given text from each document should be modified. For more information about implementing your own custom document modifiers, see the Text cleaning and unification section in the previous post. 

In this example, it is sufficient to apply UnicodeReformatter to the dataset. As each record has multiple fields, apply the operation one time on every relevant field in the dataset. These operations can be chained together through the Sequential class:

Sequential([
    Modify(UnicodeReformatter(), text_field="subject"),
    Modify(UnicodeReformatter(), text_field="body"),
    Modify(UnicodeReformatter(), text_field="category"),
])

Designing custom dataset filters

In many PEFT use cases, refining the dataset involves filtering out records that may be irrelevant or low quality, or those with specific unsuitable attributes. In the email dataset, some emails are too long or empty. For demonstration purposes, remove all such records from the dataset by implementing custom DocumentFilter classes:

from nemo_curator.filters import DocumentFilter

class FilterEmailsWithLongBody(DocumentFilter):
    """
    If the email is too long, discard.
    """

    def __init__(self, max_length: int = 5000):
        super().__init__()
        self.max_length = max_length

    def score_document(self, text: str) -> bool:
        return len(text) <= self.max_length

    def keep_document(self, score) -> bool:
        return score

class FilterEmptyEmails(DocumentFilter):
    """
    Detects empty emails (either empty body, or labeled as empty).
    """

    def score_document(self, text: str) -> bool:
        return (
            not isinstance(text, str)  # The text is not a string
            or len(text.strip()) == 0  # The text is empty
            or "Empty message" in text  # The email is labeled as empty
        )

    def keep_document(self, score) -> bool:
        return score

The FilterEmailsWithLongBody class counts the number of characters in the provided text and returns True if the length is acceptable, or False otherwise. You must explicitly apply this filter on the body field for every record.

The FilterEmptyEmails class checks the type and the content of a given text to determine whether it signifies an empty email and returns True if the email is deemed to be empty, or False otherwise. You must explicitly apply this filter on all relevant fields: the subject, body, and category fields for every record. 

The returned value is consistent with the naming of the class, which promotes code readability. However, as the goal is to discard empty emails, the results from this filter must be inverted. In other words, discard the record if the filter returns True and keep the record if the filter returns False. This can be done by supplying the relevant flag to the ScoreFilter helper:

Sequential([
    # Apply only to the `body` field.
    ScoreFilter(FilterEmailsWithLongBody(), text_field="body", score_type=bool),
    # Apply to all fields, also invert the action.
    ScoreFilter(FilterEmptyEmails(), text_field="subject", score_type=bool, invert=True),
    ScoreFilter(FilterEmptyEmails(), text_field="body", score_type=bool, invert=True),
    ScoreFilter(FilterEmptyEmails(), text_field="category", score_type=bool, invert=True),
])

Specify the flag invert=True to instruct ScoreFilter to discard documents for which the filter returns True. By specifying score_type=bool, you explicitly specify the return type for each filter, which avoids type inferencing during execution.

Redacting all personally identifiable information

Next, define a processing step to redact all personally identifiable information (PII) from the subject and the body of each record. This dataset contains many instances of PII such as emails, phone or fax numbers, names, and addresses. 

NeMo Curator makes it easy to specify the type of PII to detect and what action to take for each detection. Replace every detection with special tokens:

def redact_pii(dataset: DocumentDataset, text_field) -> DocumentDataset:
    redactor = Modify(
        PiiModifier(
            supported_entities=[
                "ADDRESS",
                "EMAIL_ADDRESS",
                "LOCATION",
                "PERSON",
                "URL",
                "PHONE_NUMBER",
            ],
            anonymize_action="replace",
            device="cpu",
        ),
        text_field=text_field,
    )
    return redactor(dataset)

You can apply these operations to the subject and body fields separately using the Python functools.partial helper:

from functools import partial

redact_pii_subject = partial(redact_pii, text_field="subject")
redact_pii_body = partial(redact_pii, text_field="body")

Sequential([
    redact_pii_subject,
    redact_pii_body,
    ]
)

Adding instruction prompts

The last step of the data curation pipeline involves adding instruction prompts to every record and ensuring that every category value terminates with a period. These can be accomplished by implementing the relevant DocumentModifier classes:

from nemo_curator.modifiers import DocumentModifier

class AddSystemPrompt(DocumentModifier):
    def modify_document(self, text: str) -> str:
        return SYS_PROMPT_TEMPLATE % text


class AddPeriod(DocumentModifier):
    def modify_document(self, text: str) -> str:
        return text + "."

In the code example, the SYS_PROMPT_TEMPLATE variable contains a formatting string that can be used for adding instruction prompts around the text. These modifiers can be chained together:

Sequential([
    Modify(AddSystemPrompt(), text_field="body"),
    Modify(AddPeriod(), text_field="category"),
])

Putting the curation pipeline together

Having implemented each step of the curation pipeline, it’s time to put everything together and sequentially apply each operation on the dataset. You can use the Sequential class to chain curation operations together:

curation_steps = Sequential(
    [
        #
        # Unify the text encoding to Unicode.
        #
        Modify(UnicodeReformatter(), text_field="subject"),
        Modify(UnicodeReformatter(), text_field="body"),
        Modify(UnicodeReformatter(), text_field="category"),

        #
        # Filtering
        #
        ScoreFilter(
            FilterEmptyEmails(), text_field="subject", score_type=bool, invert=True
        ),
        ScoreFilter(
            FilterEmptyEmails(), text_field="body", score_type=bool, invert=True
        ),
        ScoreFilter(
            FilterEmptyEmails(), text_field="category", score_type=bool, invert=True
        ),
        ScoreFilter(FilterEmailsWithLongBody(), text_field="body", score_type=bool),

        #
        # Redact personally identifiable information (PII).
        #

        redact_pii_subject,
        redact_pii_body,

        #
        # Final modifications.
        #
        Modify(AddSystemPrompt(), text_field="body"),
        Modify(AddPeriod(), text_field="category"),
    ]
)

dataset = curation_steps(dataset)
dataset = dataset.persist()
dataset.to_json("/output/path", write_to_filename=True)

NeMo Curator uses Dask to work with the dataset in a distributed manner. As Dask operations are lazy-evaluated, you must call the .persist function to instruct Dask to apply the operations. After processing finishes, you can write the dataset to disk in the JSONL format by calling the .to_json function and providing an output path.

Next steps

This tutorial demonstrated how to create a custom data curation pipeline using NeMo Curator, focusing specifically on SFT and PEFT use cases. 

For easy access, we uploaded the tutorial to the /NVIDIA/NeMo-Curator GitHub repo. Star the repo to stay up-to-date with the latest developments and receive notifications about new features, bug fixes, and updates.

Now that you’ve curated the data, you can fine-tune an LLM, such as the Llama 2 model for email classification with LoRA. For more information, see the NeMo framework PEFT with Llama 2 playbook.

You can also request access to the NVIDIA NeMo Curator microservice, which provides the easiest path for enterprises to get started with data curation from anywhere. To apply, see NeMo Curator Microservice Early Access.

Discuss (0)

Tags