DriveOS LLM SDK: TensorRT’s Large Language Model Inference Framework for Auto Platforms#

Introduction#

DriveOS LLM SDK is a light-weight, C++ software toolkit to showcase TensorRT’s capability and performance to deploy Large Language Models(LLMs) and Vision Language Models(VLMs) targeted Auto Platform. With DriveOS LLM SDK, users can:

  1. Quantize and export PyTorch model to ONNX Format on Linux x86 system.

  2. Build TensorRT Engine and run e2e LLM inference, including tokenization and sampling, on Auto Platform.

Prerequisite#

A Linux X86 host with GPU is required to export the model into ONNX format. Once the ONNX model is exported, the only dependency is TensorRT C++ library and CUDA runtime. DriveOS LLM SDK does not have any external C++ dependencies.

Supported Platforms, Models, and Precisions#

The following LLM models under ./examples/llm with corresponding precisions are supported by DriveOS LLM SDK:

Model

FP16

INT4

FP8

NVFP4

Llama3-8b-instruct

Yes

Yes

Yes

Yes

Llama3.1-8B

Yes

Yes

Yes

Yes

Llama3.2-3B

Yes

Yes

Yes

Yes

Qwen2.5-7B-instruct

Yes

Yes

Yes

Yes

Qwen2-7B-instruct

Yes

Yes

Yes

Yes

Qwen2.5-0.5B

Yes

Yes

Yes

Yes

The following VLM models under ./examples/vlm with corresponding precisions are supported by DriveOS LLM SDK. Note that the ViT will always be in FP16 precision.

Model

FP16

INT4

FP8

NVFP4

Qwen2-VL-2B-instruct

Yes

Yes

Yes

Yes

Qwen2-VL-7B-instruct

Yes

Yes

Yes

Yes

Precisions Explained and Notes#

  • FP16: All the weights and compute are in FP16.

  • FP8(W8A8): All the weights and GEMMs are in FP8, but KV Cache, LayerNorm, Attention and lm_head are in FP16 precision. FP8 can both reduce memory footprint and improve inference latency.

  • INT4(W4A16): All the weights are quantized in INT4 using awq recipe, but all the compute are in FP16 precision. INT4 can reduce memory footprint and improve significantly by reducing weights loading by 4x compared to FP16. Note that because TensorRT native(Or out-of-the-box or ootb) INT4 kernels have some performance issues, a int4GroupwiseGemmPlugin is provided as the default option for INT4.

  • NVFP4(W4A4): Similar to FP8, all the weights and GEMMs are in NVFP4 while the other parts are in FP16 precision. NVFP4 can significantly reduce memory footprint and improve inference latency, especially context phase. Current generation phase (mainly GEMV) performance of NVFP4 is good but has room for improvements. The improvements will be shipped in the next few releases.

Customized Models#

  • Decoder-only Llama series and Qwen series are likely to be supported if it fits in Thor memory, but they are not fully tested. For VLM, only Qwen2-VL is likely to be supported.

  • Other model series will likely not be supported due to the Tokenizer implementation and model architecture difference.

Other Platforms#

  • DriveOS LLM SDK does not support TensorRT 8.6 and therefore cannot be compatible with Orin and any DriveOS 6.

  • As a preview feature, DriveOS LLM SDK can also run on x86 Linux Data Center or Gaming GPUs with SM80, SM86 or SM89 in FP16 and INT4. SM89 also supports FP8 E2E inference. However, no model support is guaranteed and the resulting performance will not be comparable to TensorRT-LLM. It is recommended to deploy for all Data Center use cases.

Getting Started#

Build the C++ Project#

In the NVIDIA DriveOS container, the LLM SDK is located in the /drive/extra/driveos_llm_sdk-0.0.3.4.tar.gz archive. DRIVEInstaller is not used to install the LLM SDK; instead, you must extract the tarball and build the project:

tar -xzvf driveos_llm_sdk-0.0.3.4.tar.gz
cd driveos_llm_sdk
mkdir build

Build the C++ project on the x86 Linux host with cross build:

cd build
cmake .. -DTRT_PACKAGE_DIR=/usr/aarch64-linux-gnu -DCMAKE_TOOLCHAIN_FILE=cmake/aarch64_cross_toolchain.cmake -DAUTO_TARGET=thor
make

Build the C++ project on Thor with native build:

cd build
cmake .. -DTRT_PACKAGE_DIR=/usr/aarch64-linux-gnu -DCMAKE_TOOLCHAIN_FILE=cmake/aarch64_native_toolchain.cmake -DAUTO_TARGET=thor
make

To build and run DriveOS LLM SDK in x86 machines for rapid development, -DCMAKE_TOOCHAIN_FILE and -DAUTO_TARGET are not needed. The binaries are generated in the examples folder to be used later. The AttentionPlugin library will also be there in libAttentionPlugin.so.

Export ONNX from PyTorch Checkpoint#

First, it is needed to export the PyTorch model to ONNX on a x86 Linux host with GPU. If quantization is needed, it is recommended (or even required) to use a Data Center GPU like H100. Once the ONNX model is available, no Python will be needed.

Build Engine and Run E2E LLM Inference on C++#

Once the model is exported, you can follow the examples to build and run E2E LLM inference with C++.

DriveOS LLM SDK ONNX Exporter#

This folder contains script to export ONNX model from PyTorch model. The exported ONNX model follows the format required by DriveOS LLM SDK runtime, so it can later be converted into TensorRT engine for E2E LLM inference application on Auto platform.

Prerequisite#

Since ONNX export is platform agnostic, it is strongly recommended to run the script in Linux x86 platform with Ampere or above GPUs. Even though FP8 deployment only works with Ada and above, and NVFP4 deployment only works with Blackwell and above, the simulated quantization script can be run on any GPU. To avoid OOM during quantization and ONNX export, it is recommended to run the quantization on GPUs with 80GB memory to avoid OOM.

Usage#

  1. Download HF checkpoint from transformers and save it locally.

  2. Run cd export, then run pip3 install -r requirements.txt.

  3. If you are working with NVFP4, you need to uninstall onnx and nvidia-modelopt using pip3 uninstall onnx and pip3 uninstall nvidia-modelopt, and then install pip3 install -r requirements_nvfp4.txt.

  4. Call the export script.

python3 llm_export.py --torch_dir $TORCH_DIR --dtype [fp16|fp8|int4|nvfp4|int4_ootb] --output_dir $ONNX_DIR
python3 multimodal_export.py --torch_dir $TORCH_DIR --dtype [fp16|fp8|int4|nvfp4|int4_ootb] --output_dir $ONNX_DIR

The ONNX with desired data type will be exported in $ONNX_DIR.

Notes:#

  1. TensorRT Out-of-the-box(OOTB) has a known performance issue with INT4 GEMV. Even though the accuracy is good, the performance is not as desired. Therefore, a Int4GroupwiseGemmPlugin is written and the dq+gemms are replaced by the plugin as a temporary solution for now for int4 by default. If you do not want to use this plugin, you can pass in int4_ootb as the datatype for export script.

  2. Even though FP8 or NVFP4 is supported for ONNX export, Orin does not support FP8 or NVFP4.

  3. Pass in --keep_original to save the original exported ONNX in the ${ONNX_DIR}_raw folder. For FP16 and INT4, this is FP16 onnx, while for FP8 or NVFP4 this will be FP8 or NVFP4 onnx with FP32 weight storage. This ONNX can be reused by passing in --onnx_path to save ONNX export time.

  4. Pass --dataset_dir to skip downloading quantization calibration dataset.

  5. Default --max_seq_length=4096. Please change this field if other sequence length is required.

DriveOS LLM SDK Example: Decoder-only Language Models#

Prerequisite#

An ONNX that complies with the DriveOS LLM SDK runtime should be ready following ONNX export. To run inference with real data, a tokenizer file is also required.

Build Engine#

The llm_build binary is used to build the TensorRT engine. All the ONNX have the same IO name and data type, so the building process is agnostic for all ONNX independent of model and precision.

Example command:

./build/examples/llm/llm_build --onnxPath=llama3_fp16/model.onnx --enginePath=llama3_fp16.engine --batchSize=1 --maxInputLen=128 --maxSeqLen=4096

Notes:#

  1. --maxSeqLen includes --maxInputLen, so it must be greater than --maxInputLen. The maximum new token would equal to maxSeqLen - maxInputLen.

  2. Notice that maxSeqLen must be identical to kv_cache_capacity field of the ONNX AttentionPlugin node. This field can be adjusted using --max_seq_len during ONNX export.

  3. We can support static multi-batch batchSize < max_batch_size field of ONNX AttentionPlugin node.

Infer Engine#

The llm_benchmark, llm_accuracy and llm_chat binaries are examples to show E2E C++ LLM inference using greedy decoding. Example usages:

Interactive Chat#

./build/examples/llm/llm_chat --tokenizerPath=llama-v3-8b-instruct-hf/ --enginePath=llama3_fp16.engine --maxLength=64

Note:#

  1. Chat will prompt for each batch until it has input prompt for all batches.

Benchmark Performance#

./build/examples/llm/llm_benchmark --enginePath=llama3_fp16.engine --maxLength=256 --inputLength=24 [--warmUp 2 --numRuns 10]

Evaluate with MMLU#

To run MMLU accuracy evaluation, MMLU dataset is required.

wget https://people.eecs.berkeley.edu/~hendrycks/data.tar
tar -xf data.tar
./build/examples/llm/llm_accuracy --tokenizerPath=llama-v3-8b-instruct-hf/  --enginePath=llama3_fp16.engine --datasetPath data

Python reference

python scripts/mmlu.py

DriveOS LLM SDK Example: Multimodal Models#

This document shows how to run multimodal pipelines with DrivsOS LLM SDK, for example from image+text input modalities to text output.

Multimodal models’ LLM part and multimodal part are separated to two TensorRT engines. While LLM part is similar to LLM-only models, multimodal part is model-specific. Multimodal runner combines the two parts together. The multimodal features of shape [batch_size, num_multimodal_features, multimodal_hidden_dim] is flattened as [batch_size * num_multimodal_features, multimodal_hidden_dim] and passed like a prompt embedding table together with other model specific inputs.

We describe how to run supported models in the next section.

Qwen2-VL#

Prerequisite#

  1. Download Huggingface weights.

    git lfs install
    export MODEL_NAME="Qwen2-VL-7B-Instruct" # or Qwen2-VL-2B-Instruct
    git clone https://huggingface.co/Qwen/${MODEL_NAME} tmp/hf_models/${MODEL_NAME}
    
  2. Export to ONNX.

    An ONNX that complies with the DriveOS LLM SDK runtime should be ready following ONNX export. To run inference with real data, a tokenizer file is also required. Visual and LLM part is exported to two separate ONNX files.

    python3 ../../export/multimodal_export.py \
    --torch_dir tmp/hf_models/${MODEL_NAME} \
    --output_dir tmp/onnx/${MODEL_NAME} \
    --dtype [fp16|fp8|int4|nvfp4] \
    --model_type qwen2_vl
    

Build Engine#

The vlm_build binary is used to build TensorRT engines. Corresponding to ONNX, we build visual engine and LLM engine respectively.

  1. Static shape. Specify --batchSize and --imageTokens.

    ./build/examples/multimoal/vlm_build \
    --llmOnnxPath=tmp/onnx/${MODEL_NAME}/llm_onnx/model.onnx \
    --llmEnginePath=tmp/trt_engines/${MODEL_NAME}/llm.engine \
    --visualOnnxPath=tmp/onnx/${MODEL_NAME}/visual_enc_onnx/model.onnx \
    --visualEnginePath=tmp/trt_engines/${MODEL_NAME}/visual_enc_fp16.engine \
    --modelType="qwen2_vl" \
    --maxInputLen=1024 --maxSeqLen=4096 \
    --batchSize=1 --imageTokens=512
    
  2. Dynamic shape. Specify --maxBatchSize, --minImageTokens and --minImageTokens.

    ./build/examples/multimoal/vlm_build \
    --llmOnnxPath=tmp/onnx/${MODEL_NAME}/llm_onnx/model.onnx \
    --llmEnginePath=tmp/trt_engines/${MODEL_NAME}/llm.engine \
    --visualOnnxPath=tmp/onnx/${MODEL_NAME}/visual_enc_onnx/model.onnx \
    --visualEnginePath=tmp/trt_engines/${MODEL_NAME}/visual_enc_fp16.engine \
    --modelType="qwen2_vl" \
    --maxInputLen=1024 --maxSeqLen=4096 \
    --dynamicShape \
    --maxBatchSize=2 --minImageTokens=1280 --maxImageTokens=6620
    

Infer engine#

The vlm_chat and vlm_accuracy binaries are examples to show E2E C++ VLM inference using greedy decoding. Example usages:

VLM Chat#

# BS=2
./build/examples/multimodal/vlm_chat \
--tokenizerPath=tmp/hf_models/${MODEL_NAME} \
--llmEnginePath=tmp/trt_engines/${MODEL_NAME}/llm.engine \
--visualEnginePath=tmp/trt_engines/${MODEL_NAME}/visual_encoder_fp16.engine \
--modelType="qwen2_vl" \
--maxLength=1024 \
--inputString="Describe the picture." \
--imagePaths="examples/multimodal/qwen2vl/pics/demo.jpeg" \
--inputString="Identify the similarities between these images." \
--imagePaths="examples/multimodal/qwen2vl/pics/image1.jpeg,examples/multimodal/qwen2vl/pics/image2.jpeg"

Note:#

  1. --inputString takes input prompt for one batch. --imagePaths takes image paths for one batch. Multiple image paths in one batch should be separated with comma ','.

  2. One --inputString and one --imagePaths are paired as inputs for one batch. batchSize equals to the maximum of number of --inputString and number of --imagePaths.

  3. For any batch that contains --imagePaths only, --inputString is set to default prompt Describe this image.. For any batch that contains --inputString only, --imagePaths is set to empty, which is equivalent to pure LLM inference.

Benchmark Performance#

  1. Benchmark the E2E pipeline performance on certain input size.

    ./build/examples/multimodal/vlm_benchmark \
    --llmEnginePath=tmp/trt_engines/${MODEL_NAME}/llm.engine \
    --visualEnginePath=tmp/trt_engines/${MODEL_NAME}/visual_encoder_fp16.engine \
    --modelType="qwen2_vl" \
    --textTokenLength=512 --imageTokenLength=512 --outputLength=256
    [--warmUp=2 --numRuns=10]
    
  2. Benchmark visual encoder and LLM separately.

    • Use Qwen2 model as an approximation for Qwen2-VL LLM part performance. - Build Qwen2 engine of the same size and precision. - Use ../llm/llm_benchmark binary to benchmark.

    • Use trtexec to benchmark visual encoder engine.

    • E2E latency = visual encoder latency + LLM latency

Evaluate Accuracy with MMMU#

To match MMMU evaluation config and MMMU images size, we need to generate ONNX and TensorRT engines with the following config: --minImageTokens=1280, --maxImageTokens=6620, --maxInputLen=7168, --maxSeqLen=8192.

  1. Export ONNX.

    python3 ../../export/multimodal_export.py \
    --torch_dir tmp/hf_models/${MODEL_NAME} \
    --output_dir tmp/onnx/${MODEL_NAME} \
    --dtype [fp16|fp8|int4] \
    --model_type qwen2_vl \
    --max_seq_length 8192
    
  2. Build engine.

    ./build/examples/multimoal/vlm_build \
    --llmOnnxPath=tmp/onnx/${MODEL_NAME}/llm_onnx/model.onnx \
    --llmEnginePath=tmp/trt_engines/${MODEL_NAME}/llm.engine \
    --visualOnnxPath=tmp/onnx/${MODEL_NAME}/visual_enc_onnx/model.onnx \
    --visualEnginePath=tmp/trt_engines/${MODEL_NAME}/visual_enc_fp16.engine \
    --modelType="qwen2_vl" \
    --maxInputLen=7168 --maxSeqLen=8192 \
    --dynamicShape \
    --maxBatchSize=1 --minImageTokens=1280 --maxImageTokens=6620
    
  3. Collect inference results on MMMU-val dataset.

    wget https://opencompass.openxlab.space/utils/VLMEval/MMMU_DEV_VAL.tsv
    
    ./build/examples/multimodal/vlm_accuracy \
    --tokenizerPath=tmp/hf_models/${MODEL_NAME} \
    --llmEnginePath=tmp/trt_engines/${MODEL_NAME}/llm.engine \
    --visualEnginePath=tmp/trt_engines/${MODEL_NAME}/visual_encoder_fp16.engine \
    --modelType=qwen2_vl \
    --datasetPath=./MMMU_DEV_VAL.tsv \
    --outputPath=./mmmu-qwen2vl.csv
    
  4. Evaluate results with python script.

    python scripts/mmmu.py --csv_path=./mmmu-qwen2vl.tsv --output_path=./mmmu-qwen2vl-eval.json
    

Limitations and Known Issues#

Python Export#

Qwen export requires torch<2.5.0. With torch>=2.5.0, you will encounter the below issue. Therefore, the torch version is fixed at torch==2.4.1.

_C._jit_pass_onnx_graph_shape_type_inference(
RuntimeError: The serialized model is larger than the 2GiB limit imposed by the protobuf library. Therefore the output file must be a file path, so that the ONNX external data can be written to the same directory. Please specify the output file name.

nvidia-modelopt>0.19.0 has accuracy issues for INT4 recipe, so for the mainstream it is fixed at 0.19.0.

NVFP4 Export#

NVFP4 ONNX has not been matured, due to onnx==1.18.0 has not been released. If you want to export NVFP4 ONNX, you first need to do pip3 install -r requirements.txt, then unintall onnx and modelopt using pip3 uninstall onnx and pip3 uninstall nvidia-modelopt, and then in export folder, pip3 install -r requirements_nvfp4.txt, which installs preview onnx-weekly and nvidia-modelopt==0.23.0. You will likely encounter this issue below. You need to manually change split_complex_to_pairs to _split_complex_to_pairs in the file as a WAR because the function name has been changed by a recent ONNX commit. The issue should be fixed once onnx==1.18.0 is formally released.

File "/usr/local/lib/python3.10/dist-packages/onnxmltools/proto/__init__.py", line 14, in <module>
  from onnx.helper import split_complex_to_pairs
ImportError: cannot import name 'split_complex_to_pairs' from 'onnx.helper' (/usr/local/lib/python3.10/dist-packages/onnx/helper.py)

You may encounter the below issue for nvfp4 export. This issue comes from modelopt. You need to manually add get_quantization_format(module[0]) != QUANTIZATION_NONE as the first condition. Modelopt team will fix it in the next release.

File "/home/.local/lib/python3.10/site-packages/modelopt/torch/export/unified_export_hf.py", line 109, in requantize_resmooth_fused_llm_layers
  if tensor in output_to_layernorm.keys() and "awq" in get_quantization_format(modules[0]):
TypeError: argument of type 'NoneType' is not iterable

Engine Build#

Since Qwen and Llama’s vocab size is large (~100000), using --dynamicShape with --maxBatchSize > 1 is not supported and will run into engine build crash. TensorRT team is aware of this issue and will fix it in the later version.

Inference#

There is a known issue that cudaMallocAsync will fail when allocated memory size is large (>~5G). If you build an engine that is larger than 5GB, it will fail to load the engine. Please use this as a WAR to prevent this issue. DriveOS team is aware of this issue and will fix it in the next release.

echo 24576 | sudo tee /proc/sys/vm/nr_hugepages

If you encounter issue with mmap while loading the engine, you can use export DISABLE_MMAP_LOAD=1 to use the default IStreamReader to load engine.