Meta’s Llama collection of open large language models (LLMs) continues to grow with the recent addition of Llama 3.3 70B, a text-only instruction-tuned model. Llama 3.3 provides enhanced performance respective to the older Llama 3.1 70B model and can even match the capabilities of the larger, more computationally expensive Llama 3.1 405B model on several tasks including math, reasoning, coding, and multilingual support.
NVIDIA TensorRT-LLM, a powerful inference engine that delivers state-of-the-art performance on the latest LLMs, incorporates many optimizations to deliver outstanding Llama 3.3 70B inference throughput. These include in-flight batching, KV caching, custom FP8 quantization, speculative decoding, and more for fast, cost-efficient LLM serving.
With in-flight batching activated by default as a runtime configuration parameter, TensorRT-LLM supports batching multiple different requests at the same time for higher serving throughput. By interleaving requests in context and generation phases, in-flight batching reduces latency and improves GPU utilization by executing new requests while older requests are still in flight. Finished requests are evicted from the batch, making room for the next set of requests.
Caching the values of the key-value elements of previous tokens saves from expensive recomputation of these tensors in the generation phase for the next set of tokens. Computational savings effectively lead to higher throughput. However, KV cache grows linearly in size with number of batched requests and sequence context lengths, leading to higher memory requirements.
TensorRT-LLM KV caching addresses these challenges through several optimizations, including support for paged KV cache, quantized KV cache, circular buffer KV cache and KV cache reuse. Each of these optimizations address the challenging balance between growing memory size and avoiding unnecessary and expensive recomputation.Speculative decoding is a popular technique for faster and cost-effective LLM inference with built-in verification for the quality of output generation. It’s based on the premise that generating multiple sequences of future (draft) tokens is more efficient than processing a single token in autoregressive decoding, an inherently time-consuming process. The target model determines how many of these draft tokens to accept, which is far more efficient than having to generate one token per iteration. TensorRT-LLM supports a growing list of speculative decoding techniques including draft target, Medusa, Eagle, and lookahead decoding, among others.
In this post, we show how the NVIDIA HGX H200 platform with NVLink and NVSwitch, as well as TensorRT-LLM, achieve great performance when running the latest Llama 3.3 70B model. We describe the step-by-step setup to get speculating decoding working for Llama 3.3 70B with TensorRT-LLM. For more information, including other optimizations, different models, and multi-GPU execution, see the full list of TensorRT-LLM examples.
Achieving throughput speedups with draft target speculative decoding
Table 1 and Figure 2 highlight the throughput (output tokens/second) speedups between no draft model (that is, no speculative decoding) versus draft models of various sizes with Llama 3.3 70B target model.
Throughput Performance – Output Tokens/Second One NVIDIA H200 Tensor Core GPU | ||||
Draft | Target Models | Llama 3.2 1B | Llama 3.3 70B | Llama 3.2 3B | Llama 3.3 70B | Llama 3.1 8B | Llama 3.3 70B | Llama 3.3 70B (without draft model) |
Tokens/sec | 181.74 | 161.53 | 134.38 | 51.14 |
Speedups (with versus without draft models) | 3.55x | 3.16x | 2.63x | N/A |
Data measured on December 11, 2024. Output tokens/second is inclusive of time to generate the first token – tok/s = total generated tokens / total latency. DGX H200, TP1, FP8, batch size=1, TensorRT Model Optimizer version 0.21, TensorRT-LLM version 0.15.0.
We provide the steps to reproduce these performance gains using draft target speculative decoding within TensorRT-LLM.
# Download the following model checkpoints from Hugging Face and store them
in a directory for easy access through the setup process.
git lfs install
# Download target models
git clone https://huggingface.co/meta-llama/Meta-Llama-3.3-70B-Instruct
# Download draft models
git clone https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct
git clone https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct
git clone https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct
After the model checkpoints have been downloaded, install TensorRT-LLM.
# Obtain and start the basic docker image environment (optional).
docker run --rm --ipc=host --runtime=nvidia --gpus all --entrypoint
/bin/bash -it nvidia/cuda:12.5.1-devel-ubuntu22.04
# Install dependencies, TensorRT-LLM requires Python 3.10
apt-get update && apt-get -y install python3.10 python3-pip openmpi-bin
libopenmpi-dev git git-lfs
# Fetch the library
git clone -b v0.15.0 https://github.com/NVIDIA/TensorRT-LLM.git
cd TensorRT-LLM
# Install the latest version (corresponding to the main branch) of TensorRT-LLM.
pip3 install tensorrt_llm -U --extra-index-url https://pypi.nvidia.com
# Check installation
python3 -c "import tensorrt_llm"
Next, compile the downloaded model checkpoints into draft and target TensorRT engines. These engines are optimized to run inference with best accuracy and highest throughput.
cd examples
# Steps to build target and draft models in FP8 precision on 1 H200
# Create FP8 checkpoints
python3 quantization/quantize.py --model_dir <path to draft model repo> --dtype float16 --qformat fp8 --kv_cache_dtype fp8
--output_dir /ckpt-draft --calib_size 512 --tp_size 1
python3 quantization/quantize.py \
--model_dir=<path to target model repo> \
--output_dir=./ckpt-target-70b \
--dtype=float16 --qformat fp8 --kv_cache_dtype fp8 \
--calib_size 512 --tp_size 1
# Build draft and target engines
# Important flags for the engine build process:
# --use_paged_context_fmha=enable must be specified since we need KVcache reuse for the draft/target model.
# --speculative_decoding_mode=draft_tokens_external and --max_draft_len must be specified for target model.
trtllm-build \
--checkpoint_dir ./ckpt-draft \
--output_dir=./draft-engine \
--gpt_attention_plugin float16 \
--workers 1 \
--gemm_plugin=fp8 \
--use_paged_context_fmha=enable \
--multiple_profiles enable \
--max_batch_size=32 \
--max_seq_len=131072
trtllm-build \
--checkpoint_dir=./ckpt-target-70b \
--output_dir=./target-engine \
--gpt_attention_plugin float16 \
--workers 1 \
--gemm_plugin=fp8 \
--use_paged_context_fmha=enable \
--multiple_profiles enable \
--max_batch_size=32 \
--max_seq_len=131072 \
--low_latency_gemm_plugin fp8 \
--speculative_decoding_mode=draft_tokens_external \
--max_draft_len 10
Finally, run speculative decoding in TensorRT-LLM.
#Run decoding
# Important flags to set during the run process:
#--draft_engine_dir and --engine_dir must be specified for the draft and target engines.
#--draft_target_model_config is corresponding to the configuration of
Draft-Target-Model. As an example, [4,[0],[1],False] means draft_len=4,
device of draft model is GPU0, device of target model is GPU1, and use
tokens rather than logits to accept.
# Only CPP session (using executor as low-level API) is supported, while
Python session (--use_py_session) is not supported.
# Run with Llama 3.3 70B target model
mpirun -n 1 --allow-run-as-root python3 ./run.py \
--tokenizer_dir <path to draft model repo> \
--draft_engine_dir ./draft-engine \
--engine_dir ./target-engine \
--draft_target_model_config = "[10,[0,1,2,3,4,5,6,7],[0,1,2,3,4,5,6,7], False]" \
--kv_cache_free_gpu_memory_fraction=0.35 \
--max_output_len=1024 \
--kv_cache_enable_block_reuse \
--input_text="<|begin_of_text|><|start_header_id|>user<|end_header_id|>\nA
3-digit integer contains one of each of the digits 1,3 and 5. What is the
probability that the integer is divisible by
5.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
# Following is the LLM-generated output:
Output [Text 0 Beam 0]: "## Step 1: Determine the total number of 3-digit
integers that can be formed using the digits 1, 3, and 5.
There are 3! = 6 ways to arrange the digits 1, 3, and 5 to form different 3-digit integers.
## Step 2: Identify the condition for a number to be divisible by 5.
A number is divisible by 5 if its last digit is either 0 or 5.
## Step 3: Determine the number of arrangements where 5 is the last digit.
Since the digit 5 must be the last digit for the number to be divisible by
5, we fix the last position with 5. The remaining two positions can be
filled with the digits 1 and 3 in 2! = 2 ways.
## Step 4: Calculate the probability that the integer is divisible by 5.
The probability is the number of favorable outcomes (arrangements where 5 is the last digit)
divided by the total number of possible outcomes (total arrangements of the digits 1, 3, and 5).
## Step 5: Calculate the probability.
Probability = (Number of favorable outcomes) / (Total number of outcomes) = 2 / 6 = 1/3.
The final answer is: $\boxed{\frac{1}{3}}$"
To benchmark throughput performance without speculative decoding, follow the steps below:
# Run throughput benchmark for the 70B model without the draft model
trtllm-build --checkpoint_dir ./ckpt-target-70b --output_dir /data/70B-TRT/
--gpt_attention_plugin float16 --workers 1 --max_batch_size 32
--max_seq_len 131072 --use_fused_mlp enable --reduce_fusion enable
--use_paged_context_fmha enable --multiple_profiles enable --gemm_plugin fp8
python3 /app/tensorrt_llm/benchmarks/cpp/prepare_dataset.py --output
token-norm-dist.json --tokenizer /llama-3_3-70b/ token-norm-dist
--num-requests 1000 --input-mean 500 --input-stdev 0 --output-mean 200
--output-stdev 0 > /tmp/synthetic.txt
trtllm-bench --model <path to target model repo> latency --engine_dir
/data/70b-TRT/ --dataset /tmp/synthetic.txt
Summary
NVIDIA collaborates with Meta for the creation, optimization, and acceleration of the world’s leading open models. NVIDIA supports Llama as part of our commitment to grow open community AI models and software to enable users to customize and address their own unique workloads. NVIDIA is involved with several open-source projects through partnering with developers, maintainers, and foundations.
NVIDIA TensorRT-LLM provides several features for optimizing and efficiently running LLMs of different model architectures. These optimizations lead to significant speedups on the same hardware, enable fewer resources to serve the same workload, reduce energy costs, and improve total cost of ownership. Available through production-ready deployments using NVIDIA NIM microservices, these TensorRT optimizations accelerate the deployment of your generative AI applications across NVIDIA-accelerated infrastructure anywhere, including cloud, data center, and workstations.