Accelerate Attention with FlashAttention-3: New Capabilities and Performance
Attention, as a core layer of the ubiquitous Transformer architecture, is a bottleneck for large language models and long-context applications. FlashAttention and FlashAttention-2 pioneered an approach to speed up attention on GPUs by minimizing memory reads/writes, and is now used by most libraries to accelerate Transformer training and inference. This has contributed to a massive increase in LLM context length in the last two years, from 2-4K (GPT-3, OPT) to 128K (GPT-4) or even 1M (Llama 3). However, despite its success, FlashAttention has yet to take advantage of new capabilities in modern hardware, with FlashAttention-2 achieving only 35% utilization of theoretical max FLOPs on the H100 GPU.
In this blogpost, we describe three main techniques to speed up attention on Hopper GPUs: exploiting the asynchrony of the Tensor Cores and TMA to (1) overlap overall computation and data movement via warp-specialization and (2) interleave block-wise matmul and softmax operations, and (3) incoherent processing that leverages hardware support for FP8 low-precision. We’re excited to release FlashAttention-3 that incorporates these techniques. It’s 1.5-2.0x faster than FlashAttention-2 with FP16, up to 740 TFLOPS, i.e., 75% utilization of H100 theoretical max FLOPS. With FP8, FlashAttention-3 reaches close to 1.2 PFLOPS, with 2.6x smaller error than baseline FP8 attention.
The improvements from FlashAttention-3 will result in more efficient GPU utilization: the new technique uses up to 75% of an H100 GPU's maximum capabilities, up from just 35% before. This results in significantly (1.5-2x) faster training and running of large language models (LLMs). Better performance with lower precision: FlashAttention-3 can work with lower precision numbers (FP8) while maintaining accuracy. This allows for even faster processing and potentially lower memory usage, leading to cost savings and improved efficiency for customers running large-scale AI operations.
Moreover, by speeding up the attention mechanism, FlashAttention-3 enables AI models to work with much longer pieces of text more efficiently. This could allow for applications that can understand and generate longer, more complex content without slowing down. FlashAttention-3 is available on GitHub.
FlashAttention is an algorithm that reorders the attention computation and leverages tiling and recomputation to significantly speed it up and reduce memory usage from quadratic to linear in sequence length. We use tiling to load blocks of inputs from HBM (GPU memory) to SRAM (fast cache), perform attention with respect to that block, and update the output in HBM. By not writing the large intermediate attention matrices to HBM, we reduce memory reads/writes, resulting in a 2-4x wallclock time speedup.
FlashAttention-3 makes use of all these new features of Hopper, using powerful abstractions from NVIDIA’s CUTLASS library. By rewriting FlashAttention to utilize these new features, we can already significantly speed it up (e.g., from 350 TFLOPS in FlashAttention-2 FP16 forward pass to around 540-570 TFLOPS). However, the asynchronous nature of the new instructions on Hopper (WGMMA and TMA) opens up additional algorithmic opportunities to overlap operations and thereby extract even greater performance.
Enhance Batch Inference API: New UI and Model Support
Together AI Achieves 90% Faster Training with NVIDIA Blackwell
Related articles
Models Without Labels: Training a Classifier with Minimal Data
The study shows how a minimal number of labels can aid in training a classifier.
Advanced Data Aggregation Techniques for Business Analytics
Analysis of advanced data aggregation methods for business analytics and risk management.
A Practical Guide to Memory for Autonomous LLM Agents
Exploring the memory architecture for autonomous LLM agents and its impact on performance.