網(wǎng)站域名服務器查詢百度知道提問
1、蒸餾溫度T
正常的模型學習到的就是在正確的類別上得到最大的概率,但是不正確的分類上也會得到一些概率盡管有時這些概率很小,但是在這些不正確的分類中,有一些分類的可能性仍然是其他類別的很多倍。但是對這些非正確類別的預測概率也能反應模型的泛化能力,例如,一輛寶馬車的圖片,只有很小的概率被誤識別成垃圾車,但是被識別成垃圾車的概率還是比錯誤識別成胡蘿卜的概率高很多倍。(例如一個車,貓,狗3分類的模型識別一張貓的圖片,最后結果是:(cat,99%)
; (dog,0.95%)
;(car,0.05%)
錯誤類別 dog 上的概率仍是錯誤類別 car 的概率的19倍 )
知識蒸餾
這里一個可行的辦法是使用大模型生成的模型類別概率作為“soft targets”(使用蒸餾算法以后的概率,相對應的 head targets 就是正常的原始訓練數(shù)據(jù)集)來訓練小模型,由于 soft targets 包含了更多的信息熵,所以每個訓練樣本都提供給小模型更多的信息用來學習,這樣小模型就只需要用更少的樣本,及更高的學習率去訓練了。
仍然是上面的錯誤分類概率的例子,在 MNIST 數(shù)據(jù)集上訓練的一個大模型基本都能達到 99 % 以上的準確率,假如現(xiàn)在有一個數(shù)字 2 的圖片輸入到大模型中分類,在得到的結果是數(shù)字 3 的概率為 10e-6, 是數(shù)字 7 的概率為 10e-9,這就表示了相比于 7 ,3更接近于 2,這從側面也可以表現(xiàn)數(shù)據(jù)之間的相關性,但是在遷移階段,這樣的概率在交叉熵損失函數(shù)(cross-entropy loss function)只有很小的影響,因為它們的概率都基本為0。 所以這里,本文提出了 “distillation” 的概念, 來軟化上述的結果。
上面的公式就是蒸餾后的 softmax,其中 T 代表 temperature, 蒸餾的溫度。那么 T 有什么作用呢?
假設現(xiàn)在有一個數(shù)組 x=[2,7,10] ,當T = 1,即為正常的 Softmax函數(shù) 輸入上式中可得:
T = 1 ——> y = [0.00032,0.04741,0.95227]
可以理解為上述的一個車,貓,狗3分類網(wǎng)絡,輸入一張貓的圖片,預測為汽車的概率為0.00032, 預測為狗的概率為 0.04741, 預測為貓的概率為 0.95227。
下面再看一下改變 T 的值概率的輸出:
T = 5 ——> y = [0.11532, 0.31348, 0.5712] T = 10 ——> y = [0.20516, 0.33825, 0.45659] T = 20 ——> y = [0.26484, 0.34006, 0.3951]
下面是在(-10,10)之間隨機取多個點然后在 不同的 T 值下繪制的圖像。
?可以看到當 T = 1 是就是常規(guī)的 Softmax,而升溫T,對softmax進行蒸餾,函數(shù)的圖像會變得越來越平滑,這也是文中提高的?soft targets
?的?soft
?一詞來源吧。
假設你是每次都是進行負重登山,雖然過程很辛苦,但是當有一天你取下負重,正常的登山的時候,你就會變得非常輕松,可以比別人登得高登得遠。我們知道對于一個復雜網(wǎng)絡來說往往能夠得到很好的分類效果,錯誤的概率比正確的概率會小很多很多,但是對于一個小網(wǎng)絡來說它是無法學成這個效果的。我們?yōu)榱巳椭【W(wǎng)絡進行學習,就在小網(wǎng)絡的softmax加一個T參數(shù),加上這個T參數(shù)以后錯誤分類再經(jīng)過softmax以后輸出會變大,同樣的正確分類會變小。這就人為的加大了訓練的難度,一旦將T重新設置為1,分類結果會非常的接近于大網(wǎng)絡的分類效果。
最后將小模型在 soft targets
上訓練得到的交叉熵損失函數(shù),加上在真實帶標簽數(shù)據(jù)(hard targets)上訓練得到的交叉熵損失函數(shù)乘以 1/T^2 加在一起作為最后總的損失函數(shù)。這里hard targets 上面乘以一個系數(shù)是因為 soft targets 生成過程中蒸餾后的 softmax 求導會有一個 1/T^2 的系數(shù),為了保持兩個 Loss 所產生的影響接近一樣(各 50%)。
訓練過程
假設這里選取的 T = 10;
Teacher 模型:
( a ) Softmax(T=10)的輸出,生成“Soft targets”
Student 模型:
( a ) 對 Softmax(T = 10)的輸出與Teacher 模型的Softmax(T = 10)的輸出求 Loss1
( b ) 對 Softmax(T = 1)的輸出與原始label 求 Loss2
( c ) Loss = Loss1 + (1/T^2)Loss2
?使用soft target會增加信息量,熵高
發(fā)現(xiàn):T參數(shù)越大,soft target的分布越均勻。因此,我們可以:
- 首先用較大的T值來訓練模型,這時候復雜的神經(jīng)網(wǎng)絡能夠產生更均勻分布(更容易讓小網(wǎng)絡學習)的soft target;
- 之后小規(guī)模的神經(jīng)網(wǎng)絡用相同的T值來學習由大規(guī)模神經(jīng)網(wǎng)絡產生的soft target,接近這個soft target從而學習到數(shù)據(jù)的結構分布特征;
- 最后在實際應用中,將T值恢復到1,讓類別概率偏向正確類別。
在大數(shù)據(jù)集上訓練專家模型
Training ensembles of specialists on very big datasets?
可以用無限大的數(shù)據(jù)集來使用教師網(wǎng)絡訓練學生網(wǎng)絡
- 當數(shù)據(jù)集非常巨大以及模型非常復雜時,訓練多個模型所需要的資源是難以想象的,因此作者提出了一種新的集成模型(ensemble)方法:
- 一個generalist model:使用全部數(shù)據(jù)訓練。
- 多個specialist model(專家模型):對某些容易混淆的類別進行訓練。
- specialist model的訓練集中,一半是由訓練集中包含某些特定類別的子集(special subset)組成,剩下一半是從剩余數(shù)據(jù)集中隨機選取的。
- 這個ensemble的方法中,只有generalist model是使用完整數(shù)據(jù)集訓練的,時間較長,而剩余的所有specialist model由于訓練數(shù)據(jù)相對較少,且相互獨立,可以并行訓練,因此訓練模型的總時間可以節(jié)約很多。
- specialist model由于只使用特定類別的數(shù)據(jù)進行訓練,因此模型對別的類別的判斷能力幾乎為0,導致非常容易過擬合。
- 解決辦法:當 specialist model 通過 hard targets 訓練完成后,再使用由 generalist model 生成的 soft targets 進行微調。這樣做是因為 soft targets 保留了一些對于其他類別數(shù)據(jù)的信息,因此模型可以在原來基礎上學到更多知識,有效避免了過擬合。
?
實現(xiàn)流程:
此部分很有意思,但是不知道具體細節(jié),需要再去看論文。?
?