丹陽網(wǎng)站設(shè)計網(wǎng)站市場推廣
殘差連接和層規(guī)范化
層規(guī)范化和批量規(guī)范化的目標相同,但層規(guī)范化是基于特征維度進行規(guī)范化。盡管批量規(guī)范化在計算機視覺中被廣泛應(yīng)用,但在自然語言處理任務(wù)中(輸入通常是變長序列)批量規(guī)范化通常不如層規(guī)范化的效果好。
以下代碼對比不同維度的層規(guī)范化和批量規(guī)范化的效果。
ln = nn.LayerNorm(2)
bn = nn.BatchNorm1d(2)
X = torch.tensor([[1, 2], [2, 3]], dtype=torch.float32)
# 在訓練模式下計算X的均值和方差
print('layer norm:', ln(X), '\nbatch norm:', bn(X))
layer norm: tensor([[-1.0000, 1.0000],[-1.0000, 1.0000]], grad_fn=<NativeLayerNormBackward0>) batch norm: tensor([[-1.0000, -1.0000],[ 1.0000, 1.0000]], grad_fn=<NativeBatchNormBackward0>)
現(xiàn)在可以使用殘差連接和層規(guī)范化來實現(xiàn)AddNorm
類。暫退法也被作為正則化方法使用。
#@save
class AddNorm(nn.Module):"""殘差連接后進行層規(guī)范化"""def __init__(self, normalized_shape, dropout, **kwargs):super(AddNorm, self).__init__(**kwargs)self.dropout = nn.Dropout(dropout)self.ln = nn.LayerNorm(normalized_shape)def forward(self, X, Y):return self.ln(self.dropout(Y) + X)
殘差連接要求兩個輸入的形狀相同,以便加法操作后輸出張量的形狀相同。?
add_norm = AddNorm([3, 4], 0.5)
add_norm.eval()
add_norm(torch.ones((2, 3, 4)), torch.ones((2, 3, 4))).shape
?torch.Size([2, 3, 4])
?
?
?