Computer Vision / Video Analytics

Unlock Efficient Data Processing with the Latest from NVIDIA DALI

A decorative image.

NVIDIA DALI, a portable, open source software library for decoding and augmenting images, videos, and speech, recently introduced several features that improve performance and enable DALI with new use cases. These updates aim at simplifying the integration of DALI into existing PyTorch data processing logic, improving flexibility in building data processing pipelines by enabling CPU-to-GPU flows, and adding new video decoding patterns. These new features make DALI an indispensable tool for deep learning practitioners and include:

  • PyTorch DALI Proxy, which integrates seamlessly with PyTorch for efficient GPU utilization in its multiprocess environment. It also provides selective offloading, for users to choose which parts of the pipeline to offload to DALI.
  • Video processing improvements that boost DALI versatility by supporting a broader range of selective decoding patterns and rapid video container indexing without paying for unnecessary decoding overhead.
  • Execution flow enhancements that optimize memory consumption. Further, it creates more flexibility for execution models by enabling GPU-to-CPU transfer inside the execution flow. 

DALI Proxy: efficient GPU acceleration

DALI Proxy is a game-changer for PyTorch users, enabling them to seamlessly integrate DALI’s high-performance data processing capabilities into their existing PyTorch dataset logic. DALI architecture addresses several limitations of purely Python-based data processing, including:

  • One of the major limitations is the difficulty of using recent multi-core CPU architectures due to the Python global interpreter lock (GIL). When multi-core CPU architectures became widely available to the consumer market, the Python ecosystem introduced the global lock to simplify the multi-threading programming model at the expense of reduced performance, which hinders processing efficiency.

    The most common approach to address this setback is to run Python interpreters in multiple independent processes, using shared memory or other IPC mechanisms for communication. While this works well for CPUs, it has several limitations for GPU work orchestration:
  1. Each process creates a separate GPU context, adding overhead when switching between tasks scheduled by different threads.
  2. Each process allocates its own GPU memory, inflating overall usage.
  3. Sharing GPU memory between processes adds additional overhead.

    DALI addresses this by using native multi-threading, overcoming Python GIL.
  • Inefficient data roundtrips between CPU and GPU can occur when transferring memory between multiple Python processes. This encourages suboptimal patterns where GPU operations are followed by CPU operations, adding more time spent transferring data to the GPU and back, ultimately reducing the benefits of GPU acceleration. DALI discourages this pattern, ensuring that roundtrips are minimized or don’t happen at all.
  • Inefficient GPU work orchestration results from creating a separate context for each process and allocating separate memory, leading to overhead and inflated memory usage.
Three different approaches to data processing. Inside a single process with multiple threads synchronizing on the GIL. Multiprocessing, where each worker independently competes for GPU resources. The DALI approach uses native threads to overcome the Python GIL and utilizes multiple CPU cores for the CPU part, and efficiently orchestrates the GPU from a single thread.
Figure 1. A comparison of different approaches to data processing using Python and DALI

Figure 1 shows different approaches to data processing in Python with their limitations. The left diagram shows the most straightforward approach, where multiple Python threads are created simultaneously. However, due to Python GIL, only one thread can perform processing at a time, leaving CPUs underutilized. 

The middle diagram uses independent processes instead of threads. While efficient for the CPU, each process orchestrates a GPU to work independently and involves an expensive IPC to aggregate results from each process. 

The right diagram uses DALI with native processing, which can efficiently use all CPU cores and orchestrate a GPU to work without unnecessary overhead.

DALI effectively solves these problems at the expense of replacing the PyTorch data loader. While creating a new data pipeline from scratch with DALI is straightforward, rewriting an existing flow can be effort-intensive. The DALI Proxy enables users to selectively offload parts of the existing data pipeline to DALI, making it ideal for multi-modal applications where only specific modalities, such as vision processing, need acceleration while other parts use external libraries. Additionally, it makes GPU acceleration of data processing in PyTorch’s multiprocess environment convenient and efficient.

The concept is based on a DALI server instance running in the main process, which also orchestrates the training. A lightweight DALI proxy object transfers data still on the CPU to the main process, where it is processed in parallel using native code. This enables integration of the DALI pipeline into existing data loading code, accelerating only the time-consuming parts while leaving the rest of the logic, such as elaborate data reading patterns, untouched.

The top part of the architecture represents one of the PyTorch data processing workers. It houses a DALI proxy that accepts inputs from the loader and uses a queue to move it to the main process where the DALI server resides. It aggregates requests from all workers and runs processing in only the main server process, passing results to the loop (training or inference) where the actual data is consumed.
Figure 2. DALI proxy architecture

An example usage of DALI proxy using the most concise approach (more examples can be found in the API documentation):

@pipeline_def(num_threads=4, device_id=0)
def rn50_train_pipe():

    # the PyTorch data loader passes file names and DALI loads it
    filepaths = fn.external_source(name="images", no_copy=True)
    jpegs = fn.io.file.read(filepaths)

    # decode data on the GPU
    images = fn.decoders.image_random_crop(
        images, device="mixed", output_type=types.RGB)

    # the rest of processing happens on the GPU as well
    images = fn.resize(images, resize_x=256, resize_y=256)
    images = fn.crop_mirror_normalize(
        images,
        crop_h=224,
        crop_w=224,
        mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
        std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
        mirror=fn.random.coin_flip())
    return images, labels


pipe = rn50_train_pipe(train_dir)
with dali_proxy.DALIServer(pipe) as dali_server:
    # we want torchvision.datasets to pass just filenames to DALI
    def read_filepath(path):
        return np.frombuffer(path.encode(), dtype=np.int8)

    # use proxy as the any other transform list
    dataset = torchvision.datasets.ImageFolder(
        jpeg, transform=dali_server.proxy, loader=read_filepath)

    # usual data loader
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=nworkers,
        drop_last=True,
    )

    for data, target in loader:
        # produce_data could be skipped if we use nvidia.dali.plugin.pytorch.experimental.proxy.DataLoader

        data = dali_server.produce_data(data)

Video processing improvements

Recent updates to DALI are introducing a series of major expansions into its video processing capabilities by adding support for decoding videos with variable frame rates and enabling users to extract specific frames directly during decoding. These new features enhance flexibility and control in video data pipelines. In addition to these functional upgrades, the video decoder’s initialization time has been optimized—bringing performance in line with DALI’s previously available, more specialized decoder. This reduction in indexing time is particularly impactful for training video foundation models, which often require efficient handling of millions of video samples.

As deep learning evolves, video has become a key data source. Unlike images, which are easy to load and process individually, videos are collections of frames and must be handled differently. Each video is like a mini-dataset, and researchers often need custom strategies for reading frames. For example, to boost a video’s frame rate, models may require consecutive frames. In contrast, for action recognition, reading every N-th frame helps reduce redundancy and avoids overwhelming the network.

DALI enables all these use cases with the ability to specify the number of frames, the first frame, the last frame, the stride (step between frames), and/or an explicit list of frames. Users can define a padding mode to ensure they always receive frame sequences of the exact size they request, regardless of the video duration. If the video doesn’t contain the requested frames, existing frames are replicated. Supported padding modes include reflect padding, constant padding, and edge padding. 

Executor enhancements

The executor enhancements, exposed by the exec_dynamic argument, improve memory management efficiency by enabling the reuse of memory buffers. Originally, DALI aggressively allocated memory on demand without releasing it due to high allocation costs. Recent advancements have enabled asynchronous on-demand allocation and release, for operations to effectively reuse the same physical memory, while ensuring that the memory isn’t overwritten. This update improves memory usage for more efficient processing of large datasets.

Another notable improvement is that the new execution model supports CPU-to-GPU-to-CPU data transfer patterns. Historically, such patterns were discouraged due to substantial data transfer overhead between CPU and GPU. However, the introduction of advanced architectures like the NVIDIA GH200 Grace Hopper Superchip and NVIDIA GB200 NVL72 and its fast interconnect between CPU and GPU has opened new possibilities, making previously inefficient patterns viable. Users can now accelerate parallel parts on the GPU and move data back to the CPU to apply algorithms that are serial by nature or not yet supported by DALI.

Summary

In conclusion, the recent advancements in NVIDIA DALI significantly extend its capabilities as a high-performance data preprocessing library for deep learning. By introducing the DALI Proxy, users gain fine-grained control to integrate DALI into existing PyTorch pipelines while overcoming limitations imposed by Python’s multiprocessing model. Enhanced video processing features make DALI more suitable for modern video-based AI tasks, enabling flexible and efficient handling of complex frame selection scenarios. Meanwhile, executor improvements can reduce memory usage and unlock new execution patterns, particularly in systems with fast CPU-GPU interconnects like GH200 and GB200. Together, these updates make DALI a versatile and efficient solution for scaling data preprocessing across diverse AI workloads.

Try today

To get started with these new features, explore the following resources:

  • Learn more about DALI Proxy and how to easily integrate DALI into your existing data loading workflows.
  • Use the enhanced video decoder for your video model training to leverage improved decoding capabilities.
  • Test the new DALI execution flow and see the benefits of optimized memory management and flexible data transfers.

Learn more by asking questions or suggesting improvements using the DALI GitHub page.

Discuss (0)

Tags