邢臺(tái)商城類網(wǎng)站建設(shè)企業(yè)qq郵箱
反向傳播實(shí)際上就是在算各個(gè)階段梯度,每層的傳入實(shí)際是之前各層根據(jù)鏈?zhǔn)椒▌t梯度相乘的結(jié)果。反向傳播最初傳入的Δout是1,Δ通常表示很少量的意思,Δout=1的時(shí)候這樣在反向傳播的時(shí)候算出來的dw和dx剛好就是當(dāng)前梯度。深度神經(jīng)網(wǎng)絡(luò)中每層都會(huì)記錄正向傳播時(shí)該層傳入的x,就是為了反向傳播的時(shí)候計(jì)算dw的時(shí)候用到。反向傳播的時(shí)候也會(huì)利用w計(jì)算出dx來作為下一層的反向傳播的輸入。反向傳播時(shí)每層的輸入都是前幾層梯度相乘的結(jié)果(鏈?zhǔn)椒▌t),每層的輸出也應(yīng)該是本層梯度乘以輸入的結(jié)果(鏈?zhǔn)椒▌t),需要注意的是計(jì)算MatMul節(jié)點(diǎn)的反向傳播時(shí)要注意矩陣形狀,所以需要矩陣轉(zhuǎn)置。反向傳播計(jì)算的各種梯度就是為了梯度下降做準(zhǔn)備工作。
梯度下降的時(shí)候代碼如下:
class SGD:
def __init__(self, lr=0.01):
self.lr = lr
def update(self, params, grads):
for i in range(len(params)):
params[i] -= self.lr * grads[i]
params 是每層神經(jīng)網(wǎng)絡(luò)的w和b,grads 對應(yīng)的是各層參數(shù)的梯度。 params[i] -= self.lr * grads[i] 表示每層的w和b都要梯度下降,這是因?yàn)榉聪騻鞑サ臅r(shí)候,每層的梯度都是損失函數(shù)f_loss(x)對x的導(dǎo)數(shù)的一部分,根據(jù)鏈?zhǔn)椒▌t,因?yàn)殒準(zhǔn)椒▌t是相乘關(guān)系所以每個(gè)因子梯度下降總體梯度也是下降的。
代碼來源《深度學(xué)習(xí)進(jìn)階-自然語言處理》第一章