Generative AI

Turbocharging Meta Llama 3 Performance with NVIDIA TensorRT-LLM and NVIDIA Triton Inference Server

Llama 3 Performance with NVIDIA TensorRT-LLM and NVIDIA Triton Inference Server

We’re excited to announce support for the Meta Llama 3 family of models in NVIDIA TensorRT-LLM, accelerating and optimizing your LLM inference performance. You can immediately try Llama 3 8B and Llama 3 70B—the first models in the series—through a browser user interface. Or, through API endpoints running on a fully accelerated NVIDIA stack from the NVIDIA API catalog, where Llama 3 is packaged as an NVIDIA NIM with a standard API that can be deployed anywhere.

Large language models are computationally intensive. Their size makes them expensive and slow to run, especially without the right techniques. Many optimization techniques are available, such as kernel fusion and quantization to runtime optimizations like C++ implementations, KV caching, continuous in-flight batching, and paged attention. Developers must decide which combination helps their use case. TensorRT-LLM simplifies this work.

TensorRT-LLM is an open-source library that accelerates inference performance on the latest LLMs on NVIDIA GPUs. NeMo, an end-to-end framework for building, customizing, and deploying generative AI applications, uses TensorRT-LLM and NVIDIA Triton Inference Server for generative AI deployments. 

TensorRT-LLM uses the NVIDIA TensorRT deep learning compiler. It includes the latest optimized kernels for cutting-edge implementations of FlashAttention and masked multi-head attention (MHA) for LLM model execution. It also consists of pre-and post-processing steps and multi-GPU/multi-node communication primitives in a simple, open-source Python API for groundbreaking LLM inference performance on GPUs.

To get a feel for the library and how to use it, let’s go over an example of how to use and deploy Llama 3 8B with TensorRT-LLM and Triton Inference Server.

For a more in-depth view—including different models, different optimizations, and multi-GPU execution—check out the full list of TensorRT-LLM examples

Getting started with installation

We’ll begin by cloning and building the TensorRT-LLM library by following OS-specific install instructions with the pip command. This is one of the easier ways to build TensorRT-LLM. Alternatively, the library can be installed by using a dockerfile to retrieve dependencies.

The following commands pull the open-source library and install the dependencies needed for installing TensorRT-LLM inside the container. 

git clone -b v0.8.0 https://github.com/NVIDIA/TensorRT-LLM.git
cd TensorRT-LLM

Retrieving the model weights

TensorRT-LLM is a library for LLM inference. To use it, you must supply a set of trained weights. A set of weights can either be pulled from repositories like the Hugging Face Hub or NVIDIA NGC. Another option is to use your own model weights trained in a framework like NeMo.

The commands in this post automatically pull the weights (and tokenizer files) for the instruction-tuned variant of the 8-billion-parameter Llama 3 model from the Hugging Face Hub. You can also download the weights to use offline with the following command and update the paths in later commands to point to this directory:

git lfs install
git clone https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct

Note that using this model is subject to a particular license. Agree to the terms and authenticate with HuggingFace to download the necessary files. 

Running the TensorRT-LLM container

We’ll launch a base docker container and install the dependencies required by TensorRT-LLM.

# Obtain and start the basic docker image environment.
docker run --rm --runtime=nvidia --gpus all --volume ${PWD}:/TensorRT-LLM --entrypoint /bin/bash -it --workdir /TensorRT-LLM nvidia/cuda:12.1.0-devel-ubuntu22.04

# Install dependencies, TensorRT-LLM requires Python 3.10
apt-get update && apt-get -y install python3.10 python3-pip openmpi-bin libopenmpi-dev

# Install the stable version (corresponding to the cloned branch) of TensorRT-LLM.
pip3 install tensorrt_llm==0.8.0 -U --extra-index-url https://pypi.nvidia.com

Compiling the model

The next step in the process is compiling the model into a TensorRT engine with model weights and a model definition written in the TensorRT-LLM Python API. 

The TensorRT-LLM repository includes several model architectures and we use the Llama model definition. For further details and more robust plug-ins and quantizations available, see this Llama example and the precision documentation.

# Log in to huggingface-cli
# You can get your token from huggingface.co/settings/token
huggingface-cli login --token *****

# Build the Llama 8B model using a single GPU and BF16.
python3 examples/llama/convert_checkpoint.py --model_dir ./Meta-Llama-3-8B-Instruct \
            --output_dir ./tllm_checkpoint_1gpu_bf16 \
            --dtype bfloat16

trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_bf16 \
            --output_dir ./tmp/llama/8B/trt_engines/bf16/1-gpu \
            --gpt_attention_plugin bfloat16 \
            --gemm_plugin bfloat16

When we create the model definition with the TensorRT-LLM API, we build a graph of operations from TensorRT primitives that form the layers of our neural network. These operations map to specific kernels that are prewritten programs for the GPU. 

The TensorRT compiler can sweep through the graph to choose the best kernel for each operation and each available GPU. It can also identify patterns in the graph where multiple operations are good candidates for merging into a single fused kernel, reducing the required amount of memory movement and the overhead of launching multiple GPU kernels. 

Additionally, TensorRT builds the graph of operations into an NVIDIA CUDA Graph that can be launched at the same time. This further reduces the overhead of launching kernels.

The TensorRT compiler is efficient at fusing layers and increasing execution speed, however, there are some complex layer fusions, such as FlashAttention that involve interleaving many operations together and which can’t be automatically discovered. For those, we can explicitly replace parts of the graph with plug-ins at compile time. In our example, we include the gpt_attention plug-in, which implements a FlashAttention-like fused attention kernel, and the gemm plug-in, which performs matrix multiplication with FP32 accumulation. We also call out our desired precision for the full model as FP16, matching the default precision of the weights we downloaded from HuggingFace.  

When we finish running the build script, we should expect to see the following three files in the /tmp/llama/8B/trt_engines/bf16/1-gpu folder:

  • rank0.engine is the main output of our build script, containing the executable graph of operations with the model weights embedded. 
  • config.json includes detailed information about the model, like its general structure and precision, as well as information about which plug-ins were incorporated into the engine. 

Running the model

So, now that we’ve got our model engine, what can we do with it?

The engine file contains the information to execute the model. TensorRT-LLM includes a highly optimized C++ runtime for executing engine files and managing processes like sampling tokens from the model output, managing the KV cache, and batching requests together. 

We can use the runtime directly to execute the model locally, or we can deploy using Triton Inference Server in the production environment to share the model with multiple users. 

To run the model locally, we can execute the following command:

python3 examples/run.py --engine_dir=./tmp/llama/8B/trt_engines/bf16/1-gpu --max_output_len 100 --tokenizer_dir ./Meta-Llama-3-8B-Instruct --input_text "How do I count to nine in French?"

Deploying with the Triton Inference Server

Beyond local execution, we can also use Triton Inference Server to create a production-ready deployment of our LLM. The Triton Inference Server backend for TensorRT-LLM uses the TensorRT-LLM C++ runtime for highly performant inference execution. It includes techniques like in-flight batching and paged KV caching that provide high throughput at low latency. TensorRT-LLM backend has been bundled with Triton Inference Server and is available as a pre-built container on NGC.

First, we must create a model repository so the Triton Inference Server can read the model and any associated metadata. 

The tensorrtllm_backend repository includes the setup of a required model repository under all_models/inflight_batcher_llm/ that we can replicate. 

In the directory are four subfolders holding artifacts for different parts of the model execution process. The preprocessing/ and postprocessing/ folders contain scripts for the Triton Inference Server python backend. These scripts are for tokenizing the text inputs and de-tokenizing the model outputs to convert between strings and the token IDs that the model operates on. 

The tensorrt_llm folder is where we’ll place the model engine we previously compiled. And finally, the ensemble folder defines a model ensemble that links the previous three components together and tells the Triton Inference Server how to flow data through them. 

Pull down the example model repository and copy the model you compiled in the previous step over to it.

# After exiting the TensorRT-LLM docker container
cd ..
git clone -b v0.8.0 https://github.com/triton-inference-server/tensorrtllm_backend.git
cd tensorrtllm_backend
cp ../TensorRT-LLM/tmp/llama/8B/trt_engines/bf16/1-gpu/* all_models/inflight_batcher_llm/tensorrt_llm/1/

Next, we must modify the configuration files from the repository skeleton with the location of the compiled model engine. We must also update configuration parameters such as tokenizer to use and handle memory allocation for the KV cache when batching requests for inference.

#Set the tokenizer_dir and engine_dir paths
HF_LLAMA_MODEL=TensorRT-LLM/Meta-Llama-3-8B-Instruct
ENGINE_PATH=tensorrtllm_backend/all_models/inflight_batcher_llm/tensorrt_llm/1

python3 tools/fill_template.py -i all_models/inflight_batcher_llm/preprocessing/config.pbtxt tokenizer_dir:${HF_LLAMA_MODEL},tokenizer_type:auto,triton_max_batch_size:64,preprocessing_instance_count:1

python3 tools/fill_template.py -i all_models/inflight_batcher_llm/postprocessing/config.pbtxt tokenizer_dir:${HF_LLAMA_MODEL},tokenizer_type:auto,triton_max_batch_size:64,postprocessing_instance_count:1

python3 tools/fill_template.py -i all_models/inflight_batcher_llm/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:64,decoupled_mode:False,bls_instance_count:1,accumulate_tokens:False

python3 tools/fill_template.py -i all_models/inflight_batcher_llm/ensemble/config.pbtxt triton_max_batch_size:64

python3 tools/fill_template.py -i all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt triton_max_batch_size:64,decoupled_mode:False,max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:0

Now, we can spin up our docker container and launch the Triton server. We must specify the world size—the number of GPUs the model was built for—and point to our model_repo that we just set up.

#Change to base working directory
cd..
docker run -it --rm --gpus all --network host --shm-size=1g \
-v $(pwd):/workspace \
--workdir /workspace \
nvcr.io/nvidia/tritonserver:24.03-trtllm-python-py3

# Log in to huggingface-cli to get tokenizer
huggingface-cli login --token *****

# Install python dependencies
pip install sentencepiece protobuf

# Launch Server

python3 tensorrtllm_backend/scripts/launch_triton_server.py --model_repo tensorrtllm_backend/all_models/inflight_batcher_llm --world_size 1

Sending requests

To send inference requests and receive completions from the running server, you can use one of the Triton Inference Server client libraries or send HTTP requests to the generated endpoint
Following curl command demonstrates a quick test for requesting completions by the running server and the more fully featured client script can be reviewed for communicating with the server.

curl -X POST localhost:8000/v2/models/ensemble/generate -d \
'{
"text_input": "How do I count to nine in French?",
"parameters": {
"max_tokens": 100,
"bad_words":[""],
"stop_words":[""]
}
}'

Conclusion

TensorRT-LLM provides tools for optimizing and efficiently running large language models on NVIDIA GPUs. Triton Inference Server is ideal for deploying and efficiently serving large language models such as Llama 3. 

Using this getting started guide and start your journey leveraging open-source tools to take use Llama 3 and many other large language models. 

NVIDIA AI Enterprise, an end-to-end AI software platform that includes TensorRT, will soon include TensorRT-LLM, for mission-critical AI inference with enterprise-grade security, stability, manageability, and support.

Resources for getting started 

  • Access the TensorRT-LLM open-source library.
  • Learn more about the NVIDIA NeMo open-source library. 
  • Read the Developer Guide for TensorRT and TensorRT-LLM.
  • Explore our sample code, benchmarks, and documentation on GitHub 
  • Learn about NVIDIA NIM on ai.nvidia.com, which leverages TensorRT-LLM for optimized inference.
Discuss (8)

Tags