做網站營銷蘭州seo外包公司
上篇文章主要介紹了hook鉤子函數的大致使用流程,本篇文章主要介紹pytorch中的hook機制register_forward_hook,手動在forward之前注冊hook,hook在forward執(zhí)行以后被自動執(zhí)行。
1、hook背景
Hook被成為鉤子機制,pytorch中包含forward和backward兩個鉤子注冊函數,用于獲取forward和backward中輸入和輸出,按照自己不全面的理解,應該目的是“不改變網絡的定義代碼,也不需要在forward函數中return某個感興趣層的輸出,這樣代碼太冗雜了”。
2、源碼閱讀
register_forward_hook()函數必須在forward()函數調用之前被使用,因為該函數源碼注釋顯示這個函數“ it will not have effect on forward since this is called after :func:forward
is called”,也就是這個函數在forward()之后就沒有作用了!):
作用:獲取forward過程中每層的輸入和輸出,用于對比hook是不是正確記錄。
def register_forward_hook(self, hook):r"""Registers a forward hook on the module.The hook will be called every time after :func:`forward` has computed an output.It should have the following signature::hook(module, input, output) -> None or modified outputThe hook can modify the output. It can modify the input inplace butit will not have effect on forward since this is called after:func:`forward` is called.Returns::class:`torch.utils.hooks.RemovableHandle`:a handle that can be used to remove the added hook by calling``handle.remove()``"""handle = hooks.RemovableHandle(self._forward_hooks)self._forward_hooks[handle.id] = hookreturn handle
3、定義一個用于測試hook的類
如果隨機的初始化每個層,那么就無法測試出自己獲取的輸入輸出是不是forward中的輸入輸出了,所以需要將每一層的權重和偏置設置為可識別的值(比如全部初始化為1)。網絡包含兩層(Linear有需要求導的參數被稱為一個層,而ReLU沒有需要求導的參數不被稱作一層),init()中調用initialize函數對所有層進行初始化。
**注意:**在forward()函數返回各個層的輸出,但是ReLU6沒有返回,因為后續(xù)測試的時候不對這一層進行注冊hook。
class TestForHook(nn.Module):def __init__(self):super().__init__()self.linear_1 = nn.Linear(in_features=2, out_features=2)self.linear_2 = nn.Linear(in_features=2, out_features=1)self.relu = nn.ReLU()self.relu6 = nn.ReLU6()self.initialize()def forward(self, x):linear_1 = self.linear_1(x)linear_2 = self.linear_2(linear_1)relu = self.relu(linear_2)relu_6 = self.relu6(relu)layers_in = (x, linear_1, linear_2)layers_out = (linear_1, linear_2, relu)return relu_6, layers_in, layers_outdef initialize(self):""" 定義特殊的初始化,用于驗證是不是獲取了權重"""self.linear_1.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1], [1, 1]]))self.linear_1.bias = torch.nn.Parameter(torch.FloatTensor([1, 1]))self.linear_2.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1]]))self.linear_2.bias = torch.nn.Parameter(torch.FloatTensor([1]))return True
4、定義hook函數
hook()函數是register_forward_hook()函數必須提供的參數,首先定義幾個容器用于記錄:
定義用于獲取網絡各層輸入輸出tensor的容器:
# 同時定義module_name用于記錄相應的module名字
module_name = []
features_in_hook = []
features_out_hook = []
hook函數需要三個參數,這三個參數是系統(tǒng)傳給hook函數的,自己不能修改這三個參數:
hook函數負責將獲取的輸入輸出添加到feature列表中;并提供相應的module名字。
def hook(module, fea_in, fea_out):print("hooker working")module_name.append(module.__class__)features_in_hook.append(fea_in)features_out_hook.append(fea_out)return None
5、對需要的層注冊hook
注冊鉤子必須在forward()函數被執(zhí)行之前,也就是定義網絡進行計算之前就要注冊,下面的代碼對網絡除去ReLU6以外的層都進行了注冊(也可以選定某些層進行注冊):
注冊鉤子可以對某些層單獨進行:
net = TestForHook()
net_chilren = net.children()
for child in net_chilren:if not isinstance(child, nn.ReLU6):child.register_forward_hook(hook=hook)
6、測試forward()返回的特征和hook記錄的是否一致
6.1 測試forward()提供的輸入輸出特征
由于前面的forward()函數返回了需要記錄的特征,這里可以直接測試:
out, features_in_forward, features_out_forward = net(x)
print("*"*5+"forward return features"+"*"*5)
print(features_in_forward)
print(features_out_forward)
print("*"*5+"forward return features"+"*"*5)
輸出如下:
*****forward return features*****
(tensor([[0.1000, 0.1000],[0.1000, 0.1000]]), tensor([[1.2000, 1.2000],[1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],[3.4000]], grad_fn=<AddmmBackward>))
(tensor([[1.2000, 1.2000],[1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],[3.4000]], grad_fn=<AddmmBackward>), tensor([[3.4000],[3.4000]], grad_fn=<ThresholdBackward0>))
*****forward return features*****
6.2 hook記錄的輸入特征和輸出特征
hook通過list結構進行記錄,所以可以直接print。
測試features_in是否存儲了輸入:
print("*"*5+"hook record features"+"*"*5)
print(features_in_hook)
print(features_out_hook)
print(module_name)
print("*"*5+"hook record features"+"*"*5)
得到和forward一樣的結果:
*****hook record features*****
[(tensor([[0.1000, 0.1000],[0.1000, 0.1000]]),), (tensor([[1.2000, 1.2000],[1.2000, 1.2000]], grad_fn=<AddmmBackward>),), (tensor([[3.4000],[3.4000]], grad_fn=<AddmmBackward>),)]
[tensor([[1.2000, 1.2000],[1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],[3.4000]], grad_fn=<AddmmBackward>), tensor([[3.4000],[3.4000]], grad_fn=<ThresholdBackward0>)]
[<class 'torch.nn.modules.linear.Linear'>,
<class 'torch.nn.modules.linear.Linear'>,<class 'torch.nn.modules.activation.ReLU'>]
*****hook record features*****
6.3 把hook記錄的和forward做減法
如果害怕會有小數點后面的數值不一致,或者數據類型的不匹配,可以對hook記錄的特征和forward記錄的特征做減法:
測試forward返回的feautes_in是不是和hook記錄的一致:
print("sub result'")
for forward_return, hook_record in zip(features_in_forward, features_in_hook):print(forward_return-hook_record[0])
得到的全部都是0,說明hook沒問題:
sub result
tensor([[0., 0.],[0., 0.]])
tensor([[0., 0.],[0., 0.]], grad_fn=<SubBackward0>)
tensor([[0.],[0.]], grad_fn=<SubBackward0>)
7、完整代碼
import torch
import torch.nn as nnclass TestForHook(nn.Module):def __init__(self):super().__init__()self.linear_1 = nn.Linear(in_features=2, out_features=2)self.linear_2 = nn.Linear(in_features=2, out_features=1)self.relu = nn.ReLU()self.relu6 = nn.ReLU6()self.initialize()def forward(self, x):linear_1 = self.linear_1(x)linear_2 = self.linear_2(linear_1)relu = self.relu(linear_2)relu_6 = self.relu6(relu)layers_in = (x, linear_1, linear_2)layers_out = (linear_1, linear_2, relu)return relu_6, layers_in, layers_outdef initialize(self):""" 定義特殊的初始化,用于驗證是不是獲取了權重"""self.linear_1.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1], [1, 1]]))self.linear_1.bias = torch.nn.Parameter(torch.FloatTensor([1, 1]))self.linear_2.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1]]))self.linear_2.bias = torch.nn.Parameter(torch.FloatTensor([1]))return True# 定義用于獲取網絡各層輸入輸出tensor的容器,并定義module_name用于記錄相應的module名字
module_name = []
features_in_hook = []
features_out_hook = []# hook函數負責將獲取的輸入輸出添加到feature列表中,并提供相應的module名字
def hook(module, fea_in, fea_out):print("hooker working")module_name.append(module.__class__)features_in_hook.append(fea_in)features_out_hook.append(fea_out)return None# 定義全部是1的輸入:
x = torch.FloatTensor([[0.1, 0.1], [0.1, 0.1]])# 注冊鉤子可以對某些層單獨進行:
net = TestForHook()
net_chilren = net.children()
for child in net_chilren:if not isinstance(child, nn.ReLU6):child.register_forward_hook(hook=hook)# 測試網絡輸出:
out, features_in_forward, features_out_forward = net(x)
print("*"*5+"forward return features"+"*"*5)
print(features_in_forward)
print(features_out_forward)
print("*"*5+"forward return features"+"*"*5)# 測試features_in是不是存儲了輸入:
print("*"*5+"hook record features"+"*"*5)
print(features_in_hook)
print(features_out_hook)
print(module_name)
print("*"*5+"hook record features"+"*"*5)# 測試forward返回的feautes_in是不是和hook記錄的一致:
print("sub result")
for forward_return, hook_record in zip(features_in_forward, features_in_hook):print(forward_return-hook_record[0])