白交 發(fā)自 凹非寺
量子位 | 公眾號 QbitAI
Flash is all you need!
最近,一個超快且省內(nèi)存的注意力算法FlashAttention火了。
通過感知顯存讀取/寫入,F(xiàn)lashAttention的運行速度比PyTorch標準Attention快了2-4倍,所需內(nèi)存也僅是其5%-20%。
而它的表現(xiàn)還不止于此。
- 訓練BERT速度相較于MLPerf訓練記錄提升15%;
- 訓練GPT-2的速度提高3.5倍;
- 訓練Transformer的速度比現(xiàn)有基線快。
網(wǎng)友們紛紛表示驚嘆:Great Job!這項工作對我來說很有用。
來看看這是一項什么樣的研究~
FlashAttention
本文提出了一種IO感知精確注意力算法。
隨著Transformer變得越來越大、越來越深,但它在長序列上仍然處理的很慢、且耗費內(nèi)存。(自注意力時間和顯存復雜度與序列長度成二次方)
現(xiàn)有近似注意力方法,在試圖通過去犧牲模型質(zhì)量,以降低計算復雜度來解決該問題。
但存在一定的局限性,即不能提升運行時的訓練速度。
研究者認為,應該讓注意力算法具有IO感知,即考慮顯存級間的讀寫,比如大但慢的HBM(High Bandwidth Memory)技術(shù)與小但快的SRAM。
基于這樣的背景,研究人員提出了FlashAttention,具體有兩種加速技術(shù):按塊遞增計算即平鋪、并在后向傳遞中重新計算注意力,將所有注意力操作融合到CUDA內(nèi)核中。
FlashAttention使用平鋪來防止大的×注意力矩陣(虛線框)在GPU HBM上物化(materialization)。在外部循環(huán)中(紅色箭頭),F(xiàn)lashAttention循環(huán)通過K和V矩陣的塊,并將其加載到SRAM。
在每個區(qū)塊中,F(xiàn)lashAttention 循環(huán)Q矩陣的區(qū)塊(藍色箭頭)將其加載到 SRAM,并將注意力計算的輸出寫回 HBM。
這樣就產(chǎn)生了一種注意力算法,在實際耗時(wall-clock time)內(nèi),其內(nèi)存效率和速度都很高,相比于標準的注意力算法可以更少地訪問HBM。
結(jié)果比現(xiàn)有注意力算法都快
研究人員評估了FlashAttention來訓練Transformer的影響,包括訓練時間、模型準確性,以及注意力運行時間和內(nèi)存效率。
首先在訓練速度上。FlashAttention比MLPerf 1.1的BERT速度記錄高出15%。
在實現(xiàn)GPT-2上,比HuggingFace速度高出3倍,比Megatron的標準Transformer速度高出1.8倍,F(xiàn)lashAttention將LRA(long-range arena)的基準速度提高了2.4倍。
在模型質(zhì)量,F(xiàn)lashAttention將Transformer擴展到更長的序列,并且質(zhì)量更好。
長上下文的語言建模。
如圖所示,使用FlashAttention可以讓GPT-2上下文長度增加4倍的情況下,訓練時間還比Megatron-LM優(yōu)化實現(xiàn)快30%,同時也獲得了0.7的困惑度(困惑度越低,說明語言模型越好)。
長文檔分類
對較長序列的Transformer訓練可以提高MIMIC-III和ECtHR數(shù)據(jù)集的性能,比如序列長度為16K在MIMIC上比長度512多出4.3分。
MIMIC-III:包含重癥監(jiān)護室病人的出院總結(jié),每個都有多個標簽注釋;ECtHR:包含歐洲人權(quán)法案的法律案件;兩個數(shù)據(jù)集都包含很長的文本文件。
此外,還完成了第一個能在Path-X和Path-256任務(wù)中實現(xiàn)非隨機性能的Transformer模型。
之后,研究人員還完成了基準測試,測量FlashAttention和塊狀稀疏(Block-Sparse)FlashAttention的運行時間和內(nèi)存性能,并與帶有40GB HBM的A100 GPU上的各種注意力基線進行了比較。
結(jié)果顯示,F(xiàn)lashAttention的運行時間,比PyTorch注意力實現(xiàn)快3倍;在短序列情況下,F(xiàn)lashAttention在短序列中仍比近似和稀疏注意力運行得快;至于塊狀稀疏的FlashAttention,在所有的序列長度上都比現(xiàn)有注意力實現(xiàn)都快。
至于在顯存效率方面,F(xiàn)lashAttention比PyTorch注意力基線高20倍。
在64k序列長度、其他所有算法都已經(jīng)耗盡顯存的情況下,F(xiàn)lashAttention的效率仍比Linformer高2倍。
斯坦福博士一作
這篇研究來自斯坦福大學計算機系以及紐約州立大學布法羅分校。共同一作是兩位斯坦福計算機博士生Tri Dao和Dan Fu。
感興趣的朋友,可戳下方論文鏈接了解更多~
論文鏈接:
https://arxiv.org/abs/2205.14135
GitHub鏈接:
https://github.com/HazyResearch/flash-attention
參考鏈接:
https://twitter.com/tri_dao/status/1531437619791290369




