動態(tài)網(wǎng)站開發(fā)商城網(wǎng)站seo百度網(wǎng)站排名軟件
PyTorch深度學(xué)習(xí)實(shí)戰(zhàn)——圖像著色
- 0. 前言
- 1. 模型與數(shù)據(jù)集分析
- 1.1 數(shù)據(jù)集介紹
- 1.2 模型策略
- 2. 實(shí)現(xiàn)圖像著色
- 相關(guān)鏈接
0. 前言
圖像著色指的是將黑白或灰度圖像轉(zhuǎn)換為彩色圖像的過程,傳統(tǒng)的圖像處理技術(shù)通?;谥狈綀D匹配和顏色傳遞的方法或基于用戶交互的方法等完成圖像著色操作,不但耗時(shí)且需要專業(yè)知識,而基于深度學(xué)習(xí)的方法能夠?qū)崿F(xiàn)自動著色,極大的提高了效率。在訓(xùn)練圖著色模型時(shí),我們可以將原始圖像轉(zhuǎn)換為黑白圖像作為網(wǎng)絡(luò)輸入,原始彩色圖像作為輸出。
1. 模型與數(shù)據(jù)集分析
在本節(jié)中,我們將利用 CIFAR-10
數(shù)據(jù)集執(zhí)行圖像著色。
1.1 數(shù)據(jù)集介紹
CIFAR-10
數(shù)據(jù)集是一個(gè)廣泛應(yīng)用于計(jì)算機(jī)視覺領(lǐng)域的圖像分類數(shù)據(jù)集。它由 10
個(gè)不同類別的彩色圖像組成,每個(gè)類別包含 6000
張 32 x 32
像素的圖像。該數(shù)據(jù)集涵蓋了各種不同的對象類別,包括飛機(jī)、汽車、鳥類、貓、鹿、狗、青蛙、馬、船和卡車。與一些只包含灰度圖像的數(shù)據(jù)集相比,CIFAR-10
數(shù)據(jù)集的圖像是彩色的,但由于圖像分辨率相對較低,圖像中的細(xì)節(jié)和特征相對較少。
CIFAR-10
數(shù)據(jù)集在計(jì)算機(jī)視覺領(lǐng)域的研究和開發(fā)中得到了廣泛的應(yīng)用,許多圖像分類算法和深度學(xué)習(xí)模型都在 CIFAR-10
上進(jìn)行了測試和驗(yàn)證。它提供了一個(gè)標(biāo)準(zhǔn)化的基準(zhǔn),用于比較不同算法的性能。
1.2 模型策略
了解了所用數(shù)據(jù)集后,本節(jié)中,我們繼續(xù)介紹圖像著色模型策略:
- 獲取訓(xùn)練數(shù)據(jù)集中的原始彩色圖像,將其轉(zhuǎn)換為灰度圖像,構(gòu)造輸入(灰度)-輸出(原始彩色圖像)對
- 執(zhí)行歸一化輸入和輸出圖像
- 構(gòu)建
U-Net
架構(gòu) - 訓(xùn)練模型
2. 實(shí)現(xiàn)圖像著色
接下來,使用 PyTorch
實(shí)現(xiàn)以上策略,構(gòu)建圖像著色模型。
(1) 導(dǎo)入所需庫:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'from torchvision import datasets
from torch.utils.data import DataLoader, Dataset
from torch import nn
from torch import optim
import numpy as np
import torchvision
from matplotlib import pyplot as plt
(2) 下載數(shù)據(jù)集,并定義訓(xùn)練、驗(yàn)證數(shù)據(jù)集和數(shù)據(jù)加載器。
下載數(shù)據(jù)集:
data_folder = 'cifar10/cifar/'
datasets.CIFAR10(data_folder, download=True)
定義訓(xùn)練、驗(yàn)證數(shù)據(jù)集和數(shù)據(jù)加載器:
class Colorize(torchvision.datasets.CIFAR10):def __init__(self, root, train):super().__init__(root, train)def __getitem__(self, ix):im, _ = super().__getitem__(ix)bw = im.convert('L').convert('RGB')bw, im = np.array(bw)/255., np.array(im)/255.bw, im = [torch.tensor(i).permute(2,0,1).to(device).float() for i in [bw,im]]return bw, imtrn_ds = Colorize('cifar10/cifar/', train=True)
val_ds = Colorize('cifar10/cifar/', train=False)trn_dl = DataLoader(trn_ds, batch_size=256, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=256, shuffle=False)
輸入和輸出圖像的樣本如下:
a,b = trn_ds[0]
plt.subplot(121)
plt.imshow(a.permute(1,2,0).cpu(), cmap='gray')
plt.subplot(122)
plt.imshow(b.permute(1,2,0).cpu())
plt.show()
(3) 定義網(wǎng)絡(luò)架構(gòu):
class Identity(nn.Module):def __init__(self):super().__init__()def forward(self, x):return xclass DownConv(nn.Module):def __init__(self, ni, no, maxpool=True):super().__init__()self.model = nn.Sequential(nn.MaxPool2d(2) if maxpool else Identity(),nn.Conv2d(ni, no, 3, padding=1),nn.BatchNorm2d(no),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(no, no, 3, padding=1),nn.BatchNorm2d(no),nn.LeakyReLU(0.2, inplace=True),)def forward(self, x):return self.model(x)class UpConv(nn.Module):def __init__(self, ni, no, maxpool=True):super().__init__()self.convtranspose = nn.ConvTranspose2d(ni, no, 2, stride=2)self.convlayers = nn.Sequential(nn.Conv2d(no+no, no, 3, padding=1),nn.BatchNorm2d(no),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(no, no, 3, padding=1),nn.BatchNorm2d(no),nn.LeakyReLU(0.2, inplace=True),)def forward(self, x, y):x = self.convtranspose(x)x = torch.cat([x,y], axis=1)x = self.convlayers(x)return xclass UNet(nn.Module):def __init__(self):super().__init__()self.d1 = DownConv( 3, 64, maxpool=False)self.d2 = DownConv( 64, 128)self.d3 = DownConv( 128, 256)self.d4 = DownConv( 256, 512)self.d5 = DownConv( 512, 1024)self.u5 = UpConv (1024, 512)self.u4 = UpConv ( 512, 256)self.u3 = UpConv ( 256, 128)self.u2 = UpConv ( 128, 64)self.u1 = nn.Conv2d(64, 3, kernel_size=1, stride=1)def forward(self, x):x0 = self.d1( x) # 32x1 = self.d2(x0) # 16x2 = self.d3(x1) # 8x3 = self.d4(x2) # 4x4 = self.d5(x3) # 2X4 = self.u5(x4, x3)# 4X3 = self.u4(X4, x2)# 8X2 = self.u3(X3, x1)# 16X1 = self.u2(X2, x0)# 32X0 = self.u1(X1) # 3return X0
(4) 定義模型、優(yōu)化器和損失函數(shù):
def get_model():model = UNet().to(device)optimizer = optim.Adam(model.parameters(), lr=1e-3)loss_fn = nn.MSELoss()return model, optimizer, loss_fn
(5) 定義模型在批數(shù)據(jù)進(jìn)行訓(xùn)練和驗(yàn)證的函數(shù):
def train_batch(model, data, optimizer, criterion):model.train()x, y = data_y = model(x)optimizer.zero_grad()loss = criterion(_y, y)loss.backward()optimizer.step()return loss.item()@torch.no_grad()
def validate_batch(model, data, criterion):model.eval()x, y = data_y = model(x)loss = criterion(_y, y)return loss.item()
(6) 訓(xùn)練模型:
model, optimizer, criterion = get_model()
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)_val_dl = DataLoader(val_ds, batch_size=1, shuffle=True)n_epochs = 100
train_loss_epochs = []
val_loss_epochs = []for ex in range(n_epochs):N = len(trn_dl)trn_loss = []val_loss = []for bx, data in enumerate(trn_dl):loss = train_batch(model, data, optimizer, criterion)pos = (ex + (bx+1)/N)trn_loss.append(loss)train_loss_epochs.append(np.average(trn_loss))N = len(val_dl)for bx, data in enumerate(val_dl):loss = validate_batch(model, data, criterion)pos = (ex + (bx+1)/N)val_loss.append(loss)val_loss_epochs.append(np.average(val_loss))exp_lr_scheduler.step()if (ex+1)%10 == 0:for _ in range(5):a,b = next(iter(_val_dl))_b = model(a)plt.subplot(131)plt.imshow(a[0].permute(1,2,0).cpu(), cmap='gray')plt.subplot(132)plt.imshow(b[0].permute(1,2,0).cpu())plt.subplot(133)plt.imshow(_b[0].permute(1,2,0).detach().cpu().numpy())plt.show()
epochs = np.arange(n_epochs)+1
plt.plot(epochs, train_loss_epochs, 'bo', label='Training loss')
plt.plot(epochs, val_loss_epochs, 'r', label='Test loss')
plt.title('Training and Test loss over increasing epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid('off')
plt.show()
從前面的輸出中,可以看到模型能夠很好地為灰度圖像著色。
相關(guān)鏈接
PyTorch深度學(xué)習(xí)實(shí)戰(zhàn)(1)——神經(jīng)網(wǎng)絡(luò)與模型訓(xùn)練過程詳解
PyTorch深度學(xué)習(xí)實(shí)戰(zhàn)(2)——PyTorch基礎(chǔ)
PyTorch深度學(xué)習(xí)實(shí)戰(zhàn)(3)——使用PyTorch構(gòu)建神經(jīng)網(wǎng)絡(luò)
PyTorch深度學(xué)習(xí)實(shí)戰(zhàn)(4)——常用激活函數(shù)和損失函數(shù)詳解
PyTorch深度學(xué)習(xí)實(shí)戰(zhàn)(5)——計(jì)算機(jī)視覺基礎(chǔ)
PyTorch深度學(xué)習(xí)實(shí)戰(zhàn)(6)——神經(jīng)網(wǎng)絡(luò)性能優(yōu)化技術(shù)
PyTorch深度學(xué)習(xí)實(shí)戰(zhàn)(7)——批大小對神經(jīng)網(wǎng)絡(luò)訓(xùn)練的影響
PyTorch深度學(xué)習(xí)實(shí)戰(zhàn)(8)——批歸一化
PyTorch深度學(xué)習(xí)實(shí)戰(zhàn)(9)——學(xué)習(xí)率優(yōu)化
PyTorch深度學(xué)習(xí)實(shí)戰(zhàn)(10)——過擬合及其解決方法
PyTorch深度學(xué)習(xí)實(shí)戰(zhàn)(11)——卷積神經(jīng)網(wǎng)絡(luò)
PyTorch深度學(xué)習(xí)實(shí)戰(zhàn)(12)——數(shù)據(jù)增強(qiáng)
PyTorch深度學(xué)習(xí)實(shí)戰(zhàn)(13)——可視化神經(jīng)網(wǎng)絡(luò)中間層輸出
PyTorch深度學(xué)習(xí)實(shí)戰(zhàn)(14)——類激活圖
PyTorch深度學(xué)習(xí)實(shí)戰(zhàn)(15)——遷移學(xué)習(xí)
PyTorch深度學(xué)習(xí)實(shí)戰(zhàn)(16)——面部關(guān)鍵點(diǎn)檢測
PyTorch深度學(xué)習(xí)實(shí)戰(zhàn)(17)——多任務(wù)學(xué)習(xí)
PyTorch深度學(xué)習(xí)實(shí)戰(zhàn)(18)——目標(biāo)檢測基礎(chǔ)
PyTorch深度學(xué)習(xí)實(shí)戰(zhàn)(19)——從零開始實(shí)現(xiàn)R-CNN目標(biāo)檢測
PyTorch深度學(xué)習(xí)實(shí)戰(zhàn)(20)——從零開始實(shí)現(xiàn)Fast R-CNN目標(biāo)檢測
PyTorch深度學(xué)習(xí)實(shí)戰(zhàn)(21)——從零開始實(shí)現(xiàn)Faster R-CNN目標(biāo)檢測
PyTorch深度學(xué)習(xí)實(shí)戰(zhàn)(22)——從零開始實(shí)現(xiàn)YOLO目標(biāo)檢測
PyTorch深度學(xué)習(xí)實(shí)戰(zhàn)(23)——使用U-Net架構(gòu)進(jìn)行圖像分割
PyTorch深度學(xué)習(xí)實(shí)戰(zhàn)(24)——從零開始實(shí)現(xiàn)Mask R-CNN實(shí)例分割