吉林市網(wǎng)站制作視頻剪輯培訓
Pytorch多GPU訓練模型保存和加載
在多GPU訓練中,模型通常被包裝在torch.nn.DataParallel或torch.nn.parallel.DistributedDataParallel中,這會在模型的參數(shù)名前加上module前綴。因此,在保存模型時,需要使用model.module.state_dict()來獲取模型的狀態(tài)字典,以確保保存的參數(shù)名與模型定義中的參數(shù)名一致。(本質(zhì)上原來的model還是存在的,參數(shù)也會同步更新)
-
多GPU訓練模型保存
在多GPU訓練時,模型通常被包裝在torch.nn.DataParallel或torch.nn.parallel.DistributedDataParallel中,這會在模型的參數(shù)名前加上module前綴。因此,在保存模型時,需要使用model.module.state_dict()來獲取模型的狀態(tài)字典,以確保保存的參數(shù)名與模型定義中的參數(shù)名一致。 -
單GPU或CPU加載模型
當在單GPU或CPU上加載模型時,如果直接使用model.state_dict()保存的模型,由于缺少module前綴,會導致參數(shù)名不匹配,從而無法正確加載模型。因此,在保存多GPU訓練的模型時,應該使用model.module.state_dict()來保存模型的狀態(tài)字典,這樣在單GPU或CPU上加載模型時,可以直接加載,不會出現(xiàn)參數(shù)名不匹配的問題。 -
示例代碼
以下是一個示例代碼,展示了如何在多GPU訓練時保存模型,并在單GPU或CPU上加載模型:
import torch
import torch.nn as nn
import os
os.os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" #設(shè)置GPU編號
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 假設(shè)這是你的模型定義
class YourModel(nn.Module):def __init__(self):super(YourModel, self).__init__()self.fc = nn.Linear(10, 2)def forward(self, x):return self.fc(x)# 創(chuàng)建模型實例
model = YourModel()# 將模型移動到多GPU上
if torch.cuda.device_count() > 1:model = nn.DataParallel(model)model = model.to(device)
else:model = model.to(device)
······
# 假設(shè)這是你的訓練代碼,訓練完成后保存模型
if torch.cuda.device_count() > 1:torch.save(model.module.state_dict(), 'model.pth')
else:torch.save(model.state_dict(), 'model.pth')# 在單、多GPU或CPU上加載模型
model = YourModel()
if torch.cuda.device_count() > 1:model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load('model.pth'))
model = model.to(device)
2 在多GPU訓練得到的模型加載時,通常需要考慮以下幾個步驟:
- 模型保存
在多GPU訓練時,模型通常被包裝在torch.nn.DataParallel或torch.nn.parallel.DistributedDataParallel中。因此,在保存模型時,需要確保保存的是模型的state_dict而不是整個模型對象。例如:
if torch.cuda.device_count() > 1:torch.save(model.module.state_dict(), 'model.pth')
else:torch.save(model.state_dict(), 'model.pth')
- 模型加載
在加載模型時,首先需要創(chuàng)建模型的實例,然后使用load_state_dict方法來加載保存的權(quán)重。如果模型是在多GPU環(huán)境下訓練的,那么在加載時也應該使用torch.nn.DataParallel或torch.nn.parallel.DistributedDataParallel來包裝模型。例如:
model = YourModel()
if torch.cuda.device_count() > 1:model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load('model.pth'))
model = model.to('cuda')
- 注意事項
在加載模型時,需要注意以下幾點:
如果模型是在多GPU環(huán)境下訓練的,那么在加載時也應該使用相同數(shù)量的GPU,或者使用torch.nn.DataParallel來包裝模型,即使只有一個GPU可用。
如果模型是在分布式訓練環(huán)境下訓練的,那么在加載時也應該使用torch.nn.parallel.DistributedDataParallel來包裝模型。
如果模型是在混合精度訓練(如使用了torch.cuda.amp)下訓練的,那么在加載模型后,應該恢復之前的精度設(shè)置。
3 為了避免模型保存和加載出錯
在多GPU訓練的模型使用了torch.nn.DataParallel來包裝模型,但本質(zhì)上原來的model是依然存在的,且參數(shù)會同步更新:
- torch.nn.DataParallel 的工作原理
torch.nn.DataParallel 是 PyTorch 提供的一個類,用于在多個 GPU 上并行訓練模型。它的工作原理如下:
模型復制:DataParallel 會在每個 GPU 上創(chuàng)建模型的副本。
數(shù)據(jù)分發(fā):輸入數(shù)據(jù)會被分發(fā)到各個 GPU 上。
前向傳播:每個 GPU 上的模型副本會獨立進行前向傳播計算。
梯度收集:所有 GPU 上的梯度會被收集并匯總到主 GPU 上。
參數(shù)更新:主 GPU 上的優(yōu)化器會根據(jù)匯總后的梯度更新模型參數(shù),然后將更新后的參數(shù)同步回其他 GPU。 - 模型參數(shù)更新
當你使用 model_train = torch.nn.DataParallel(model) 后,model_train 實際上是一個包裝了原始模型 model 的對象。雖然 model_train 是多GPU并行的版本,但它的參數(shù)更新是通過主 GPU 上的優(yōu)化器完成的,并且這些更新會同步回原始模型 model。
因此,model 的參數(shù)確實會被更新。具體來說:
前向傳播和反向傳播:在 train_model 函數(shù)中,model_train 用于前向傳播和反向傳播。
參數(shù)更新:優(yōu)化器 optimizer 使用的是 model.parameters(),即原始模型的參數(shù)。在每次迭代中,優(yōu)化器會根據(jù)匯總后的梯度更新這些參數(shù)。
參數(shù)同步:更新后的參數(shù)會自動同步到 model_train 中的各個 GPU 副本。
因此可以使用如下代碼,加載模型和保存模型:
import torch
import torch.nn as nn
import os
os.os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" #設(shè)置GPU編號
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 假設(shè)這是你的模型定義
class YourModel(nn.Module):def __init__(self):super(YourModel, self).__init__()self.fc = nn.Linear(10, 2)def forward(self, x):return self.fc(x)# 創(chuàng)建模型實例
model = YourModel()# 將模型移動到多GPU上,單GPU依然適用
if torch.cuda.device_count() > 1:model_train = nn.DataParallel(model)model_train = model_train.to(device)
else:model_train = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)#注意這是model的參數(shù)
······
output = model_train(input) # 多卡時訓練的輸入和輸出,注意這是model_train# 假設(shè)這是你的訓練代碼,訓練完成后保存模型
torch.save(model.state_dict(), 'model.pth') #注意這是model
- 再在單/多GPU或CPU上加載模型,都不會報錯,因為這里的model不是包裝體,不帶module
model = YourModel()
if torch.cuda.device_count() > 1:model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load('model.pth',map_location = device))
model = model.to(device)