如何做網(wǎng)站遷移網(wǎng)絡(luò)營銷手段
機(jī)器學(xué)習(xí)課程學(xué)習(xí)周報(bào)八
文章目錄
- 機(jī)器學(xué)習(xí)課程學(xué)習(xí)周報(bào)八
- 摘要
- Abstract
- 一、機(jī)器學(xué)習(xí)部分
- 1.1 self-attention的計(jì)算量
- 1.2 人類理解代替自注意力計(jì)算
- 1.2.1 Local Attention/Truncated Attention
- 1.2.2 Stride Attention
- 1.2.3 Global Attention
- 1.2.4 聚類Query和Key
- 1.3 自動(dòng)選擇自注意力計(jì)算
- 1.4 Attention Matrix中的線性組合
- 1.5 通過矩陣乘法推導(dǎo)自注意力計(jì)算
- 1.6 Batch Normalization
- 總結(jié)
摘要
本周的學(xué)習(xí)重點(diǎn)是自注意力機(jī)制的計(jì)算優(yōu)化。我探討了如何通過Local Attention、Stride Attention、Global Attention等方法減少計(jì)算量。此外,還介紹了自動(dòng)選擇注意力計(jì)算和Attention Matrix的線性組合方法。最后,補(bǔ)充了Batch Normalization的知識(shí),為模型訓(xùn)練提供了更好的穩(wěn)定性。
Abstract
This week’s focus is on optimizing the computation of the self-attention mechanism. I explored methods like Local Attention, Stride Attention, and Global Attention to reduce computational load. Additionally, we discussed automatic selection of attention computation and linear combinations in the Attention Matrix. Lastly, we supplemented our understanding with Batch Normalization, enhancing model training stability.
一、機(jī)器學(xué)習(xí)部分
1.1 self-attention的計(jì)算量
如果現(xiàn)在自注意力模型輸入的序列長度為 N N N,則對(duì)應(yīng)的Query為 N N N個(gè),對(duì)應(yīng)的Key也為 N N N個(gè)。它們之間相互計(jì)算關(guān)聯(lián)性(即注意力分?jǐn)?shù)),可以得到上圖中的Attention Matrix,這個(gè)矩陣的復(fù)雜度是 N 2 {N^2} N2,當(dāng) N N N的數(shù)值很大時(shí),該矩陣的計(jì)算量就會(huì)變得很大。因此,這一節(jié)介紹多種方法以加速計(jì)算Attention Matrix的計(jì)算。
Notice:當(dāng) N N N很大時(shí),self-attention的計(jì)算才會(huì)主導(dǎo)整個(gè)模型中計(jì)算量。例如:在Transformer模型中,除了self-attention還有其他模塊的計(jì)算量,self-attention模塊的計(jì)算量占模型整體計(jì)算量是與 N N N有關(guān)的,當(dāng) N N N過小時(shí),對(duì)self-attention的改進(jìn)計(jì)算并不會(huì)明顯提高Transformer模型的運(yùn)算速度。
1.2 人類理解代替自注意力計(jì)算
根據(jù)人類對(duì)問題的理解,對(duì)Attention Matrix某些位置的值直接賦值,跳過計(jì)算步驟,從而減少計(jì)算量。
1.2.1 Local Attention/Truncated Attention
計(jì)算self-attention時(shí),并非計(jì)算整個(gè)序列間的self-attention分?jǐn)?shù),而是只看自己和左右的鄰居,其他的關(guān)聯(lián)性都設(shè)定為0。下圖在Attention Matrix中,表示為灰色的部分都人工設(shè)定為0,只計(jì)算藍(lán)色部分的self-attention分?jǐn)?shù)。這種方法叫做Local Attention或Truncated Attention。
Local Attention與CNN較為相似,主要體現(xiàn)在它們的局部關(guān)注機(jī)制上。這種機(jī)制使得模型在處理輸入數(shù)據(jù)時(shí),只關(guān)注輸入數(shù)據(jù)的局部區(qū)域,而不是整體。卷積神經(jīng)網(wǎng)絡(luò)(CNN)中,卷積層通過滑動(dòng)窗口的方式在輸入數(shù)據(jù)上提取特征。這種操作也可以看作是一種局部關(guān)注機(jī)制,通過卷積核僅關(guān)注輸入數(shù)據(jù)的局部區(qū)域來提取特征。Local attention相比于之前介紹的包含全序列的注意力,更加注重輸入數(shù)據(jù)的局部關(guān)系,與卷積核的滑動(dòng)也很類似。
1.2.2 Stride Attention
根據(jù)自己對(duì)問題的理解,計(jì)算局部的self-attention并不一定是左右鄰居,如下圖,可以是分別計(jì)算序列中兩步前或兩步后的關(guān)聯(lián)性,也可以是分別計(jì)算序列中一步前或一步后的關(guān)聯(lián)性,灰色的地方設(shè)定為0。這種方法叫做Stride Attention。
1.2.3 Global Attention
前面介紹的方法都是以某一個(gè)位置為中心,分別計(jì)算左右的關(guān)聯(lián)性。Global Attention注重于整個(gè)序列,其會(huì)添加特殊的token到原始的序列中,特殊的token分別與整個(gè)序列計(jì)算self-attention,具體做法有兩種:
- 從原來的token序列中,選擇一部分作為特殊的token。
- 外加一部分額外的token。
從上圖的Attention Matrix觀察得到,在原始的序列中,第一和第二個(gè)位置被選擇為特殊的token。從橫軸的角度看,第一和第二個(gè)位置的Query與整個(gè)序列的Key分別做了self-attention。從縱軸的角度看,序列每一個(gè)位置的Query都與第一和第二位置的Key做了self-attention?;疑奈恢迷O(shè)定為0。
在Big Bird中提出了Random attention并且將其與前面的Local Attention和Global Attention一并融合。
1.2.4 聚類Query和Key
第一步,根據(jù)相似度聚類Query和Key,上圖中根據(jù)不同顏色聚類為了4類。
第二步,相同類之間的Query和Key才做self-attention。
1.3 自動(dòng)選擇自注意力計(jì)算
通過神經(jīng)網(wǎng)絡(luò)學(xué)習(xí)出一個(gè)0-1矩陣,深色位置代表1,淺色位置代表0。只有深色位置計(jì)算self-attention,淺色位置不計(jì)算。
輸入序列中的每一個(gè)位置都通過一個(gè)神經(jīng)網(wǎng)絡(luò)產(chǎn)生一個(gè)長度為 N N N的向量,然后將這些向量拼起來得到大小為 N × N N \times N N×N的矩陣。然而現(xiàn)在這個(gè)由向量拼成得到的矩陣中的值,是連續(xù)值,要轉(zhuǎn)換為0-1矩陣,這一部分是可以微分的,所以可以通過學(xué)習(xí)得到,具體需要看Sinkhorn Sorting Network的論文。
1.4 Attention Matrix中的線性組合
計(jì)算Attention Matrix的Rank(秩),得到Low Rank,說明該矩陣的很多列是其它列的線性組合。由此可得,實(shí)際上并不需要 N × N N \times N N×N的矩陣,目前 N × N N \times N N×N的矩陣中包含很多重復(fù)的信息,也許可以通過減少Attention Matrix的大小(主要是列數(shù)量)實(shí)現(xiàn)減少運(yùn)算量。
選擇具有代表性的Key,得到K個(gè)Key,即得到大小為 N × K N \times K N×K的Attention Matrix。接下來考慮self-attention這一層的輸出,同樣地要從N個(gè)Value中挑出具有代表性的K個(gè)Value,一個(gè)Key對(duì)應(yīng)一個(gè)Value向量。然后用Value矩陣乘上Attention Matrix可以得到self-attention層的輸出。
為什么我們不能挑出K個(gè)代表的Query呢?
輸出序列的長度與Query的數(shù)量是一致的,如果減少Q(mào)uery的數(shù)量,輸出序列的長度就會(huì)變短。
挑選具有代表性的Key的方法為:
卷積降維和線性組合(K個(gè)向量是N個(gè)向量的K種線性組合,下圖右)
1.5 通過矩陣乘法推導(dǎo)自注意力計(jì)算
簡要復(fù)習(xí)一下自注意力機(jī)制的矩陣計(jì)算過程:第一步,輸入序列分別做三種不同的變換,得到 d × N d \times N d×N大小的Query和 d × N d \times N d×N大小的Key,其中 d d d是Query和Key的維度, N N N代表序列的長度。并得到 d ′ × N d' \times N d′×N大小的Value,其中特別用 d ′ d' d′表示Value的維度,是因?yàn)閂alue的維度可以與Query、Key不一樣。第二步, K T {K^{\rm T}} KT乘上 Q Q Q得到Attention Matrix,然后通過softmax做歸一化。第三步,用 V V V乘上歸一化后的Attention Matrix( A ′ A' A′)得到自注意力層的輸出 O O O。
如果我們先忽略softmax的操作,self-attention的計(jì)算方法就是上圖中第一行的計(jì)算過程,現(xiàn)在考慮第二行運(yùn)算,先算 V V V乘上 K T {K^{\rm T}} KT的結(jié)果,再乘上 Q Q Q,這樣的計(jì)算順序與第一行有何不同?得到的結(jié)果是一樣的,運(yùn)算量是不一樣的。
盡管 A ( C P ) = ( A C ) P A\left( {CP} \right) = \left( {AC} \right)P A(CP)=(AC)P,但是第一種計(jì)算方式的計(jì)算量是 1 0 6 {10^6} 106,第二種計(jì)算方式的計(jì)算量的 1 0 3 {10^3} 103,兩者計(jì)算量之間的差異很大。因此我們這里先忽略softmax操作,考慮self-attention中矩陣計(jì)算的改進(jìn)。
根據(jù)上圖證明, V ( K T Q ) V({K^{\rm T}}Q) V(KTQ)的計(jì)算量通常大于 ( V K T ) Q (V{K^{\rm T}})Q (VKT)Q的計(jì)算量。
接下來加入softmax,寫出計(jì)算self-attention的數(shù)學(xué)表達(dá)式:
下面通過數(shù)學(xué)證明的角度說明更換矩陣乘法順序,計(jì)算self-attention的過程:
還有一個(gè)問題是, exp ? ( q ? k ) ≈ Φ ( q ) ? Φ ( k ) \exp (q \cdot k) \approx \Phi (q) \cdot \Phi (k) exp(q?k)≈Φ(q)?Φ(k)是如何實(shí)現(xiàn)的,具體需要參考下面的論文。
1.6 Batch Normalization
在Transformer的編碼器中使用到了Layer Normalization,在上一周的周報(bào)中并將其與Batch Normalization做了比較,這里特別補(bǔ)充Batch Normalization的知識(shí)。
做標(biāo)準(zhǔn)化的原因是,希望能把不同維度的特征值規(guī)范到同樣的數(shù)值范圍,從而使得error surface比較平滑,更好訓(xùn)練。
Batch Normalization是對(duì)不同特征向量的同一維度,計(jì)算平均值和標(biāo)準(zhǔn)差,然后將特征值減去平均值再除以標(biāo)準(zhǔn)差,實(shí)現(xiàn)標(biāo)準(zhǔn)化。標(biāo)準(zhǔn)化后,同一維度上的數(shù)值的平均值是0,方差是1,接近高斯分布。
在神經(jīng)網(wǎng)絡(luò)中,輸入特征 x ~ 1 {\tilde x^1} x~1、 x ~ 2 {\tilde x^2} x~2、 x ~ 3 {\tilde x^3} x~3已經(jīng)做過了標(biāo)準(zhǔn)化,在經(jīng)過 W 1 {W^1} W1層后,且輸入 W 2 {W^2} W2層之前仍需要做標(biāo)準(zhǔn)化。至于是對(duì)激活函數(shù)前的 z 1 {z^1} z1、 z 2 {z^2} z2、 z 3 {z^3} z3還是之后的 a 1 {a^1} a1、 a 2 {a^2} a2、 a 3 {a^3} a3做標(biāo)準(zhǔn)化,差別不是很大。以 z 1 {z^1} z1、 z 2 {z^2} z2、 z 3 {z^3} z3為例, z 1 {z^1} z1、 z 2 {z^2} z2、 z 3 {z^3} z3都是向量,做標(biāo)準(zhǔn)化的方法如下:
μ = 1 3 ∑ i = 1 3 z i \mu = \frac{1}{3}\sum\limits_{i = 1}^3 {{z^i}} μ=31?i=1∑3?zi是對(duì)向量 z i {z^i} zi中對(duì)應(yīng)元素進(jìn)行相加,然后取平均。 σ = 1 3 ∑ i = 1 3 ( z i ? μ ) 2 \sigma = \sqrt {\frac{1}{3}\sum\limits_{i = 1}^3 {{{\left( {{z^i} - \mu } \right)}^2}} } σ=31?i=1∑3?(zi?μ)2?是向量 z i {z^i} zi與 μ \mu μ相減,然后逐元素平方,求和平均后,再對(duì)向量的逐元素開根號(hào)。如果直接看公式會(huì)有一些歧義,因?yàn)?span id="vxwlu0yf4" class="katex--inline"> z i {z^i} zi、 μ \mu μ、 σ \sigma σ都是向量,其中的求和,平方,開根號(hào)都是對(duì)向量中逐元素操作。最后標(biāo)準(zhǔn)化公式為:
z ~ i = z i ? μ σ {{\tilde z}^i} = \frac{{{z^i} - \mu }}{\sigma } z~i=σzi?μ?
實(shí)際上,GPU的內(nèi)存不足以把整個(gè)dataset的數(shù)據(jù)一次性加載。因此,只考慮一個(gè)batch中的樣本,對(duì)一個(gè)batch中的樣本做Batch Normalization。在inference中,不可能等到整個(gè)batch數(shù)量的輸入才做推理,具體方法為:在訓(xùn)練時(shí)計(jì)算 μ \mu μ和 σ \sigma σ的moving average,訓(xùn)練時(shí)的第一個(gè)batch為 μ 1 {\mu^1} μ1,第二個(gè)batch為 μ 1 {\mu^1} μ1,直到第t個(gè)batch為 μ t {\mu^t} μt,且不斷地計(jì)算moving average:
μ ˉ ← p μ ˉ + ( 1 ? p ) μ t \bar \mu \leftarrow p\bar \mu + \left( {1 - p} \right){\mu ^t} μˉ?←pμˉ?+(1?p)μt
inference中標(biāo)準(zhǔn)化的公式變?yōu)?#xff1a;
z ~ i = z i ? μ ˉ σ ˉ {{\tilde z}^i} = \frac{{{z^i} - \bar \mu }}{{\bar \sigma }} z~i=σˉzi?μˉ??
總結(jié)
通過本周的學(xué)習(xí),我對(duì)自注意力機(jī)制的優(yōu)化策略有了更深入的了解,不同的注意力方法提供了多樣化的計(jì)算選擇,有助于提高模型的效率。下周還會(huì)圍繞自注意力機(jī)制進(jìn)行拓展學(xué)習(xí)。