(FP32) arithmetic by default. However, using FP32 for all operations is not essential to achieve full accuracy for many state-of-the-art deep neural networks (DNNs). In 2017, NVIDIA researchers developed a methodology for mixed-precision training in which a few operations are executed in FP32 while the majority of the network is executed using 16-bit floating point (FP16) arithmetic. FP16 arithmetic offers the following additional performance benefits on Volta GPUs:
- FP16 reduces memory bandwidth and storage requirements by 2x. Bandwidth-bound operations can realize up to 2x speedup immediately.
- FP16 arithmetic enables Tensor Cores, which in Volta GPUs offer 125 TFlops of computational throughput on generalized matrix-matrix multiplications (GEMMs) and convolutions, an 8X increase over FP32.
With mixed precision training, networks receive almost all the memory savings and improved throughput of pure FP16 training while matching the accuracy of FP32 training. A number of recently published results demonstrate the accuracy and high performance of the mixed precision recipe:
- Facebook AI Research’s FAIRseq translation network achieves a nearly 5x speedup over pure FP32 training on the same number of GPUs, and state-of-the-art BLEU score on an English to German translation task.
- Researchers from NVIDIA demonstrated speedups between 3.25x and 4.25x on BERT pretraining. This example code is open-sourced as part of NVIDIA’s deep learning examples.
- Researchers from NVIDIA and Baidu recently showed that a wide range of bellwether networks, applied to a wide range of tasks, achieve comparable or superior test accuracy when trained with mixed precision, using the same hyperparameters and training schedules as pure FP32 baselines.
- Ensuring that weight updates are carried out in FP32.
- Loss scaling to prevent underflowing gradients.
- A few operations (e.g. large reductions) left in FP32.
- Everything else (the majority of the network) executed in FP16.
Mixed-Precision in PyTorch
PyTorch has comprehensive built-in support for mixed-precision training. Calling
.half() on a module converts its parameters to FP16, and calling
.half() on a tensor converts its data to FP16. Any operations performed on such modules or tensors will be carried out using fast FP16 arithmetic. PyTorch also has strong built-in support for NVIDIA math libraries (cuBLAS and cuDNN). These libraries use Tensor Cores to perform GEMMs (e.g., fully connected layers) and convolutions on FP16 data. For a GEMM with dimensions [M, K] x [K, N] -> [M, N], to allow cuBLAS to use Tensor Cores, there exists the additional requirement that M, K, and N be multiples of 8.
We developed Apex to streamline the mixed precision user experience and enable researchers to leverage mixed precision training in their models more conveniently. Apex is a lightweight PyTorch extension containing (among other utilities) Amp, short for Automatic Mixed-Precision. Amp enables users to take advantage of mixed precision training by adding just a few lines to their networks. Apex was released at CVPR 2018, and the current incarnation of Amp was announced at GTC San Jose 2019. Since release, Apex has seen good adoption by the PyTorch community, with nearly 3,000 stars on GitHub.
Amp emphasizes simplicity by performing relatively low-level modifications to the running model so you don’t need to worry about mixed types when writing or running your model training script. Models that use PyTorch in less common ways may find Amp’s assumptions don’t fit as well, but hooks exist to modify those assumptions as needed.
Drop-in Mixed-Precision Training: Amp
Amp provides all the benefits of mixed-precision training without any explicit management of loss scaling or type conversions.
Integrating Amp into an existing PyTorch model
The following steps are required to integrate Amp into an existing PyTorch script:
- Import Amp from the Apex library.
- Initialize Amp so it can insert the necessary modifications to the model, optimizer, and PyTorch internal functions.
- Mark where backpropagation (
.backward()) occurs so that Amp can both scale the loss and clear per-iteration state.
Step one is a single line of code:
from apex import amp
Step 2 is also a single line of code, it requires that both the neural network model and the optimizer used for training be already defined:
model, optimizer = amp.initialize(model, optimizer)
You can pass additional options that give you finer control of how Amp adjusts the tensor and operation types.
As for step three, identify where in your code the backward pass occurs. You’ll see a few lines of code that look like the following:
loss = criterion(…) loss.backward() optimizer.step()
To enable loss scaling, you simply wrap the backward pass in the Amp context manager:
loss = criterion(…) with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step()
And that’s it. You can now re-run your script and have mixed-precision training enabled.
The Amp API offers additional features to handle complications like multiple optimizers, multiple backward passes, and working with custom C++ or CUDA layers not part of native PyTorch. Complete documentation can be found here.
How Amp works
At the logical level, Amp works by employing a whitelist / blacklist model. PyTorch’s tensor operations include neural network functions like
torch.nn.functional.conv2d, basic math functions like
torch.log, and tensor methods like
torch.Tensor.__add__ called when you write
a + b for two tensors). Note that these functions are a level below the neural network Module API. Modules (e.g.,
torch.nn.Conv2d) call into the corresponding functions for their implementation.
We divide the universe of functions into three sets:
- Whitelist. Functions where we expect a speedup with FP16 math. The most common examples of these are the matrix multiply and convolution functions.
- Blacklist. Functions for which 16 bits of precision may not be sufficient, so we want to ensure that inputs are in FP32. The most common examples of these are the neural net loss functions like softmax with cross entropy.
- Everything else (whatever functions are leftover). These include functions for which FP16 can work but the cost of an FP32 -> FP16 cast to run them in FP16 isn’t worthwhile since the speedup is small.
In principle, the job of Amp is straightforward. Whenever a PyTorch function gets called, Amp checks whether it is whitelist / blacklist / neither. If whitelist, cast all arguments to FP16; if blacklist, cast all arguments to FP32; and if neither, simply ensure all arguments are of the same type (casting to the widest type if not). In practice, though, implementing the above policy is not entirely straightforward.
Capturing function calls
Because PyTorch is so flexible and dynamic (a good thing!), it lacks a static model object or graph to latch onto and insert the casts described above. Instead, Amp does so dynamically by “monkey patching” the necessary functions to intercept and cast their arguments. For example, to ensure that
torch.nn.functional.linear always casts its arguments to fp16, you can write code like this:
orig_linear = torch.nn.functional.linear def wrapped_linear(*args): casted_args =  for arg in args: if torch.is_tensor(arg) and torch.is_floating_point(arg): casted_args.append(torch.cast(arg, torch.float16)) else: casted_args.append(arg) return orig_linear(*casted_args) torch.nn.functional.linear = wrapped_linear
Other subtleties exist to make the code more robust (different argument types, keyword arguments), but what Amp essentially does on a call to
amp.init() is insert monkey patches on all of the relevant PyTorch functions so that arguments are casted appropriately at runtime.
One additional challenge with the function-based casting approach remains. Naively applied, any sort of parameter sharing induces multiple casts of the same weight on each iteration. For example, the
nn.RNNCell module will call an RNN function once for each timestep with the same FP32 weight arguments.
To ensure each weight is casted FP32 -> FP16 no more than once per iteration, Amp keeps an internal cache of any parameter casts and reuses casted versions when appropriate. The context manager around the backward pass indicates to Amp when to clear the cache at each iteration.
Get Started with Apex
Installation instructions can be found on Apex GitHub page and complete API documentation can be found here. Apex was developed in dialogue with deep learning researchers at NVIDIA and the external community. It’s an open source project, and we welcome any suggestions, feature requests, bug reports, or contributions. Feel free to submit PRs and issues on Github, or leave a comment below.
Our GTC 2019 presentation covers the theory and PyTorch usage of automatic mixed precision in-depth.