甪直做網(wǎng)站蘇州seo門戶網(wǎng)
文章目錄
- 昇思MindSpore應(yīng)用實踐
- 基于MindSpore的Pix2Pix圖像轉(zhuǎn)換
- 1、Pix2Pix 概述
- 2、U-Net架構(gòu)
- 定義UNet Skip Connection Block
- 2、生成器部分
- 3、基于PatchGAN的判別器
- 4、Pix2Pix的生成器和判別器初始化
- 5、模型訓(xùn)練
- 6、模型推理
- Reference
昇思MindSpore應(yīng)用實踐
本系列文章主要用于記錄昇思25天學(xué)習(xí)打卡營的學(xué)習(xí)心得。
基于MindSpore的Pix2Pix圖像轉(zhuǎn)換
1、Pix2Pix 概述
Pix2Pix 是一個專門為圖像到圖像的轉(zhuǎn)換任務(wù)設(shè)計的網(wǎng)絡(luò),可以實現(xiàn)語義/標簽到真實圖片、灰度圖到彩色圖、航空圖到地圖、白天到黑夜、線稿圖到實物圖的轉(zhuǎn)換。Pix2Pix是將條件GAN(CGAN)應(yīng)用于有監(jiān)督(需要成對的輸入素描圖像Sketch和真實圖像GT,來訓(xùn)練網(wǎng)絡(luò))的圖像到圖像翻譯的經(jīng)典之作,和所有的GANs一樣,模型同樣包括:生成器和判別器兩個部分。
CGAN:CGAN(條件GAN) 的目標是生成與給定條件匹配的數(shù)據(jù)樣本。這些條件可以是標簽、部分實例標注數(shù)據(jù)或任何其他形式的多模態(tài)輔助信息。CGAN 通過將條件并入網(wǎng)絡(luò)的生成器和判別器中來指導(dǎo)數(shù)據(jù)生成過程。
相比普通的生成對抗損失:
L G A N ( G , D ) = E y [ l o g ( D ( y ) ) ] + E ( x , z ) [ l o g ( 1 ? D ( x , z ) ) ] L_{GAN}(G,D)=\mathbb{E}_{y}[log(D(y))]+\mathbb{E}_{(x,z)}[log(1-D(x,z))] LGAN?(G,D)=Ey?[log(D(y))]+E(x,z)?[log(1?D(x,z))]
- x x x:代表觀測圖像的數(shù)據(jù)。
- z z z:代表隨機噪聲的數(shù)據(jù)。
- y = G ( x , z ) y=G(x,z) y=G(x,z):生成器網(wǎng)絡(luò),給出由觀測圖像 x x x與隨機噪聲 z z z生成的“假”圖片,其中 x x x來自于訓(xùn)練數(shù)據(jù)而非生成器。
- D ( x , G ( x , z ) ) D(x,G(x,z)) D(x,G(x,z)):判別器網(wǎng)絡(luò),給出圖像判定為真實圖像的概率,其中 x x x來自于訓(xùn)練數(shù)據(jù), G ( x , z ) G(x,z) G(x,z)來自于生成器。
CGAN多了來自于觀測圖像的條件 x x x(因此Pix2Pix訓(xùn)練時采用有監(jiān)督的方式,需要標注好的語義數(shù)據(jù),如下圖中的
Map2Aerial數(shù)據(jù)集、Anime Sketch Colorization Pair 素描生成動漫數(shù)據(jù)集),
CGAN的目標可以表示為:
L C G A N ( G , D ) = E ( x , y ) [ l o g ( D ( x , y ) ) ] + E ( x , z ) [ l o g ( 1 ? D ( x , G ( x , z ) ) ) ] L_{CGAN}(G,D)=\mathbb{E}_{(x,y)}[log(D(x,y))]+\mathbb{E}_{(x,z)}[log(1-D(x,G(x,z)))] LCGAN?(G,D)=E(x,y)?[log(D(x,y))]+E(x,z)?[log(1?D(x,G(x,z)))]
Pix2Pix 還包括 L1 損失,幫助生成器產(chǎn)生結(jié)構(gòu)上接近真實圖像的結(jié)果,這一點在圖像翻譯任務(wù)中尤為重要:
L L 1 ( G ) = E ( x , y , z ) [ ∣ ∣ y ? G ( x , z ) ∣ ∣ 1 ] L_{L1}(G)=\mathbb{E}_{(x,y,z)}[||y-G(x,z)||_{1}] LL1?(G)=E(x,y,z)?[∣∣y?G(x,z)∣∣1?]
進而得到最終目標:
a r g min ? G max ? D L C G A N ( G , D ) + λ L L 1 ( G ) arg\min_{G}\max_{D}L_{CGAN}(G,D)+\lambda L_{L1}(G) argGmin?Dmax?LCGAN?(G,D)+λLL1?(G)
圖像轉(zhuǎn)換問題本質(zhì)上其實就是像素到像素的映射問題,Pix2Pix使用完全一樣的網(wǎng)絡(luò)結(jié)構(gòu)和目標函數(shù),僅更換不同的訓(xùn)練數(shù)據(jù)集就能分別實現(xiàn)以上的任務(wù)。
2、U-Net架構(gòu)
U-Net架構(gòu):Pix2Pix 使用 U-Net 架構(gòu)作為其生成器,在傳統(tǒng)的編解碼網(wǎng)絡(luò)結(jié)構(gòu)基礎(chǔ)上加入了跳躍連接的方式,可以更好地捕捉圖像的細節(jié)和上下文信息,適合于圖像到圖像的翻譯任務(wù)。相比于普通的編解碼結(jié)構(gòu)(Encoder-Decoder),U-Net在編碼器和解碼器之間引入了跳躍連接,極大地改善了梯度流:
定義UNet Skip Connection Block
import mindspore
import mindspore.nn as nn
import mindspore.ops as opsclass UNetSkipConnectionBlock(nn.Cell):def __init__(self, outer_nc, inner_nc, in_planes=None, dropout=False,submodule=None, outermost=False, innermost=False, alpha=0.2, norm_mode='batch'):super(UNetSkipConnectionBlock, self).__init__()down_norm = nn.BatchNorm2d(inner_nc)up_norm = nn.BatchNorm2d(outer_nc)use_bias = Falseif norm_mode == 'instance':down_norm = nn.BatchNorm2d(inner_nc, affine=False)up_norm = nn.BatchNorm2d(outer_nc, affine=False)use_bias = Trueif in_planes is None:in_planes = outer_ncdown_conv = nn.Conv2d(in_planes, inner_nc, kernel_size=4,stride=2, padding=1, has_bias=use_bias, pad_mode='pad')down_relu = nn.LeakyReLU(alpha)up_relu = nn.ReLU()if outermost:up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,kernel_size=4, stride=2,padding=1, pad_mode='pad')down = [down_conv]up = [up_relu, up_conv, nn.Tanh()]model = down + [submodule] + upelif innermost:up_conv = nn.Conv2dTranspose(inner_nc, outer_nc,kernel_size=4, stride=2,padding=1, has_bias=use_bias, pad_mode='pad')down = [down_relu, down_conv]up = [up_relu, up_conv, up_norm]model = down + upelse:up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,kernel_size=4, stride=2,padding=1, has_bias=use_bias, pad_mode='pad')down = [down_relu, down_conv, down_norm]up = [up_relu, up_conv, up_norm]model = down + [submodule] + upif dropout:model.append(nn.Dropout(p=0.5))self.model = nn.SequentialCell(model)self.skip_connections = not outermostdef construct(self, x):out = self.model(x)if self.skip_connections:out = ops.concat((out, x), axis=1)return out
2、生成器部分
原始CGAN的輸入是條件x和噪聲z兩種信息,這里的生成器只使用了條件信息,因此不能生成多樣性的結(jié)果。因此Pix2Pix在訓(xùn)練和測試時都使用了dropout,這樣可以生成多樣性的結(jié)果。
通過MindSpore實現(xiàn)基于U-Net的生成器:
class UNetGenerator(nn.Cell):def __init__(self, in_planes, out_planes, ngf=64, n_layers=8, norm_mode='bn', dropout=False):super(UNetGenerator, self).__init__()unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=None,norm_mode=norm_mode, innermost=True)for _ in range(n_layers - 5):unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=unet_block,norm_mode=norm_mode, dropout=dropout)unet_block = UNetSkipConnectionBlock(ngf * 4, ngf * 8, in_planes=None, submodule=unet_block,norm_mode=norm_mode)unet_block = UNetSkipConnectionBlock(ngf * 2, ngf * 4, in_planes=None, submodule=unet_block,norm_mode=norm_mode)unet_block = UNetSkipConnectionBlock(ngf, ngf * 2, in_planes=None, submodule=unet_block,norm_mode=norm_mode)self.model = UNetSkipConnectionBlock(out_planes, ngf, in_planes=in_planes, submodule=unet_block,outermost=True, norm_mode=norm_mode)def construct(self, x):return self.model(x)
3、基于PatchGAN的判別器
判別器使用的PatchGAN
結(jié)構(gòu),可看做卷積。
生成的矩陣中的每個點代表原圖的一小塊區(qū)域(patch)。通過矩陣中的各個值來判斷原圖中對應(yīng)每個Patch的真假。
import mindspore.nn as nnclass ConvNormRelu(nn.Cell):def __init__(self,in_planes,out_planes,kernel_size=4,stride=2,alpha=0.2,norm_mode='batch',pad_mode='CONSTANT',use_relu=True,padding=None):super(ConvNormRelu, self).__init__()norm = nn.BatchNorm2d(out_planes)if norm_mode == 'instance':norm = nn.BatchNorm2d(out_planes, affine=False)has_bias = (norm_mode == 'instance')if not padding:padding = (kernel_size - 1) // 2if pad_mode == 'CONSTANT':conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad',has_bias=has_bias, padding=padding)layers = [conv, norm]else:paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))pad = nn.Pad(paddings=paddings, mode=pad_mode)conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', has_bias=has_bias)layers = [pad, conv, norm]if use_relu:relu = nn.ReLU()if alpha > 0:relu = nn.LeakyReLU(alpha)layers.append(relu)self.features = nn.SequentialCell(layers)def construct(self, x):output = self.features(x)return outputclass Discriminator(nn.Cell):def __init__(self, in_planes=3, ndf=64, n_layers=3, alpha=0.2, norm_mode='batch'):super(Discriminator, self).__init__()kernel_size = 4layers = [nn.Conv2d(in_planes, ndf, kernel_size, 2, pad_mode='pad', padding=1),nn.LeakyReLU(alpha)]nf_mult = ndffor i in range(1, n_layers):nf_mult_prev = nf_multnf_mult = min(2 ** i, 8) * ndflayers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))nf_mult_prev = nf_multnf_mult = min(2 ** n_layers, 8) * ndflayers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1))self.features = nn.SequentialCell(layers)def construct(self, x, y):x_y = ops.concat((x, y), axis=1)output = self.features(x_y)return output
4、Pix2Pix的生成器和判別器初始化
實例化Pix2Pix生成器和判別器:
import mindspore.nn as nn
from mindspore.common import initializer as initg_in_planes = 3
g_out_planes = 3
g_ngf = 64
g_layers = 8
d_in_planes = 6
d_ndf = 64
d_layers = 3
alpha = 0.2
init_gain = 0.02
init_type = 'normal'net_generator = UNetGenerator(in_planes=g_in_planes, out_planes=g_out_planes,ngf=g_ngf, n_layers=g_layers)
for _, cell in net_generator.cells_and_names():if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):if init_type == 'normal':cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))elif init_type == 'xavier':cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))elif init_type == 'constant':cell.weight.set_data(init.initializer(0.001, cell.weight.shape))else:raise NotImplementedError('initialization method [%s] is not implemented' % init_type)elif isinstance(cell, nn.BatchNorm2d):cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))cell.beta.set_data(init.initializer('zeros', cell.beta.shape))net_discriminator = Discriminator(in_planes=d_in_planes, ndf=d_ndf,alpha=alpha, n_layers=d_layers)
for _, cell in net_discriminator.cells_and_names():if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):if init_type == 'normal':cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))elif init_type == 'xavier':cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))elif init_type == 'constant':cell.weight.set_data(init.initializer(0.001, cell.weight.shape))else:raise NotImplementedError('initialization method [%s] is not implemented' % init_type)elif isinstance(cell, nn.BatchNorm2d):cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))cell.beta.set_data(init.initializer('zeros', cell.beta.shape))class Pix2Pix(nn.Cell):"""Pix2Pix模型網(wǎng)絡(luò)"""def __init__(self, discriminator, generator):super(Pix2Pix, self).__init__(auto_prefix=True)self.net_discriminator = discriminatorself.net_generator = generatordef construct(self, reala):fakeb = self.net_generator(reala)return fakeb
5、模型訓(xùn)練
訓(xùn)練分為兩個主要部分:訓(xùn)練判別器和訓(xùn)練生成器;
訓(xùn)練判別器的目的是最大程度地提高判別圖像真?zhèn)蔚母怕?#xff1b;
訓(xùn)練生成器是希望能產(chǎn)生更好的虛假圖像;
在這兩個部分中,分別獲取訓(xùn)練過程中的損失,并在每個周期結(jié)束時進行統(tǒng)計。
通過MindSpore進行訓(xùn)練:
import numpy as np
import os
import datetime
from mindspore import value_and_grad, Tensorepoch_num = 3
ckpt_dir = "results/ckpt"
dataset_size = 400
val_pic_size = 256
lr = 0.0002
n_epochs = 100
n_epochs_decay = 100def get_lr():lrs = [lr] * dataset_size * n_epochslr_epoch = 0for epoch in range(n_epochs_decay):lr_epoch = lr * (n_epochs_decay - epoch) / n_epochs_decaylrs += [lr_epoch] * dataset_sizelrs += [lr_epoch] * dataset_size * (epoch_num - n_epochs_decay - n_epochs)return Tensor(np.array(lrs).astype(np.float32))dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True, num_parallel_workers=1)
steps_per_epoch = dataset.get_dataset_size()
loss_f = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()def forword_dis(reala, realb):lambda_dis = 0.5fakeb = net_generator(reala)pred0 = net_discriminator(reala, fakeb)pred1 = net_discriminator(reala, realb)loss_d = loss_f(pred1, ops.ones_like(pred1)) + loss_f(pred0, ops.zeros_like(pred0))loss_dis = loss_d * lambda_disreturn loss_disdef forword_gan(reala, realb):lambda_gan = 0.5lambda_l1 = 100fakeb = net_generator(reala)pred0 = net_discriminator(reala, fakeb)loss_1 = loss_f(pred0, ops.ones_like(pred0))loss_2 = l1_loss(fakeb, realb)loss_gan = loss_1 * lambda_gan + loss_2 * lambda_l1return loss_gand_opt = nn.Adam(net_discriminator.trainable_params(), learning_rate=get_lr(),beta1=0.5, beta2=0.999, loss_scale=1)
g_opt = nn.Adam(net_generator.trainable_params(), learning_rate=get_lr(),beta1=0.5, beta2=0.999, loss_scale=1)grad_d = value_and_grad(forword_dis, None, net_discriminator.trainable_params())
grad_g = value_and_grad(forword_gan, None, net_generator.trainable_params())def train_step(reala, realb):loss_dis, d_grads = grad_d(reala, realb)loss_gan, g_grads = grad_g(reala, realb)d_opt(d_grads)g_opt(g_grads)return loss_dis, loss_ganif not os.path.isdir(ckpt_dir):os.makedirs(ckpt_dir)g_losses = []
d_losses = []
data_loader = dataset.create_dict_iterator(output_numpy=True, num_epochs=epoch_num)for epoch in range(epoch_num):for i, data in enumerate(data_loader):start_time = datetime.datetime.now()input_image = Tensor(data["input_images"])target_image = Tensor(data["target_images"])dis_loss, gen_loss = train_step(input_image, target_image)end_time = datetime.datetime.now()delta = (end_time - start_time).microsecondsif i % 2 == 0:print("ms per step:{:.2f} epoch:{}/{} step:{}/{} Dloss:{:.4f} Gloss:{:.4f} ".format((delta / 1000), (epoch + 1), (epoch_num), i, steps_per_epoch, float(dis_loss), float(gen_loss)))d_losses.append(dis_loss.asnumpy())g_losses.append(gen_loss.asnumpy())if (epoch + 1) == epoch_num:mindspore.save_checkpoint(net_generator, ckpt_dir + "Generator.ckpt")
6、模型推理
導(dǎo)入模型訓(xùn)練保存的權(quán)重:
from mindspore import load_checkpoint, load_param_into_netparam_g = load_checkpoint(ckpt_dir + "Generator.ckpt")
load_param_into_net(net_generator, param_g)
dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True)
data_iter = next(dataset.create_dict_iterator())
predict_show = net_generator(data_iter["input_images"])
plt.figure(figsize=(10, 3), dpi=140)
for i in range(10):plt.subplot(2, 10, i + 1)plt.imshow((data_iter["input_images"][i].asnumpy().transpose(1, 2, 0) + 1) / 2)plt.axis("off")plt.subplots_adjust(wspace=0.05, hspace=0.02)plt.subplot(2, 10, i + 11)plt.imshow((predict_show[i].asnumpy().transpose(1, 2, 0) + 1) / 2)plt.axis("off")plt.subplots_adjust(wspace=0.05, hspace=0.02)
plt.show()
圖像翻譯效果如下:
Reference
昇思官方文檔-Pix2Pix實現(xiàn)圖像轉(zhuǎn)換
昇思大模型平臺
AI 助你無碼看片,生成對抗網(wǎng)絡(luò)(GAN)大顯身手