Edge Computing

Model Quantization: Turn FP8 Checkpoints into High-Performance Inference Engines with NVIDIA TensorRT

Decorative image.

Converting a quantized checkpoint into an NVIDIA TensorRT engine bridges the gap between model optimization and production deployment, enabling faster inference, higher throughput, and more efficient GPU utilization at scale.

In a previous post, we produced a high-quality FP8-quantized Contrastive Language-Image Pretraining (CLIP) checkpoint with NVIDIA TensorRT Model Optimizer.

This post picks up where we left off, walking through how to export the checkpoint to ONNX and compile it into an NVIDIA TensorRT engine ready for production inference. We also profile the resulting FP8 TensorRT engine against the FP16 baseline to measure the real-world speedup the quantized model delivers.

Figure 1 shows the five stages of a typical end-to-end quantization workflow. This is the standard pipeline for deploying a quantized CLIP model. Quantized LLMs follow a different path through TensorRT-LLM, which is covered in this tutorial

Export model to ONNX format

The first step is to export the ModelOpt checkpoint to ONNX. The following pseudo-code does this for the FP8-quantized CLIP checkpoint using a built-in helper from Modelopt (the export targets ONNX opset 20+, where FP8 QuantizeLinear/DequantizeLinear is fully supported). It folds each weight-side quantize-then-dequantize (Q-DQ) pair into an FP8-stored DQ-only chain, noticeably shrinking the ONNX file.

In principle native torch.onnx.export works too, but requires us to write a custom conversion script.

import torch
from transformers import CLIPModel, CLIPTokenizer
from transformers.models.clip.modeling_clip import CLIPAttention
import modelopt.torch.opt as mto
import modelopt.torch.quantization as mtq
from modelopt.torch._deploy.utils import OnnxBytes, get_onnx_bytes_and_metadata
from modelopt.torch.quantization.plugins.diffusion.diffusers import _QuantAttention

# Thin wrappers expose a single forward to the ONNX exporter
class TextEncoder(torch.nn.Module):
    def __init__(self, m):
        super().__init__(); self.m = m
    def forward(self, x):
        return self.m.get_text_features(x)

class ImageEncoder(torch.nn.Module):
    def __init__(self, m):
        super().__init__(); self.m = m
    def forward(self, x): 
        return self.m.get_image_features(x)

def prepare_for_fp8_onnx_export(model):
    # 1) turn on FP8 attention fusion (off by default, lost on reload).
    # 2) clear CLIP's float `scale` — exporter chokes on it.
    for _, mod in model.named_modules():
        if isinstance(mod, _QuantAttention):
            mod._disable_fp8_mha = False
        if isinstance(mod, CLIPAttention) and getattr(mod, "scale", None) is not None:
            mod.scale = None

def export(wrapper, dummy, axis_name, out_name):
    """ModelOpt's exporter folds Q+DQ on weights into FP8-stored DQ-only chains
    and rewrites TRT custom ops to native ONNX QDQ — output is TRT-ready."""
    onnx_bytes, _ = get_onnx_bytes_and_metadata(
        model=wrapper, dummy_input=(dummy,), model_name=out_name,
        dynamic_axes={axis_name: {0: "batch"}}, onnx_opset=20, weights_dtype="fp16",
    )
    OnnxBytes.from_bytes(onnx_bytes).write_to_disk("./onnx_output", clean_dir=False)

# Restore the FP8-quantized CLIPModel from the ModelOpt checkpoint
mto.enable_huggingface_checkpointing()
mtq.QuantModuleRegistry.register({CLIPAttention: "CLIPAttention"})(_QuantAttention)
model = (
    CLIPModel.from_pretrained(modelopt_ckpt, attn_implementation="sdpa", torch_dtype=torch.float16)
    .eval().cuda()
)
prepare_for_fp8_onnx_export(model)

# Export Text encoder to ONNX
tok = CLIPTokenizer.from_pretrained(model_ckpt)
dummy_text = tok(["a photo of a cat"], return_tensors="pt", padding="max_length", max_length=77)["input_ids"].cuda()
export(TextEncoder(model), dummy_text, "text_input", "text_clip_fp8")

# Export Image encoder to ONNX 
dummy_image = torch.randn(16, 3, 224, 224, dtype=torch.float16).cuda()
export(ImageEncoder(model), dummy_image, "image_input", "image_clip_fp8")

Model componentFP8 Modelopt checkpoint FP16 HuggingFace checkpointSize reduction
CLIP text encoder ONNX156 MB237 MB~34%
CLIP image encoder ONNX292 MB582 MB~50%
Table 1. CLIP ONNX model size: FP8 vs FP16

Table 1 compares the ONNX file sizes of the FP8 ModelOpt checkpoint export against the original FP16 HuggingFace checkpoint export. The FP8 checkpoint export produces noticeably smaller ONNX files, ~34% smaller for the text encoder and ~50% smaller for the image encoder. 

Note that shrinking the ONNX file is a convenience, not a requirement. TensorRT folds the weight-side Q node into the FP8 weight at engine-build time. ModelOpt ONNX exporter folds earlier on the ONNX side to keep the on-disk file smaller.

We can inspect the exported ONNX file with the NVIDIA Nsight Deep Learning Designer, an efficient tool for ONNX model editing, performance profiling, and TensorRT engine building.

Figure 2 shows a portion of the exported ONNX graph visualized in Nsight Deep Learning Designer. We can see that the graph now contains QuantizeLinear/ DequantizeLinear (Q/DQ) nodes, marking the FP8 boundaries.

During engine building, TensorRT fuses these nodes with adjacent layers to optimize inference performance. This fusion eliminates unnecessary quantize-then-dequantize transitions, enabling the use of optimized FP8 kernels for computation.

Profile ONNX model with TensorRT

With the FP8 ONNX model exported, the next step is to pass it to TensorRT and measure how fast it runs. Before we begin, make sure TensorRT is properly downloaded and installed by following this tutorial. Once ready, we will use trtexec (TensorRT command-line wrapper) to benchmark the ONNX model with the following command:

# Set up the TensorRT environment
export PATH=<TensorRT-${version}/bin>:$PATH
export LD_LIBRARY_PATH=<TensorRT-${version}/lib>:$LD_LIBRARY_PATH

# Benchmark the ONNX model with trtexec
trtexec --onnx=text_clip_fp8.onnx \
        --shapes=text_input:128x77 \
        --stronglyTyped \
        --saveEngine=text_clip_fp8.plan

trtexec --onnx=image_clip_fp8.onnx \
        --shapes=image_input:128x3x224x224 \
        --stronglyTyped \
        --saveEngine=image_clip_fp8.plan
  • --onnx specifies the input ONNX model that TensorRT will build the engine from.
  • --shapes pins the input shape so TensorRT can build an optimized engine for that exact size.
  • --stronglyTyped forces TensorRT to respect the precision annotations that ModelOpt baked into the ONNX graph, ensuring our FP8 weights and activations actually execute in FP8. 
  • --saveEngine writes the built TensorRT engine to disk for later reuse, either for standalone TensorRT inference runtime or for serving through NVIDIA Triton Inference Server (see this example).

One caveat: ModelOpt’s exporter wraps the attention scaling in an FP32 round-trip, which --stronglyTyped rejects (you may see a Float vs Half type mismatch error). Before trtexec benchmarking, we cast these scale constants and Cast ops back to FP16 to get a clean, strongly typed engine.

# Re-type all FP32 initializers and Cast(to=FP32) ops in the FP8 ONNX to FP16.
import numpy as np
import onnx
from onnx import TensorProto, numpy_helper, shape_inference

model = onnx.load("clip_fp8.onnx")

for init in model.graph.initializer:
    if init.data_type == TensorProto.FLOAT:
        arr = numpy_helper.to_array(init).astype(np.float16)
        init.CopyFrom(numpy_helper.from_array(arr, name=init.name))

for node in model.graph.node:
    if node.op_type == "Cast":
        to_attr = next(a for a in node.attribute if a.name == "to")
        if to_attr.i == TensorProto.FLOAT:
            to_attr.i = TensorProto.FLOAT16

model = shape_inference.infer_shapes(model, data_prop=True, check_type=False)
onnx.save(model, "clip_fp8_strongtyped.onnx")

Alternatively, we can also profile the ONNX model with TensorRT using Nsight Deep Learning Designer by following the profiling section in the official user guide.

We run the benchmark on an NVIDIA RTX 6000 Ada GPU with TensorRT 10.16 using the trtexec command, with a static batch size of 128. Each reported latency is the median across all inference iterations within the default measurement window. Note that FP8 is only supported for matrix multiplications (GEMM) on Ada and later architectures (compute capability 8.9 or above). For a detailed breakdown of which data types are supported on which GPUs, see the TensorRT support matrix.

Figure 3 shows the benefits of FP8 quantization over FP16 across both TensorRT engine size and inference latency. On the left, the image encoder shrinks from 588 MB to 306 MB (a 48% reduction) and the text encoder from 238 MB to 156 MB (a 34% reduction), cutting the combined on-disk footprint nearly in half. The same savings carry over to GPU VRAM usage at inference time because smaller engines require less memory to load and run. 

On the right, the latency story is just as compelling. The image encoder drops from 166.2 ms to 119.8 ms and the text encoder from 13.2 ms to 9.1 ms, delivering a 1.39x speedup on the image side and 1.45x on the text side.

Where exactly does the FP8 speedup come from? Beyond the raw numbers trtexec reports, Nsight Deep Learning Designer offers a richer visual breakdown to give us a clear answer. 

Figure 4 places the FP16 and FP8 image encoder profiles side by side, and three differences immediately stand out. 

  1. The GEMM bar drops from roughly 1.8 ms to 0.84 ms, more than a 2x speedup on the dominant matmul layer, delivered by NVIDIA RTX 6000 Ada GPUs’ FP8 Tensor Core kernels. 
  2. The “fusion” layer category visible in the FP16 profile is gone in the FP8 profile because TensorRT now routes the entire attention block through a specialized FP8 MHA kernel and yields a more streamlined execution path. 
  3. The precision donut shifts from a mostly orange (FP16) plot to a mostly purple (FP8) one. These signals confirm that our quantized weights and activations are running on FP8 Tensor Cores, which is exactly where FP8’s gains come from—higher computational throughput and lower memory bandwidth usage in every matmul-heavy step.

How quantization works in TensorRT

When importing an ONNX model, TensorRT looks for QuantizeLinear / DequantizeLinear (Q/DQ) nodes, which mark the points in the graph where a tensor transitions between full- and low-precision data types such as FP8.

Internally, TensorRT requires a Q/DQ layer pair on each input of every quantizable layer. At engine build time, the optimizer fuses these Q/DQ nodes into their adjacent layers and replaces the original layer with a specialized kernel that operates directly on the low-precision tensors. This eliminates the round-trip quantize-then-dequantize transitions and lets the engine execute with higher compute throughput and lower memory bandwidth. 

Figure 5 shows this transformation for an FP8 GEMM. In the exported ONNX, both the activation \(x_f\) and the weight tensor are wrapped in a QuantizeLinear/DequantizeLinear pair, and after TensorRT’s optimizer fuses them, what remains is a single FP8 GEMM kernel that takes the FP8-quantized activation and a pre-stored FP8 weight tensor directly.

For a deeper dive into TensorRT quantization mechanics, see the documentation.

Get started

In this post, we walked through the full ModelOpt → ONNX → TensorRT workflow for deploying a quantized model. We exported a CLIP checkpoint to ONNX with Q/DQ nodes, built a TensorRT engine, and benchmarked it against the FP16 baseline with both trtexec and Nsight Deep Learning Designer. 

The results revealed that FP8 quantization delivers significant speed and memory footprint improvements compared to the original FP16 models on an RTX 6000 Ada GPU. We also gave a brief overview of how TensorRT realizes those gains by fusing Q/DQ nodes into specialized low-precision kernels at build time. 

Try NVIDIA Model Optimizer and NVIDIA TensorRT and explore the efficiency gains that model quantization can deliver.

Discuss (0)

Tags