Agentic AI / Generative AI

Accelerating Long-Context Inference with Skip Softmax in NVIDIA TensorRT-LLM

For machine learning engineers deploying LLMs at scale, the equation is familiar and unforgiving: as context length increases, attention computation costs explode. Whether you’re dealing with retrieval-augmented generation (RAG) pipelines, agentic AI workflows, or long-form content generation, the O(N^2) complexity of attention remains a primary bottleneck.

This post explains a technique known as Skip Softmax, a hardware-friendly, drop-in sparse attention method that accelerates inference without any retraining. Read on to learn how Skip Softmax delivers up to 1.4x faster time-to-first-token (TTFT), and up to 1.4x faster time-per-output-token (TPOT), and how to get started with the technique in NVIDIA TensorRT-LLM.

How does Skip Softmax work?

At its core, Skip Softmax provides a dynamic way to prune attention blocks. This is possible as it exploits a fundamental property of the Softmax function: \exp(\text{small negative number}) \approx 0.

In standard FlashAttention, the GPU computes attention scores (logits) for blocks of queries (Q) and keys (K). It then applies softmax to normalize these scores into probabilities (P) and multiplies them by values (V).

However, attention is intrinsically sparse. For many blocks, the attention scores are so low compared to the dominant tokens that their contribution to the final output is statistically negligible. Skip Softmax modifies the FlashAttention loop to detect these blocks early and simply skips them.

The Skip Softmax algorithm

Implemented directly within the FlashAttention kernel, the logic follows this heuristic:

  1. Compute local max: Calculate the maximum logit for the current block (Q \cdot K^T).
  2. Compare to running max: Check if the difference between the current block’s local max (m_{i}^{(j)}) and the running global max (m_{i}^{j-1}) exceeds a calibrated threshold (\lambda).
  3. Skip: If the condition is met, the kernel skips the softmax and BMM2 calculation for that block and, crucially, skips loading the V block from High Bandwidth Memory (HBM).

What are the benefits of using Skip Softmax?

Skip Softmax offers drop-in compatibility, hardware efficiency, flexibility, and versatility. 

Unlike approaches that need specific architectural modifications (such as Linear Attention), Skip Softmax is compatible with existing pretrained models that use standard attention mechanisms like MHA, GQA, or MLA. It is optimized to leverage the specific tensor core and memory hierarchy of NVIDIA Hopper and NVIDIA Blackwell GPUs. It can also be integrated with other optimization methods. For instance, combining XAttention during prefill with Skip Softmax during decoding has been shown to deliver substantial speed improvements without compromising accuracy.

Skip Softmax is versatile because it addresses bottlenecks in both the prefill and decode phases. Based on performance data on Hopper and Blackwell architectures, Skip Softmax is beneficial during bandwidth-bound decoding and compute-bound prefilling, especially in long-context scenarios.

Bandwidth-bound decoding

During the generation (decode) phase, LLM inference is typically bound by memory bandwidth. The GPU spends more time moving KV cache data than computing.

  • Benefit: By identifying unimportant blocks early, Skip Softmax avoids loading the associated V blocks entirely.
  • Data: On Llama 3.3 70B (NVIDIA GB200 NVL72), Skip Softmax achieves a projected 1.36x end-to-end speedup during decoding.

Compute-bound prefilling

During the prefill phase (processing the input prompt), the system is compute-bound.

  • Benefit: Skipping the softmax and the second matrix multiplication (BMM2) saves significant FLOPs.
  • Data: For the same Llama 3.3 70B model (NVIDIA GB200 NVL72), prefill sees an estimated 1.4x end-to-end speedup at 128K context length.

Long-context scenarios

The efficacy of Skip Softmax increases with sequence length. The threshold for skipping is mathematically related to the context length (L) by the relationship \text{Threshold} \propto 1/L. This means that, as context grows, the opportunity to safely identify and skip sparse blocks increases.

The tradeoff between accuracy and sparsity

The obvious question for any approximation technique is, “How does this approach impact accuracy?”

Extensive testing on the RULER (synthetic long-context) and LongBench (realistic long-context) benchmarks suggests a clear “safe zone” for sparsity.

  • Safe zone: A 50% sparsity ratio (skipping half the blocks) is observed to be the safe zone. In tests with Llama 3.1 8B and Qwen3-8B, running at ~50% sparsity resulted in near-lossless accuracy across most tasks.
  • Danger zone: Pushing sparsity beyond 60% often leads to sharp accuracy drops, particularly in complex “needle-in-a-haystack” multikey tasks.
  • Long generation: For tasks requiring long output generation such as MATH-500, Skip Softmax maintains accuracy parity with dense attention, unlike some static KV cache compression methods.
ModelDatasetSparsityAccuracy delta versus baseline
Llama 3.1 8BRULER-16K~50% at prefill stage-0.19% 
Qwen-3-8BMATH500~50% at decode stage0.36% 
Table 1. Accuracy delta versus baseline without sparsity
ScenarioThresholdSpeedup (BF16)Baseline accuracySparse accuracyAccuracy delta
Context only0.2130.63%37.21%36.74%-0.47%
Context plus generation0.6138.37%35.81%34.42%-1.39%
Table 2. Speedup with Qwen3-30B-Instruct model at a massive 128K sequence length

Additional optimizations while deploying include the following: 

  • Automated calibration procedures to determine the optimal thresholds for target sparsity levels. 
  • Sparsity-aware training makes models more robust to sparse attention patterns. 

Get started with Skip Softmax in NVIDIA TensorRT-LLM

Skip Softmax Attention is integrated directly into NVIDIA TensorRT-LLM and supported on NVIDIA Hopper and NVIDIA Blackwell data center GPUs. This enables you to further accelerate the attention computation, on the basis of the state-of-the-art LLM inference performance powered by TensorRT-LLM.

Skip Softmax Attention can be enabled through the sparse attention configuration of the LLM API:

from tensorrt_llm import LLM
from tensorrt_llm.llmapi import SkipSoftmaxAttentionConfig
sparse_attention_config = SkipSoftmaxAttentionConfig(threshold_scale_factor=1000.0)
# Additionally, the threshold_scale_factor for prefill and decode could be separately configured.
sparse_attention_config = SkipSoftmaxAttentionConfig(threshold_scale_factor={"prefill": 1000.0, "decode": 500.0})
llm = LLM(
   model="Qwen/Qwen3-30B-A3B-Instruct-2507",
   sparse_attention_config=sparse_attention_config,
   # Other LLM arguments...
)

The actual threshold value equals the threshold_scale_factor divided by the context length.

The configuration could also be specified through the extra LLM API options YAML file. An example to launch an OpenAI-compatible endpoint is shown below:

cat >extra_llm_api_options.yaml <<EOF
sparse_attention_config:
    algorithm: skip_softmax
    threshold_scale_factor: 1000.0
EOF

# Additionally, the threshold_scale_factor for prefill and decode could be separately configured.
cat >extra_llm_api_options.yaml <<EOF
sparse_attention_config:
    algorithm: skip_softmax
    threshold_scale_factor: 
        prefill: 1000.0
        decode: 500.0
EOF

trtllm-serve Qwen/Qwen3-30B-A3B-Instruct-2507 --extra_llm_api_options extra_llm_api_options.yaml

Learn more

To learn more, see BLASST: Dynamic Blocked Attention Sparsity via Softmax Thresholding, as well as the TensorRT-LLM documentation for LLM API and CLI. The calibration will be supported by NVIDIA Model Optimizer, which enables users to specify the target sparsity and get the desired threshold scale factors. 

The Skip Softmax sparse attention kernel will also be available through the FlashInfer Python API. Stay tuned for the official release in the upcoming TensorRT-LLM, Model Optimizer, and FlashInfer release update.

Discuss (0)

Tags