保健品網(wǎng)站怎么做的如何找友情鏈接
文章目錄
- 1. 創(chuàng)建數(shù)據(jù)集
- 1.1. 直接繼承Dataset類
- 1.2. 使用TensorDataset類
- 2. 加載數(shù)據(jù)集
- 3. 將數(shù)據(jù)轉(zhuǎn)移到GPU
1. 創(chuàng)建數(shù)據(jù)集
主要是將數(shù)據(jù)集讀入內(nèi)存,并用Dataset類封裝。
1.1. 直接繼承Dataset類
必須要重寫__getitem__方法,用于根據(jù)索引獲得相應(yīng)樣本數(shù)據(jù)。必要時還可以重寫__len__方法,用于返回數(shù)據(jù)集的大小。
from torch.utils.data import Datasetclass BostonHousingDataset(Dataset):"""定義波士頓房價數(shù)據(jù)集"""def __init__(self):self.data = np.load('../dataset/boston_housing/boston_housing.npz')def __getitem__(self, index):return self.data['x'][index], self.data['y'][index]def __len__(self):return self.data['x'].shape[0]
1.2. 使用TensorDataset類
將多個張量組合成一個數(shù)據(jù)集,要保證所有張量的第一個維度相等,保證每批樣本數(shù)據(jù)格式相同。
import torch
from torch.utils.data import TensorDatasetdata = np.load('../dataset/boston_housing/boston_housing.npz')
X = torch.tensor(data['x'])
y = torch.tensor(data['y'])
dataset = TensorDataset(X, y)
2. 加載數(shù)據(jù)集
使用DataLoader類將Dataset封裝的數(shù)據(jù)集分成批次并進(jìn)行迭代,以便于模型訓(xùn)練。DataLoader常用參數(shù)如下:
- dataset
要加載的數(shù)據(jù)集。 - batch_size
每個數(shù)據(jù)批次中包含的樣本數(shù)。默認(rèn)為1。 - shuffle
是否打亂數(shù)據(jù)集。默認(rèn)為False。 - num_workers
使用幾個進(jìn)程來加載數(shù)據(jù)。默認(rèn)為0,即在主進(jìn)程中加載數(shù)據(jù)。 - drop_last
當(dāng)數(shù)據(jù)集樣本數(shù)不能被batch_size整除時,是否舍棄最后一個不完整的batch。默認(rèn)為False。
from torch.utils.data import DataLoaderdataloader = DataLoader(dataset, batch_size=16, shuffle=True)
3. 將數(shù)據(jù)轉(zhuǎn)移到GPU
一般在要運(yùn)算時才將數(shù)據(jù)轉(zhuǎn)移到GPU,有以下兩種方法:
- var.to(device)
- var.cuda()
import torchdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for X,y in dataloader:# 將數(shù)據(jù)轉(zhuǎn)移到GPUX = X.to(device)y = y.to(device)# 也可以X = X.cuda()y = y.cuda()