Generative AI

Fusing Epilog Operations with Matrix Multiplication Using nvmath-python

Code showing how to use epilogs with matrix multiplication in nvmath-python.

nvmath-python (Beta) is an open-source Python library, providing Python programmers with access to high-performance mathematical operations from NVIDIA CUDA-X math libraries.  nvmath-python provides both low-level bindings to the underlying libraries and higher-level Pythonic abstractions. It is interoperable with existing Python packages, such as PyTorch and CuPy.

In this post, I show how to use epilogs with matrix multiplication in nvmath-python. Epilogs are operations that can be fused with the mathematical operation being performed, like FFT or matrix multiplication. Available epilogs cover the most common deep-learning computations. I demonstrate their usage by implementing the common forward and backward pass operations of a simple neural network.

To install nvmath-python, follow the installation instructions.

Optimizing the forward pass with the RELU_BIAS epilog

In this section, I demonstrate how to use epilogs to implement a forward pass of a simple linear layer. This layer first multiplies the input vectors by a weights matrix, then adds a bias to each element of the resulting matrix, and finally applies the ReLU activation function.

ReLU, short for Rectified Linear Unit, is a commonly used activation function that replaces negative values with zeros while leaving positive values unchanged.

In terms of matrix operations, the layer can be expressed as follows:

relu(Wx + B)

In the equation, the following definitions are true:

  • x is a batch of input vectors of shape n \times b:
    • n is the number of layer’s inputs.
    • b is the batch size.
  • W is the weight matrix of shape m \times n:
    • m is the number of layer’s outputs.
    • n is the number of its inputs.
  • B is the bias vector of length m, which is added to each column of the resulting matrix.

Assume that you have your inputs, weights, and bias as CuPy arrays:

num_inputs, num_outputs = 784, 100
batch_size = 256

weights = cupy.random.rand(num_outputs, num_inputs)
bias = cupy.random.rand(num_outputs)
x = cupy.zeros((num_inputs, batch_size))

In the most basic version, you can implement this linear layer by using nvmath-python for calculating Wx, and then handling bias and ReLU manually, as in the following code example.

In this example, I use a stateful API, in which you can separate initialization and planning from the actual execution of the multiplication. I recommend this approach when you must perform multiple similar multiplications, as it enables you to amortize the initial cost of planning. For more information about Matmul, see nvmath.linalg.advanced.Matmul.

mm = Matmul(weights, x)
mm.plan()

def forward():
    y = mm.execute()
    y += bias[:,cupy.newaxis]
    y[y < 0] = 0
    return y

To improve the performance of the code, take advantage of the RELU_BIAS epilog to perform all three operations in a single, fused cuBLAS operation. This epilog first adds the bias to the result of the multiplication and then applies the ReLU function.

You can specify the epilog using the epilog argument of the Matmul.plan method. Some epilogs, including RELU_BIAS, take extra inputs, which can be specified in the epilog_inputs dictionary. For more information about epilogs, see nvmath.linalg.advanced.Matmul.

from nvmath.linalg.advanced import MatmulEpilog

mm = Matmul(weights, x)
mm.plan(epilog=MatmulEpilog.RELU_BIAS, epilog_inputs={"bias": bias})

def forward():
    y = mm.execute()
    return y

As I explain later, to backpropagate through the ReLU function, you must know which inputs to the ReLU were positive and which ones were negative. This auxiliary information, called the ReLU mask, can be obtained with the RELU_AUX_BIAS epilog.

When an epilog with auxiliary outputs is used, a tuple containing the actual result and the dictionary of auxiliary outputs is returned from Matmul.execute. In the case of RELU_AUX_BIAS, the auxiliary output dictionary has one key relu_mask, which contains the ReLu mask. This mask is bit-encoded and might be hard to read, but there are dedicated epilogs that do this for you during the backward pass.

from nvmath.linalg.advanced import MatmulEpilog

mm = Matmul(weights, x)
mm.plan(epilog=MatmulEpilog.RELU_AUX_BIAS, epilog_inputs={"bias": bias})

relu_mask = None

def forward():
	global relu_mask
    y, aux_outputs = mm.execute()
	 relu_aux = aux_outputs["relu_aux"]
    return y
A block diagram shows the operations of a forward pass: multiplication by the weights, addition of bias and application of ReLU. Matmul with RELU_AUX_BIAS epilog is handling all three operations, and producing the ReLU mask as an auxiliary output.
Figure 1. Operations of forward pass covered by Matmul with the RELU_AUX_BIAS epilog

The implementation using the RELU_AUX_BIAS epilog is faster than its naive counterpart, providing a significant performance gain.

A bar plot showing the performance of the naive implementation and RELU_AUX_BIAS. Naive implementation reaches 62.8% of peak TFLOP/s, and RELU_AUX_BIAS reaches 79.7%.
Figure 2. Performance comparison of forward pass implementations

Figure 2 shows performing matrix multiplication of float16 matrices of sizes (65536,16384)(16384, 8192), followed by bias addition and ReLU. The performance is measured on an NVIDIA H200 GPU.

Optimizing the backward pass with the DRELU_BGRAD epilog

During the backward pass of a neural network, the gradient of the loss function with respect to the output is propagated back through the network layers to compute the gradients for each parameter.

Intuitively, for each operation, when the effect of its output on the loss is known, it becomes possible to determine how its inputs and parameters (such as the values in a weight matrix) influence the loss. For more information, see Backpropagation.

In this part, I assume that there are several linear layers stacked together. I implement a backpropagation over the sequence of operations that are normally considered to belong to different layers: adding bias, applying ReLU, and multiplying by the weights.

A block diagram shows the operations of a forward pass with multiple linear layers: multiplication by weights, adding bias, applying ReLU, multiplying by weights, adding bias, and so on. The backward pass box covers adding bias, applying ReLu, and multiplying by weights.
Figure 3. Operations implemented in forward and the part to be covered in backward

Let t_0 be the input to the part of the network shown earlier, and show the intermediate results by t_1, t_2, and t_3, respectively:

  • t_1 = x + B
  • t_2 = relu(t_1)
  • t_3 = Wt_2

In backpropagation, when you know how the loss function L is affected by t_3, which is \frac{\partial L}{\partial t_3}, it is possible to calculate the gradients with respect to other parameters. For more information about the derivations of the formulas used to compute the gradients, see Automatic Differentiation and Neural Networks.

  • \frac{\partial L}{\partial W} = t_2^T \frac{\partial L}{\partial t_3}
  • \frac{\partial L}{\partial t_2} = W^T \frac{\partial L}{\partial t_3}
  • \frac{\partial L}{\partial t_1} = 0 where t_1 was negative and \frac{\partial L}{\partial t_1} = \frac{\partial L}{\partial t_2} where t_2 was non-negative (ReLU mask contains this information)
  • \frac{\partial L}{\partial B} is \frac{\partial L}{\partial t_1}, summed over the batch dimension
A block diagram shows the operations of a forward pass and backward pass, with the formulas for gradients. Matmul with DRELU_BGRAD epilog covers computing the gradients for t2 (multiplying by weights), t1 (applying ReLU mask) and B (batch sum). Computing the gradients for W is not covered by the DRELU_BGRAD epilog.
Figure 4. Operations of the backward pass, with operations covered by DRELU_BGRAD epilog marked

The operations required to compute \frac{\partial L}{\partial B} and \frac{\partial L}{\partial t_1} can be naively implemented by using Matmul just for matrix multiplication, and then handling masking and batch sum manually:

mm = Matmul(weights.T, grad)
mm.plan()

def backward():
    grad_t1 = mm.execute()
    grad_t1[mask] = 0  # assuming that `mask = (t1 < 0)`
    grad_bias = cupy.sum(grad_t1, axis=1)
    return grad_t1, grad_bias

To optimize your backward pass, use the DRELU_BGRAD epilog. Assume that the gradient \frac{\partial L}{\partial t_3} is available in a CuPy array grad. The DRELU_BGRAD epilog expects one input, relu_aux, containing the mask returned from RELU_AUX_BIAS epilog. It applies this mask to the result of the multiplication. It also returns an auxiliary output with the column-wise sum of the result, which happens to be \frac{\partial L}{\partial B}.

mm = Matmul(weights.T, grad)
mm.plan(epilog=MatmulEpilog.DRELU_BGRAD, epilog_inputs={"relu_aux":relu_mask})

def backward():
    grad_t1, aux_outputs = mm.execute()
    grad_bias = aux_outputs["drelu_bgrad"]
    return grad_t1, grad_bias
A bar plot shows the performance of the naive implementation and DRELU_BGRAD. Naive implementation reaches 56.9% of peak TFLOP/s, and DRELU_BGRAD reaches 66.4%.
Figure 5. Performance comparison of backward pass implementations

Figure 5 shows performing matrix multiplication of float16 matrices of sizes (65536,16384)(16384, 8192), followed by the application of ReLU mask and bias gradient computation. The performance was measured on an NVIDIA H200 GPU.

Conclusion

With the epilogs of nvmath-python, you can fuse common deep learning computations together in your Python code, which enables you to greatly improve the performance. For more information, see the nvmath-python: Unleashing the Full Capabilities of NVIDIA Math Libraries within Python documentation. For an example of end-to-end implementation of a simple neural network with nv-math python, see the Backpropagation Jupyter notebook on GitHub.

We are an open-source library, so feel free to visit the /NVIDIA/nvmath-python GitHub repo and reach out to us there.

Discuss (1)

Tags