Agentic AI / Generative AI

Overcoming Compute and Memory Bottlenecks with FlashAttention-4 on NVIDIA Blackwell 

A person typing on a computer.

Transformer architecture has become a foundational breakthrough driving the revolution in generative AI, powering large language models (LLMs) like GPT, DeepSeek, and Llama. The key to transformer architecture is the self-attention mechanism, which enables models to process an entire input sequence rather than word by word. This parallelism enables the capture of long-range dependencies. 

While the self-attention mechanism is powerful, its computational and memory complexity is quadratic. This creates a memory bottleneck when dealing with the long context windows of modern LLMs.

In this post, we’ll discuss FlashAttention, an algorithmic breakthrough that can mitigate this, reducing computational and memory complexity.

What is FlashAttention?

FlashAttention is an input/output-aware (IO-aware) algorithm that computes the same mathematical result as standard attention, more efficiently. FlashAttention achieves this with:

  • Reduced memory access that minimizes the slow transfer of data between a GPU’s main high-bandwidth memory (HBM) and the faster but much smaller on-chip static random access memory (SRAM). It achieves this by combining computational steps (like matrix multiplication and softmax) into a single optimized GPU kernel. A technique called kernel fusion.
  • Near-linear memory uses techniques such as tiling (breaking the computation into smaller blocks) and online softmax (normalizing the distribution incrementally). FlashAttention reduces the memory complexity from O(N2) to O(N) with respect to sequence length N.

These optimizations lead to faster training and inference. This also enables models to handle longer sequences of tokens, for applications that require maintaining long-running conversations, like processing high-resolution images.

A performance bar chart comparing execution times, illustrating how FlashAttention consolidates multiple discrete steps (MatMul, Dropout, Softmax) into a single "Fused Kernel" to achieve a 7.6x speedup over PyTorch. The visual data underscores the impact of kernel fusion and IO-aware optimization, highlighting a 20x reduction in memory usage that enables the processing of significantly longer sequences without hardware bottlenecks.
Figure 1. FlashAttention achieves 7.6x faster execution and 20x lower memory usage than standard PyTorch baselines
A technical schematic mapping the optimized FlashAttention data flow across the B200 architecture, illustrating how the new Tensor Memory hierarchy reduces latency for massive AI workloads.
Figure 2. Flash-Attention algorithm with NVIDIA HGX B200

FlashAttention-4 

FlashAttention-4 (FA4) is the latest iteration of optimized CUDA kernels, with a leap in efficiency. It’s hardware-software co-designed and tailored to maximize performance on the NVIDIA Blackwell architecture, like the NVIDIA HGX B200

FA4 achieves a peak performance of 1,605 TFLOPS/s, harnessing 71% of the hardware’s theoretical maximum. By redesigning the attention mechanism to address Blackwell’s asymmetric scaling (where compute power scales much faster than memory bandwidth), FA4 outperforms standard baselines, delivering up to 1.3x speedup over NVIDIA cuDNN and 2.4x over NVIDIA Triton Inference Server implementations. 

These gains extend to the backward pass, where FA4 uses tensor memory (TMEM) dedicated, Tensor Core—proximate memory (more available register capacity)—to bypass register accumulation and relieve register pressure. This enables larger tiles (up to 128×128) and deeper pipelines, while reducing shared memory (SMEM) traffic and maximizing operation overlap. This ensures that the training speed keeps pace with the doubled throughput of the new Tensor Cores rather than being bottlenecked by memory logistics. 

FA4 co-designs the algorithm and kernel implementation around the following new features and mitigation strategies for Blackwell:

Blackwell hardware featureBottleneckFA4 technique
TMEM – 256 KB on-chip memory per SM; 5th-gen tensor cores asynchronously write outputs directly to TMEMStandard backward passes overuse shared memory (SMEM) for intermediates, creating a bandwidth bottleneck relative to tensor coresTMEM-based backward pass: FA4 stores backward intermediates (S, P, dP, dS, dQ) directly in TMEM, drastically reducing SMEM traffic
SMEMSMEM bandwidth becomes limiting as tensor core performance scales faster than memory movementReduced SMEM pressure by relocating intermediates to TMEM
Asymmetric scalingTensor Core throughput roughly doubles (~2.25 PFLOPs), while MUFU throughput remains unchanged from the prior generation (16 ops/clock)Compute rebalancing to reduce reliance on MUFU-heavy paths
Exponential units (MUFU)Softmax exponentials dominate runtime, exceeding matmul time by ~25–60%Software-emulated exponentials using FMA-based polynomial approximations alongside MUFU
Expanded MMA tile size (128×128)Larger tiles increase register pressure and impose stricter scheduling constraintsNew CTA scheduling and register allocation, including LPT scheduling for causal masking
Fully asynchronous tensor coresSequential MMA–softmax dependencies can leave compute units idle if not overlappedRedesigned asynchronous pipelines to maximize overlap across MMA, softmax, and memory operations
Finite non-matmul resourcesNon-matmul ALUs scale more slowly than tensor coresAlgorithmic minimization of non-matmul work
Online softmaxRedundant rescaling wastes non-matmul cyclesConditional softmax rescaling, updating only when the running max crosses a threshold
CUDA 13 and CUDA-X toolingKernel complexity slows tuning and optimizationKernel-level graphs and performance tools used to optimize FA4 kernels
Developer productivityComplex C++ templates slow compile times and hinder iterationCuTe DSL in Python, achieving 20–30× faster compile times than FA3 while preserving kernel expressivity
Table 1. Blackwell hardware features and performance bottlenecks targeted by the FA4 technique

The forward and backward pass performance gains on a Blackwell GPU for different sequence sizes are shown in Figures 1 and 2, respectively.

Three-color bar charts with each color representing FA2, cuDNN, and FA4 performance, respectively. Increasing TFLOPS are represented along the Y-axis, and increasing Sequence Lengths across the X-axis.
Figure 3. Forward pass TFLOPS on a B200 with a head dimension of 128. FA4 achieves 3.6x speedup over FA2 with a sequence length of 32,768
Multi-color bar charts with each color representing FA2, Triton, Gluon, cuDNN, and FA4 performance, respectively. Increasing TFLOPS are represented along the Y-axis, and increasing Sequence Lengths across the X-axis.
Figure 4. Backward pass TFLOPS on B200 with bead dimension of 128. FA4 achieves 3.15 speedup over FA2 with a sequence length of 32,768

Learn more

The FlashAttention-4 algorithm was developed using a hardware-software co-design and kernel pipeline that mitigates bottlenecks induced by modern accelerators. FA4 uses the NVIDIA Blackwell Tensor Core and Tensor Memory architecture to increase performance and power efficiency, especially in multi-GPU multi-Node (MGMN) distributed configurations. The forward and backward pass kernel design incorporates various optimizations that achieve speedups over previous versions of FlashAttention algorithms.

Inference frameworks such as SGLang and vLLM are compatible with FlashAttention-4 prefill,  and NVIDIA has incorporated FA4 techniques into NVIDIA cuDNN 9.14.

Learn more about cuDNN and unlocking deep learning performance on Blackwell with cuDNN.

Discuss (0)

Tags