中國(guó)空間站照片小學(xué)生抄寫(xiě)新聞20字
《PyTorch 2.5重磅更新:性能優(yōu)化+新特性》中的一個(gè)新特性就是:正式支持在英特爾?獨(dú)立顯卡上訓(xùn)練模型!
PyTorch?2.5 | |
獨(dú)立顯卡類型 | 支持的操作系統(tǒng) |
?Intel??數(shù)據(jù)中心GPU?Max系列 | Linux |
Intel??Arc?系列 | Linux/Windows |
本文將在Intel?Core? Ultra 7 155H自帶的Arc?集成顯卡上展示使用Pytorch2.5搭建并訓(xùn)練AI模型的全流程。
一,搭建開(kāi)發(fā)環(huán)境
首先,請(qǐng)安裝顯卡驅(qū)動(dòng),參考指南:
https://dgpu-docs.intel.com/driver/client/overview.html
然后,請(qǐng)下載并安裝Anaconda,鏈接↓↓
https://www.anaconda.com/download
并用下面的命令創(chuàng)建并激活名為pytorch_arc的虛擬環(huán)境:
conda create -n pytorch_arc python=3.11 #創(chuàng)建虛擬環(huán)境
conda activate pytorch_arc #激活虛擬環(huán)境
python?-m?pip?install?--upgrade?pip????????#升級(jí)pip到最新版本
接著,安裝Pytorch XPU版;
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu
滑動(dòng)查看更多
最后,執(zhí)行命令,驗(yàn)證安裝。看到返回結(jié)果為“True”,證明環(huán)境搭建成功!
>>> import torch
>>> torch.xpu.is_available()
二,訓(xùn)練ResNet模型
執(zhí)行下載的訓(xùn)練代碼,實(shí)現(xiàn)在Intel??Arc?集成顯卡上訓(xùn)練ResNet50模型。代碼下載鏈接:
?
import torch
import torchvision
?
LR = 0.001
DOWNLOAD = True
DATA = "datasets/cifar10/"
?
transform = torchvision.transforms.Compose(
[
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
train_dataset = torchvision.datasets.CIFAR10(
root=DATA,
train=True,
transform=transform,
download=DOWNLOAD,
)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128)
train_len = len(train_loader)
?
model = torchvision.models.resnet50()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9)
model.train()
model = model.to("xpu")
criterion = criterion.to("xpu")
?
print(f"Initiating training")
for batch_idx, (data, target) in enumerate(train_loader):
data = data.to("xpu")
target = target.to("xpu")
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if (batch_idx + 1) % 10 == 0:
iteration_loss = loss.item()
print(f"Iteration [{batch_idx+1}/{train_len}], Loss: {iteration_loss:.4f}")
torch.save(
{
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
},
"checkpoint.pth",
)
?
print("Execution finished")
三,總結(jié)
使用PyTorch在英特爾獨(dú)立顯卡上訓(xùn)練模型將為AI行業(yè)新增計(jì)算硬件選擇!