Large language models (LLMs) are rapidly expanding their context windows, with recent models supporting sequences of 128K tokens, 256K tokens, and beyond. However, training these models with extended context lengths presents significant computational and communication challenges. As context lengths grow, the memory and communication overhead of attention mechanisms scale quadratically, creating bottlenecks that traditional parallelism strategies struggle to address efficiently.
This post demonstrates that integrating the NVSHMEM communication library into Accelerated Linear Algebra (XLA) compiler optimizes context parallelism. This integration enables the efficient training of Llama 3 8B model in JAX framework with sequences up to 256K tokens. Our results show that NVSHMEM provides up to 36% speedup over NVIDIA Collective Communications Library (NCCL) for long-context training workloads, particularly when combined with tensor parallelism across multiple nodes.
The long-context training challenge
To understand why NVSHMEM provides significant speedups for long-context training, it’s necessary to first understand how context parallelism works and the unique communication patterns it creates. This section explains why the fine-grained, latency-sensitive communication of ring attention makes it an ideal candidate for optimization.
Context parallelism and ring attention
Context parallelism (CP) is a parallelization strategy designed specifically for handling long sequences in transformer models. Unlike data parallelism, which splits the batch, or tensor parallelism, which splits the model, context parallelism splits the sequence dimension across multiple devices.
Ring attention is a memory-efficient implementation of context parallelism that uses a ring-based communication pattern. During attention computation, each device:
- Processes its local portion of the sequence
- Exchanges Key Value (KV) tensors with neighboring devices in a ring topology
- Incrementally computes attention scores as KV blocks circulate around the ring
This approach reduces peak memory usage while maintaining mathematical equivalence to standard attention, making it possible to train with sequences that would otherwise exceed GPU memory capacity.
Communication patterns in ring attention
Ring attention involves frequent, fine-grained communication operations:
- Point-to-point transfers: Sending KV tensors to the next device in the ring
- Overlapped compute-communication: Computing attention on current KV blocks while fetching the next blocks
- Low-latency requirement: KV transfers are on the critical path and must complete before attention can proceed
These characteristics make ring attention an ideal candidate for low-latency communication libraries like NVSHMEM.
GPU-optimized communication with NVSHMEM
NVSHMEM is a communication library that implements the OpenSHMEM parallel programming model for NVIDIA GPUs. It provides several key features that distinguish it from traditional communication libraries, including symmetric memory (SM), stream-aware communication, copy engine offloading, and more, as detailed below.
Symmetric memory
NVSHMEM provides a partitioned global address space (PGAS) resident in GPUs memories. Applications allocate buffers from this symmetric heap using nvshmem_malloc, and these pointers can be directly used in communication operations. For example:
int32_t *src_d = (int32_t *)nvshmem_malloc(1024 * sizeof(int));
int32_t *dest_d = (int32_t *)nvshmem_malloc(1024 * sizeof(int));
ret = nvshmemx_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD, dest_d, src_d, 1024, 0);

Stream-aware communication
NVSHMEM provides peer-to-peer (P2P) on-stream APIs (such as put_nbi_on_stream and signal_on_stream) to efficiently move data and provide low-latency synchronization over P2P-connected GPUs.
One of the key advantages of these APIs over traditional host-initiated communication is their ability to perform these operations through a zero-SM footprint by leveraging the copy-engine (CE) and stream memory operations capabilities of GPU hardware. Some of the underlying CUDA interfaces include:
- Direct GPU-to-GPU transfers: Similar to
cudaMemcpyAsync, but with lower latency through optimized data paths - Fine-grained synchronization: Using
cuStreamWriteValue32andcuStreamWaitValue32primitives for efficient signaling between devices without CPU involvement
In addition to the P2P on-stream APIs, NVSHMEM also provides popular collective operations (reduce_on_stream, for example) commonly used in AI workloads such as AllReduce. These collectives leverage SHARP, in-network reductions, and multicast acceleration features of NVIDIA NVLINK Switch to enable latency-optimized one-shot and throughput-optimized two-shot AllReduce algorithms. The underlying CUDA interface includes multimem ISA, providing additional benefits of a reduced-SM footprint as primitives such as reductions and broadcast are offloaded to the switch.
Both of these features can demonstrate useful compute-communication operations pipelining as most or all of the GPU SMs are available for compute operations, when overlapped in time on the same CUDA stream.
CUDA Graphs interoperability
NVSHMEM operations can be captured into CUDA Graphs, enabling:
- Amortized kernel launch overhead across multiple iterations
- Optimized execution scheduling by the CUDA runtime
- Seamless composition with other graph-captured operations
This composability is crucial for production training frameworks that rely on CUDA Graphs for performance optimization.
Integrating NVSHMEM and XLA
This section describes how NVSHMEM is integrated into the XLA compiler infrastructure, covering runtime flags, automatic backend selection heuristics, and the compilation flow.
Runtime control through debug options
XLA exposes a runtime flag for dynamic control:
XLA_FLAGS="--xla_gpu_experimental_enable_nvshmem=true"
This flag is defined in xla/debug_options_flags.cc and allows users to enable or disable NVSHMEM without recompilation (default value = false). The “experimental” prefix indicates that the API may evolve as the feature matures.
Automatic backend selection
The CollectiveBackendAssigner pass in the compilation pipeline determines which communication backend to use based on workload characteristics. This is where the intelligence of this system lies.
Selection heuristics
The compiler analyzes each collective operation and decides whether to use NVSHMEM based on three key criteria:
- Single device: Use NVSHMEM when only one device is visible per process (no network overhead)
- Single partition: Use NVSHMEM when all participating devices in the collective operation are managed by the same process
- NVLink domain: Use NVSHMEM for intranode communication over NVIDIA NVLink fabric
Additionally, message size heuristics apply:
- AllReduce operations: Only use NVSHMEM if message size < threshold (typically 16 MB). For larger messages, fall back to NCCL which is optimized for bandwidth.
- CollectivePermute operations: Always use NVSHMEM regardless of message size (no threshold applied).
- Rationale: AllReduce benefits from NCCL ring or tree algorithms for large messages, while CollectivePermute point-to-point nature makes NVSHMEM low latency ideal at any size.
JAX framework integration
The strength of this architecture lies in its complete transparency to Python frameworks. A JAX developer writes standard collective operations:
import jax
import jax.numpy as jnp
@jax.jit
def collective_permute_example(x):
# Shift data from each device to the next device in a ring
axis_name = 'devices'
perm = [(i, (i + 1) % jax.device_count()) for i in range(jax.device_count())]
return jax.lax.ppermute(x, axis_name, perm=perm)
# The compiler automatically selects NVSHMEM when appropriate
result = collective_permute_example(data)
The XLA compiler analyzes this ppermute (collective permute) operation and automatically with the following steps:
- Applies heuristics: single device, single partition, or within NVLink domain
- Recognizes a CollectivePermute operation (no message size threshold applies)
- Selects NVSHMEM for optimal point-to-point communication
- Generates thunks that invoke NVSHMEM host APIs at runtime
- NVSHMEM host APIs enqueue operations on CUDA streams. For example:
nvshmemx_float_sum_reduce_on_stream,nvshmemx_float_put_nbi_on_stream
This end-to-end integration means that high-level JAX code automatically benefits from NVSHMEM performance without requiring any user-level changes or annotations.
Experimental methodology
To evaluate NVSHMEM performance benefits, the team conducted experiments on Llama 3 8B across a range of sequence lengths (64K to 256K tokens) and parallelism configurations. This section details the model setup, hardware configuration, and the metrics used to compare NVSHMEM against the NCCL baseline.
Model configuration
The team evaluated NVSHMEM-accelerated context parallelism on the Llama 3 8B model with the following configurations.
- Model: Llama 3 8B
- Precision: BF16
- Context parallel strategy: Ring attention
- Framework: MaxText (JAX-based training framework)
- Hardware: NVIDIA GB200 NVL72
- Docker image: Available through NVIDIA/JAX-Toolbox
- JAX version: JAX 0.6.2 and later
Parallelism configurations
Various combinations of parallelism strategies were tested across different sequence lengths (Table 1).
| Sequence length | Nodes | GPUs | Context parallelism | Tensor parallelism | Fully sharded data parallelism | Sequence length per GPU after CP split |
| 64K | 1-4 | 4-16 | 4-16 | 1 | 1-2 | 4K-16K |
| 128K | 2-8 | 8-32 | 8-32 | 1 | 1-2 | 4K-16K |
| 256K | 8-16 | 32-64 | 16-32 | 2 | 1-2 | 8K-16K |
Longer sequences (256K) employed tensor parallelism (TP=2) in addition to context parallelism to fit the model within GPU memory constraints.
Communication backend comparison
Each configuration was evaluated with two communication backends:
- NCCL (baseline)
- NVSHMEM-enabled implementation
Measurements:
- TFLOP/s per device: GPU computational throughput
- Step time (seconds): Time per training iteration
- Speedup: Relative performance improvement of NVSHMEM over NCCL
All metrics were averaged across iterations 3-20 (skipping the first two warmup iterations) and computed from rank 0 logs to ensure consistency.
Performance results
As shown in Table 2, the NVSHMEM performance advantage grows significantly with sequence length:
- 64K sequences: 0.3-3.9% speedup (modest improvement)
- 128K sequences: 0.7-2.4% speedup (consistent improvement)
- 256K sequences: 30.4-36.3% speedup (dramatic improvement)
This scaling behavior aligns with the ring attention communication pattern: longer sequences require more KV tensor exchanges around the ring, amplifying the benefits of the NVSHMEM lower-latency communication.
When scaling across nodes, internode communication latency becomes more critical. NVSHMEM nonblocking host APIs and optimized data paths provide consistent benefits across 8-16 node deployments.
| Sequence length | Nodes | CP | TP | GPUs | Seq/GPU | Default TFLOP/s | NVSHMEM TFLOP/s | Speedup |
| 64K | 1 | 4 | 1 | 4 | 16K | 605.64 | 607.36 | +0.3% |
| 64K | 2 | 8 | 1 | 8 | 8K | 549.92 | 557.17 | +1.3% |
| 64K | 4 | 16 | 1 | 16 | 4K | 482.19 | 501.06 | +3.9% |
| 128K | 2 | 8 | 1 | 8 | 16K | 512.22 | 515.87 | +0.7% |
| 128K | 4 | 16 | 1 | 16 | 8K | 473.58 | 472.46 | -0.2% |
| 128K | 8 | 32 | 1 | 32 | 4K | 420.99 | 431.13 | +2.4% |
| 256K | 8 | 16 | 2 | 32 | 16K | 366.94 | 500.22 | +36.3% |
| 256K | 16 | 32 | 2 | 64 | 8K | 346.33 | 451.70 | +30.4% |
Practical implications
Based on these results, NVSHMEM provides clear advantages for:
- Long-context training: Sequences ≥ 128K tokens where communication becomes a bottleneck
- Multinode deployments: Scaling beyond single-node NVLink domains
- Ring attention and similar patterns: Workloads with fine-grained, latency-sensitive communication
- Hybrid parallelism: Configurations combining CP, TP, and FSDP
The XLA integration makes NVSHMEM accessible to JAX. No user code changes are required, simply use an NVSHMEM-enabled XLA build and set the appropriate environment flags.
Get started with long-context model training
Training LLMs with long-context windows requires efficient communication strategies that can handle fine-grained, latency-sensitive data exchanges. The integration of NVSHMEM into XLA enables transparent acceleration of context parallelism with ring attention, providing up to 36% speedup for 256K token sequences on Llama 3 8B.
Key takeaways:
- The NVSHMEM nonblocking host APIs and low-latency data paths are ideally suited for the ring attention communication pattern
- XLA compiler integration makes NVSHMEM accessible to high-level frameworks without requiring code changes
- Performance benefits scale with sequence length, with dramatic improvements for sequences ≥ 256K tokens
- Multinode deployments see the largest gains, making NVSHMEM essential for production long-context training
As context windows continue to grow, solutions optimizing low-latency communication like NVSHMEM will be crucial for making long-context training practical and cost-effective. We encourage the community to try NVSHMEM-enabled XLA builds in JAX framework and share their experiences with long-context workloads.
To get started, check out MaxText Framework, NVIDIA/JAX-Toolbox, and openxla/xla on GitHub.
Acknowledgments
We would like to express our gratitude to NVSHMEM contributors Seth Howell and Akhil Langer.