Data Science

Simplifying and Accelerating Machine Learning Predictions in Apache Beam with NVIDIA TensorRT

Loading and preprocessing data for running machine learning models at scale often requires seamlessly stitching the data processing framework and inference engine together.

In this post, we walk through the integration of NVIDIA TensorRT with Apache Beam SDK and show how complex inference scenarios can be fully encapsulated within a data processing pipeline. We also demonstrate how terabytes of data can be processed from both batch and streaming sources with a few lines of code for high-throughput and low-latency model inference.

  • NVIDIA TensorRT is an SDK that facilitates high-performance machine learning inference. It is designed to work with deep learning frameworks such as TensorFlow, PyTorch, and MXNet. It focuses specifically on optimizing and running a trained neural network for inference efficiently on NVIDIA GPUs. TensorRT can maximize inference throughput with multiple optimizations while preserving model accuracy including model quantization, layer and tensor fusions, kernel auto-tuning, multi-stream executions, and efficient tensor memory usage.
  • Proven with 15+ years in production, Dataflow is a no-ops, serverless data processing platform to process data, in batch or in real time, for analytical, ML and application use cases. These often include incorporating pretrained models into data pipelines. Whatever the use case may be, the use of Apache Beam as its SDK enables DataFlow to make use of the robust community and simplify your data architectures and deliver insights with ML.

Build a TensorRT engine for inference

To use TensorRT with Apache Beam, at this stage, you need a converted TensorRT engine file from a trained model. Here’s how to convert a TensorFlow Object Detection SSD MobileNet v2 320×320 model to ONNX, build a TensorRT engine from ONNX, and run the engine locally.

Convert the TF model to ONNX

To convert TensorFlow Object Detection SSD MobileNet v2 320×320 to ONNX, use one of the TensorRT example converters. This can be done on an on-premises system if the system has the same GPU that will be used in Dataflow for inference.

To prepare your environment, follow the instructions under Setup. This post follows this guide up to and including the Create ONNX Graph. Use –batch_size 1 as the example that we are covering further works with batch size 1 only. You can name the final –onnx file  ssd_mobilenet_v2_320x320_coco17_tpu-8.onnx. Building and running is handled in GCP. 

Make sure that you set up a GCP project with proper credentials and API access to Dataflow, Google Cloud Storage (GCS), and Google Compute Engine (GCE). For more information, see Create a Dataflow pipeline using Python.

Spin up a GCE VM

You need a machine that contains the following installed resources:

  • NVIDIA T4 Tensor Core GPU
  • GPU driver
  • Docker
  • NVIDIA container toolkit

You can do this by creating a new GCE VM. Follow the instructions but use the following settings:

  • Name: tensorrt-demo
  • GPU type: NVIDIA T4
  • Number of GPUs: 1
  • Machine type: n1-standard-2

You may need a more powerful machine if you know that you are working with models that are large.

In the Boot disk section, choose CHANGE, and go to the PUBLIC IMAGES tab. For Operating system, choose Deep Learning on Linux. There are many versions, but make sure you choose one with CUDA. The version Debian 10 based Deep Learning VM with M98 works for this example.

The other settings can be left to their default values.

Next, connect to the VM using SSH. Install NVIDIA drivers if you are prompted to do so.

Inside the VM, run the following commands to create a few directories to be used later:

mkdir models
mkdir tensorrt_engines

For more information, see Create a VM with attached GPUs.

Build the image

You need a custom container that contains the necessary dependencies to execute the TensorRT code: CUDA, cuDNN, and TensorRT.

You can copy the following example Dockerfile into a new file and name it tensor_rt.dockerfile.

ARG BUILD_IMAGE=nvcr.io/nvidia/tensorrt:22.09-py3

FROM ${BUILD_IMAGE} 

ENV PATH="/usr/src/tensorrt/bin:${PATH}"

WORKDIR /workspace

RUN pip install --no-cache-dir apache-beam[gcp]==2.42.0
COPY --from=apache/beam_python3.8_sdk:2.42.0 /opt/apache/beam /opt/apache/beam

RUN pip install --upgrade pip \
    && pip install torch>=1.7.1 \
    && pip install torchvision>=0.8.2 \
    && pip install pillow>=8.0.0 \
    && pip install transformers>=4.18.0 \
    && pip install cuda-python

ENTRYPOINT [ "/opt/apache/beam/boot" ]

View the Docker file used for testing in the Apache Beam repo. Keep in mind that there may be a later version of Beam available than what was used in this post.

Build the image by running the following command, locally or in a GCE VM:

docker build -f tensor_rt.dockerfile -t tensor_rt .

If you did this locally, follow the next steps. Otherwise, you can skip to the next section.

The following commands are only necessary if you are creating the image in a different machine than the one in which you intend to build the TensorRT engine. For this post, use Google Container Registry. Tag your image to a URI that you use for your project and then push to the registry. Make sure to replace GCP_PROJECT and MY_DIR with the appropriate values.

docker tag tensor_rt us.gcr.io/{GCP_PROJECT}/{MY_DIR}/tensor_rt
docker push us.gcr.io/{GCP_PROJECT}/{MY_DIR}/tensor_rt

Creating the TensorRT engine

The following commands are only necessary if you created the image in a different machine than the one in which you intend to build the TensorRT engine. Pull the TensorRT image from the registry:

docker pull us.gcr.io/{GCP_PROJECT}/{MY_DIR}/tensor_rt
docker tag us.gcr.io/{GCP_PROJECT}/{MY_DIR}/tensor_rt tensor_rt

If the ONNX model is not in the GCE VM, you can copy it from your local machine to the /models directory:

gcloud compute scp ~/Downloads/ssd_mobilenet_v2_320x320_coco17_tpu-8.onnx tensorrt-demo:~/models --zone=us-central1-a

You should now have the ONNX model and the built Docker image in the VM. Now it’s time to use them both.

Launch the Docker container interactively:

docker run --rm -it --gpus all -v /home/{username}/:/mnt tensor_rt bash

Create the TensorRT engine out of the ONNX file:

trtexec --onnx=/mnt/models/ssd_mobilenet_v2_320x320_coco17_tpu-8.onnx --saveEngine=/mnt/tensorrt_engines/ssd_mobilenet_v2_320x320_coco17_tpu-8.trt --useCudaGraph --verbose

You should now see the ssd_mobilenet_v2_320x320_coco17_tpu-8.trt file in your /tensorrt_engines directory in the VM.

Upload the TensorRT Engine to GCS

Copy the file to GCP. If you run into issues with gsutil in uploading the file directly from GCE to GCS, you may have to first copy it to your local machine.

gcloud compute scp tensorrt-demo:~/tensorrt_engines/ssd_mobilenet_v2_320x320_coco17_tpu-8.trt ~/Downloads/ --zone=us-central1-a

In the GCP console, upload the TensorRT engine file to your chosen GCS bucket:

gs://{GCS_BUCKET}/ssd_mobilenet_v2_320x320_coco17_tpu-8.trt

Testing the TensorRT engine locally

Make sure that you have a Beam pipeline that uses TensorRT RunInference. One example is tensorrt_object_detection.py, which you can follow by running the following commands in your GCE VM. Exit the Docker container first by typing Ctrl+D.

git clone https://github.com/apache/beam.git
cd beam/sdks/python
pip install --upgrade pip setuptools
pip install -r build-requirements.txt
pip install --user -e ."[gcp,test]"

You also create a file called image_file_names.txt, which contains paths to the images. The images can be in an object store like GCS, or in the GCE VM.

gs://{GCS_BUCKET}/000000289594.jpg
gs://{GCS_BUCKET}/000000000139.jpg

Then, run the following command:

docker run --rm -it --gpus all -v /home/{username}/:/mnt -w /mnt/beam/sdks/python tensor_rt python -m apache_beam.examples.inference.tensorrt_object_detection --input gs://{GCS_BUCKET}/tensorrt_image_file_names.txt --output /mnt/tensorrt_predictions.csv --engine_path gs://{GCS_BUCKET}/ssd_mobilenet_v2_320x320_coco17_tpu-8.trt

You should now see a file called tensorrt_predictions.csv. Each line has data separated by a semicolon.

  • The first item is the file name.
  • The second item is a list of dictionaries, where each dictionary corresponds with a single detection.
  • A detection contains box coordinates (ymin, xmin, ymax, xmax), score, and class.

For more information about how to set up and run TensorRT RunInference locally, follow the instructions in the Object Detection section.

The TensorRT Support Guide provides an overview of all the supported NVIDIA TensorRT 8.5.1 samples on GitHub and in the product package. These samples are designed to show how to use TensorRT in numerous use cases while highlighting different capabilities of the interface. These samples specifically help in use cases such as recommenders, machine comprehension, character recognition, image classification, and object detection.

Running TensorRT Engine with DataFlow RunInference

Now that you have the TensorRT engine, you can run a pipeline on Dataflow.

The following code example is a part of the pipeline, where you use TensorRTEngineHandlerNumPy to load the TensorRT engine and set other inference parameters. You then read the images, do preprocessing to attach keys to the images, do the prediction, and then write to a file in GCS.

For more information about the full code example, see tensorrt_object_detection.py.

  engine_handler = KeyedModelHandler(
      TensorRTEngineHandlerNumPy(
          min_batch_size=1,
          max_batch_size=1,
          engine_path=known_args.engine_path))

  with beam.Pipeline(options=pipeline_options) as p:
    filename_value_pair = (
        p
        | 'ReadImageNames' >> beam.io.ReadFromText(known_args.input)
        | 'ReadImageData' >> beam.Map(
            lambda image_name: read_image(
                image_file_name=image_name, path_to_dir=known_args.images_dir))
        | 'AttachImageSizeToKey' >> beam.Map(attach_im_size_to_key)
        | 'PreprocessImages' >> beam.MapTuple(
            lambda file_name, data: (file_name, preprocess_image(data))))
    predictions = (
        filename_value_pair
        | 'TensorRTRunInference' >> RunInference(engine_handler)
        | 'ProcessOutput' >> beam.ParDo(PostProcessor()))

    _ = (
        predictions | "WriteOutputToGCS" >> beam.io.WriteToText(
            known_args.output,
            shard_name_template='',
            append_trailing_newlines=True))

Make sure that you have completed the Google Cloud setup mentioned in the previous section. You also must have the Beam SDK installed.

To run this job on Dataflow, run the following command locally:

python -m apache_beam.examples.inference.tensorrt_object_detection \
--input gs://{GCP_PROJECT}/image_file_names.txt \
--output gs://{GCP_PROJECT}/predictions.txt \
--engine_path gs://{GCP_PROJECT}/ssd_mobilenet_v2_320x320_coco17_tpu-8.trt \
--runner DataflowRunner \
--experiment=use_runner_v2 \
--machine_type=n1-standard-4 \
--experiment="worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver" \
--disk_size_gb=75 \
--project {GCP_PROJECT} \
--region us-central1 \
--temp_location gs://{GCP_PROJECT}/tmp/ \
--job_name tensorrt-object-detection \
--sdk_container_image="us.gcr.io/{GCP_PROJECT}/{MY_DIR}/tensor_rt tensor_rt"

Depending on the size constraints of the model, you may want to adjust machine_type, the type and count of the GPU, or disk_size_gb. For more information about Beam pipeline options, see Set Dataflow pipeline options.

TensorRT and TensorFlow object detection benchmarking

To benchmark, we decided to do a comparison between the TensorRT and TensorFlow object detection versions of the previously mentioned SSD MobileNet v2 320×320 model.

Every single inference call was timed in both the TensorRT and TensorFlow object detection versions. We calculated an average of 5000 inference calls, not taking the first 10 images into account due to ramp-up latencies. The SSD model that we used is a small model. You’ll observe even better speedup when your model can make full use of the GPU.

First, we compared the direct performance speedup between TensorFlow and TensorRT with a local benchmark. We aimed to prove the added benefits with reduced precision mode on TensorRT.

Framework and precisionInference latency (ms)
TensorFlow Object Detection FP32 (end-to-end)29.47 ms
TensorRT FP32 (end-to-end)3.72 ms
TensorRT FP32 (GPU compute)2.39 ms
TensorRT FP16 (GPU compute)1.48 ms
TensorRT INT8 (GPU compute)1.34 ms
Table 1. Direct performance speedup on TensorRT

The overall speedup with TensorRT FP32 is 7.9x. End-to-end included data copies, while the GPU compute only included actual inference time. We did this separation because the example model is small. End-to-end TensorRT latency in this case is mostly data copies. You see more significant end-to-end performance improvements using different precisions in bigger models, especially in cases where inference compute is the bottleneck, not data copies.

FP16 is 1.6x faster than FP32 and has no accuracy penalty. INT8 is 1.8x faster than FP32, but sometimes comes with accuracy degradation and requires a calibration process. Accuracy degradation is model-specific, so it’s always good to try yours and see the produced accuracy.

This issue can also be mitigated using quantized networks with the NVIDIA QAT Toolkit. For more information, see Accelerating Quantized Networks with the NVIDIA QAT Toolkit for TensorFlow and NVIDIA TensorRT and the NVIDIA TensorRT Developer Guide.

Dataflow benchmarking

In Dataflow, with the TensorRT engine generated in earlier experiments, we ran with the following configurations: n1-standard-4 machine, disk_size_gb=75, and 10 workers.

To simulate a stream of data coming into Dataflow through PubSub, we set batch sizes to 1. This was done by setting ModelHandlers to have min and max batch sizes of 1.

 Stage with RunInferenceMean inference_batch_latency_micro_secs
TensorFlow with T4 GPU12 min 43 sec99,242
TensorRT with T4 GPU7 min 20 sec10,836
Table 2. Dataflow benchmarks

The Dataflow runner decomposes a pipeline into multiple stages. You can get a better picture of the performance of RunInference by looking at the stage that contains the inference call, and not the other stages that read and write data. This is in the Stage with RunInference column.

For this metric, TensorRT only spends 57% of the runtime of TensorFlow. You expect the acceleration to grow if you adapt a larger model that fully uses GPU processing power.

The metric inference_batch_latency_micro_secs is the time, in microseconds, that it takes to perform the inference on the batch of examples, that is, the time to call model_handler.run_inference. This varies over time depending on the dynamic batching decision of BatchElements, and the particular values or dtype values of the elements. For this metric, you can see that TensorRT is about 9.2x faster than TensorFlow.

Conclusion

In this post, we demonstrated how to run machine learning models at scale by seamlessly stitching together a data processing framework (Apache Beam) and inference engine (TensorRT). We presented an end-to-end example of how inference workload can be fully integrated within a data processing pipeline.

This integration enables a new inference pipeline that helps reduce production inference cost with better NVIDIA GPU utilization and much-improved inference latency and throughput. The same approach can be applied to many other inference workloads using many off-shelf TensorRT samples. In the future, we plan to further automate TensorRT engine building and work on deeper integration of TensorRT with Apache Beam.

Discuss (0)

Tags