Developer Tools & Techniques

Enhancing GPU-Accelerated Vector Search in Faiss with NVIDIA cuVS

As companies collect more unstructured data and increasingly use large language models (LLMs), they need faster and more scalable systems. Advanced tools for finding information, such as retrieval-augmented generation (RAG), can take hours or even days to process massive amounts of data—sometimes at the scale of terabytes or petabytes.

Meanwhile, online search applications like ad recommendation systems struggle to deliver instant results on CPUs. Thousands of CPUs would be required to meet real-time speed requirements, increasing infrastructure costs.

This post explores how to solve these challenges using NVIDIA cuVS with the Meta Faiss library for efficient similarity search and clustering of dense vectors. cuVS uses GPU acceleration to dramatically speed up both the creation of search indexes and the actual search process. The result is much faster, lower-cost, and more efficient performance, all while maintaining seamless compatibility between CPUs and GPUs.

Specifically, the post covers:

  • The benefits of integrating cuVS and Faiss 
  • How and where cuVS improves vector search performance
  • Performance with GPU-accelerated inverted file index (IVF) and graph-based indexes 
  • Benchmarks and Python code examples demonstrating how to build and search cuVS-powered indexes with Faiss

What are the benefits of integrating cuVS and Faiss?

Whether you’re querying millions of vectors per second, working with large multi-modal embeddings, or building massive indexes with GPUs, the cuVS integration with Faiss unlocks the next level of performance and flexibility.

cuVS enables you to: 

  • Build indexes up to 12x faster on GPU at 95% recall
  • Achieve search latencies up to 8x lower at 95% recall
  • Easily move indexes between GPU and CPU environments to match your deployment needs

GPU acceleration in Faiss

Faiss is a popular library for vector search across research and production environments. It supports standalone usage, integration with PyTorch, and embedding within vector databases like RocksDB, OpenSearch, and Milvus.

Faiss pioneered GPU support in 2018 and has continued evolving since then. At the NeurIPS 2021 big-ann-benchmarks competition, NVIDIA claimed first place with GPU-accelerated algorithms. These methods were later contributed to Faiss and now live in the open source cuVS library.

Since Faiss v1.10.0, users can opt into cuVS for enhanced versions of inverted file index algorithms IVF-PQ, IVF-Flat, Flat (aka brute-force), and CAGRA (Cuda Anns GRAph-based)—a high-performance graph-based index built from the ground up for GPUs.

Effortless CPU-GPU interoperability

Accelerating GPU indexes in Faiss with cuVS unlocks new levels of CPU-GPU interoperability. With Faiss, you can build indexes on the GPU and then deploy them to the CPU. This gives Faiss users the ability to accelerate index building with GPUs while maintaining their CPU search architectures. It’s all accomplished seamlessly in the Faiss library.

To provide an example, Hierarchical Navigable Small-World (HNSW) indexes are notoriously slow to build on the CPU, especially at scale, taking several hours or even days. CAGRA indexes, on the other hand, can be built up to 12x faster. These CAGRA graphs can be formatted as HNSW indexes in Faiss and then deployed for search on the CPU.

Benchmarking Faiss with cuVS

Performance benchmarks were performed comparing on the following two datasets comparing Faiss with and without cuVS enabled:

  1. Deep100M: A 100M-vector subset of the Deep1B dataset (96 dimensions). 
  2. OpenAI Text Embeddings: 5M vectors (1,536 dimensions) from the text-embedding-ada-002 model.

Tests were run on an NVIDIA H100 Tensor Core GPU and an Intel Xeon Platinum 8480CL CPU. Measurements were taken for:

  • Index build time
  • Single-query latency (online search)
  • Large-batch throughput (offline search)

Because the growth of unstructured data is happening so quickly, it’s important that index build performance continues to increase. However, measuring an index build time alone is meaningless without considering the search performance and quality of the resulting model. For this reason, the team created its own methodology for benchmarking index builds. For more details, see the cuVS documentation

In addition to considering search performance and quality, it’s also important to compare models against the best performing parameter settings. This is done using Pareto curves to ensure that each comparison is fair. Speedups in latency and throughput to compare various indexes are done at the 95% recall level.

IVF: cuVS versus Faiss GPU classic

We first benchmarked the IVF indexes IVF-Flat and IVF-PQ to compare Faiss classic GPU implementations against the new Faiss variants w/ cuVS support:

  • Build time: IVF-PQ and IVF-Flat were built up to 4.7x faster using cuVS (Figure 1)
  • Latency: Search latency was up to 8x lower for IVF-PQ, and 90% lower for IVF-Flat (Figure 1)
  • Throughput: cuVS improved large-batch search throughput up to 3x for IVF-PQ across both datasets (Figure 2), while maintaining comparable performance for IVF-Flat. This makes it well-suited for high-volume and large offline search workloads.

Online latency

Figures 1a and 1b show online search latency and build time across IVF index variants. cuVS consistently delivers faster index builds and significantly lower search latency across both datasets compared to classic Faiss.

Two side-by-side images. One the left: A chart showing average index build times for the best performing configurations of the IVF-Flat and IVF-PQ indexes on the Deep-100M dataset. FAISS w/ cuVS consistently outperforms FAISS Classic on GPU. On the right: A chart showing pareto frontier curves for search latency for the best performing configurations of the IVF-Flat and IVF-PQ indexes on the Deep-100M dataset. FAISS w/ cuVS shows comparable or better performance than FAISS Classic on GPU.
Figure 1a. For Deep100M images (100M x 96), average index build times for best-performing configurations (lowest online latency) at specific recall levels (left), and search latency Pareto frontier for single-query online search at k=10—lower is better (right)
Two side-by-side images: On the left:  chart showing pareto frontier curves for search latency for the best performing configurations of the IVF-Flat and IVF-PQ indexes on the OpenAI text embeddings dataset. FAISS w/ cuVS shows comparable or better performance than FAISS Classic on GPU. On the right: A chart showing pareto frontier curves for search latency for the best performing configurations of the IVF-Flat and IVF-PQ indexes on the OpenAI text embeddings dataset. FAISS w/ cuVS shows comparable or better performance than FAISS Classic on GPU.
Figure 1b. For OpenAI text embeddings, average index build times for best-performing configurations (left) and search latency Pareto frontier—lower is better (right)

Batch (offline) throughput

Figure 2 shows batch throughput across IVF index variants. cuVS improves batch processing performance, serving significantly more queries per second across both image and text embeddings. 

These improvements stem from better GPU clustering (for example, balanced k-means), expanded parameter support (for example, more subquantizers for IVF-PQ), and code-level optimizations.

Graph-based indexes: cuVS CAGRA versus Faiss HNSW (CPU)

CAGRA is a GPU-optimized, fixed-degree flat graph index that offers major performance advantages over CPU-based HNSW, including:

  • Build time: CAGRA builds up to 12.3x faster (Figure 3)
  • Latency: Online search is up to 4.7x faster (Deep100M) (Figure 3)
  • Throughput: In offline search settings, CAGRA delivers up to 18x higher throughput for image data and more than 8x for text embeddings (Figure 4), making it ideal for workloads requiring high-volume inference at low latency.

cuVS enables a CAGRA graph to be converted directly to an HNSW graph, which allows the graph to build much faster on the GPU, while using the CPU for search with comparable speed and quality.

Online latency

Figures 3a and 3b show online latency and build time for GPU CAGRA versus CPU HNSW. CAGRA dramatically accelerates index builds and lowers online query latency—up to 4.7x faster search compared to HSNW on CPU for Deep100M.

Two side-by-side images. One the left: A chart showing average index build times for the best performing configurations of the CAGRA and HNSW indexes on the Deep-100M dataset. FAISS w/ cuVS (CAGRA) consistently outperforms FAISS on CPU (HNSW). On the right: A chart showing pareto frontier curves for search latency for the best performing configurations of the CAGRA and HNSW indexes on the Deep-100M dataset. FAISS w/ cuVS (CAGRA) shows much better performance than FAISS on CPU (HNSW) while searching a CAGRA graph on the CPU w/ HNSW show comparable performance.
Figure 3a. For Deep100M (100M x 96) for GPU CAGRA versus CPU HNSW: average index build times for best-performing configurations across recall levels (left) and search latency Pareto frontier for single query search—lower is better (right)
Two side-by-side images. On the left: A chart showing average index build times for the best performing configurations of the CAGRA and HNSW indexes on the OpenAI text embeddings dataset. FAISS w/ cuVS (CAGRA) consistently outperforms FAISS on CPU (HNSW). On the right: A chart showing pareto frontier curves for search latency for the best performing configurations of the CAGRA and HNSW indexes on the OpenAI text embeddings dataset. FAISS w/ cuVS (CAGRA) shows much better performance than FAISS on CPU (HNSW) while searching a CAGRA graph on the CPU w/ HNSW show comparable performance.
Figure 3b. For OpenAI text embeddings (5M x 1,536) for GPU CAGRA versus CPU HNSW: average index build times for best-performing configurations (left) and search latency Pareto frontier—lower is better (right)

Batch (offline) throughput

Figure 4 shows GPU CAGRA versus CPU HNSW batch throughput. CAGRA achieves high throughput in batch scenarios—serving millions of queries per second and outperforming CPU-based HNSW across both datasets.

How to get started with cuVS in Faiss

This section briefly introduces the process for installing Faiss with cuVS support and provides brief code examples for creating and searching an index with Python. 

Installation

You can build Faiss with cuVS or with prebuilt Conda packages:

# Conda install (CUDA 12.4)
conda install -c rapidsai -c conda-forge -c nvidia pytorch::faiss-gpu-cuvs 
'cuda-version>=12.0,<=12.9'

Alternatively, you can install the latest nightly build of the cuVS-enabled Faiss package using the following command:

conda install -c rapidsai -c rapidsai-nightly -c conda-forge -c nvidia 
pytorch/label/nightly::faiss-gpu-cuvs 'cuda-version>=12.0,<=12.9'

Memory management

Use the following snippet to enable GPU memory pooling with RMM (recommended). This approach can improve performance.

import rmm
pool = rmm.mr.PoolMemoryResource(
    rmm.mr.CudaMemoryResource(),
    initial_pool_size=2**30
)
rmm.mr.set_current_device_resource(pool)

Build an IVFPQ Index with cuVS

With the faiss-gpu-cuvs package, cuVS is automatically used for supported index types—requiring no code changes to benefit from its performance improvements. An example of creating an IVFPQ index using the cuVS backend is shown below:

import faiss
import numpy as np

np.random.seed(1234)
xb = np.random.random((1000000, 96)).astype('float32')
xq = np.random.random((10000, 96)).astype('float32')
xt = np.random.random((100000, 96)).astype('float32')

res = faiss.StandardGpuResources()
# Disable the default temporary memory allocation since an RMM pool resource has already been set.
res.noTempMemory()

# Case 1: Creating cuVS GPU index
config = faiss.GpuIndexIVFPQConfig()
config.interleavedLayout = True
index_gpu = faiss.GpuIndexIVFPQ(res, 96, 1024, 96, 6, faiss.METRIC_L2, config) # expanded parameter set with cuVS (bits per code = 6).
index_gpu.train(xt)
index_gpu.add(xb)

# Case 2: Cloning a CPU index to a cuVS GPU index
quantizer = faiss.IndexFlatL2(96)
index_cpu = faiss.IndexIVFPQ(quantizer,96, 1024, 96, 8, faiss.METRIC_L2)
index_cpu.train(xt)
co = faiss.GpuClonerOptions()
index_gpu = faiss.index_cpu_to_gpu(res, 0, index_cpu, co)
# The cuVS index now uses the trained quantizer as it's IVF centroids.
assert(index_gpu.is_trained)
index_gpu.add(xb)
k = 10
D, I = index_gpu.search(xq, k)

Build a cuVS CAGRA index

The following example demonstrates how to build and query a CAGRA index using Faiss with cuVS acceleration.

import faiss
import numpy as np

# Step 1: Create the CAGRA index config
config = faiss.GpuIndexCagraConfig()
config.graph_degree = 32
config.intermediate_graph_degree = 64

# Step 2: Initialize the CAGRA index
res = faiss.StandardGpuResources()
gpu_cagra_index = faiss.GpuIndexCagra(res, 96, faiss.METRIC_L2, config)

# Step 3: Add the 1M vectors to the index
n = 1000000
data = np.random.random((n, 96)).astype('float32')
gpu_cagra_index.train(data)

# Step 4: Search the index for top 10 neighbors for each query.
xq = np.random.random((10000, 96)).astype('float32')
D, I = gpu_cagra_index.search(xq,10)

CAGRA indexes can be automatically converted to HNSW format through the new faiss.IndexHNSWCagra CPU class, enabling GPU-accelerated index builds followed by CPU-based search:

# Create the HNSW index object for vectors with 96 dimensions.
M = 16
cpu_hnsw_index = faiss.IndexHNSWCagra(96, M, faiss.METRIC_L2)
cpu_hnsw_index.base_level_only=False

# Initializes the HNSW base layer with the CAGRA graph. 
gpu_cagra_index.copyTo(cpu_hnsw_index)

# Add new vectors to the hierarchy.
newVecs = np.random.random((100000, 96)).astype('float32')
cpu_hnsw_index.add(newVecs)

For full code examples, see the Faiss cuVS notebook.

Get more from your vectors

The integration of NVIDIA cuVS into Faiss delivers substantial improvements in both speed and scalability for approximate nearest neighbors (ANN) search. Whether you’re working with inverted file (IVF) indexes or graph-based methods, Faiss integration of cuVS offers:

  • Faster index builds: Up to 12x acceleration on GPU
  • Lower search latency: Up to 4.7x improvement in real-time search
  • Effortless CPU-GPU interoperability: Build on GPU, search on CPU, and vice versa

The team has also introduced CAGRA, a high-performance, graph-based index purpose-built for GPUs, which outperforms classical CPU-based HNSW in both build time and throughput. Better still, CAGRA graphs can be converted to HNSW for efficient CPU-based inference—offering the best of both for hybrid deployments.

Whether you’re scaling search infrastructure to handle millions of queries per second or rapidly experimenting with new embedding models, integrating Faiss with cuVS gives you the tools to move faster, iterate smarter, and deploy confidently.

Ready to get started? Install the faiss-gpu-cuvs package and explore the example notebook.

Discuss (0)

Tags