Generative AI / LLMs

Optimizing Inference on Large Language Models with NVIDIA TensorRT-LLM, Now Publicly Available

Stylized image of a workflow, with nodes labelled LLM, Optimize, and Deploy.

Today, NVIDIA announces the public release of TensorRT-LLM to accelerate and optimize inference performance for the latest LLMs on NVIDIA GPUs. This open-source library is now available for free on the /NVIDIA/TensorRT-LLM GitHub repo and as part of the NVIDIA NeMo framework.

Large language models (LLMs) have revolutionized the field of artificial intelligence and created entirely new ways of interacting with the digital world. But, as organizations and application developers around the world look to incorporate LLMs into their work, some of the challenges with running these models become apparent.

Put simply, LLMs are large. That fact can make them expensive and slow to run without the right techniques.

Many optimization techniques have risen to deal with this, from model optimizations like kernel fusion and quantization to runtime optimizations like C++ implementations, KV caching, continuous in-flight batching, and paged attention. It can be difficult to decide which of these are right for your use case, and to navigate the interactions between these techniques and their sometimes-incompatible implementations.

That’s why NVIDIA introduced TensorRT-LLM, a comprehensive library for compiling and optimizing LLMs for inference. TensorRT-LLM incorporates all of those optimizations and more while providing an intuitive Python API for defining and building new models.

The TensorRT-LLM open-source library accelerates inference performance on the latest LLMs on NVIDIA GPUs. It is used as the optimization backbone for LLM inference in NVIDIA NeMo, an end-to-end framework to build, customize, and deploy generative AI applications into production. NeMo provides complete containers, including TensorRT-LLM and NVIDIA Triton, for generative AI deployments.

TensorRT-LLM is also now available for native Windows as a beta release. Application developers and AI enthusiasts can now benefit from accelerated LLMs running locally on PCs and Workstations powered by NVIDIA RTX and NVIDIA GeForce RTX GPUs.

TensorRT-LLM wraps TensorRT’s deep learning compiler and includes the latest optimized kernels made for cutting-edge implementations of FlashAttention and masked multi-head attention (MHA) for LLM execution.

TensorRT-LLM also consists of pre– and post-processing steps and multi-GPU/multi-node communication primitives in a simple, open-source Python API for groundbreaking LLM inference performance on GPUs.

Highlights of TensorRT-LLM include the following:

  • Support for LLMs such as Llama 1 and 2, ChatGLM, Falcon, MPT, Baichuan, and Starcoder
  • In-flight batching and paged attention
  • Multi-GPU multi-node (MGMN) inference
  • NVIDIA Hopper transformer engine with FP8
  • Support for NVIDIA Ampere architecture, NVIDIA Ada Lovelace architecture, and NVIDIA Hopper GPUs
  • Native Windows support (beta)

Over the past 2 years, NVIDIA has been working closely with leading LLM companies, including Anyscale, Baichuan, Cohere, Deci, Grammarly, Meta, Mistral AI, MosaicML, now part of Databricks, OctoML, Perplexity AI, Tabnine, Together.ai, Zhipu, and many others to accelerate and optimize LLM inference.

To help you get a feel for the library and how to use it, here’s an example of how to use and deploy Llama 2, a popular publicly available LLM, with TensorRT-LLM and NVIDIA Triton on Linux. To get started with the beta release, see the TensorRT-LLM for native Windows GitHub repo.

For more information, including different models, different optimizations, and multi-GPU execution, see the full list of TensorRT-LLM examples.

Getting started with installation

Here are the steps to get started with TensorRT-LLM:

  • Retrieving the model weights
  • Installing the TensorRT-LLM library
  • Compiling the model
  • Running the model
  • Deploying with Triton Inference Server
  • Sending requests

Retrieving the model weights

TensorRT-LLM is a library for LLM inference, and so to use it, you must supply a set of trained weights. You can either use your own model weights trained in a framework like NVIDIA NeMo or pull a set of pretrained weights from repositories like the HuggingFace Hub.

The commands in this post automatically pull the weights and tokenizer files for the chat-tuned variant of the 7B parameter Llama 2 model from the HuggingFace Hub. You can also download the weights yourself to use offline with the following command. You just have to update the paths in later commands to point to this directory:

git lfs install
git clone https://huggingface.co/meta-llama/Llama-2-7b-chat-hf

Usage of this model is subject to a particular license. To download the necessary files, agree to the terms and authenticate with Hugging Face.

Installing the TensorRT-LLM Library

Launch a Docker container and install the TensorRT-LLM library.

# Obtain and start the basic docker image environment.
docker run --rm --runtime=nvidia --gpus all --entrypoint /bin/bash -it nvidia/cuda:12.1.0-devel-ubuntu22.04

# Install dependencies, TensorRT-LLM requires Python 3.10
apt-get update && apt-get -y install python3.10 python3-pip openmpi-bin libopenmpi-dev

# Install the latest preview version (corresponding to the main branch) of TensorRT-LLM.
# If you want to install the stable version (corresponding to the release branch), please
# remove the `--pre` option.
pip3 install tensorrt_llm -U --pre --extra-index-url https://pypi.nvidia.com

# Check installation
python3 -c "import tensorrt_llm"

Compiling the model

The next step in the process is to compile the model into a TensorRT engine. For this, you need the model weights as well as a model definition written in the TensorRT-LLM Python API.

The TensorRT-LLM repository contains a wide variety of predefined model architectures. For this post, you use the included Llama model definition instead of writing your own. This is a minimal example of some of the optimizations available in TensorRT-LLM.

For more information about available plug-ins and quantizations, see the full Llama example and Numerical Precision.

# Log in to huggingface-cli
# You can get your token from huggingface.co/settings/token
huggingface-cli login --token *****

# Build the LLaMA 7B model using a single GPU and BF16.
python convert_checkpoint.py --model_dir ./meta-llama/Llama-2-7b-chat-hf \
                              --output_dir ./tllm_checkpoint_1gpu_bf16 \
                              --dtype bfloat16

trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_bf16 \
            --output_dir ./tmp/llama/7B/trt_engines/bf16/1-gpu \
            --gpt_attention_plugin bfloat16 \
            --gemm_plugin bfloat16

When you create the model definition with the TensorRT-LLM API, you build a graph of operations from NVIDIA TensorRT primitives that form the layers of your neural network. These operations map to specific kernels: prewritten programs for the GPU.

The TensorRT compiler can sweep through the graph to choose the best kernel for each operation and available GPU. Crucially, it can also identify patterns in the graph where multiple operations are good candidates for being fused into a single kernel. This reduces the required amount of memory movement and the overhead of launching multiple GPU kernels.

TensorRT also compiles the graph of operations into a single CUDA Graph that can be launched all at one time, further reducing the kernel launch overhead.

The TensorRT compiler is extremely powerful for fusing layers and increasing execution speed, but there are some complex layer fusions—like FlashAttention—that involve interleaving many operations together and which can’t be automatically discovered. For those, you can explicitly replace parts of the graph with plugins at compile time.

In this example, you include the gpt_attention plug-in, which implements a FlashAttention-like fused attention kernel, and the gemm plug-in, which performs matrix multiplication with FP32 accumulation. You also call out your desired precision for the full model as FP16, matching the default precision of the weights that you downloaded from HuggingFace.

Here’s what this script produces when you finish running it. In the /tmp/llama/7B/trt_engines/bf16/1-gpu folder, there are now the following files:

  • Llama_float16_tp1_rank0.engine: The main output of the build script, containing the executable graph of operations with the model weights embedded.
  • config.json: Includes detailed information about the model, like its general structure and precision, as well as information about which plug-ins were incorporated into the engine.
  • model.cache: Caches some of the timing and optimization information from model compilation, making successive builds quicker.

Running the model

So, now that you’ve got your model engine, what can you do with it?

The engine file contains the information that you need for executing the model, but LLM usage in practice requires much more than a single forward pass through the model. TensorRT-LLM includes a highly optimized C++ runtime for executing built LLM engines and managing processes like sampling tokens from the model output, managing the KV cache, and batching requests together.

You can use that runtime directly to execute the model locally, or you can use the TensorRT-LLM runtime backend for NVIDIA Triton Inference Server to serve the model for multiple users.

To run the model locally, execute the following command:

python3 examples/llama/run.py 
   --engine_dir=./tmp/llama/7B/trt_engines/bf16/1-gpu 
   --max_output_len 100 
   --tokenizer_dir meta-llama/Llama-2-7b-chat-hf 
   --input_text "How do I count to nine in French?

Deploying with Triton Inference Server

Beyond local execution, you can also use the NVIDIA Triton Inference Server to create a production-ready deployment of your LLM.

NVIDIA is releasing a new Triton Inference Server backend for TensorRT-LLM that leverages the TensorRT-LLM C++ runtime for rapid inference execution and includes techniques like in-flight batching and paged KV-caching. Triton Inference Server with the TensorRT-LLM backend is available as a pre-built container through NGC.

First, create a model repository so that Triton Inference Server can read the model and any associated metadata. The tensorrtllm_backend repository includes the skeleton of an appropriate model repository under all_models/inflight_batcher_llm/ that you can use. In that directory are now four subfolders that hold artifacts for different parts of the model execution process:

  • /preprocessing and /postprocessing: Contain scripts for the Triton Inference Server Python backend for tokenizing the text inputs and detokenizing the model outputs to convert between strings and the token IDs on which the model operates.
  • /tensorrt_llm: Where you place the model engine that you previously compiled.
  • /ensemble: Defines a model ensemble that links the previous three components together and tells Triton Inference Server how to flow data through them.

Pull down the example model repository and copy the model you compiled in the previous step over to it:

# After exiting the TensorRT-LLM Docker container
cd ..
git clone -b release/0.8.0 https://github.com/triton-inference-server/tensorrtllm_backend.git
cd tensorrtllm_backend
cp ../TensorRT-LLM/examples/llama/out/*   all_models/inflight_batcher_llm/tensorrt_llm/1/

Next, modify some of the configuration files from the repository skeleton with information like the following:

  • Where the compiled model engine is
  • What tokenizer to use
  • How to handle memory allocation for the KV cache when performing inference in batches
python3 tools/fill_template.py --in_place \
      all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt \
      decoupled_mode:true,engine_dir:/all_models/inflight_batcher_llm/tensorrt_llm/1,\
max_tokens_in_paged_kv_cache:,batch_scheduler_policy:guaranteed_completion,kv_cache_free_gpu_mem_fraction:0.2,\
max_num_sequences:4

python tools/fill_template.py --in_place \
    all_models/inflight_batcher_llm/preprocessing/config.pbtxt \
    tokenizer_type:llama,tokenizer_dir:meta-llama/Llama-2-7b-chat-hf

python tools/fill_template.py --in_place \
    all_models/inflight_batcher_llm/postprocessing/config.pbtxt \
    tokenizer_type:llama,tokenizer_dir:meta-llama/Llama-2-7b-chat-hf

Now, spin up the Docker container and launch the Triton server. Specify world_size, which is the number of GPUs the model was built for, and point to the model_repo just set up.

docker run -it --rm --gpus all --network host --shm-size=1g \
     -v $(pwd)/all_models:/all_models \
     -v $(pwd)/scripts:/opt/scripts \
     nvcr.io/nvidia/tritonserver:23.10-trtllm-python-py3

# Log in to huggingface-cli to get tokenizer
huggingface-cli login --token *****

# Install python dependencies
pip install sentencepiece protobuf

# Launch Server
python /opt/scripts/launch_triton_server.py --model_repo /all_models/inflight_batcher_llm --world_size 1

Sending requests

To send requests to and interact with the running server, you can use one of the Triton Inference Server client libraries or send HTTP requests to the generate endpoint. To get started, you can use the more fully featured client script or the following curl command:

curl -X POST localhost:8000/v2/models/ensemble/generate -d \
'{
"text_input": "How do I count to nine in French?",
"parameters": {
"max_tokens": 100,
"bad_words":[""],
"stop_words":[""]
}
}'

Conclusion

Together, TensorRT-LLM and Triton Inference Server provide an indispensable toolkit for optimizing, deploying, and running LLMs efficiently. With the release of TensorRT-LLM as an open-source library on GitHub, it’s easier than ever for organizations and application developers to harness the potential of these models.

If you’re eager to dive into the world of LLMs, now is the time to get started with TensorRT-LLM. Explore its capabilities, experiment with different models and optimizations, and embark on your journey to unlock the incredible power of AI-driven language models. I can’t wait to see the incredible things you all will build.

For more information about getting started with TensorRT-LLM, see the following resources:

Discuss (8)

Tags