吉林市網(wǎng)站制作廣州網(wǎng)站制作服務(wù)
Pytorch多GPU訓(xùn)練模型保存和加載
在多GPU訓(xùn)練中,模型通常被包裝在torch.nn.DataParallel或torch.nn.parallel.DistributedDataParallel中,這會(huì)在模型的參數(shù)名前加上module前綴。因此,在保存模型時(shí),需要使用model.module.state_dict()來獲取模型的狀態(tài)字典,以確保保存的參數(shù)名與模型定義中的參數(shù)名一致。(本質(zhì)上原來的model還是存在的,參數(shù)也會(huì)同步更新)
-
多GPU訓(xùn)練模型保存
在多GPU訓(xùn)練時(shí),模型通常被包裝在torch.nn.DataParallel或torch.nn.parallel.DistributedDataParallel中,這會(huì)在模型的參數(shù)名前加上module前綴。因此,在保存模型時(shí),需要使用model.module.state_dict()來獲取模型的狀態(tài)字典,以確保保存的參數(shù)名與模型定義中的參數(shù)名一致。 -
單GPU或CPU加載模型
當(dāng)在單GPU或CPU上加載模型時(shí),如果直接使用model.state_dict()保存的模型,由于缺少module前綴,會(huì)導(dǎo)致參數(shù)名不匹配,從而無法正確加載模型。因此,在保存多GPU訓(xùn)練的模型時(shí),應(yīng)該使用model.module.state_dict()來保存模型的狀態(tài)字典,這樣在單GPU或CPU上加載模型時(shí),可以直接加載,不會(huì)出現(xiàn)參數(shù)名不匹配的問題。 -
示例代碼
以下是一個(gè)示例代碼,展示了如何在多GPU訓(xùn)練時(shí)保存模型,并在單GPU或CPU上加載模型:
import torch
import torch.nn as nn
import os
os.os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" #設(shè)置GPU編號
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 假設(shè)這是你的模型定義
class YourModel(nn.Module):def __init__(self):super(YourModel, self).__init__()self.fc = nn.Linear(10, 2)def forward(self, x):return self.fc(x)# 創(chuàng)建模型實(shí)例
model = YourModel()# 將模型移動(dòng)到多GPU上
if torch.cuda.device_count() > 1:model = nn.DataParallel(model)model = model.to(device)
else:model = model.to(device)
······
# 假設(shè)這是你的訓(xùn)練代碼,訓(xùn)練完成后保存模型
if torch.cuda.device_count() > 1:torch.save(model.module.state_dict(), 'model.pth')
else:torch.save(model.state_dict(), 'model.pth')# 在單、多GPU或CPU上加載模型
model = YourModel()
if torch.cuda.device_count() > 1:model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load('model.pth'))
model = model.to(device)
2 在多GPU訓(xùn)練得到的模型加載時(shí),通常需要考慮以下幾個(gè)步驟:
- 模型保存
在多GPU訓(xùn)練時(shí),模型通常被包裝在torch.nn.DataParallel或torch.nn.parallel.DistributedDataParallel中。因此,在保存模型時(shí),需要確保保存的是模型的state_dict而不是整個(gè)模型對象。例如:
if torch.cuda.device_count() > 1:torch.save(model.module.state_dict(), 'model.pth')
else:torch.save(model.state_dict(), 'model.pth')
- 模型加載
在加載模型時(shí),首先需要?jiǎng)?chuàng)建模型的實(shí)例,然后使用load_state_dict方法來加載保存的權(quán)重。如果模型是在多GPU環(huán)境下訓(xùn)練的,那么在加載時(shí)也應(yīng)該使用torch.nn.DataParallel或torch.nn.parallel.DistributedDataParallel來包裝模型。例如:
model = YourModel()
if torch.cuda.device_count() > 1:model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load('model.pth'))
model = model.to('cuda')
- 注意事項(xiàng)
在加載模型時(shí),需要注意以下幾點(diǎn):
如果模型是在多GPU環(huán)境下訓(xùn)練的,那么在加載時(shí)也應(yīng)該使用相同數(shù)量的GPU,或者使用torch.nn.DataParallel來包裝模型,即使只有一個(gè)GPU可用。
如果模型是在分布式訓(xùn)練環(huán)境下訓(xùn)練的,那么在加載時(shí)也應(yīng)該使用torch.nn.parallel.DistributedDataParallel來包裝模型。
如果模型是在混合精度訓(xùn)練(如使用了torch.cuda.amp)下訓(xùn)練的,那么在加載模型后,應(yīng)該恢復(fù)之前的精度設(shè)置。
3 為了避免模型保存和加載出錯(cuò)
在多GPU訓(xùn)練的模型使用了torch.nn.DataParallel來包裝模型,但本質(zhì)上原來的model是依然存在的,且參數(shù)會(huì)同步更新:
- torch.nn.DataParallel 的工作原理
torch.nn.DataParallel 是 PyTorch 提供的一個(gè)類,用于在多個(gè) GPU 上并行訓(xùn)練模型。它的工作原理如下:
模型復(fù)制:DataParallel 會(huì)在每個(gè) GPU 上創(chuàng)建模型的副本。
數(shù)據(jù)分發(fā):輸入數(shù)據(jù)會(huì)被分發(fā)到各個(gè) GPU 上。
前向傳播:每個(gè) GPU 上的模型副本會(huì)獨(dú)立進(jìn)行前向傳播計(jì)算。
梯度收集:所有 GPU 上的梯度會(huì)被收集并匯總到主 GPU 上。
參數(shù)更新:主 GPU 上的優(yōu)化器會(huì)根據(jù)匯總后的梯度更新模型參數(shù),然后將更新后的參數(shù)同步回其他 GPU。 - 模型參數(shù)更新
當(dāng)你使用 model_train = torch.nn.DataParallel(model) 后,model_train 實(shí)際上是一個(gè)包裝了原始模型 model 的對象。雖然 model_train 是多GPU并行的版本,但它的參數(shù)更新是通過主 GPU 上的優(yōu)化器完成的,并且這些更新會(huì)同步回原始模型 model。
因此,model 的參數(shù)確實(shí)會(huì)被更新。具體來說:
前向傳播和反向傳播:在 train_model 函數(shù)中,model_train 用于前向傳播和反向傳播。
參數(shù)更新:優(yōu)化器 optimizer 使用的是 model.parameters(),即原始模型的參數(shù)。在每次迭代中,優(yōu)化器會(huì)根據(jù)匯總后的梯度更新這些參數(shù)。
參數(shù)同步:更新后的參數(shù)會(huì)自動(dòng)同步到 model_train 中的各個(gè) GPU 副本。
因此可以使用如下代碼,加載模型和保存模型:
import torch
import torch.nn as nn
import os
os.os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" #設(shè)置GPU編號
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 假設(shè)這是你的模型定義
class YourModel(nn.Module):def __init__(self):super(YourModel, self).__init__()self.fc = nn.Linear(10, 2)def forward(self, x):return self.fc(x)# 創(chuàng)建模型實(shí)例
model = YourModel()# 將模型移動(dòng)到多GPU上,單GPU依然適用
if torch.cuda.device_count() > 1:model_train = nn.DataParallel(model)model_train = model_train.to(device)
else:model_train = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)#注意這是model的參數(shù)
······
output = model_train(input) # 多卡時(shí)訓(xùn)練的輸入和輸出,注意這是model_train# 假設(shè)這是你的訓(xùn)練代碼,訓(xùn)練完成后保存模型
torch.save(model.state_dict(), 'model.pth') #注意這是model
- 再在單/多GPU或CPU上加載模型,都不會(huì)報(bào)錯(cuò),因?yàn)檫@里的model不是包裝體,不帶module
model = YourModel()
if torch.cuda.device_count() > 1:model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load('model.pth',map_location = device))
model = model.to(device)