A defining feature of the new NVIDIA Volta GPU architecture is Tensor Cores, which give the NVIDIA V100 accelerator a peak throughput that is 12x the 32-bit floating point throughput of the previous-generation NVIDIA P100. Tensor Cores enable you to use mixed-precision for higher throughput without sacrificing accuracy.
Tensor Cores provide a huge boost to convolutions and matrix operations. They are programmable using NVIDIA libraries and directly in CUDA C++ code. CUDA 9 provides a preview API for programming V100 Tensor Cores, providing a huge boost to mixed-precision matrix arithmetic for deep learning.
Tensor Cores are already supported for deep learning training, either in a main release or through pull requests, in many DL frameworks, including TensorFlow, PyTorch, MXNet, and Caffe2. For more information about enabling Tensor Cores when using these frameworks, see the Mixed-Precision Training Guide. For DL inference, the recent NVIDIA TensorRT 3 release also supports Tensor Cores.
In this post, we show you how to use Tensor Cores in your own application using CUDA Libraries as well as how to program them directly in CUDA C++ device code.
What are Tensor Cores?
NVIDIA V100 Tensor Cores are programmable matrix-multiply-and-accumulate units that can deliver up to 125 Tensor TFLOPS for training and inference applications. The V100 GPU contains 640 Tensor Cores, 8 per SM. Tensor Cores and their associated data paths are custom-crafted to dramatically increase floating-point compute throughput at only modest area and power costs. Clock gating is used extensively to maximize power savings.
Each Tensor Core provides a 4x4x4 matrix processing array that performs the operation D = A * B + C, where A, B, C, and D are 4×4 matrices (Figure 1). The matrix multiply inputs A and B are FP16 matrices, while the accumulation matrices C and D may be FP16 or FP32 matrices.
Each Tensor Core performs 64 floating-point FMA mixed-precision operations per clock, with FP16 input multiply with full-precision product and FP32 accumulate (Figure 2) and 8 Tensor Cores in an SM perform a total of 1024 floating-point operations per clock.
This is a dramatic 8X increase in throughput for deep learning applications per SM compared to Pascal GP100 using standard FP32 operations, resulting in a total 12x increase in throughput for the Volta V100 GPU compared to the Pascal P100 GPU. Tensor Cores operate on FP16 input data with FP32 accumulation. The FP16 multiply results in a full-precision result that is accumulated in FP32 operations with the other products in a given dot product for a 4x4x4 matrix multiply (Figure 2).
During program execution, multiple Tensor Cores are used concurrently by a full warp of execution. The threads within a warp provide a larger 16x16x16 matrix operation to be processed by the Tensor Cores. CUDA exposes these operations as warp-level matrix operations in the CUDA C++ WMMA API. These C++ interfaces provide specialized matrix load, matrix multiply and accumulate, and matrix store operations to efficiently use Tensor Cores in CUDA C++ programs.
But before we get into the details of low-level programming of Tensor Cores, here’s how to access their performance through CUDA libraries.
Tensor Cores in CUDA Libraries
Two CUDA libraries that use Tensor Cores are cuBLAS and cuDNN. cuBLAS uses Tensor Cores to speed up GEMM computations (GEMM is the BLAS term for a matrix-matrix multiplication). cuDNN uses Tensor Cores to speed up both convolutions and recurrent neural networks (RNNs).
Many computational applications use GEMMs: signal processing, fluid dynamics, and many, many others. As the data sizes of these applications grow exponentially, these applications require matching increases in processing speed. The mixed-precision GEMM performance chart (Figure 3) shows that Tensor Cores clearly answer this need.
The need for convolution speed improvements is just as great. For example, today’s deep neural networks (DNNs) use many layers of convolutions. AI researchers design deeper and deeper neural nets every year and the number of convolution layers in the deepest nets is now several dozen.
Training DNNs requires the convolution layers to be run repeatedly, during both forward- and back-propagation. The convolution performance chart in Figure 4 shows that Tensor Cores answer the need for convolution performance. For more information, see Mixed-Precision Training of Deep Neural Networks.
Both performance charts show that V100 Tensor Cores provide multiple times the performance of the previous-generation P100. Performance improvements this large change how work is done in computational fields, making interactivity possible, enabling “what-if” scenario studies, or reducing server farm usage. If you use GEMMs or convolutions in your applications, use the following steps to turbocharge your work.
How to use Tensor Cores in cuBLAS
You can take advantage of Tensor Cores by making a few changes to your existing cuBLAS code. The changes are small changes in your use of the cuBLAS API.
The following example code applies a few simple rules to indicate to cuBLAS that Tensor Cores should be used. These rules are enumerated explicitly after the code.
Example code
The following code example is largely the same as the common code used to invoke a GEMM in cuBLAS on previous architectures.
// First, create a cuBLAS handle:
cublasStatus_t cublasStat = cublasCreate(&handle);
// Set the math mode to allow cuBLAS to use Tensor Cores:
cublasStat = cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH);
// Allocate and initialize your matrices (only the A matrix is shown):
size_t matrixSizeA = (size_t)rowsA * colsA;
T_ELEM_IN **devPtrA = 0;
cudaMalloc((void**)&devPtrA[0], matrixSizeA * sizeof(devPtrA[0][0]));
T_ELEM_IN A = (T_ELEM_IN *)malloc(matrixSizeA * sizeof(A[0]));
memset( A, 0xFF, matrixSizeA* sizeof(A[0]));
status1 = cublasSetMatrix(rowsA, colsA, sizeof(A[0]), A, rowsA, devPtrA[i], rowsA);
// ... allocate and initialize B and C matrices (not shown) ...
// Invoke the GEMM, ensuring k, lda, ldb, and ldc are all multiples of 8,
// and m is a multiple of 4:
cublasStat = cublasGemmEx(handle, transa, transb, m, n, k, alpha,
A, CUDA_R_16F, lda,
B, CUDA_R_16F, ldb,
beta, C, CUDA_R_16F, ldc, CUDA_R_32F, algo);
A few simple rules
cuBLAS users will notice a few changes from the existing cuBLAS GEMM code:
- The routine must be a GEMM. Currently, only GEMMs support Tensor Core execution.
- The math mode must be set to
CUBLAS_TENSOR_OP_MATH
. Floating point math is not associative, so the results of the Tensor Core math routines are not quite bit-equivalent to the results of the analogous non-Tensor Core math routines. cuBLAS requires you to opt-in to the use of Tensor Cores. - All of
k
,lda
,ldb
, andldc
must be a multiple of eight;m
must be a multiple of four. The Tensor Core math routines stride through input data in steps of eight values, so the dimensions of the matrices must be multiples of eight. - The input and output data types for the matrices must be either half-precision or single-precision. (Only
CUDA_R_16F
is shown in the example, butCUDA_R_32F
also is supported.)
GEMMs that do not satisfy these rules fall back to a non-Tensor Core implementation.
GEMM performance
As mentioned earlier, Tensor Cores deliver several times the GEMM performance over previous hardware. Figure 3 shows the comparative performance of GP100 (Pascal) to GV100 (Volta) hardware.
How to use Tensor Cores in cuDNN
Using Tensor Cores in cuDNN is also easy, and again involves only slight changes to existing code.
Example code
The example code for using Tensor Cores in cuDNN can be found in conv_sample.cpp
in the cuDNN samples directory. We copied some excerpts in this post. The cuDNN samples directory is packaged with the documentation.
// Create a cuDNN handle:
checkCudnnErr(cudnnCreate(&handle_));
// Create your tensor descriptors:
checkCudnnErr( cudnnCreateTensorDescriptor( &cudnnIdesc ));
checkCudnnErr( cudnnCreateFilterDescriptor( &cudnnFdesc ));
checkCudnnErr( cudnnCreateTensorDescriptor( &cudnnOdesc ));
checkCudnnErr( cudnnCreateConvolutionDescriptor( &cudnnConvDesc ));
// Set tensor dimensions as multiples of eight (only the input tensor is shown here):
int dimA[] = {1, 8, 32, 32};
int strideA[] = {8192, 1024, 32, 1};
checkCudnnErr( cudnnSetTensorNdDescriptor(cudnnIdesc, getDataType(),
convDim+2, dimA, strideA) );
// Allocate and initialize tensors (again, only the input tensor is shown):
checkCudaErr( cudaMalloc((void**)&(devPtrI), (insize) * sizeof(devPtrI[0]) ));
hostI = (T_ELEM*)calloc (insize, sizeof(hostI[0]) );
initImage(hostI, insize);
checkCudaErr( cudaMemcpy(devPtrI, hostI, sizeof(hostI[0]) * insize, cudaMemcpyHostToDevice));
// Set the compute data type (below as CUDNN_DATA_FLOAT):
checkCudnnErr( cudnnSetConvolutionNdDescriptor(cudnnConvDesc,
convDim,
padA,
convstrideA,
dilationA,
CUDNN_CONVOLUTION,
CUDNN_DATA_FLOAT) );
// Set the math type to allow cuDNN to use Tensor Cores:
checkCudnnErr( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) );
// Choose a supported algorithm:
cudnnConvolutionFwdAlgo_t algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
// Allocate your workspace:
checkCudnnErr( cudnnGetConvolutionForwardWorkspaceSize(handle_, cudnnIdesc,
cudnnFdesc, cudnnConvDesc,
cudnnOdesc, algo, &workSpaceSize) );
if (workSpaceSize > 0) {
cudaMalloc(&workSpace, workSpaceSize);
}
// Invoke the convolution:
checkCudnnErr( cudnnConvolutionForward(handle_, (void*)(&alpha), cudnnIdesc, devPtrI,
cudnnFdesc, devPtrF, cudnnConvDesc, algo,
workSpace, workSpaceSize, (void*)(&beta),
cudnnOdesc, devPtrO) );
A few simple rules
There are a few changes from the common cuDNN use:
- The convolution algorithm must be
ALGO_1
(IMPLICIT_PRECOMP_GEMM
for forward). Other convolution algorithms besidesALGO_1
may use Tensor Cores in future cuDNN releases. - The math type must be set to
CUDNN_TENSOR_OP_MATH
. As in cuBLAS, the results of the Tensor Core math routines are not quite bit-equivalent to the results of the analogous non-tensor core math routines, so cuDNN requires you to opt-in to the use of Tensor Cores. - Both input and output channel dimensions must be a multiple of eight. Again as in cuBLAS, the Tensor Core math routines stride through input data in steps of eight values, so the dimensions of the input data must be multiples of eight.
- The input, filter, and output data types for the convolutions must be half-precision.
Convolutions that do not satisfy these rules fall back to a non-Tensor Core implementation.
The example code shows NCHW data format. For NHWC support as well, see the conv_sample.cpp
sample.
Convolution performance
As mentioned earlier, Tensor Cores deliver several times the convolution performance over previous hardware. Figure 4 shows the comparative performance of GP100 (Pascal) to GV100 (Volta) hardware.
In Figure 4, the comparison is between the geometric means of run times of the convolution layers from each neural network. Both V100 and P100 use FP16 input/output data and FP32 computation. V100 uses Tensor Cores, while P100 uses FP32 fused-multiply add (FMA).
Programmatic access to Tensor Cores in CUDA 9.0
Access to Tensor Cores in kernels through CUDA 9.0 is available as a preview feature. The data structures, APIs, and code described in this section are subject to change in future CUDA releases.
While cuBLAS and cuDNN cover many of the potential uses for Tensor Cores, you can also program them directly in CUDA C++.
Tensor Cores are exposed in CUDA 9.0 through a set of functions and types in the nvcuda::wmma
namespace. These enable you to load or initialize values into the special format required by the Tensor Cores, perform matrix multiply-accumulate (MMA) steps, and store values back out to memory.
During program execution, multiple Tensor Cores are used concurrently by a full warp. This enables the warp to perform a 16x16x16 MMA at very high throughput (Figure 5).
There’s a change in numbering from Figure 1 to Figure 5, as multiple Tensor Core operations are combined by the WMMA API to perform 16×16 matrix-multiply and accumulate operations.
Here’s an example that shows how you can use the WMMA (Warp Matrix Multiply Accumulate) API to perform a matrix multiplication. This example is not tuned for high performance and mostly serves as a demonstration of the API. For an example of optimizations you might apply to this code to get better performance, see the cudaTensorCoreGemm
sample in the CUDA Toolkit. For highest performance in production code, use cuBLAS, as described earlier.
Headers and namespaces
The WMMA API is contained within the mma.h
header file. The complete namespace is nvcuda::wmma::*
, but it’s useful to keep wmma
explicit throughout the code, so we just be using the nvcuda
namespace.
#include <mma.h>
using namespace nvcuda;
Declarations and initialization
The full GEMM specification enables the algorithm to work on transpositions of a
or b
, and for data strides to be larger than the strides in the matrix. For simplicity, assume that neither a
nor b
are transposed, and that the leading dimensions of the memory and the matrix are the same.
The strategy to employ is to have a single warp responsible for a single 16×16 section of the output matrix. By using a 2D grid and thread blocks, you can effectively tile the warps across the 2D output matrix.
// The only dimensions currently supported by WMMA
const int WMMA_M = 16;
const int WMMA_N = 16;
const int WMMA_K = 16;
__global__ void wmma_example(half *a, half *b, float *c,
int M, int N, int K,
float alpha, float beta)
{
// Leading dimensions. Packed with no transpositions.
int lda = M;
int ldb = K;
int ldc = M;
// Tile using a 2D grid
int warpM = (blockIdx.x * blockDim.x + threadIdx.x) / warpSize;
int warpN = (blockIdx.y * blockDim.y + threadIdx.y);
Before the MMA operation is performed, the operand matrices must be represented in the registers of the GPU. As an MMA is a warp-wide operation these registers are distributed amongst the threads of a warp with each thread holding a fragment of the overall matrix. The mapping between individual matrix parameters to their fragments is opaque, so your program should not make assumptions about it.
In CUDA,, a fragment is a templated type with template parameters that describe the following:
- Which matrix the fragment holds (
A
,B
, or accumulator) - The shape of the overall WMMA operation
- The data type
- For A and B matrices, whether the data is row- or column-major
This last argument can be used to perform the transposition of either A
or
B matrices. This example does no transposition so both matrices are column-major, which is standard for GEMM.
// Declare the fragments
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> a_frag;
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc_frag;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;
The final part of the initialization step is to fill the accumulator fragment with zeros.
wmma::fill_fragment(acc_frag, 0.0f);
The inner loop
The strategy for the GEMM is to compute one tile of the output matrix per warp. To do this, loop over rows of the A
matrix, and columns of the B
matrix. This is along the K-dimension of both matrices and produces an MxN output tile. The load matrix function takes data from memory (global memory in this example, although it could be any memory space) and places it into a fragment.
The third argument to the load is the leading dimension in memory of the matrix. The 16×16 tile b being loaded is discontinuous in memory so the function has to know the stride between consecutive columns or rows, if these were row-major fragments.
The MMA call accumulates in place, so both the first and last arguments are the accumulator fragment previously initialized to zero.
// Loop over the K-dimension
for (int i = 0; i < K; i += WMMA_K) {
int aRow = warpM * WMMA_M;
int aCol = i;
int bRow = i;
int bCol = warpN * WMMA_N;
// Bounds checking
if (aRow < M && aCol < K && bRow < K && bCol < N) {
// Load the inputs
wmma::load_matrix_sync(a_frag, a + aRow + aCol * lda, lda);
wmma::load_matrix_sync(b_frag, b + bRow + bCol * ldb, ldb);
// Perform the matrix multiplication
wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);
}
}
Finishing up
acc_frag
now holds the result for this warp’s output tile based on the multiplication of A and B. The complete GEMM specification enables this result to be scaled, and for it to be accumulated on top of a matrix in place.
One way to do this scaling is to perform element-wise operations on the fragment. Although the mapping from matrix coordinates to threads isn’t defined, element-wise operations don’t have to know this mapping so they can still be performed using fragments.
As such, it’s legal to perform scaling operations on fragments or to add the contents of one fragment to another, so long as those two fragments have the same template parameters. If the fragments have different template parameters, the results are undefined. Using this feature, load in the existing data in C and, with the correct scaling, accumulate the result of the computation so far with it.
// Load in current value of c, scale by beta, and add to result scaled by alpha
int cRow = warpM * WMMA_M;
int cCol = warpN * WMMA_N;
if (cRow < M && cCol < N) {
wmma::load_matrix_sync(c_frag, c + cRow + cCol * ldc, ldc, wmma::mem_col_major);
for(int i=0; i < c_frag.num_elements; i++) {
c_frag.x[i] = alpha * acc_frag.x[i] + beta * c_frag.x[i];
}
Finally, store the data out to memory. Again, the target pointer can be any memory space visible to the GPU and the leading dimension in memory must be specified. There is also an option to specify whether the output is to be written row or column major.
// Store the output
wmma::store_matrix_sync(c + cRow + cCol * ldc, c_frag, ldc, wmma::mem_col_major);
}
}
With that, the matrix multiplication is complete. I’ve omitted the host code in this post. However, a complete working example is available on GitHub.
Get started with Tensor Cores in CUDA 9 today
Hopefully, this example has given you ideas about how you might use Tensor Cores in your application. For more information, see the CUDA Programming Guide section on wmma.
The CUDA 9 Tensor Core API is a preview feature, so we’d love to hear your feedback. If you have any comments or questions, please don’t hesitate to leave a comment.
CUDA 9 is available free, so download it today.