Data Science

Supercharge Tree-Based Model Inference with Forest Inference Library in NVIDIA cuML

Picture of moss-covered trees in a forest.

Tree-ensemble models remain a go-to for tabular data because they’re accurate, comparatively inexpensive to train, and fast. But deploying Python inference on CPUs quickly becomes the bottleneck once you need sub-10 ms of latency or millions of predictions per second. 

Forest Inference Library (FIL) first appeared in cuML 0.9 in 2019, and has always been about one thing: blazing-fast inference for gradient-boosted trees and random forests trained in XGBoost, LightGBM, scikit-learn, or NVIDIA cuML. In general, if your model can be converted to Treelite, then you can use FIL. 

FIL has been redesigned in RAPIDS 25.04, and the new highlights include:

  • New C++ implementation that allows you to perform batched inference on either GPU or CPU
  • An optimize() function for tuning your inference model
  • New advanced inference APIs (predict_per_tree, apply)
  • Up to 4x faster GPU throughput than cuML 25.02 FIL

In this blog post, you’ll learn about the new capabilities, performance, and features of FIL in cuML 25.04, and learn about the advantages it offers over the previous versions of cuML. 

Quick-start example (XGBoost → FIL)

Users can train models as usual in XGBoost LightGBM or Scikit-learn, save them to disk, then use FIL to reload those models and apply them to new data. This could be in a deployment server or totally different hardware than what the training was done in. Here is a quick example of the easy-to-use Python API: 

import xgboost as xgb
from cuml.fil import ForestInference

# Train your model as usual and save it
xgb_model = xgb.XGBClassifier()
xgb_model.fit(X_train, y_train)
xgb_model.save_model("xgb_model.ubj")

# Load into FIL and auto-tune for 10 k-row batches
fil_model = ForestInference.load("xgb_model.ubj")

# Now you can predict with FIL directly 
preds = fil_model.predict(X_test) 
probs = fil_model.predict_proba(X_test)

What’s new in FIL in cuML 

Auto-optimization

Forest inference capabilities in cuML allow users to fine-tune performance with a variety of hyperparameters. It is difficult to predict what the optimal values will be for any given model and batch size, so it is often necessary to determine them empirically. The new version of FIL significantly simplifies this process with a built-in method for auto-optimization at any given batch size:

fil_model = ForestInference.load("xgb_model.ubj")
fil_model.optimize(batch_size=1_000_000)
result = fil_model.predict(data)

Once .optimize is called, subsequent prediction calls will use the optimal performance hyperparameters found for the indicated batch size. You can also check what hyperparameters were selected by looking at the .layout and .default_chunk_size attributes.

print(fil_model.layout)
print(fil_model.default_chunk_size)

New prediction APIs

Usually, the final output of a forest model is all we need, whether it is a class prediction or a combined numeric output from all trees. Sometimes, however, it’s useful to get more granular information about individual trees in the ensemble. That’s why we introduced two new prediction methods: .predict_per_tree and .apply.

The first, .predict_per_tree, gives the predictions of every tree individually. This can be useful for experimenting with novel ensembling techniques or for analysis of how the ensemble reached an overall prediction. For example, one could weight each tree by its age, out-of-bag AUC, or even data-drift score, then average tree votes for a smarter final decision, with no retraining required. Another example is providing prediction intervals, without needing bootstrapping or different training, in a fast manner:

per_tree = fil_model.predict_per_tree(X)
mean = per_tree.mean(axis=1)
lower = cupy.percentile(per_tree, 10, axis=1)
upper = cupy.percentile(per_tree, 90, axis=1)

The second, .apply gives the node ID of the leaf node of every tree for a given input. This opens forest models to novel uses that extend beyond straightforward regression or classification. A very simple application of this is measuring how “alike” two rows are by counting how many trees send them to the same leaf:

leaf = fil_model.apply(X) 
sim = (leaf[i] == leaf[j]).mean()  # fraction of matching leaves
print(f"{sim:.0%} of trees agree on rows {i} & {j}")

GPU and CPU support

While forest inference capabilities in cuML originally started with accelerating forest inference on GPUs,  users wanted to develop forest inference applications on systems without NVIDIA GPUs. A common use case is local testing on small subsets of data before deploying a model into production. Another valuable use case is the ability to scale down to CPU-only machines when traffic is light and scale up using GPUs in order to benefit from speed and cost-savings when traffic volumes increase.

You can compile FIL in a CPU-only mode and call it from C++.  This allows you to use it without any CUDA dependency, yet still load Treelite-compatible models and spread inference across all available CPU cores with OpenMP.  

For Python users, you can try executing FIL in CPU with a new context manager introduced in cuML 25.04:

from cuml.fil import ForestInference, get_fil_device_type, set_fil_device_type

with set_fil_device_type("cpu"):
    fil_model = ForestInference.load("xgboost_model.ubj")
    result = fil_model.predict(data)

Future versions will also provide Python packages that can be installed in CPU-only systems. 

How FIL gets its speed

This new release speeds up tree-based models by reducing how often data has to be fetched from memory. Each decision point or node of the tree is now automatically stored in the minimum size required (typically 8 or 16 bytes) and the nodes are arranged in smarter layouts. Most of the time, the processor can grab the next node in a single, fast read instead of several slow ones.  By default, the depth_first layout is used, which works best for deeper trees (depth ≥ 4). If your trees are shallow, try layered for smaller batches (1–128 rows) or breadth_first for larger batches—but remember, the built-in .optimize function can test them for you.

Additionally, a new performance hyperparameter, align_bytes was introduced to allow trees in depth_first and breadth_first layouts to be aligned such that they begin on cache line boundaries. This sometimes, but not always, provides a performance improvement. Aligning to 64 bytes on the CPU offers optimal performance for most models. On GPU, this alignment rarely offers a benefit, but some models do benefit from a 128 byte alignment.

Performance

In order to obtain the most complete understanding of the new FIL’s performance characteristics, we performed an exhaustive sweep across a broad range of each of these variables, as summarized in the following table:

VariableValues
Maximum tree depth2; 4; 8; 16; 32
Tree count16; 128; 1024; 2048
Feature count8; 32; 128; 512
Batch Size1; 16; 128; 1,024; 1,048,576; 16,777,216
Table 1. Model and batch size parameters explored for forest inference benchmarks with cuML 25.04

A RandomForestRegressor model was trained with every combination of maximum tree depth, tree count, and feature count using 10,000 rows of synthetically generated data. For cuML 25.04 the new .optimize method was used, and a manual grid search was used with the previous version.

Runtime performance was then tested using synthetically generated input batches from the same distribution as the training data. Input was provided via CuPy arrays for GPU FIL and NumPy arrays for CPU FIL. Both versions were given three warmup rounds. Then performance was measured on five inference rounds for each batch size, and the minimum runtime was taken across those rounds.

A single NVIDIA H100 (80GB HBM3) was used for GPU results, and a 2-socket Intel Xeon Platinum 8480CL machine was used for CPU results. Across all of these scenarios, cuML 25.04 outperformed the prior version in 75% of cases. The best, worst, and median relative and absolute performance changes are shown in the table below. A relative speedup of less than one indicates a performance regression. Note that the worst absolute slowdown was 62 milliseconds, while the best absolute speedup was five seconds.

Speedup (cuML 25.04 vs 25.02)
Minimum0.73x
Median1.16x
Maximum4.1x
Table 2. Summary statistics comparing forest model inference performance of cuML 25.04 with previous version across a wide range of model parameters and batch sizes.

While these high-level summary statistics offer a general sense of the new FIL’s performance improvements, it’s also useful to review performance  for a specific use case. The two scenarios that are typically of greatest interest are batch size 1 performance and the maximum throughput obtainable at any batch size. These represent use cases where inference requests must be processed one at a time or where minimizing latency is paramount and use cases where batched inference is possible in order to reduce processing time and expense.

batch size 1 speedup chart comparing cuml 25.04 vs 25.02
Figure 1. Heat map comparing batch size 1 performance of 25.04 and prior version for a variety of model depths, tree counts, and feature counts.

As shown, at batch size 1, 25.04 outperforms previous versions for 81% of the tested models. It underperforms slightly for models with many deep trees but offers a median speedup of 1.6x overall.

The maximum throughput performance is captured in a similar heatmap below.

large batch size speedup chart comparing cuml 25.04 vs 25.02
Figure 2. Heat map comparing maximum throughput of cuML 25.04 and previous version for a variety of model depths, tree counts, and feature counts.

Here, cuML 25.04 still outperforms the original FIL for the previous version for 76% of models with a median speedup of 1.4x, with minor regressions on shallow tree cases. 

To put performance into perspective, here’s a comparison of cuML 25.04 to scikit-learn’s RandomForest performance. One of the great features of the scikit-learn codebase is the cleanliness and simplicity of many of its implementations. Users can look at an algorithm’s implementation and quickly understand exactly how it works and how it can be modified.

In the case of RandomForest models, however, this approach doesn’t always produce the highest performance for inference. An important goal for the update is to offer direct acceleration of scikit-learn forest models so that users can get the best possible inference performance without adding additional complexity to the scikit-learn codebase itself.

In 100% of these scenarios, comparing an AMD EPYC 9654P 96-core CPU against a single H100 (80GB HBM3) GPU, FIL outperformed scikit-learn native execution. As before, we summarize overall performance in the following table:

Speedup (New FIL vs Sklearn native for Forest Inference)
Minimum13.9x
Median147x
Maximum882x
Table 3. Summary statistics comparing performance of 25.04 to native Scikit-Learn inference across a wide range of model parameters and batch sizes

The speedups for batch size 1 can be seen below: 

batch size 1 speedup chart comparing cuml 25.04 vs  scikit-learn native
Figure 3. Heat map comparing maximum throughput of cuML 25.04 and previous version for a variety of model depths, tree counts, and feature counts for batch size 1. 

From the heatmaps, it can be calculated that the median batch size 1 speedup is 239x. Maximum throughput performance relative to scikit-learn native is presented in the same way below:

large batch size speedup chart comparing cuml 25.04 vs  scikit-learn native
Figure 4. Heat map comparing maximum throughput of cuML 25.04 and native scikit-learn inference for a variety of model depths, tree counts, and feature counts for large batch size. 

Get Started with FIL in NVIDIA cuML today

The new rewrite of the Forest Inference Library in cuML offers a number of useful new features, as well as significant performance improvements relative to prior versions. The new auto-optimization feature makes it simpler to get the most out of the new performance enhancements.

This makes  FIL ideal for many scenarios: 

  • User-facing APIs where every millisecond counts
  • High-volume batch jobs (ad-click scoring, IoT analytics)
  • Hybrid deployments—same model file, choose CPU or GPU at runtime
  • Prototype locally and deploy to GPU accelerated production servers 
  • Cost reduction—one GPU can replace CPUs with 50 cores or more.

Try new forest inference capabilities that are part of FIL today by downloading the cuML 25.04 release. These capabilities will also be available in a future release of Triton Inference Server

Upcoming blog posts will share the technical details of this new implementation, further benchmarks, as well as the integration of FIL with NVIDIA Triton Inference Server. 

To learn more about FIL, including performance, API documentation, benchmarks and more head to the cuML FIL docs. To learn more about accelerated data science, check out the hands-on courses in our DLI Learning Path.

Discuss (0)

Tags