Agentic AI / Generative AI

How to Achieve 4x Faster Inference for Math Problem Solving

Decorative math image.

Large language models can solve challenging math problems. However, making them work efficiently at scale requires more than a strong checkpoint. You need the right serving stack, quantization strategy, and decoding methods—often spread across different tools that don’t work together cleanly. Teams end up juggling containers, conversion scripts, and ad‑hoc glue code to compare BF16 vs FP8 or to test a speculative decoding setup.

This post shows how to build a fast, reproducible inference pipeline with the NVIDIA NeMo-Skills library to manage NVIDIA TensorRT-LLM. This streamlined version of the setup we used to win the AI Mathematical Olympiad Prize 2024, which achieved 4x faster batched inference on two NVIDIA H100 GPUs with FP8 quantization and ReDrafter speculative decoding. The same workflow can run on a single workstation or scale out on a cluster, with minimal changes.

By the end of this blog post, you’ll learn how to:

  1. Prepare and quantize an OpenMath model to an FP8 TensorRT-LLM engine.
  2. Train and integrate a ReDrafter draft model for speculative decoding.
  3. Launch an optimized inference server with optional tool-calling through a secure code sandbox.
  4. Benchmark latency and throughput across BF16, FP8, and FP8+ReDrafter configurations.

If you’re following along, we recommend a machine with two H100 (or comparable FP8-capable) GPUs or a Slurm cluster with similar nodes. 

Setting up your environment 

Our first step is to establish a consistent and isolated environment. We’ll use an NVIDIA PyTorch NGC container and install the essential libraries: TensorRT-LLM for model optimization and NeMo-Skills for the overall pipeline management. FP8 inference requires an NVIDIA GPU that supports FP8 inference, including the NVIDIA Ada Lovelace, NVIDIA Hopper, NVIDIA Blackwell, or NVIDIA Rubin architecture. For this example, we assume two GPUs are available.

Container setup and library installation

Once inside the nvcr.io/nvidia/pytorch:25.05-py3 container, run the following commands to install TensorRT-LLM and NeMo-Skills:

# Ensure no conflicting TensorRT installations and install TensorRT-LLM
[ -f /etc/pip/constraint.txt ] && : > /etc/pip/constraint.txt
pip uninstall -y tensorrt
pip3 install tensorrt_llm==1.1.0rc0

# Install NeMo-Skills
pip install git+https://github.com/NVIDIA/NeMo-Skills.git

Preparing model weights

The next step is preparing our large language model (LLM). We’ll download the nvidia/OpenMath-Nemotron-14B-Kaggle model and transform it into an optimized TensorRT-LLM engine using FP8 quantization.

Note on FP8 Quantization: FP8 (8-bit floating point) quantization is highly efficient but requires GPUs that support E4M3 FP8 (like NVIDIA Hopper GPUs). For other GPUs, int8_wo (8-bit integer with weight-only quantization) is recommended and doesn’t require calibration.

Downloading model weights and datasets

Generate a Hugging Face token and export it as an environment variable. Then use the Hugging Face CLI to download the necessary models and datasets.

# Export your Hugging Face token
export HF_TOKEN=hf_YOUR_HUGGING_FACE_TOKEN 

# Install Hugging Face CLI
pip install -U "huggingface_hub[cli]"

# Download the 14B parameter main model
huggingface-cli download nvidia/OpenMath-Nemotron-14B-kaggle --local-dir OpenMath-Nemotron-14B-kaggle

# Download the OpenMathReasoning dataset for calibration
huggingface-cli download nvidia/OpenMathReasoning --repo-type dataset --local-dir OpenMathReasoning

Preparing the calibration dataset for FP8 quantization

For FP8 quantization, a small calibration dataset representative of inference data is essential. We’ll use a subset of the OpenMathReasoning dataset to create it. An example is provided to generate the math calibration dataset in HuggingFace format.

Converting and quantizing to TensorRT-LLM engine

Now, convert the Hugging Face model to a TensorRT-LLM engine, applying FP8 quantization and using the prepared calibration dataset. This step generates the FP8 quantized LLM inference engine.

ns convert \
    --input_model OpenMath-Nemotron-14B-kaggle \
    --output_model OpenMath-Nemotron-14B-kaggle-fp8-trtllm \
    --convert_from hf \
    --convert_to trtllm \
    --num_gpus 2 \
    --dtype fp8 \
    --hf_model_name nvidia/OpenMath-Nemotron-14B-kaggle \
    --model_type qwen \
    --max_input_len 30000 \
    --max_seq_len 32000 \
    --no-trt_reuse_tmp_engine \
    --calib_dataset ./calibration_dataset

After this command, your FP8 LLM engine is ready for deployment.

Accelerating inference with ReDrafter

To push our inference efficiency further, we integrate ReDrafter. This speculative decoding technique uses a smaller “draft” model to predict tokens, enabling the main LLM to generate responses faster. ReDrafter is an RNN-based inference method developed by Apple. The ReDrafter implementation is compatible with most models supported within the TensorRT-LLM library.

Installing and training ReDrafter

First, install the ReDrafter library. The tokenizer and training data for the draft model should be the same as those used for the base model. If the original training data is not available, base model generations can also be used for training the draft model.

# Install the ReDrafter library
pip install --no-binary=protobuf --ignore-requires-python \     
"git+https://github.com/apple/ml-recurrent-drafter.git#egg=recurrent-drafting[dev,train]"

# Train the ReDrafter model
ns run_cmd --log_dir ./logs/ \
torchrun --nproc_per_node=2 -m nemo_skills.training.train_redrafter \
    --llm_name_or_path 'OpenMath-Nemotron-14B-kaggle' \
    --dataset "OpenMathReasoning" \
    --dataset_split "tir" \
    --bf16 True \
    --output_dir "redrafter_output" \
    --num_train_epochs 1 \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 4 \
    --save_strategy "no" \
    --learning_rate 0.001 \
    --weight_decay 0. \
    --warmup_ratio 0.1 \
    --lr_scheduler_type "cosine" \
    --logging_steps 20 \
    --tf32 True \
    --model_max_length 2048 \
    --dataset_nrows 50000 \
    --drafter_predict_n_tokens 3 \
    --drafter_num_layers 2 \
    --rnn True \
    --phase train \
    --report_to wandb # Remove if not using wandb

During training, observe the redrafter2_top1 score. Aiming for above 0.6 indicates close to 2x runtime performance (60% of steps accept the next three drafted tokens).

Building the TensorRT-LLM engine for the ReDrafter model

Now, we’ll convert our trained ReDrafter model into a TensorRT-LLM checkpoint and then combine it with our main LLM to create the final, accelerated TensorRT-LLM engine.

First, clone the TensorRT-LLM repository to access its conversion scripts:

git clone https://github.com/NVIDIA/TensorRT-LLM/

Next, convert the trained ReDrafter PyTorch checkpoint to a TensorRT-LLM checkpoint.

# Base model intermediate checkpoint from FP8 quantization step
export BASE_TRTLLM_CKPT=$(pwd)/OpenMath-Nemotron-14B-kaggle-fp8-trtllm-tmp-ckpt
# Trained draft checkpoint
export REDRAFTER_PYTORCH_CKPT=$(pwd)/redrafter_output/redrafter__redrafter_OpenMath-Nemotron-14B-kaggle_n_3_lr_0.001_layers_2
export REDRAFTER_TRTLLM_CKPT=$(pwd)/OpenMath-Nemotron-14B-kaggle-fp8-draft-ckpt

cd ./TensorRT-LLM/examples/redrafter
python convert_checkpoint.py \
    --base_model_checkpoint_dir $BASE_TRTLLM_CKPT \
    --drafter_model_dir $REDRAFTER_PYTORCH_CKPT \
    --output_dir $REDRAFTER_TRTLLM_CKPT \
    --dtype bfloat16 \
    --tp_size 2 \
    --redrafter_num_beams 1 \
    --redrafter_draft_len_per_beam 3
cd ../../../

Finally, build the combined TensorRT-LLM engine base model with a draft head for speculative decoding.

trtllm-build \
    --checkpoint_dir $REDRAFTER_TRTLLM_CKPT \
    --output_dir OpenMath-Nemotron-14B-kaggle-fp8-redrafter-trtllm \
    --gemm_plugin fp8 \
    --use_paged_context_fmha=enable \
    --max_batch_size 32 \
    --max_seq_len 32000 \
    --max_input_len 32000 \
    --max_num_tokens 32000 \
    --speculative_decoding_mode explicit_draft_tokens \
    --max_beam_width 1 \
    --kv_cache_type paged

Your TensorRT-LLM engine, now supercharged with ReDrafter, is ready to be served!

Benchmarking and results

We’ve prepared a companion notebook where you can try out the full pipeline yourself. The notebook was run with the same container setup and installations as the container setup section above, along with two H100 GPUs for inference. In the notebook, you can:

  • Run inference on different TensorRT-LLM engines (BF16, FP8, FP8+ReDrafter).
  • Compare performance benchmarks such as time to first token and throughput per device.
  • Explore advanced controls, such as early stopping after a fixed time or terminating after the first N generations complete.
  • Run inference with tool-calling.

Here’s a sample of the kind of benchmark results you’ll see:

MetricsBF16FP8FP8+ReDrafter
Total generation time(s)144.264.730.5
Average sample throughput (Tok/s)34.675.2138.5
Table 1. TensorRT-LLM performance comparison across different configurations on two H100 GPUs

Full benchmarks and code available in the notebook. Check out the AIMO-2 Winning Solution paper for more results.

Optional: Enabling tool-calling and the code execution sandbox

The OpenMath LLM is a powerful tool-instruction reasoning model. This means it doesn’t just generate text. It can also write and execute Python code in a secure sandbox to solve problems. In the companion notebook, we provide an example of how to launch both the LLM server and its accompanying code execution sandbox.

The interaction works like this:

  1. The LLM generates Python code wrapped in <tool_call> and </tool_call> tokens.
  2. The inference engine extracts and sends this code to the sandbox.
  3. The sandbox executes the code and returns the results.
  4. The output is fed back to the LLM for continued generation or to finalize its answer.

Here’s an example of such an interaction:

<tool_call>
# Initialize a list to store valid bases
valid_bases = []


# Check bases from 10 upwards
for b in range(10, 10000):  # Arbitrary large upper limit
    num1 = 9 * b + 7
    num2 = b + 7
    if num1 % num2 == 0:
        valid_bases.append(b)
        print(f"Found base: {b}")


# Sum the valid bases
sum_bases = sum(valid_bases)
print(f"Sum: {sum_bases}")


# If sum is over 1000, take modulo 1000
if sum_bases > 1000:
    result = sum_bases % 1000
else:
    result = sum_bases


print(f"Final Result: {result}")
</tool_call>
```output
Found base: 21
Found base: 49
Sum: 70
Final Result: 70
```

To turn off tool-calling in the companion notebook, use get_model instead of get_code_execution_model as shown in the NeMo-Skills docs.

Try it yourself. Run the companion notebook to benchmark these performance improvements on your hardware and experiment with tool-calling capabilities.

Discuss (0)

Tags