NVIDIA Holoscan is the NVIDIA domain-agnostic multimodal real-time AI sensor processing platform that delivers the foundation for developers to build their end-to-end sensor processing pipeline. NVIDIA Holoscan SDK features include:
- Combined hardware systems for low-latency sensor and network connectivity
- Optimized libraries for data processing and AI
- Flexible deployment: edge or cloud
- Various programming languages, such as Python and C++
Holoscan SDK can be used to build streaming AI pipelines for a range of industries and use cases, including medical devices, high-performance computing at the edge, and industrial inspection. For more information, see Developing Production-Ready AI Sensor Processing Applications with NVIDIA Holoscan.
The Holoscan SDK accelerates streaming AI applications by leveraging software and hardware. It can cooperate with RDMA technology to further improve the end-to-end pipeline performance with GPU acceleration. An end-to-end sensor processing pipeline typically includes:
- Sensor data ingress
- Accelerated computing and AI inference
- Real-time visualization, actuation, and data stream egress
All the data within this pipeline is stored in GPU memory and can be accessed directly by Holoscan native operators without host-device memory transfer.
This post explains how to achieve end-to-end GPU-accelerated workflows without additional memory transfer by integrating Holoscan SDK and the open-source library OpenCV.
What is OpenCV?
OpenCV (Open Source Computer Vision Library) is a comprehensive open-source computer vision library that contains over 2,500 algorithms, including Image and Video Manipulation, Object and Face Detection, and the OpenCV Deep Learning Module.
OpenCV supports GPU acceleration, including a CUDA module that provides a set of classes and functions to utilize CUDA computational capabilities. It’s implemented using the NVIDIA CUDA Runtime API and provides utility functions, low-level vision primitives, and high-level algorithms.
With the comprehensive GPU-accelerated algorithms and operators provided in OpenCV, developers can implement even more complicated pipelines based on Holoscan SDK (Figure 2).
Integrate OpenCV operators in a Holoscan SDK pipeline
To get started integrating OpenCV operators in a Holoscan SDK pipeline, you need the following:
- OpenCV >= 4.8.0
- Holoscan SDK >= v0.6
To install OpenCV with the CUDA module, follow the guide available through opencv/opencv_contrib. To build an image with the Holoscan SDK and OpenCV CUDA, see the nvidia-holoscan/holohub Dockerfile.
The Tensor, which serves as the data type within the Holoscan SDK, is defined as a multidimensional array of elements of a single data type. The Tensor class is a wrapper around the DLManagedTensorCtx struct that holds the DLManagedTensor object. The Tensor class supports both DLPack and the NumPy array interface (__array_interface__
and __cuda_array_interface__
) so it can be used with other Python libraries such as CuPy, PyTorch, JAX, TensorFlow, and Numba.
However, the data type of OpenCV is GpuMat, which implements neither the __cuda_array_interface__
nor the standard DLPack. Achieving the end-to-end GPU-accelerated pipeline or application requires implementing two functions to convert GpuMat to CuPy array, which can be accessed directly with Holoscan Tensor and vice versa.
Seamless zero-copy from GpuMat to CuPy array
The GpuMat object of OpenCV Python bindings provides a cudaPtr
method that can be used to access the GPU memory address of a GpuMat object. This memory pointer can be used to initialize a CuPy array directly, enabling efficient data handling by avoiding unnecessary data transfers between the host and device.
The function below is used to create a CuPy array from GpuMat. The source code is provided in the HoloHub Endoscopy Depth Estimation application.
import cv2
import cupy as cp
def gpumat_to_cupy(gpu_mat: cv2.cuda.GpuMat) -> cp.ndarray:
w, h = gpu_mat.size()
size_in_bytes = gpu_mat.step * w
shapes = (h, w, gpu_mat.channels())
assert gpu_mat.channels() <=3, "Unsupported GpuMat channels"
dtype = None
if gpu_mat.type() in [cv2.CV_8U,cv2.CV_8UC1,cv2.CV_8UC2,cv2.CV_8UC3]:
dtype = cp.uint8
elif gpu_mat.type() == cv2.CV_8S:
dtype = cp.int8
elif gpu_mat.type() == cv2.CV_16U:
dtype = cp.uint16
elif gpu_mat.type() == cv2.CV_16S:
dtype = cp.int16
elif gpu_mat.type() == cv2.CV_32S:
dtype = cp.int32
elif gpu_mat.type() == cv2.CV_32F:
dtype = cp.float32
elif gpu_mat.type() == cv2.CV_64F:
dtype = cp.float64
assert dtype is not None, "Unsupported GpuMat type"
mem = cp.cuda.UnownedMemory(gpu_mat.cudaPtr(), size_in_bytes, owner=gpu_mat)
memptr = cp.cuda.MemoryPointer(mem, offset=0)
cp_out = cp.ndarray(
shapes,
dtype=dtype,
memptr=memptr,
strides=(gpu_mat.step, gpu_mat.elemSize(), gpu_mat.elemSize1()),
)
return cp_out
Note that we used the Unowned Memory API to create the CuPy array in this function. In some cases, the OpenCV operators will create a new device memory that needs to be handled by CuPy, and the lifetime is not limited to one operator but the whole pipeline. In this case, the CuPy array initiated from GpuMat knows the owner and keeps the reference to the object. For more details, see the CuPy interoperability documentation.
Seamless zero-copy from Holoscan Tensor to GpuMat
With the release of OpenCV 4.8, the Python bindings for OpenCV now support the initialization of GpuMat objects directly from GPU memory pointers. This capability facilitates more efficient data handling and processing by enabling direct interaction with GPU-resident data, bypassing the need for data transfer between host and device memory.
Within pipeline applications based on the Holoscan SDK, the GPU memory pointer can be obtained through the __cuda_array_interface__
provided by CuPy arrays. Refer to the functions outlined below for creating GpuMat objects utilizing CuPy arrays. For a detailed implementation, see the source code provided in the HoloHub Endoscopy Depth Estimation application.
import cv2
import cupy as cp
def gpumat_from_cp_array(arr: cp.ndarray) -> cv2.cuda.GpuMat:
assert len(arr.shape) in (2, 3), "CuPy array must have 2 or 3 dimensions to be a valid GpuMat"
type_map = {
cp.dtype('uint8'): cv2.CV_8U,
cp.dtype('int8'): cv2.CV_8S,
cp.dtype('uint16'): cv2.CV_16U,
cp.dtype('int16'): cv2.CV_16S,
cp.dtype('int32'): cv2.CV_32S,
cp.dtype('float32'): cv2.CV_32F,
cp.dtype('float64'): cv2.CV_64F
}
depth = type_map.get(arr.dtype)
assert depth is not None, "Unsupported CuPy array dtype"
channels = 1 if len(arr.shape) == 2 else arr.shape[2]
mat_type = depth + ((channels - 1) << 3)
mat = cv2.cuda.createGpuMatFromCudaMemory(
arr.__cuda_array_interface__['shape'][1::-1],
mat_type,
arr.__cuda_array_interface__['data'][0]
)
return mat
Integrate OpenCV operators
With the two previous functions, you can use any OpenCV-CUDA operation without memory transfer in a Holoscan SDK-based pipeline. Use the following steps:
- Create a customized operator where the OpenCV operator is called. For details, see the Holoscan SDK example documentation.
- In the compute function in operator:
- Receive the message from the previous operator and create a CuPy array from the Holoscan Tensor.
- Call
gpumat_from_cp_array
to create the GpuMat. - Process with your custom OpenCV operator.
- Call
gpumat_to_cupy
to create a CuPy array from GpuMat.
See the demonstration code below. For the complete source code, see the HoloHub Endoscopy Depth Estimation application.
def compute(self, op_input, op_output, context):
stream = cv2.cuda_Stream()
message = op_input.receive("in")
cp_frame = cp.asarray(message.get("")) # CuPy array
cv_frame = gpumat_from_cp_array(cp_frame) # GPU OpenCV mat
## Call OpenCV Operator
cv_frame = cv2.cuda.XXX(hsv_merge, cv2.COLOR_HSV2RGB)
cp_frame = gpumat_to_cupy(cv_frame)
cp_frame = cp.ascontiguousarray(cp_frame)
out_message = Entity(context)
out_message.add(hs.as_tensor(cp_frame), "")
op_output.emit(out_message, "out")
Summary
Incorporating OpenCV CUDA operators into applications built on the Holoscan SDK simply requires implementing two functions that facilitate the conversion between the OpenCV GpuMat and CuPy arrays. These functions enable direct access to Holoscan Tensors within customized operators. By invoking these functions, you can seamlessly create end-to-end GPU-accelerated applications without memory transfer for enhanced performance.
To get started, download Holoscan SDK 2.0 and check out the release notes. Ask questions and share information in the NVIDIA Developer forums.
Learn more about how to integrate other external libraries into a Holoscan SDK pipeline in the companion NVIDIA HoloHub tutorial. You can also start from the sample code and applications at nvidia-holoscan/holohub, the central repository for the NVIDIA Holoscan AI sensor processing community.