廈門哪里做網(wǎng)站網(wǎng)上競價(jià)平臺(tái)
相關(guān)說明
這篇文章的大部分內(nèi)容參考自我的新書《解構(gòu)大語言模型:從線性回歸到通用人工智能》,歡迎有興趣的讀者多多支持。
本文將討論如何利用梯度檢查點(diǎn)算法來減少模型在訓(xùn)練時(shí)候(更準(zhǔn)確地說是運(yùn)行反向傳播算法時(shí))的內(nèi)存開支。這在訓(xùn)練超大規(guī)模的模型時(shí)會(huì)用到。
關(guān)于其他的工程技巧可以參考:
- 大語言模型的工程技巧(一)——GPU計(jì)算
- 大語言模型的工程技巧(二)——混合精度訓(xùn)練
- 大語言模型的工程技巧(三)——分布式計(jì)算
關(guān)于大語言模型的討論請參考:
- 理解大語言模型(二)——從零開始實(shí)現(xiàn)GPT-2
內(nèi)容大綱
- 相關(guān)說明
- 一、標(biāo)準(zhǔn)反向傳播
- 二、內(nèi)存極簡算法
- 三、梯度檢查點(diǎn)
一、標(biāo)準(zhǔn)反向傳播
根據(jù)梯度的定義,變量的梯度與其本身的值密切相關(guān)。因此,要想得到某個(gè)變量的梯度,必須先知道這個(gè)變量的值。這也是為什么在進(jìn)行反向傳播算法之前,需要先對計(jì)算圖進(jìn)行向前傳播,并記錄每個(gè)節(jié)點(diǎn)的計(jì)算結(jié)果,如圖1左側(cè)部分所示。這樣在計(jì)算節(jié)點(diǎn)的梯度時(shí),可以利用這些事先緩存的結(jié)果,直接啟動(dòng)反向傳播過程,從而得到梯度,如圖1中的節(jié)點(diǎn)d所示。這種方法也被稱為標(biāo)準(zhǔn)反向傳播。這種方式能夠確保梯度計(jì)算以最高效的方式進(jìn)行。
二、內(nèi)存極簡算法
然而,采用標(biāo)準(zhǔn)反向傳播算法會(huì)造成較大的內(nèi)存開銷。為了在計(jì)算過程中盡可能地壓縮內(nèi)存使用,可以采用一種以時(shí)間換空間的方法。在這種算法中,一旦向前傳播完成,僅會(huì)保留頂點(diǎn)的計(jì)算結(jié)果,而中間節(jié)點(diǎn)的結(jié)果會(huì)被清空(葉子節(jié)點(diǎn)的值會(huì)保留)。在反向傳播遇到中間計(jì)算節(jié)點(diǎn)沒有緩存時(shí),則重新觸發(fā)向前傳播,以獲取所需節(jié)點(diǎn)的結(jié)果。這就是內(nèi)存極簡的反向傳播算法。以節(jié)點(diǎn)d為例,為了計(jì)算其梯度,需要首先從節(jié)點(diǎn)a開始重新觸發(fā)向前傳播直到節(jié)點(diǎn)d,并緩存計(jì)算結(jié)果。然后使用這個(gè)緩存的結(jié)果以及節(jié)點(diǎn)e的梯度,計(jì)算出節(jié)點(diǎn)d的梯度。對于其他節(jié)點(diǎn),也采用類似的步驟計(jì)算梯度。通過這種方式,在完成反向傳播的同時(shí),節(jié)省了內(nèi)存開銷。以圖1為例,內(nèi)存極簡算法只需要3個(gè)存儲(chǔ)空間,而標(biāo)準(zhǔn)算法需要5個(gè)存儲(chǔ)空間。
三、梯度檢查點(diǎn)
盡管內(nèi)存極簡算法在降低內(nèi)存開銷方面取得了顯著成果,但它涉及大量的重復(fù)計(jì)算,運(yùn)行時(shí)間相對較長。為了在內(nèi)存使用和運(yùn)行時(shí)間之間取得平衡,下面引入梯度檢查點(diǎn)(Gradient Checkpoint)。這一算法的核心思想是選擇一些中間節(jié)點(diǎn)作為存儲(chǔ)點(diǎn),以便在再次觸發(fā)向前傳播時(shí),以這些存儲(chǔ)點(diǎn)作為起點(diǎn)開始傳播,避免從頭開始重復(fù)計(jì)算。這種方式在一定程度上減少重復(fù)計(jì)算,從而提高運(yùn)行效率。需要注意的是,由于需要存儲(chǔ)額外的中間結(jié)果,梯度檢查點(diǎn)會(huì)稍微增加一些內(nèi)存開銷。
關(guān)于梯度檢查點(diǎn)算法,PyTorch中已經(jīng)提供了便捷的封裝函數(shù),即torch.utils.checkpoint。這個(gè)工具能夠幫助我們更方便地應(yīng)用梯度檢查點(diǎn)算法,以平衡內(nèi)存開鎖和運(yùn)行時(shí)間。更多細(xì)節(jié)請參考這個(gè)鏈接。