Data Science

Speeding Up Deep Learning Inference Using TensorRT

Looking for more? Check out the hands-on DLI training course: Optimization and Deployment of TensorFlow Models with TensorRT

This is an updated version of How to Speed Up Deep Learning Inference Using TensorRT. This version starts from a PyTorch model instead of the ONNX model, upgrades the sample application to use TensorRT 7, and replaces the ResNet-50 classification model with UNet, which is a segmentation model.

Figure 1. TensorRT logo

NVIDIA TensorRT is an SDK for deep learning inference. TensorRT provides APIs and parsers to import trained models from all major deep learning frameworks. It then generates optimized runtime engines deployable in the datacenter as well as in automotive and embedded environments.

This post provides a simple introduction to using TensorRT. You learn how to deploy a deep learning application onto a GPU, increasing throughput and reducing latency during inference. It uses a C++ example to walk you through converting a PyTorch model into an ONNX model and importing it into TensorRT, applying optimizations, and generating a high-performance runtime engine for the datacenter environment.

TensorRT supports both C++ and Python; if you use either, this workflow discussion could be useful. If you prefer to use Python, see Using the Python API in the TensorRT documentation.

Deep learning applies to a wide range of applications such as natural language processing, recommender systems, image, and video analysis. As more applications use deep learning in production, demands on accuracy and performance have led to strong growth in model complexity and size.

Safety-critical applications such as automotive place strict requirements on throughput and latency expected from deep learning models. The same holds true for some consumer applications, including recommendation systems.

TensorRT is designed to help deploy deep learning for these use cases. With support for every major framework, TensorRT helps process large amounts of data with low latency through powerful optimizations, use of reduced precision, and efficient memory use.

To follow along with this post, you need a computer with a CUDA-capable GPU or a cloud instance with GPUs and an installation of TensorRT. On Linux, the easiest place to get started is by downloading the GPU-accelerated PyTorch container with TensorRT integration from the NVIDIA NGC container registry.

The sample application uses input data from Brain MRI segmentation data from Kaggle to perform inference.

Simple TensorRT example

Following are the four steps for this example application: 

  1. Convert the pretrained image segmentation PyTorch model into ONNX.
  2. Import the ONNX model into TensorRT.
  3. Apply optimizations and generate an engine.
  4. Perform inference on the GPU. 

Importing the ONNX  model includes loading it from a saved file on disk and converting it to a TensorRT network from its native framework or format. ONNX is a standard for representing deep learning models enabling them to be transferred between frameworks.

Many frameworks such as Caffe2, Chainer, CNTK, PaddlePaddle, PyTorch, and MXNet support the ONNX format. Next, an optimized TensorRT engine is built based on the input model, target GPU platform, and other configuration parameters specified. The last step is to provide input data to the TensorRT engine to perform inference.

The application uses the following components in TensorRT:

  • ONNX parser: Takes a converted PyTorch trained model into the ONNX format as input and populates a network object in TensorRT. 
  • Builder: Takes a network in TensorRT and generates an engine that is optimized for the target platform. 
  • Engine: Takes input data, performs inferences, and emits inference output.
  • Logger: Associated with the builder and engine to capture errors, warnings, and other information during the build and inference phases.

Convert the pretrained image segmentation PyTorch model into ONNX

Start with the  PyTorch container from the NGC registry to get the framework and CUDA components pre-installed and ready to go. After you have installed the PyTorch container successfully, run the following commands to download everything needed to run this sample application (example code, test input data, and reference outputs), update dependencies, and compile the application with the makefile provided.

>> sudo apt-get install libprotobuf-dev protobuf-compiler # protobuf is a prerequisite library
>> git clone --recursive https://github.com/onnx/onnx.git # Pull the ONNX repository from GitHub 
>> cd onnx
>> mkdir build && cd build 
>> cmake .. # Compile and install ONNX
>> make # Use the ‘-j’ option for parallel jobs, for example, ‘make -j $(nproc)’ 
>> make install 
>> cd ../..
>> git clone https://github.com/parallel-forall/code-samples.git
>> cd code-samples/posts/TensorRT-introduction
>> make clean && make # Compile the TensorRT C++ code
>> cd ..
>> wget https://developer.download.nvidia.com/devblogs/speeding-up-unet.7z // Get the ONNX model and test the data
>> tar xvf speeding-up-unet.7z # Unpack the model data into the unet folder
>> cd unet
>> python create_network.py #Inside the unet folder, it creates the unet.onnx file

Convert the PyTorch-trained UNet model into ONNX, as shown in the following code example:

import torch
from torch.autograd import Variable
import torch.onnx as torch_onnx
import onnx
def main():
    input_shape = (3, 256, 256)
    model_onnx_path = "unet.onnx"
    dummy_input = Variable(torch.randn(1, *input_shape))
    model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
      in_channels=3, out_channels=1, init_features=32, pretrained=True)
    model.train(False)
    
    inputs = ['input.1']
    outputs = ['186']
    dynamic_axes = {'input.1': {0: 'batch'}, '186':{0:'batch'}}
    out = torch.onnx.export(model, dummy_input, model_onnx_path, input_names=inputs, output_names=outputs, dynamic_axes=dynamic_axes)

if __name__=='__main__':
    main() 

Next, prepare the input data for inference. Download all images from the Kaggle directory. Copy to the /unet directory any three images that don’t have _mask in their filename and the utils.py file from the brain-segmentation-pytorch repository. Prepare three images to be used as input data later in this post. To prepare the input_0. pb and ouput_0. pb files for use later, run the following code example:

import torch 
import argparse
import numpy as np
from torchvision import transforms                    
from skimage.io import imread
from onnx import numpy_helper
from utils import normalize_volume
def main(args):
    model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
      in_channels=3, out_channels=1, init_features=32, pretrained=True)
    model.train(False)
    
    filename = args.input_image
    input_image = imread(filename)
    input_image = normalize_volume(input_image)
    input_image = np.asarray(input_image, dtype='float32')
    
    preprocess = transforms.Compose([
      transforms.ToTensor(),
    ])
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0)
    
    tensor1 = numpy_helper.from_array(input_batch.numpy())
    with open(args.input_tensor, 'wb') as f:
        f.write(tensor1.SerializeToString())
    if torch.cuda.is_available():
        input_batch = input_batch.to('cuda')
        model = model.to('cuda')
    with torch.no_grad():
        output = model(input_batch)
    
    tensor = numpy_helper.from_array(output[0].cpu().numpy())
    with open(args.output_tensor, 'wb') as f:
        f.write(tensor.SerializeToString())
if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_image', type=str)
    parser.add_argument('--input_tensor', type=str, default='input_0.pb')
    parser.add_argument('--output_tensor', type=str, default='output_0.pb')
    args=parser.parse_args()
    main(args) 

To generate processed input data for inference, run the following commands:

>> pip install medpy #dependency for utils.py file
>> mkdir test_data_set_0
>> mkdir test_data_set_1
>> mkdir test_data_set_2
>> python prepareData.py --input_image your_image1 --input_tensor test_data_set_0/input_0.pb --output_tensor test_data_set_0/output_0.pb   # This creates input_0.pb and output_0.pb
>> python prepareData.py --input_image your_image2 --input_tensor test_data_set_1/input_0.pb --output_tensor test_data_set_1/output_0.pb   # This creates input_0.pb and output_0.pb
>> python prepareData.py --input_image your_image3 --input_tensor test_data_set_2/input_0.pb --output_tensor test_data_set_2/output_0.pb   # This creates input_0.pb and output_0.pb 

That’s it, you have the input data ready to perform inference. Begin with a simplified version of the application, simpleONNX_1.cpp and build on it. Subsequent versions are available in the same folder, simpleONNX_2.cpp and simpleONNX.cpp.

Import the ONNX model into TensorRT, generate the engine, and perform inference 

Run the sample application with the trained model and input data passed as inputs. The data is provided as an ONNX protobuf file. The sample application compares output generated from TensorRT with reference values available as ONNX .pb files in the same folder and summarizes the result on the prompt.

It can take a few seconds to import the UNet ONNX model and generate the engine. It also generates the output image in the portable gray map (PGM) format as output.pgm. 

>> cd to code-samples/posts/TensorRT-introduction
>> ./simpleOnnx_1 path/to/unet/unet.onnx path/to/unet/test_data_set_0/input_0.pb # The sample application expects output reference values in path/to/unet/test_data_set_0/output_0.pb
...
 Tactic: 0 is the only option, timing skipped
: Fastest Tactic: 0 Time: 0
: Formats and tactics selection completed in 2.26589 seconds.
: After reformat layers: 32 layers
: Block size 1073741824
: Block size 536870912

...
: Total Activation Memory: 2248146944
INFO: Detected 1 inputs and 1 output network tensors.
Engine generation completed in 3.37261 seconds.
OK 

And that’s it, you have an application that is optimized with TensorRT and running on your GPU. Figure 2 shows the output of a sample test case.

Here are a few key code examples used in the earlier sample application.

The main function in the following code example starts by declaring a CUDA engine to hold the network definition and trained parameters. The engine is generated in the createCudaEngine function that takes the path to the ONNX model as input.

// Declare the CUDA engineunique_ptr<ICudaEngine, Destroy<ICudaEngine>> engine{nullptr};
...
// Create the CUDA engine
engine.reset(createCudaEngine(onnxModelPath));

The createCudaEngine function parses the ONNX model and holds it in the network object. To handle the dynamic input dimensions of input images and shape tensors for U-Net model, you must create an optimization profile from the builder class, as shown in the following code example.

The optimization profile enables you to set the optimum input, minimum, and maximum dimensions to the profile. The builder selects the kernel that results in lowest runtime for input tensor dimensions and which is valid for all input tensor dimensions in the range between the minimum and maximum dimensions. It also converts the network object into a TensorRT engine.

The setMaxBatchSize function in the following code example is used to specify the maximum batch size that a TensorRT engine expects. The setMaxWorkspaceSize function allows you to increase the GPU memory footprint during the engine building phase.

nvinfer1::ICudaEngine* createCudaEngine(string const& onnxModelPath, int batchSize){
    unique_ptr<nvinfer1::IBuilder, Destroy<nvinfer1::IBuilder>> builder{nvinfer1::createInferBuilder(gLogger)};
    const auto explicitBatch = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
    unique_ptr<nvinfer1::INetworkDefinition, Destroy<nvinfer1::INetworkDefinition>> network{builder->createNetworkV2(explicitBatch)};
    unique_ptr<nvonnxparser::IParser, Destroy<nvonnxparser::IParser>> parser{nvonnxparser::createParser(*network, gLogger)};
    unique_ptr<nvinfer1::IBuilderConfig,Destroy<nvinfer1::IBuilderConfig>> config{builder->createBuilderConfig()};
 
    if (!parser->parseFromFile(onnxModelPath.c_str(), static_cast<int>(ILogger::Severity::kINFO)))
    {
            cout << "ERROR: could not parse input engine." << endl;
            return nullptr;
    }
    builder->setMaxBatchSize(batchSize);
    config->setMaxWorkspaceSize((1 << 30));
    
    auto profile = builder->createOptimizationProfile();
    profile->setDimensions(network->getInput(0)->getName(), OptProfileSelector::kMIN, Dims4{1, 3, 256 , 256});
    profile->setDimensions(network->getInput(0)->getName(), OptProfileSelector::kOPT, Dims4{1, 3, 256 , 256});
    profile->setDimensions(network->getInput(0)->getName(), OptProfileSelector::kMAX, Dims4{32, 3, 256 , 256});    
    config->addOptimizationProfile(profile);
    return builder->buildEngineWithConfig(*network, *config);
} 

After an engine has been created, create an execution context to hold intermediate activation values generated during inference. The following code shows how to create the execution context.

// Declare the execution context
unique_ptr<IExecutionContext, Destroy<IExecutionContext>> context{nullptr};
...
// Create the execution context
context.reset(engine->createExecutionContext()); 

This application places inference requests on the GPU asynchronously in the function launchInference shown in the following code example. Inputs are copied from host (CPU) to device (GPU) within launchInference, inference is then performed with the enqueue function, and results copied back asynchronously.

The example uses CUDA streams to manage asynchronous work on the GPU. Asynchronous inference execution generally increases performance by overlapping compute as it maximizes GPU utilization. The enqueue function places inference requests on CUDA streams and takes as input runtime batch size, pointers to input and output, plus the CUDA stream to be used for kernel execution. Asynchronous data transfers are performed from the host to device and the reverse using cudaMemcpyAsync.

void launchInference(IExecutionContext* context, cudaStream_t stream, vector<float> const& inputTensor, vector<float>& outputTensor, void** bindings, int batchSize)
{
    int inputId = getBindingInputIndex(context);

    cudaMemcpyAsync(bindings[inputId], inputTensor.data(), inputTensor.size() * sizeof(float), cudaMemcpyHostToDevice, stream);
    context->enqueueV2(bindings, stream, nullptr);
    cudaMemcpyAsync(outputTensor.data(), bindings[1 - inputId], outputTensor.size() * sizeof(float), cudaMemcpyDeviceToHost, stream);
} 

Using the cudaStreamSynchronize function after calling launchInference ensures GPU computations complete before the results are accessed. The number of inputs and outputs, as well as the value and dimension of each, can be queried using functions from the ICudaEngine class. The sample finally compares reference output with TensorRT-generated inferences and prints discrepancies to the prompt.

For more information about classes, see the TensorRT Class List. The complete code example is in simpleOnnx_1.cpp.

Batch your inputs

This application example expects a single input and returns output after performing inference on it. Real applications commonly batch inputs to achieve higher performance and efficiency. A batch of inputs that are identical in shape and size can be computed in parallel on different layers of the neural network.

Larger batches generally enable more efficient use of GPU resources. For example, batch sizes using multiples of 32 may be particularly fast and efficient in lower precision on Volta and Turing GPUs because TensorRT can use special kernels for matrix multiply and fully connected layers that leverage Tensor Cores.

Pass the images to the application on the command line using the following code. The number of images (.pb files) passed as input arguments on the command line determines the batch size in this example. Use test_data_set_* to take all the input_0.pb files from all the directories. Instead of reading just one input, the following command reads all inputs available in the folders.

Currently, the downloaded data has three input directories, so batch size is 3. This version of the example profiles the application and prints the result to the prompt. For more information, see the next section, Profile the application.

>> ./simpleOnnx_2 path/to/unet.onnx path/to/unet/test_data_set_*/input_0.pb # Use all available test data sets.
...
: Formats and tactics selection completed in 2.33156 seconds.
: After reformat layers: 32 layers
: Block size 1073741824
: Block size 536870912
...
: Total Activation Memory: 2248146944
INFO: Detected 1 inputs and 1 output network tensors.
: Engine generation completed in 3.45499 seconds.
Inference batch size 3 average over 10 runs is 5.23616ms
OK 

To process multiple images in one inference pass, make a couple of changes to the application. First, collect all images (.pb files) in a loop to use as input in the application:

input_files.push_back(string{argv[2]});    
for (int i = 2; i < argc; ++i)
   input_files.push_back(string{argv[i]}); 

Next, specify the maximum batch size that a TensorRT engine expects using the setMaxBatchSize function. The builder then generates an engine tuned for that batch size by choosing algorithms that maximize its performance on the target platform. While the engine does not accept larger batch sizes, using smaller batch sizes at runtime is allowed.

The choice of maxBatchSize value depends on the application as well as the expected inference traffic (for example, the number of images) at any given time. A common practice is to build multiple engines optimized for different batch sizes (using different maxBatchSize values), and then choosing the most optimized engine at runtime.

When not specified, the default batch size is 1, meaning that the engine does not process batch sizes greater than 1. Set this parameter as shown in the following code example:

 builder->setMaxBatchSize(batchSize); 

Profile the application

Now that you’ve seen an example, here’s how to measure its performance. The simplest performance measurement for network inference is the time elapsed between an input being presented to the network and an output being returned, referred to as latency.

For many applications on embedded platforms, latency is critical while consumer applications require quality-of-service. Lower latencies make these applications better. This example measures the average latency of an application using time stamps on the GPU. There are many ways to profile your application in CUDA. For more information, see How to Implement Performance Metrics in CUDA C/C++ .

CUDA offers lightweight event API functions to create, destroy, and record events, as well as calculate the time between them. The application can record events in the CUDA stream, one before initiating inference and another after the inference completes, shown in the following code example.

In some cases, you might care about including the time it takes to transfer data between the GPU and CPU before inference initiates and after inference completes. Techniques exist to pre-fetch data to the GPU as well as overlap compute with data transfers that can significantly hide data transfer overhead. The function cudaEventElapsedTime measures the time between these two events being encountered in the CUDA stream.

Use the code example at the beginning of the last section to run this sample and review profiling output. To profile the application, wrap the inference launch within the function doInference in simpleONNX_2.cpp. This example includes an updated function call.

launchInference(context, stream, inputTensor, outputTensor, bindings, batchSize);
//Wait until the work is finished
cudaStreamSynchronize(stream);
doInference(context.get(), stream, inputTensor, outputTensor, bindings, batchSize); 

Calculate latency within doInference as follows:

// Number of times to run inference and calculate average timeconstexpr int ITERATIONS = 10;
...
void doInference(IExecutionContext* context, cudaStream_t stream, vector<float> const& inputTensor, vector<float>& outputTensor, void** bindings, int batchSize)
{
    CudaEvent start;
    CudaEvent end;
    double totalTime = 0.0;

    for (int i = 0; i < ITERATIONS; ++i)
    {
        float elapsedTime;

        // Measure time that it takes to copy input to GPU, run inference, and move output back to CPU
        cudaEventRecord(start, stream);
        launchInference(context, stream, inputTensor, outputTensor, bindings, batchSize);
        cudaEventRecord(end, stream);

        // Wait until the work is finished
        cudaStreamSynchronize(stream);
        cudaEventElapsedTime(&elapsedTime, start, end);

        totalTime += elapsedTime;
    } 

    cout << "Inference batch size " << batchSize << " average over " << ITERATIONS << " runs is " << totalTime / ITERATIONS << "ms" << endl;
} 

Many applications perform inferences on large amounts of input data accumulated and batched for offline processing. The maximum number of inferences possible per second, known as throughput, is a valuable metric for these applications.

You measure throughput by generating optimized engines for larger specific batch sizes, run inference, and measure the number of batches that can be processed per second. Use the number of batches per second and batch size to calculate the number of inferences per second, but this is out of scope for this post. 

Optimize your application

Now that you know how to run inference in batches and profile your application, optimize it. The key strength of TensorRT is its flexibility and use of techniques including mixed precision, efficient optimizations on all GPU platforms, and the ability to optimize across a wide range of model types.

In this section, we describe a few techniques to increase throughput and reduce latency from applications. For more information, see Best Practices for TensorRT Performance.

Here are a few common techniques: 

  • Use mixed precision computation
  • Change the workspace size
  • Reuse the TensorRT engine

Use mixed precision computation

TensorRT uses FP32 algorithms for performing inference to obtain the highest possible inference accuracy by default. However, you can use FP16 and INT8 precision for inference with minimal impact to accuracy of results in many cases.

Using reduced precision to represent models enables you to fit larger models in memory and achieve higher performance given lower data transfer requirements for reduced precision. You can also mix computations in FP32 and FP16 precision with TensorRT, referred to as mixed precision, or use INT8 quantized precision for weights, activations, and execute layers.

Enable FP16 kernels by setting the setFp16Mode parameter to true for devices that support fast FP16 math.

 builder->setFp16Mode(builder->platformHasFastFp16());

The setFp16Mode parameter indicates to the builder that a lower precision is acceptable for computations. TensorRT uses FP16 optimized kernels if they perform better with the chosen configuration and target platform.

With this mode turned on, weights can be specified in FP16 or FP32, and are converted automatically to the appropriate precision for the computation. You also have the flexibility of specifying 16-bit floating point data type for input and output tensors, which is out of scope for this post.

Change the workspace size

TensorRT allows you to increase GPU memory footprint during the engine building phase with the setMaxWorkspaceSize function. Increasing the limit may affect the number of applications that could share the GPU at the same time. Setting this limit too low may filter out several algorithms and create a suboptimal engine. TensorRT allocates just the memory required even if the amount set in IBuilder::setMaxWorkspaceSize is much higher. Applications should therefore allow the TensorRT builder as much workspace as they can afford. TensorRT allocates no more than this and typically less.

This example uses 1 GB, which lets TensorRT pick any algorithm available.

// Allow TensorRT to use up to 1 GB of GPU memory for tactic selectionconstexpr size_t MAX_WORKSPACE_SIZE = 1ULL << 30; // 1 GB worked well for this example
...
// Set the builder flag
builder->setMaxWorkspaceSize(MAX_WORKSPACE_SIZE); 

Reuse the TensorRT engine

When building the engine, the builder object selects the most optimized kernels for the chosen platform and configuration. Building the engine from a network definition file can be time-consuming and should not be repeated each time you perform inference, unless the model, platform, or configuration changes.

Figure 3 shows that you can transform the format of the engine after generation and store on disk for reuse later, known as serializing the engine. Deserializing occurs when you load the engine from disk into memory and continue to use it for inference.

Figure 3. Serializing and deserializing the TensorRT engine.

The runtime object deserializes the engine.

Instead of creating the engine each time, simpleOnnx.cpp contains the getCudaEngine function to load and use an engine if it exists. If the engine is not available, it creates and saves the engine in the current directory with the name unet_batch4.engine. Before this example tries to build a new engine, it picks this engine if it is available in the current directory.

To force a new engine to be built with updated configuration and parameters, use the make clean_engines command to delete all existing serialized engines stored on disk before re-running the code example.

engine.reset(createCudaEngine(onnxModelPath, batchSize))engine.reset(getCudaEngine(onnxModelPath, batchSize)); 
 
ICudaEngine* getCudaEngine(string const& onnxModelPath)
{
    string enginePath{getBasename(onnxModelPath) + ".engine"};
    ICudaEngine* engine{nullptr};

    string buffer = readBuffer(enginePath);
    if (buffer.size())
    {
        // Try to deserialize the engine
        unique_ptr<IRuntime, Destroy<IRuntime>> runtime{createInferRuntime(gLogger)};
        engine = runtime->deserializeCudaEngine(buffer.data(), buffer.size(), nullptr);
    }

    if (!engine)
    {
        // Fall back to creating the engine from scratch
        engine = createCudaEngine(onnxModelPath);

        if (engine)
        {
            unique_ptr<IHostMemory, Destroy<IHostMemory>> engine_plan{engine->serialize()};
            // Try to save the engine for future uses
            writeBuffer(engine_plan->data(), engine_plan->size(), enginePath);
        }
    }
    return engine;
} 

Use this saved engine with different batch sizes. The following code example  takes input data, repeats it as many times as your batch size variable, and then passes this appended input to the sample. The first run creates the engine and the second run tries to deserialize the engine.

>> for x in seq {1..4} ; do echo path/to/unet/test_data_set_0/input_0.pb ; done  | xargs ./simpleOnnx path/to/unet/unet.onnx...

: Tactic: 0 is the only option, timing skipped
: Fastest Tactic: 0 Time: 0
: Formats and tactics selection completed in 2.3837 seconds.
: After reformat layers: 32 layers
: Block size 1073741824
: Block size 536870912
...
: Total Activation Memory: 2248146944
Inference batch size 4 average over 10 runs is 6.86188ms

>> for x in seq {1 4}; do echo unet/test_data_set_0/input_0.pb ; done  | xargs ./simpleOnnx unet/unet.onnx
: Deserialize required 1400284 microseconds.
Inference batch size 4 average over 10 runs is 6.80197ms
OK 

You’ve now learned how to speed up inference of a simple application using TensorRT. We measured the earlier performance on NVIDIA TITAN V GPUs with TensorRT 7.

Next steps

Real-world applications have much higher computing demands with larger deep learning models, more data processing needs, and tighter latency bounds. TensorRT offers high-performance optimizations for compute- heavy deep learning applications and is an invaluable tool for inference.

Hopefully, this post has familiarized you with the key concepts needed to get amazing performance with TensorRT. Here are some ideas to apply what you have learned, use other models, and explore the impact of design and performance tradeoffs by changing parameters introduced in this post. 

  1. The TensorRT support matrix provides a look into supported features and software for TensorRT APIs, parsers, and layers. While this example used C++, TensorRT provides both C++ and Python APIs. To run the sample application included in this post, see the APIs and Python and C++ code examples in the TensorRT Developer Guide.
  2. Change the allowable precision with the parameter setFp16Mode to true/false for the  models and profile the applications to see the difference in performance.
  3. Change the batch size used at run time for inference and see how that impacts the performance (latency, throughput) of your model and dataset.
  4. Change the maxbatchsize parameter from 64 to 4 and see different kernels get selected among the top five. Use nvprof to see the kernels in the profiling results.

One topic not covered in this post is performing inference accurately in TensorRT with INT8 precision. TensorRT automatically converts an FP32 network for deployment with INT8 reduced precision while minimizing accuracy loss. To achieve this goal, TensorRT uses a calibration process that minimizes the information loss when approximating the FP32 network with a limited 8-bit integer representation. For more information, see Fast INT8 Inference for Autonomous Vehicles with TensorRT 3.

There are numerous resources to help you accelerate applications for image/video, speech apps, and recommendation systems. These range from code samples, self-paced Deep Learning Institute labs and tutorials to developer tools for profiling and debugging applications. 

If you have issues with TensorRT, check the NVIDIA TensorRT Developer Forum to see if other members of the TensorRT community have a resolution first. NVIDIA Registered Developers can also file bugs on the Developer Program page.

Discuss (5)

Tags