怎么做購物平臺網(wǎng)站企業(yè)建站
在做深度學(xué)習(xí)項目時,從頭訓(xùn)練一個模型是需要大量時間和算力的,我們通常采用加載預(yù)訓(xùn)練權(quán)重的方法,而我們往往面臨以下幾種情況:
未修改網(wǎng)絡(luò),A與B一致
很簡單,直接.load_state_dict()
net = ANet(num_classses = 5,init_weights=True)
net.to(device)
net.load_state_dict(torch.load('weight/B_weight.pth'))
修改了網(wǎng)絡(luò),A與B不一致
[pytorch官方文檔](Search — PyTorch master documentation):
load_state_dict(state_dict, strict=True)
將 state_dict 中的參數(shù)和緩沖區(qū)復(fù)制到此模塊及其后代中。如果 strict 為 True,則 state_dict 的鍵必須與該模塊的 state_dict() 函數(shù)返回的鍵完全匹配。
state_dict是包含參數(shù)和持久緩沖區(qū)的字典,可以看出 strict默認(rèn)為True,所以默認(rèn)狀態(tài)下是嚴(yán)格要求state_dict中的key與torch.nn.Module.state_dict返回的key完全一致的
load_state_dict()函數(shù)有兩個返回值:
missing_keys 是包含缺失鍵的 str 列表
unexpected_keys 是包含意外鍵的 str 列表
方法一:
將strict改為false,加載鍵值相同的部分。
model = NET2()
state_dict = model.state_dict()
weights = torch.load(weights_path)['model_state_dict'] #讀取預(yù)訓(xùn)練模型權(quán)重
model.load_state_dict(weights, strict=False) #strict
但是此時還存在一種情況:鍵值相同但shape不同,故應(yīng)進行if…in…的判斷:
ANet = torch.load('ANet.pt') # 加載預(yù)訓(xùn)練權(quán)重模型(.pt文件)參數(shù)
#現(xiàn)成的模型的話,如resnet50 = models.resnet50(pretrained=True)
#采用:pretrained_dict = resnet50().state_dict()
model = Model() # 創(chuàng)建模型
model_dict = model.state_dict() # 得到模型的參數(shù)字典# 判斷預(yù)訓(xùn)練模型中網(wǎng)絡(luò)的模塊是否修改后的網(wǎng)絡(luò)中也存在,并且shape相同,如果相同則取出
pretrained_dict = {k: v for k, v in ANet.items() if k in model_dict and (v.shape == model_dict[k].shape)}# 更新修改之后的 model_dict
model_dict.update(pretrained_dict)# 加載我們真正需要的 state_dict
model.load_state_dict(model_dict, strict=False)
方法二:
1.將權(quán)重導(dǎo)入原模型,之后在加載后的原模型基礎(chǔ)上進行修改。
2.修改權(quán)重文件參數(shù),再進行導(dǎo)入
適用于改動不大的模型