深圳市多語言網(wǎng)站建設(shè)公司廣州百度關(guān)鍵詞排名
梯度累加
梯度累加(Gradient Accmulation)是一種增大訓(xùn)練時batch size的技巧。當(dāng)batch size在一張卡放不下時,可以將很大的batch size分解為一個個小的mini batch,分別計(jì)算每一個mini batch的梯度,然后將其累加起來優(yōu)化
正常的pytorch訓(xùn)練流程如下(來自知乎)
for i, (image, label) in enumerate(train_loader):pred = model(image) # 1loss = criterion(pred, label) # 2optimizer.zero_grad() # 3loss.backward() # 4optimizer.step() # 5
- 神經(jīng)網(wǎng)絡(luò)forward過程
- 獲取loss,通過pred和label計(jì)算你損失函數(shù)
- 清空網(wǎng)絡(luò)中參數(shù)的梯度
- 反向傳播,計(jì)算當(dāng)前梯度
- 根據(jù)梯度更新網(wǎng)絡(luò)參數(shù)
使用梯度累加的方法如下
for i,(image, label) in enumerate(train_loader):# 1. input outputpred = model(image)loss = criterion(pred, label)# 2.1 loss regularizationloss = loss / accumulation_steps # 2.2 back propagationloss.backward()# 3. update parameters of netif (i+1) % accumulation_steps == 0:# optimizer the netoptimizer.step() # update parameters of netoptimizer.zero_grad() # reset gradient
- 神經(jīng)網(wǎng)絡(luò)forward過程,同時計(jì)算損失函數(shù)
- 反向傳播計(jì)算當(dāng)前梯度(在backward時,計(jì)算的loss要除batch的大小得到均值)
- 不斷重復(fù)1、2步驟,重復(fù)獲取梯度
- 梯度累加到一定次數(shù)后,先optimizer.step()更新網(wǎng)絡(luò)參數(shù),隨后zero_grad()清除梯度,為下一次梯度累加做準(zhǔn)備
DDP中的梯度累加
問題:在DDP中所有卡的梯度all_reduce階段發(fā)生在loss.bachward()階段,也就是說執(zhí)行l(wèi)oss.backward()之后,所有卡的梯度會進(jìn)行一次匯總,但是如果我們?nèi)绻褂锰荻壤奂硬呗?#xff0c;假設(shè)梯度累加K=2,就需要all_reduce匯總兩次,會帶來額外的計(jì)算錯誤和時間開銷
解決方案:知乎寫的很好,這里參考其解決方案,只需要在前K-1次取消梯度同步即可,DDP提供了一個暫時取消梯度同步的context函數(shù)no_sync(),在這個函數(shù)下,DDP不會進(jìn)行梯度同步
model = DDP(model)for 每次梯度累加循環(huán)optimizer.zero_grad()# 前accumulation_step-1個step,不進(jìn)行梯度同步,每張卡分別累積梯度。for _ in range(K-1)::with model.no_sync():prediction = model(data)loss = loss_fn(prediction, label) / Kloss.backward() # 積累梯度,但是多卡之間不進(jìn)行同步# 第K個stepprediction = model(data)loss = loss_fn(prediction, label) / Kloss.backward() # 進(jìn)行多卡之間的梯度同步optimizer.step()
優(yōu)雅寫法
from contextlib import nullcontext
# 如果你的python版本小于3.7,請注釋掉上面一行,使用下面這個:
# from contextlib import suppress as nullcontextif local_rank != -1:model = DDP(model)optimizer.zero_grad()
for i, (data, label) in enumerate(dataloader):# 只在DDP模式下,輪數(shù)不是K整數(shù)倍的時候使用no_syncmy_context = model.no_sync if local_rank != -1 and i % K != 0 else nullcontextwith my_context():prediction = model(data)loss = loss_fn(prediction, label) / Kloss.backward() # 積累梯度,不應(yīng)用梯度改變if i % K == 0:optimizer.step()optimizer.zero_grad()
梯度累加的影響
BN的影響