nvmath-python (Beta) is an open-source Python library, providing Python programmers with access to high-performance mathematical operations from NVIDIA CUDA-X math libraries. nvmath-python provides both low-level bindings to the underlying libraries and higher-level Pythonic abstractions. It is interoperable with existing Python packages, such as PyTorch and CuPy.
In this post, I show how to use epilogs with matrix multiplication in nvmath-python. Epilogs are operations that can be fused with the mathematical operation being performed, like FFT or matrix multiplication. Available epilogs cover the most common deep-learning computations. I demonstrate their usage by implementing the common forward and backward pass operations of a simple neural network.
To install nvmath-python, follow the installation instructions.
Optimizing the forward pass with the RELU_BIAS epilog
In this section, I demonstrate how to use epilogs to implement a forward pass of a simple linear layer. This layer first multiplies the input vectors by a weights matrix, then adds a bias to each element of the resulting matrix, and finally applies the ReLU activation function.
ReLU, short for Rectified Linear Unit, is a commonly used activation function that replaces negative values with zeros while leaving positive values unchanged.
In terms of matrix operations, the layer can be expressed as follows:
In the equation, the following definitions are true:
- is a batch of input vectors of shape :
- is the number of layer’s inputs.
- is the batch size.
- is the weight matrix of shape :
- is the number of layer’s outputs.
- is the number of its inputs.
- is the bias vector of length , which is added to each column of the resulting matrix.
Assume that you have your inputs, weights, and bias as CuPy arrays:
num_inputs, num_outputs = 784, 100
batch_size = 256
weights = cupy.random.rand(num_outputs, num_inputs)
bias = cupy.random.rand(num_outputs)
x = cupy.zeros((num_inputs, batch_size))
In the most basic version, you can implement this linear layer by using nvmath-python for calculating , and then handling bias and ReLU manually, as in the following code example.
In this example, I use a stateful API, in which you can separate initialization and planning from the actual execution of the multiplication. I recommend this approach when you must perform multiple similar multiplications, as it enables you to amortize the initial cost of planning. For more information about Matmul
, see nvmath.linalg.advanced.Matmul.
mm = Matmul(weights, x)
mm.plan()
def forward():
y = mm.execute()
y += bias[:,cupy.newaxis]
y[y < 0] = 0
return y
To improve the performance of the code, take advantage of the RELU_BIAS epilog to perform all three operations in a single, fused cuBLAS operation. This epilog first adds the bias to the result of the multiplication and then applies the ReLU function.
You can specify the epilog using the epilog
argument of the Matmul.plan
method. Some epilogs, including RELU_BIAS
, take extra inputs, which can be specified in the epilog_inputs
dictionary. For more information about epilogs, see nvmath.linalg.advanced.Matmul.
from nvmath.linalg.advanced import MatmulEpilog
mm = Matmul(weights, x)
mm.plan(epilog=MatmulEpilog.RELU_BIAS, epilog_inputs={"bias": bias})
def forward():
y = mm.execute()
return y
As I explain later, to backpropagate through the ReLU function, you must know which inputs to the ReLU were positive and which ones were negative. This auxiliary information, called the ReLU mask, can be obtained with the RELU_AUX_BIAS
epilog.
When an epilog with auxiliary outputs is used, a tuple containing the actual result and the dictionary of auxiliary outputs is returned from Matmul.execute
. In the case of RELU_AUX_BIAS
, the auxiliary output dictionary has one key relu_mask
, which contains the ReLu mask. This mask is bit-encoded and might be hard to read, but there are dedicated epilogs that do this for you during the backward pass.
from nvmath.linalg.advanced import MatmulEpilog
mm = Matmul(weights, x)
mm.plan(epilog=MatmulEpilog.RELU_AUX_BIAS, epilog_inputs={"bias": bias})
relu_mask = None
def forward():
global relu_mask
y, aux_outputs = mm.execute()
relu_aux = aux_outputs["relu_aux"]
return y
The implementation using the RELU_AUX_BIAS
epilog is faster than its naive counterpart, providing a significant performance gain.
Figure 2 shows performing matrix multiplication of float16 matrices of sizes (65536,16384)(16384, 8192), followed by bias addition and ReLU. The performance is measured on an NVIDIA H200 GPU.
Optimizing the backward pass with the DRELU_BGRAD epilog
During the backward pass of a neural network, the gradient of the loss function with respect to the output is propagated back through the network layers to compute the gradients for each parameter.
Intuitively, for each operation, when the effect of its output on the loss is known, it becomes possible to determine how its inputs and parameters (such as the values in a weight matrix) influence the loss. For more information, see Backpropagation.
In this part, I assume that there are several linear layers stacked together. I implement a backpropagation over the sequence of operations that are normally considered to belong to different layers: adding bias, applying ReLU, and multiplying by the weights.
Let be the input to the part of the network shown earlier, and show the intermediate results by , , and , respectively:
In backpropagation, when you know how the loss function is affected by , which is , it is possible to calculate the gradients with respect to other parameters. For more information about the derivations of the formulas used to compute the gradients, see Automatic Differentiation and Neural Networks.
- where was negative and where was non-negative (ReLU mask contains this information)
- is , summed over the batch dimension
The operations required to compute and can be naively implemented by using Matmul
just for matrix multiplication, and then handling masking and batch sum manually:
mm = Matmul(weights.T, grad)
mm.plan()
def backward():
grad_t1 = mm.execute()
grad_t1[mask] = 0 # assuming that `mask = (t1 < 0)`
grad_bias = cupy.sum(grad_t1, axis=1)
return grad_t1, grad_bias
To optimize your backward pass, use the DRELU_BGRAD
epilog. Assume that the gradient is available in a CuPy array grad
. The DRELU_BGRAD
epilog expects one input, relu_aux
, containing the mask returned from RELU_AUX_BIAS
epilog. It applies this mask to the result of the multiplication. It also returns an auxiliary output with the column-wise sum of the result, which happens to be .
mm = Matmul(weights.T, grad)
mm.plan(epilog=MatmulEpilog.DRELU_BGRAD, epilog_inputs={"relu_aux":relu_mask})
def backward():
grad_t1, aux_outputs = mm.execute()
grad_bias = aux_outputs["drelu_bgrad"]
return grad_t1, grad_bias
Figure 5 shows performing matrix multiplication of float16 matrices of sizes (65536,16384)(16384, 8192), followed by the application of ReLU mask and bias gradient computation. The performance was measured on an NVIDIA H200 GPU.
Conclusion
With the epilogs of nvmath-python, you can fuse common deep learning computations together in your Python code, which enables you to greatly improve the performance. For more information, see the nvmath-python: Unleashing the Full Capabilities of NVIDIA Math Libraries within Python documentation. For an example of end-to-end implementation of a simple neural network with nv-math python, see the Backpropagation Jupyter notebook on GitHub.
We are an open-source library, so feel free to visit the /NVIDIA/nvmath-python GitHub repo and reach out to us there.