北京 高端網(wǎng)站設(shè)計網(wǎng)站分析報告范文
Transformer升級之路:7、長度外推性與局部注意力
Transformer升級之路:8、長度外推性與位置魯棒性
Bias項的神奇作用:RoPE + Bias = 更好的長度外推性
長度外推
1.1 什么是長度外推性?
長度外推性=train short, test long
train short:1)受限于訓練成本;2)大部分文本的長度不會特別長,訓練時的max_length特別特別大其實意義不大(長尾)。
test long:這里long是指比訓練時的max_length長,希望不用微調(diào)就能在長文本上也有不錯的效果。
1.2 為了做到長度外推性,需要解決兩個主要問題:
1)預測時位置編碼的外推:沒見過的就無法保證很好的泛化,不僅學習式位置編碼如此;像正弦位置編碼、RoPE也有這樣的問題,它們自身雖然不用學習,但是會影響上層參數(shù)的學習;
2)預測時序列更長,導致注意力相比訓練時更分散:序列長度增大意味著attention分布的熵增大了,注意力更分散了;
1.3 長度外推性的推測
可見,長度外推性問題并不完全與設(shè)計一個良好的位置編碼等價。
然后,還有個問題是,雖然PE一直是transformer類模型中的重要的基礎(chǔ)組件,很多位置編碼也在嘗試做一些外推性的工作,但整體來看早期的LLM其實沒有特別關(guān)注或者說糾結(jié)長度外推性,直到后面各種NLG模型的崛起,尤其是ChatGPT的出現(xiàn),大家才驚覺原來上下文可以做的這么長了?
為什么目前市面上的LLM鮮有使用呢(據(jù)目前所知,好像只有BLOOM/MPT/采用了ALiBi)?可能的原因:
1)專注于長度外推性的工作主要是在21/22年后才逐漸出現(xiàn),效果尚未經(jīng)過充分檢驗;
2)長度外推性的評測指標與LLM的評測指標并不完全match:目前長度外推性主要看PPL,這其實不夠全面。PPL這類語言模型的指標,可能更關(guān)注局部上下文的預測,因此局部注意力相關(guān)的方案可能在這類評測上天然占優(yōu)。
3)目前的長度外推性工作似乎更多的在強調(diào)外推性如何如何,但更重要的應(yīng)該還是max_length內(nèi)的效果,從LLM的角度來看,應(yīng)該在保證max_length內(nèi)的效果后再去追求外推性。比如,從GLM的消融實驗來看,ALiBi的效果還是不如RoPE的。
一直的誤解
第一篇明確研究Transformer長度外推性的工作應(yīng)該是ALIBI,出自2021年中期,距今也不算太久。為什么這么晚(相比Transformer首次發(fā)表的2017年)才有人專門做這個課題呢?估計是因為我們長期以來,都想當然地認為Transformer的長度外推性是位置編碼的問題,找到更好的位置編碼就行了。
事實上,通過對比現(xiàn)有的一些位置編碼的外推效果,確實能找到支撐該觀點的一些論據(jù)。比如后面分享的多篇實驗效果顯示,相對位置編碼的長度外推性,平均好于絕對位置編碼的;像RoPE這樣的函數(shù)式相對位置編碼,又會比訓練式相對位置編碼的外推效果好些。所以看上去,似乎只要我們不斷優(yōu)化位置編碼形式,最終就能給Transformer提供更好的長度外推性,從而解決這個問題。然而,情況沒有那么樂觀,像RoPE算是外推能力較好的位置編碼,也只能外推10%到20%左右的長度而保持效果不變差,再長效果就會驟降。這個比例與預期差太遠了,設(shè)想中好歹能外推個幾倍長度才算是有價值的外推,所以不難想象,單靠改進位置編碼改進Transformer的長度外推性,就不知道要等多久才能實現(xiàn)更長的效果了。
在直覺上,相信很多讀者覺得像Sinusoidal或RoPE之類的函數(shù)式位置編碼,它們沒有訓練參數(shù),長度外推性應(yīng)該很好才對,但事實上并非如此,這類位置編碼并沒有在長度外推方面表現(xiàn)出什么優(yōu)勢。為什么會這樣呢?其實是大家在假設(shè)函數(shù)式位置編碼的外推性時,忘了它的基本前提——“光滑性”。
其實,外推性就是局部推斷整體,對此我們應(yīng)該并不陌生,泰勒級數(shù)近似就是經(jīng)典的例子,它只需要知道函數(shù)某點處若干階導數(shù)的值,就可以對一個鄰域內(nèi)的值做有效估計,它依賴的就是給定函數(shù)的高階光滑性(高階導數(shù)存在且有界)。但是Sinusoidal或RoPE是這種函數(shù)嗎?并不是。它們是一系列正余弦函數(shù)的組合,其相位函數(shù)是k/100002i/d,當2i/d≈0時,函數(shù)近似就是sink,cosk,這算是關(guān)于位置編碼k的高頻振蕩函數(shù)了,而不是直線或者漸近趨于直線之類的函數(shù),所以基于它的模型往往外推行為難以預估。能否設(shè)計不振蕩的位置編碼?很難,位置編碼函數(shù)如果不振蕩,那么往往缺乏足夠的容量去編碼足夠多的位置信息,也就是某種意義上來說,位置編碼函數(shù)的復雜性本身也是編碼位置的要求。
1.4 實現(xiàn)長度外推性的超強基線
長度外推性是一個訓練和預測的長度不一致的問題。具體來說,不一致的地方有兩點:
1、預測的時候用到了沒訓練過的位置編碼(不管絕對還是相對);
2、預測的時候注意力機制所處理的token數(shù)量遠超訓練時的數(shù)量。
第1點可能大家都容易理解,沒訓練過的就沒法保證能處理好,這是DL中很現(xiàn)實的現(xiàn)象,哪怕是Sinusoidal或RoPE這種函數(shù)式位置編碼也是如此。
關(guān)于第2點,可能讀者會有些迷惑,Attention理論上不就是可以處理任意長度的序列嗎?訓練和預測長度不一致影響什么呢?答案是熵,我們在《從熵不變性看Attention的Scale操作》也已經(jīng)分析過這個問題,越多的token去平均注意力,意味著最后的分布相對來說越“均勻”(熵更大),即注意力越分散;而訓練長度短,則意味著注意力的熵更低,注意力越集中,這也是一種訓練和預測的差異性,也會影響效果。事實上,對于相對位置編碼的Transformer模型,通過一個非常簡單的Attention Mask,就可以一次性解決以上兩個問題,并且取得接近SOTA的效果:
不難理解,這就是將預測時的Attention變?yōu)橐粋€局部Attention,每個token只能看到訓練長度個token。這樣一來,每個token可以看到的token數(shù)跟訓練時一致,這就解決了第2個問題,同時由于是相對位置編碼,位置的計數(shù)以當前token為原點,因此這樣的局部Attention也不會比訓練時使用更多的未知編碼,這就解決了第1個問題。所以,就這個簡單的Attention Mask一次性解決了長度外推的2個難點,還不用重新訓練模型,更令人驚嘆的是,各種實驗結(jié)果顯示,如果以它為baseline,那么各種同類工作的相對提升就弱得可憐了,也就是它本身已經(jīng)很接近SOTA了,可謂是又快又好的“超強基線”。
對于第二點:
其中m是訓練長度,n是預測長度。經(jīng)過這樣修改(下面簡稱為“l(fā)ogn縮放注意力”),注意力的熵隨著長度的變化更加平穩(wěn),緩解了這個不一致問題。個人的實驗結(jié)果顯示,至少在MLM任務(wù)上,“l(fā)ogn縮放注意力”的長度外推表現(xiàn)更好。
第1點不一致性,即“預測的時候用到了沒訓練過的位置編碼”,那么為了解決它,就應(yīng)該做到“訓練階段把預測所用的位置編碼也訓練一下”。一篇ACL22還在匿名評審的論文《Randomized Positional Encodings Boost Length Generalization of Transformers》首次從這個角度考慮了該問題,并且提出了解決方案。
論文的思路很簡單:隨機位置訓練 設(shè)N為訓練長度(論文N=40),M為預測長度(論文M=500),那么選定一個較大L>M(這是一個超參,論文L=2048),訓練階段原本長度為N的序列對應(yīng)的位置序列是[0,1,?,N?2,N?1],現(xiàn)在改為從{0,1,?,L?2,L?1}中隨機不重復地選N個并從小到大排列,作為當前序列的位置序列。
預測階段,也可以同樣的方式隨機采樣位置序列,也可以直接在區(qū)間中均勻取點(個人的實驗效果顯示均勻取點的效果一般好些),這就解決了預測階段的位置編碼沒有被訓練過的問題。不難理解,這是一個很樸素的訓練技巧(下面稱之為“隨機位置訓練”),目標是希望Transformer能對位置的選擇更加魯棒一些,但后面我們將看到,它能取得長度外推效果的明顯提升。筆者也在MLM任務(wù)上做了實驗,結(jié)果顯示在MLM上也是有效果的,并且配合“l(fā)ogn縮放注意力”提升幅度更明顯(原論文沒有“l(fā)ogn縮放注意力”這一步)。
1.5 長度外推性的新基準
Google去年在論文《Neural Networks and the Chomsky Hierarchy》專門提出的一個長度外泛化基準(下面簡稱該測試基準為“CHE基準”,即“Chomsky Hierarchy Evaluation Benchmark”),這給我們提供了理解長度外推的一個新視角。
這個基準包含多個任務(wù),分為R(Regular)、DCF(Deterministic Context-Free)、CS(Context-Sensitive)三個級別,每個級別的難度依次遞增,每個任務(wù)的簡介如下:
- Even Pairs,難度R,給定二元序列?