Data Science

How to Deploy an AI Model in Python with PyTriton

AI models are everywhere, in the form of chatbots, classification and summarization tools, image models for segmentation and detection, recommendation models, and more. AI machine learning (ML) models help automate many business processes, generate insights from data, and deliver new experiences. 

Python is one of the most popular languages used in AI/ML development. In this post, you will learn how to use NVIDIA Triton Inference Server to serve models within your Python code and environment using the new PyTriton interface

More specifically, you will learn how to prototype and test inference of an AI model in a Python development environment with a production-class tool, and how to go to production with the PyTriton interface. You will also learn the advantages of using PyTriton, compared to a generic web framework like FastAPI or Flask. The post includes several code examples to illustrate how you can activate high-performance batching, preprocessing, and multi-node inference; and implement online learning.

What is PyTriton?

PyTriton is a simple interface that enables Python developers to use Triton Inference Server to serve AI models, simple processing functions, or entire inference pipelines within Python code. Triton Inference Server is an open-source multi-framework inference serving software with high performance on CPUs and GPUs.

PyTriton enables rapid prototyping and testing of ML models while achieving performance and efficiency with, for example, high GPU utilization. A single line of code brings up Triton Inference Server, providing benefits such as dynamic batching, concurrent model execution, and support for GPU and CPU from within the Python code. 

PyTriton removes the need to set up model repositories and port models from the development environment to production. Existing inference pipeline code can also be used without modification. This is especially useful for newer types of frameworks like JAX, or complex pipelines that are part of the application code without dedicated backends in Triton Inference Server.

Simplicity of Flask

Flask and FastAPI are generic Python web frameworks used to deploy a wide variety of Python applications. Because of their simplicity and widespread adoption, many developers use them to deploy and run AI models in production. However, significant drawbacks to this approach include the following:

  • General-purpose web servers lack support for AI inference features. There is no out-of-box support to take advantage of accelerators like GPUs, or to turn on dynamic batching or multi-node inference.
  • Users need to build logic to meet the demands of specific use cases, like audio/video streaming input, stateful processing, or preprocessing the input data to fit the model.
  • Metrics on compute and memory utilization or inference latency are not easily accessible to monitor application performance and scale.

Triton Inference Server includes built-in support for features like those listed above, and many more. PyTriton provides the simplicity of Flask and the benefits of Triton in Python. An example deployment of a HuggingFace text classification pipeline using PyTriton is shown below. For the full code, see the HuggingFace BERT JAX Model.

import logging
import numpy as np
from transformers import BertTokenizer, FlaxBertModel
from pytriton.decorators import batch
from pytriton.model_config import ModelConfig, Tensor
from pytriton.triton import Triton
logger = logging.getLogger("examples.huggingface_bert_jax.server")
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(name)s: %(message)s")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = FlaxBertModel.from_pretrained("bert-base-uncased")
def _infer_fn(**inputs: np.ndarray):
	(sequence_batch,) = inputs.values()
	# need to convert dtype=object to bytes first
	# end decode unicode bytes
	sequence_batch = np.char.decode(sequence_batch.astype("bytes"), "utf-8")
	last_hidden_states = []
	for sequence_item in sequence_batch:
    		tokenized_sequence = tokenizer(sequence_item.item(), return_tensors="jax")
    		results = model(**tokenized_sequence)
	last_hidden_states = np.array(last_hidden_states, dtype=np.float32)
	return [last_hidden_states]

with Triton() as triton:"Loading BERT model.")
            Tensor(name="sequence", dtype=np.bytes_, shape=(1,)),
            Tensor(name="last_hidden_state", dtype=np.float32, shape=(-1,)),
    )"Serving inference")

PyTriton offers an interface familiar to Flask users for easy installation and setup, and provides the following benefits: 

  • ​Bring up NVIDIA Triton with a single line of code
  • No need to set up model repositories and model format conversion (important for a high-performance implementation using Triton Inference Server)
  • Use of existing inference pipeline code without modification
  • Support for many decorators to adapt model input  

Whether working on a generative AI application or any other model, PyTriton enables you to gain the benefits of Triton Inference Server in your own development environment. It helps take advantage of the GPU to produce an inference response in very short time (milliseconds or seconds, depending on the use case). It also helps run the GPU at high capacity and serve many inference requests at the same time, keeping ‌infrastructure costs low.

PyTriton code examples

This section provides a few code examples you can use to get started with PyTriton. They begin on a local machine, which is ideal to test and prototype, and provide Kubernetes configuration for scaled deployment. 

Dynamic batching support

A key difference between Flask/FastAPI and PyTriton, dynamic batching enables batching of inference requests from multiple calling applications for the model, while retaining the latency requirements. Two examples are HuggingFace BART PyTorch and HuggingFace ResNET PyTorch.

Online learning

Online learning is learning from new data continuously in production. With PyTriton, you can control the number of distinct model instances backing your inference server. This enables you to train and serve the same model simultaneously from two different endpoints. Learn more about how to use PyTriton to train and infer models at the same time on MNIST dataset.

Multi-node inference of large language models

Large language models (LLMs) that are too large to fit into a single GPU memory require the model to be partitioned across multiple GPUs, and in certain cases across multiple nodes for inference. Check out an example using Hugging Face OPT model in JAX with inference done on multiple nodes. 

See NeMo Megatron GPT model deployment for a second example that uses the NVIDIA NeMo 1.3B parameter model. The multi-node inference deployment orchestration is shown using both Slurm and Kubernetes.

Stable Diffusion

With PyTriton, you can use preprocessing decorators to perform advanced batching operations, like batching together images of the same size using simple definitions:


To learn more, check out this example that uses the Stable Diffusion 1.5 image generation pipeline from Hugging Face.


PyTriton provides a simple interface that enables Python developers to use NVIDIA Triton Inference Server to serve a model, a simple processing function, or an entire inference pipeline. This native support for Triton Inference Server in Python enables rapid prototyping and testing of ML models with performance and efficiency. A single line of code brings up Triton Inference Server. Dynamic batching, concurrent model execution, and support for GPU and CPU from within the Python code are among the benefits. PyTriton offers the simplicity of Flask and the benefits of Triton Inference Server in Python. 

Try PyTriton using the examples in this post, or using your own model. See Migrating to the Triton Inference Server for information on migrating from Flask to PyTriton and Triton Inference Server. To learn more, visit the Triton Inference Server page and PyTriton repository on GitHub.

Discuss (1)