Agentic AI / Generative AI

Adaptive Inference in NVIDIA TensorRT for RTX Enables Automatic Optimization

Deploying AI applications across diverse consumer hardware has traditionally forced a trade-off. You can optimize for specific GPU configurations and achieve peak performance at the cost of portability. Alternatively, you can build generic, portable engines and leave performance on the table. Bridging this gap often requires manual tuning, multiple build targets, or accepting compromises.

NVIDIA TensorRT for RTX seeks to eliminate this trade-off. At under 200 MB, this lean inference library provides a Just-In-Time (JIT) optimizer that compiles engines in under 30 seconds. This makes it ideal for real-time, responsive AI applications on consumer-grade devices. 

TensorRT for RTX introduces adaptive inference—engines that optimize automatically at runtime for your specific system, progressively improving compilation and inference performance as your application runs. No manual tuning, no multiple build targets, no intervention required.

Build a lightweight, portable engine once, deploy it anywhere, and let it adapt to the user’s hardware. At runtime, the engine automatically compiles GPU-specific specialized kernels, learns from your workload patterns, and improves performance over time—all without any developer intervention. For more details, see the NVIDIA TensorRT for RTX documentation.

Adaptive inference

With TensorRT for RTX, runtime performance improves over time without any manual intervention. Three features work in tandem to enable this self-optimization: Dynamic Shape specialized kernels tune performance to your workloads’ shapes, CUDA Graphs eliminate overhead when executing those kernels, and runtime caching persists these improvements across sessions. The result: your engine gets faster as it runs.

  • Dynamic Shapes Kernel Specialization: Automatically compiles faster kernels for shapes encountered at runtime and seamlessly swaps these in, improving performance in real-time by specializing for workload conditions. 
  • Built-in CUDA Graphs: Automatically captures, instantiates, and executes kernels as a single batch, reducing kernel launch overhead and boosting inference performance, while integrating with Dynamic Shapes. 
  • Runtime caching: Reduces JIT time overhead by storing compiled kernels across sessions, reducing overhead and avoiding redundant compilation. 

For a live demonstration of these features working together on a real diffusion pipeline with concrete speedups, see the Adaptive Inference Acceleration With TensorRT for RTX walkthrough video.

Static optimization versus adaptive inference workflows

Traditional inference frameworks require developers to predict input shapes and build optimized engines for each target configuration at compile time. TensorRT for RTX takes a different approach: engines adapt to actual workloads at runtime. Table 1 compares these two workflows.

ComponentStatic workflowAdaptive inference
Build targetsMultiple engines per GPUSingle portable engine
Shape flexibilityOptimized at build time for predicted shapesOptimized automatically at runtime for actual seen shapes
Inference run 1Optimal performance (if pretuned shape)Near-optimal performance 
Inference run NSame performance Performance improves over time as new shapes are encountered (plus cached specializations)
Developer effortManual tuning per configZero intervention
Table 1. A comparison between static optimization and adaptive inference workflows

Adaptive inference closes the gap with the static workflow, offering optimal performance while eliminating build complexity and developer effort.

Performance comparison: Adaptive versus static

To demonstrate the performance of adaptive inference, we compared the FLUX.1 [dev] model in FP8 precision at 512×512 with dynamic shapes on an RTX 5090 (Windows 11) using TensorRT for RTX 1.3 compared to a static optimizer. As shown in Figure 1, adaptive inference surpasses static optimization by iteration 2 and reaches 1.32x faster with all features enabled. Runtime caching also accelerates JIT compilation from 31.92s to 1.95s (16x), enabling subsequent sessions to start at peak performance immediately.

Bar chart comparing inference times across five configurations. Adaptive Iteration 1 (Fallback Kernels): 6.46s. Static Inference: 3.64s (blue bar). Adaptive Iteration 2 (Fallback + Specialized): 3.15s. Adaptive Iteration 3 (Specialized Kernels): 2.94s. Adaptive Iteration 3 with CUDA Graphs and Runtime Cache: 2.76s. Green bars represent adaptive inference; blue bar represents static inference.
Figure 1. Adaptive inference performance progression on FLUX.1 [dev] FP8 512×512 (RTX 5090, Windows 11)

Motivating example

Creating a TensorRT engine from an ONNX model provides a motivating example:

import tensorrt_rtx as trt_rtx

logger = trt_rtx.Logger(trt.Logger.WARNING)
builder = trt_rtx.Builder(logger)
network = builder.create_network()

parser = trt_rtx.OnnxParser(network, logger)
with open("your_model.onnx", "rb") as f:
    parser.parse(f.read())

Dynamic Shapes Kernel Specialization

Models tend to have varying input dimensions across different image resolutions, variable sequence lengths, or dynamic batch sizes. Dynamic Shapes Kernel Specialization automatically generates and caches optimized kernels for shapes that your application encounters at runtime, tailored specifically to the model’s input dimensions. These optimized kernels are cached and reused, so subsequent inferences with the same shape run at peak performance, minimizing the compromise between flexibility and speed.

Figure 1 presents the inference speedup with TensorRT for RTX Dynamic Shapes Kernel Specialization across model categories on NVIDIA GeForce RTX 5090 (Windows 11). Each bar shows the average performance gain when specialized kernels are automatically generated and swapped in for encountered input shapes versus using generic “fallback” kernels.

Bar chart showing average speedup factors from Dynamic Shapes Kernel Specialization across four model categories on RTX 5090. Categories shown are Stable Diffusion 2.1 FP16, Language Models, Audio Models, and Convolution-Based Image Models, with speedup values ranging from 1.43x on the Stable Diffusion 2.1 FP16 pipeline models to 3.15x on Convolution-based image models.
Figure 2. Inference speedup with TensorRT for RTX Dynamic Shapes Kernel Specialization across model categories

The benefits scale with your workload variety. Models that process diverse input shapes see consistent performance across all configurations, while maintaining the flexibility to handle whatever comes next. Learn more about working with dynamic shapes.

Continuing with the initial example:

# Define optimization profile: min/opt/max shapes for dynamic dimensions
profile = builder.create_optimization_profile()
profile.set_shape("input",
    min=(1, 3, 224, 224),
    opt=(8, 3, 224, 224),
    max=(32, 3, 224, 224)
)
config.add_optimization_profile(profile)

# ... build engine ...

# Configure dynamic shape kernel specialization strategy
# The default is Lazy compilation, explicitly set below for illustrative purposes
# Lazy compilation automatically swaps in kernels compiled in the background, adaptively improving perf for shapes encountered at runtime
runtime_config = engine.create_runtime_config()
runtime_config.dynamic_shapes_kernel_specialization_strategy = (
    trt_rtx.DynamicShapesKernelSpecializationStrategy.LAZY
)

Built-in CUDA Graphs 

Modern neural networks can execute hundreds of individual GPU kernels per inference. Each kernel launch carries overhead—typically 5-15 microseconds of CPU and driver work. For models dominated by small operations (compact convolutions, small matrix multiplications, elementwise operations), this launch time becomes a bottleneck.

When per-kernel launch overhead dominates execution time, the GPU idles while the CPU queues work—the enqueue time approaches or exceeds actual GPU compute time. This condition, known as being “enqueue-bound,” can be addressed with CUDA Graphs.

CUDA Graphs capture the entire inference sequence as a graph structure, eliminating kernel-launch overhead and optimizing common use cases including repeated model calls. TensorRT for RTX launches the complete computation graph in a single operation, instead of launching kernels individually. 

This can shave many milliseconds off of every inference iteration, for instance providing a 1.8 ms (23%) boost on every run of the SD 2.1 UNet model as measured on a Windows machine with an RTX 5090 GPU. This feature is particularly beneficial on Windows systems with Hardware Accelerated GPU Scheduling enabled. Models with many small kernels see the greatest benefit, boosting the performance of enqueue-bound workloads.

Moreover, in the context of dynamic shapes, the built-in CUDA Graphs support only captures and executes the shape-specialized kernels. This approach ensures that the CUDA Graph focuses on accelerating the most performant kernels—typically those that are used most frequently. Read more about working with built-in CUDA Graphs.

Figure 3 shows the inference speedup with TensorRT for RTX using built-in CUDA Graphs on an RTX 5090 GPU (Windows 11, Hardware-Accelerated GPU Scheduling enabled). Note that gains for CUDA Graphs are more pronounced on image networks with many relatively short-running kernels.

Bar chart showing performance improvements from TensorRT for RTX built-in CUDA Graphs across 14 different neural network models on RTX 5090 with Windows 11 and Hardware-Accelerated GPU Scheduling enabled. Each bar displays speedup factor (xFactor) with CUDA Graphs on, with a gray baseline from 0 to 1.0x and green extension showing the speedup gain. Models are sorted by performance improvement and include Diffuser UNet models, transformers, convolution networks, and text encoders. Speedups range from approximately 1.06x to nearly 5x. The chart demonstrates that built-in CUDA Graphs can deliver strong performance improvements across diverse model architectures.
Figure 3. Inference speedup with TensorRT for RTX using built-in CUDA Graphs

Adding to the example:

# Enable CUDA Graph capture for reduced kernel launch overhead
runtime_config.cuda_graph_strategy = trt_rtx.CudaGraphStrategy.WHOLE_GRAPH_CAPTURE

Runtime caching

JIT compilation provides portability and automatic GPU-specific optimization in TensorRT for RTX. Runtime caching takes this further by preserving compiled kernels—including the specialized dynamic shape kernels referenced previously—across sessions, eliminating redundant compilation work.

Bar chart showing TensorRT for RTX JIT compilation time speedup with runtime cache across five diffuser core models, with speedups ranging from 17.8x to 53.5x.
Figure 4. JIT compilation time speedup with TensorRT for RTX on core diffuser models

To use runtime caching, begin by running your initial inferences using optimized implementations for commonly used shapes. This process generates specialized kernels tailored to those shapes. Using the runtime cache API, these kernels can then be serialized into a binary blob, which can be saved to disk for future reuse.

By loading this binary blob in subsequent sessions, you ensure that the most optimized kernels are available immediately—eliminating the need for a warm-up period, avoiding any performance regression, and preventing fallback to generic kernels. This enables your application to achieve peak performance from the very first inference run.

In addition, the runtime cache file can be bundled with your application. If you know your target users’ specific platforms—such as OS, GPU, CUDA, and TensorRT versions—you can pregenerate the runtime cache for those environments. Using your provided runtime cache file, users can bypass any kernel compilation overhead entirely, enabling optimal performance from the very first run. Read more about working with runtime caching.

Completing the example:

from polygraphy import util

# Create runtime cache to persist compiled kernels across runs
runtime_cache = runtime_config.create_runtime_cache()

# Load existing cache if available
runtime_cache_file = "runtime.cache"
with util.LockFile(runtime_cache_file):
    try:
        loaded_cache_bytes = util.load_file(runtime_cache_file)
        if loaded_cache_bytes:
            runtime_cache.deserialize(loaded_cache_bytes)
    except:
        pass  # No cache yet, will be populated during inference

runtime_config.set_runtime_cache(runtime_cache)
context = engine.create_execution_context(runtime_config)

# ... run inference ...

# Save cache for future runs
runtime_cache = runtime_config.get_runtime_cache()
with util.LockFile(runtime_cache_file):
    with runtime_cache.serialize() as buffer:
        util.save_file(buffer, runtime_cache_file, description="runtime cache")

Get started with adaptive inference

Three technologies work together to make adaptive inference optimizations easy:

  • Dynamic Shapes Kernel Specialization ensures each shape runs optimally.
  • CUDA Graphs eliminate the overhead of executing those optimized kernels.
  • Runtime caching makes those optimizations persistent across sessions.

AI applications can adapt to any input dimension while maintaining the performance characteristics of static-shape inference. No compromises or artificial constraints on your application design. Read more about TensorRT for RTX best practices for performance.

To experience adaptive inference with NVIDIA TensorRT for RTX, visit the NVIDIA/TensorRT-RTX GitHub repo and try the FLUX.1 [dev] Pipeline Optimized with TensorRT RTX notebook. You can also view the Adaptive Inference Acceleration with TensorRT for RTX walkthrough video for a live demonstration of these features.

Start building AI apps for NVIDIA RTX PCs to run models faster and more privately on-device, and streamline development with NVIDIA tools, SDKs, and models on Windows.

Discuss (0)

Tags