集趣網(wǎng)站怎么做兼職百度seo收錄
[GAN] 使用GAN網(wǎng)絡(luò)進(jìn)行圖片生成的“煉丹人”日志——生成向日葵圖片
文章目錄
- [GAN] 使用GAN網(wǎng)絡(luò)進(jìn)行圖片生成的“煉丹人”日志——生成向日葵圖片
- 1. 寫(xiě)在前面:
- 1.1 應(yīng)用場(chǎng)景:
- 1.2 數(shù)據(jù)集情況:
- 1.3 實(shí)驗(yàn)原理講解和分析(簡(jiǎn)化版,到時(shí)候可以出一期深入的PaperReading)
- 1.4 一些必要的介紹
- 2. 重要實(shí)驗(yàn)代碼:
- 2.1 一些相關(guān)的數(shù)據(jù)預(yù)處理
- 2.2 生成器和判別器
- 2.3 損失函數(shù)計(jì)算
- 2.4 訓(xùn)練和反向傳播
- 3. 實(shí)驗(yàn)結(jié)果分析:
- 3.0 baseline
- 3.0.1 損失函數(shù):
- 3.0.2 last picture:
- 3.0.3 gif picture:
- 3.1 epoch不變的情況下提高學(xué)習(xí)率:
- 3.1.1 損失函數(shù):
- 3.1.2 last picture:
- 3.1.3 gif picture:
- 3.2 試試增加epoch?:
- 3.2.1 損失函數(shù):
- 3.2.2 last picture:
- 3.2.3 gif picture:
- 4. 目前比較不錯(cuò)的效果展示
- 5. 一些其它問(wèn)題和小小的總結(jié)
- 參考資料
1. 寫(xiě)在前面:
1.1 應(yīng)用場(chǎng)景:
為了支撐人工智能落地,為人們的生活帶來(lái)更多的便利,充足的數(shù)據(jù)尤為重要。而在實(shí)際的應(yīng)用中常常會(huì)面臨專(zhuān)業(yè)數(shù)據(jù)匱乏,數(shù)據(jù)不均衡的問(wèn)題,所以利用神經(jīng)網(wǎng)絡(luò)根據(jù)已有的數(shù)據(jù)生成新的數(shù)據(jù),進(jìn)行數(shù)據(jù)擴(kuò)充,成為了助力人工智能落地的新思路。
1.2 數(shù)據(jù)集情況:
我所使用的數(shù)據(jù)集是總量為256張的彩色的向日葵的圖片。
1.3 實(shí)驗(yàn)原理講解和分析(簡(jiǎn)化版,到時(shí)候可以出一期深入的PaperReading)
- GAN網(wǎng)絡(luò)俗稱(chēng)生成式對(duì)抗網(wǎng)絡(luò),該網(wǎng)絡(luò)訓(xùn)練了兩個(gè)模型(即生成器G和判別器D)來(lái)進(jìn)行相互博弈,而博弈的目的是為了得到一個(gè)性能較好的可以用于生成我們想要的圖片的生成器G。
- 其中生成器網(wǎng)絡(luò)G是為了生成可以用來(lái)迷惑判別器網(wǎng)絡(luò)D的"假"圖像。按數(shù)學(xué)語(yǔ)言來(lái)理解就是要最大化判別器D犯錯(cuò)的概率。
- 而判別器網(wǎng)絡(luò)D則是為了判別一個(gè)樣本是不是來(lái)自于真實(shí)數(shù)據(jù)。按數(shù)學(xué)語(yǔ)言來(lái)理解就是它用于估計(jì)出一個(gè)樣本是來(lái)源于真實(shí)的數(shù)據(jù)而非來(lái)自于G的概率。
- 因此,不難得出這個(gè)模型的訓(xùn)練的過(guò)程大抵就是一個(gè)生成器G和判別器D之間的左右互博的過(guò)程。
- 不過(guò),值得注意的是這里對(duì)G和D的模型的構(gòu)建使用的是多層感知機(jī)MLP(Multilayer perceptrons),也就是在網(wǎng)絡(luò)上主要是使用全連接層。
- 從這里我們可以看到GAN網(wǎng)絡(luò)的損失函數(shù)為:
- 這個(gè)估值函數(shù)中由兩個(gè)部分的數(shù)學(xué)期望所組成,第一部分是當(dāng)輸入是來(lái)自真實(shí)樣本數(shù)據(jù)的期望,而第二部分則是當(dāng)輸入是來(lái)自生成器生成的樣本時(shí)的期望。
- 判別器輸出的值是一個(gè)概率值,這個(gè)概率表示輸出值是來(lái)自真實(shí)數(shù)據(jù)而非來(lái)自生成器的程度。
- 這個(gè)值越接近1就越表明當(dāng)前的輸入來(lái)自真實(shí)數(shù)據(jù),而越接近0就表示這個(gè)輸入來(lái)自生成器。
- 這樣們就可以理解D(x)的目的是為了更好地區(qū)分二者,這樣能是的D函數(shù)輸出的值是合理的(更接近1或0)。
- 而G的目的是為了讓G(z)更像數(shù)據(jù)樣本,這樣可以使得第二個(gè)期望中的D(G(z))能被誤判為1,這樣就可以達(dá)到讓第二個(gè)期望的值盡可能小的效果。
- 再反過(guò)來(lái)看D的訓(xùn)練,D能更好判別真假,就更加使得第二個(gè)期望中的D(G(z))能被正確判為0,這樣就可以達(dá)到讓第二個(gè)期望的值盡可能大的效果。
- 所以綜合地來(lái)看,判別器D就是為了讓整個(gè)損失(價(jià)值)函數(shù)盡量大,而生成器則反之,它想讓損失函數(shù)足夠小。這樣也就符合我們訓(xùn)練一個(gè)網(wǎng)絡(luò)的指標(biāo)是讓損失值減小,而我們也就可以沿著想辦法讓損失減小的方向去優(yōu)化我們的模型從而達(dá)到訓(xùn)練出一個(gè)較好的生成器。
1.4 一些必要的介紹
- 在我個(gè)人的實(shí)踐中,我所使用的深度學(xué)習(xí)框架為華為昇騰AI系列的
mindspore-1.9
深度學(xué)習(xí)框架。 - 所使用的筆記本的操作系統(tǒng)為Windows10
- 我使用的是AMD的CPU來(lái)進(jìn)行訓(xùn)練,因?yàn)楸旧碓揹emo的數(shù)據(jù)量并不是很大。
2. 重要實(shí)驗(yàn)代碼:
2.1 一些相關(guān)的數(shù)據(jù)預(yù)處理
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image # 一個(gè)讀取圖片和對(duì)圖片做基礎(chǔ)操作的類(lèi)
# 數(shù)據(jù)轉(zhuǎn)換
image_size = 64
input_images = np.asarray([np.asarray # 將Python的數(shù)組轉(zhuǎn)化成npArray(Image.open(input_data_dir + "/" + file).resize((image_size, image_size)) # 將圖片的尺寸轉(zhuǎn)化為 64* 64.convert("L")) # 將圖片轉(zhuǎn)化為灰度圖,這樣就簡(jiǎn)化了運(yùn)算,只需要考慮一個(gè)顏色通道了。(可拓展點(diǎn)對(duì)RGB三個(gè)顏色的通道都進(jìn)行處理。)for file in filename])
# 數(shù)據(jù)預(yù)處理
input_images = input_images.reshape(256, 4096) # 將256張圖片展平為一維向量
# input_images = input_images.astype('float32')/255 # 把圖片的值放縮到(0,1)之間
input_images = (input_images.astype('float32') - 127.5) / 127.5 # 把圖片的值放縮到(-1,1)之間
# input_images = (input_images.astype('float32')-mean)/std # 把數(shù)據(jù)樣本轉(zhuǎn)化為均值為0,方差為1的標(biāo)準(zhǔn)化數(shù)據(jù)(未完成)
2.2 生成器和判別器
# 構(gòu)建生成器
img_size = 64 # 訓(xùn)練圖像長(zhǎng)(寬)class Generator(nn.Cell):def __init__(self, latent_size, auto_prefix=True):super(Generator, self).__init__(auto_prefix=auto_prefix)self.model = nn.SequentialCell()# [N, 100] -> [N, 128]# 輸入一個(gè)100維的0~1之間的高斯分布,然后通過(guò)第一層線性變換將其映射到256維self.model.append(nn.Dense(latent_size, 128))self.model.append(nn.ReLU())# [N, 128] -> [N, 256]self.model.append(nn.Dense(128, 256))self.model.append(nn.BatchNorm1d(256))self.model.append(nn.ReLU())# [N, 256] -> [N, 512]self.model.append(nn.Dense(256, 512))self.model.append(nn.BatchNorm1d(512))self.model.append(nn.ReLU())# [N, 512] -> [N, 1024]self.model.append(nn.Dense(512, 1024))self.model.append(nn.BatchNorm1d(1024))self.model.append(nn.ReLU())# [N, 1024] -> [N, 4096]# 經(jīng)過(guò)線性變換將其變成4096維self.model.append(nn.Dense(1024, img_size * img_size))# 經(jīng)過(guò)Tanh激活函數(shù)是希望生成的假的圖片數(shù)據(jù)分布能夠在-1~1之間self.model.append(nn.Tanh())def construct(self, x):img = self.model(x)return ops.reshape(img, (-1, 1, 64, 64))latent_size = 100 # 隱碼的長(zhǎng)度
net_g = Generator(latent_size)
net_g.update_parameters_name('generator')
# 構(gòu)建判別器class Discriminator(nn.Cell):def __init__(self, auto_prefix=True):super().__init__(auto_prefix=auto_prefix)self.model = nn.SequentialCell()# [N, 4096] -> [N, 1024]self.model.append(nn.Dense(img_size * img_size, 1024)) # 輸入特征數(shù)為4096,輸出為1024self.model.append(nn.LeakyReLU()) # 默認(rèn)斜率為0.2的非線性映射激活函數(shù)# [N, 1024] -> [N, 256]self.model.append(nn.Dense(1024, 256)) # 進(jìn)行一個(gè)線性映射self.model.append(nn.LeakyReLU())# [N, 256] -> [N, 1]self.model.append(nn.Dense(256, 1))self.model.append(nn.Sigmoid()) # 二分類(lèi)激活函數(shù),將實(shí)數(shù)映射到[0,1]def construct(self, x):x_flat = ops.reshape(x, (-1, img_size * img_size))return self.model(x_flat)net_d = Discriminator()
net_d.update_parameters_name('discriminator')
2.3 損失函數(shù)計(jì)算
# 損失函數(shù)
adversarial_loss = nn.BCELoss(reduction='mean')# 損失及梯度計(jì)算函數(shù)
# 生成器計(jì)算損失過(guò)程
def generator_forward(test_noises):fake_data = net_g(test_noises)fake_out = net_d(fake_data)loss_g = adversarial_loss(fake_out, ops.ones_like(fake_out))return loss_g# 判別器計(jì)算損失過(guò)程
def discriminator_forward(real_data, test_noises):fake_data = net_g(test_noises)fake_out = net_d(fake_data)real_out = net_d(real_data)real_loss = adversarial_loss(real_out, ops.ones_like(real_out))fake_loss = adversarial_loss(fake_out, ops.zeros_like(fake_out))loss_d = real_loss + fake_lossreturn loss_d
2.4 訓(xùn)練和反向傳播
def train_step(real_data, latent_code):# 計(jì)算判別器損失和梯度# 前向計(jì)算 => 得到損失函數(shù)和梯度參數(shù)# 反向傳播 => 使用梯度參數(shù)進(jìn)行權(quán)重參數(shù)更新loss_d, grads_d = grad_d(real_data, latent_code)optimizer_d(grads_d)loss_g, grads_g = grad_g(latent_code)optimizer_g(grads_g)return loss_d, loss_g
3. 實(shí)驗(yàn)結(jié)果分析:
- 寫(xiě)在前面——在正式進(jìn)行實(shí)驗(yàn)前還有一些隨機(jī)性的探索。
其中值得一提的是,比起直接把
256
張照片一整個(gè)當(dāng)成一個(gè)批次epoch來(lái)訓(xùn)練的話,在一個(gè)epoch
內(nèi)將整個(gè)數(shù)據(jù)集分成幾個(gè)batch
效果會(huì)好得多,下面的所有的實(shí)驗(yàn)都是在這種情況下進(jìn)行的訓(xùn)練。
3.0 baseline
- 以下是使用
SGD優(yōu)化器
在學(xué)習(xí)率lr=0.01
并且訓(xùn)練100個(gè)epoch
后的結(jié)果。
3.0.1 損失函數(shù):
3.0.2 last picture:
3.0.3 gif picture:
- 學(xué)習(xí)率是我們進(jìn)行超參數(shù)調(diào)節(jié)中非常經(jīng)常用來(lái)調(diào)節(jié)的一個(gè)參數(shù),而
lr=0.01
是一個(gè)很常用的經(jīng)驗(yàn)值,所以這次我們就i用這個(gè)值來(lái)作為一個(gè)實(shí)驗(yàn)的起始的參考值。- 從上面的損失函數(shù)的趨勢(shì)可以看出,在一個(gè)數(shù)值比較小的
lr
下,損失函數(shù)的曲線是相對(duì)很平滑的。- 從上面的損失函數(shù)的曲線我們也可以看到一個(gè)健康的GAN網(wǎng)絡(luò)訓(xùn)練的過(guò)程生成器G的損失和判別器D的損失一般是呈現(xiàn)為在某個(gè)區(qū)間內(nèi)相互對(duì)峙波動(dòng)發(fā)展的過(guò)程。
- 而從上面的結(jié)果圖來(lái)看,現(xiàn)在當(dāng)前的模型是尚未收斂的狀態(tài),需要 “ 去做更多的學(xué)習(xí)來(lái)讓自己收斂。 ”
- 那么怎么往下去學(xué)得更多呢?
- 我們知道學(xué)習(xí)的過(guò)程是一個(gè)反向傳播的過(guò)程,而控制這個(gè)過(guò)程的一個(gè)重要的參數(shù)是學(xué)習(xí)率,也就是說(shuō),我們可以考慮讓學(xué)習(xí)率高一些,這樣就可以學(xué)得更快一些。
- 從另外一個(gè)角度來(lái)說(shuō)我們也可以考慮“學(xué)得久一些”,比如增大
epoch
看看效果會(huì)怎么樣? - 而這就是我們本文所研究的兩條調(diào)參路線。
3.1 epoch不變的情況下提高學(xué)習(xí)率:
3.1.1 損失函數(shù):
SGD優(yōu)化器
,100個(gè)epoch
,學(xué)習(xí)率lr=0.05
SGD優(yōu)化器
,100個(gè)epoch
,學(xué)習(xí)率lr=0.10
SGD優(yōu)化器
,100個(gè)epoch
,學(xué)習(xí)率lr=0.20
3.1.2 last picture:
SGD優(yōu)化器
,100個(gè)epoch
,學(xué)習(xí)率lr=0.05
SGD優(yōu)化器
,100個(gè)epoch
,學(xué)習(xí)率lr=0.10
SGD優(yōu)化器
,100個(gè)epoch
,學(xué)習(xí)率lr=0.20
3.1.3 gif picture:
SGD優(yōu)化器
,100個(gè)epoch
,學(xué)習(xí)率lr=0.05
SGD優(yōu)化器
,100個(gè)epoch
,學(xué)習(xí)率lr=0.10
SGD優(yōu)化器
,100個(gè)epoch
,學(xué)習(xí)率lr=0.20
- 從上面的部分結(jié)果來(lái)看的話,在只變動(dòng)學(xué)習(xí)率的情況下,對(duì)于當(dāng)前的例子,使用更大的學(xué)習(xí)率確實(shí)能夠加速模型的收斂,讓生成器最后的效果呈現(xiàn)出一種比較不錯(cuò)的效果,至少整個(gè)圖片看起來(lái)已經(jīng)是很像一張向日葵的圖片。這個(gè)是一個(gè)不錯(cuò)的進(jìn)步。
- 但是依然產(chǎn)生了一些新的問(wèn)題,比如因?yàn)?strong>學(xué)習(xí)率變大,雖然收斂的速度變快了,但是損失函數(shù)卻不是很平滑,充滿(mǎn)了各種爆炸的毛刺的氣息,這讓我想到了過(guò)擬合和不穩(wěn)定。
3.2 試試增加epoch?:
3.2.1 損失函數(shù):
SGD優(yōu)化器
,200個(gè)epoch
,學(xué)習(xí)率lr=0.05
SGD優(yōu)化器
,200個(gè)epoch
,學(xué)習(xí)率lr=0.10
SGD優(yōu)化器
,200個(gè)epoch
,學(xué)習(xí)率lr=0.20
3.2.2 last picture:
SGD優(yōu)化器
,200個(gè)epoch
,學(xué)習(xí)率lr=0.05
SGD優(yōu)化器
,200個(gè)epoch
,學(xué)習(xí)率lr=0.10
SGD優(yōu)化器
,200個(gè)epoch
,學(xué)習(xí)率lr=0.20
3.2.3 gif picture:
SGD優(yōu)化器
,200個(gè)epoch
,學(xué)習(xí)率lr=0.05
SGD優(yōu)化器
,200個(gè)epoch
,學(xué)習(xí)率lr=0.10
SGD優(yōu)化器
,200個(gè)epoch
,學(xué)習(xí)率lr=0.20
- 從最后的效果來(lái)看,把epoch增多,最后生成的照片的細(xì)膩程度遠(yuǎn)比
僅有100個(gè)epoch
的最后的成片的效果好了很多。由此可見(jiàn),在學(xué)習(xí)率合理的情況下,去增大訓(xùn)練的epoch量也確實(shí)是能比較不錯(cuò)地提升GAN網(wǎng)絡(luò)最后生成的圖片的效果。- 不過(guò)也產(chǎn)生了許多新的問(wèn)題,從上面的這些損失函數(shù)可以找到一個(gè)共性,那就是
在初期的epoch中,生成器G的損失值是在判別器的損失值的之下的,而隨著訓(xùn)練的epoch的量足夠大之后,在中后期,會(huì)出現(xiàn)判別器D的損失值不斷下降,而生成器的損失值則開(kāi)始上升的情況。這其實(shí)直接說(shuō)明了在這些階段中繼續(xù)增大epoch可能并不能很好地朝著我們想要的訓(xùn)練出一個(gè)效果更好的生成器的方向演變了。
- 從部分實(shí)驗(yàn)結(jié)果中我們可以發(fā)現(xiàn):
當(dāng)判別器D的能力相比生成器G更強(qiáng)的時(shí)候,G為了能夠繼續(xù)優(yōu)化,往往就會(huì)向模式崩塌的方向走去,它會(huì)開(kāi)始投機(jī)取巧,使得最后生成出來(lái)的圖片會(huì)普遍有某種類(lèi)似,在個(gè)性上就不夠有好效果了。我們稱(chēng)其為泛化能力不夠。
- 這里我以我訓(xùn)練了
500個(gè)epoch
的一些過(guò)程性的截圖來(lái)展示: SGD優(yōu)化器
,1個(gè)epoch
,學(xué)習(xí)率lr=0.25
SGD優(yōu)化器
,50個(gè)epoch
,學(xué)習(xí)率lr=0.25
SGD優(yōu)化器
,100個(gè)epoch
,學(xué)習(xí)率lr=0.25
SGD優(yōu)化器
,150個(gè)epoch
,學(xué)習(xí)率lr=0.25
SGD優(yōu)化器
,200個(gè)epoch
,學(xué)習(xí)率lr=0.25
SGD優(yōu)化器
,250個(gè)epoch
,學(xué)習(xí)率lr=0.25
SGD優(yōu)化器
,300個(gè)epoch
,學(xué)習(xí)率lr=0.25
SGD優(yōu)化器
,350個(gè)epoch
,學(xué)習(xí)率lr=0.25
SGD優(yōu)化器
,400個(gè)epoch
,學(xué)習(xí)率lr=0.25
SGD優(yōu)化器
,450個(gè)epoch
,學(xué)習(xí)率lr=0.25
SGD優(yōu)化器
,500個(gè)epoch
,學(xué)習(xí)率lr=0.25
- 特別指出這個(gè)例子的原因是我發(fā)現(xiàn)epoch增大越到后期,生成出來(lái)的向日葵就基本都是
懟臉向日葵
居多,而前面還能看到的苗條向日葵
,則其實(shí)基本偏少了,更不用說(shuō)其他更有特性
的向日葵了。- 當(dāng)我返回去看這
256張向日葵的數(shù)據(jù)集
的時(shí)候,我發(fā)現(xiàn)其實(shí)原始的相冊(cè)中,其實(shí)居多的也主要是懟臉向日葵
,其次是苗條向日葵
,最后是一些零散的各類(lèi)較有個(gè)性的向日葵。- 尤次可見(jiàn),最后的最后,我們導(dǎo)向的結(jié)果依然是
最后影響一個(gè)模型的質(zhì)量的,還是回到了訓(xùn)練這個(gè)模型的數(shù)據(jù)集的質(zhì)量。
高質(zhì)量的數(shù)據(jù)處理對(duì)模型的訓(xùn)練是非常非常非常重要的!
- 數(shù)據(jù)集照片情況概覽:
4. 目前比較不錯(cuò)的效果展示
- 以下是使用
SGD
優(yōu)化器,學(xué)習(xí)率為0.25
,訓(xùn)練了500個(gè)epoch
的一個(gè)演變效果。
5. 一些其它問(wèn)題和小小的總結(jié)
- 總得來(lái)說(shuō)經(jīng)過(guò)本次實(shí)驗(yàn)的探究,其實(shí)我所在對(duì)抗的主要是兩個(gè)問(wèn)題:
- "生成的圖片不像我的目的圖像"的問(wèn)題。(欠擬合,未收斂)
- ”生成的圖片大多長(zhǎng)得類(lèi)似,或者甚至一模一樣!“(過(guò)擬合,模式崩塌)
- 結(jié)合做了以上那么多的實(shí)驗(yàn)來(lái)看,我現(xiàn)在對(duì)GAN網(wǎng)絡(luò)的兩個(gè)模型的損失函數(shù)的理解是正常的情況G和D應(yīng)該是兩條有波動(dòng),但整體上是對(duì)峙者推進(jìn)的一上一下的趨勢(shì),其中最好是G在下,而D在上。這樣的狀態(tài)持續(xù)得越多個(gè)epoch,最終我們得到的生成器的綜合效果就會(huì)越佳,而一旦打破了這個(gè)
平衡
,生成器的質(zhì)量就會(huì)往某一個(gè)方向偏移,一般是模式崩塌即判別器不斷在進(jìn)化,使得判別器太強(qiáng),而生成器只能通過(guò)投機(jī)取巧的方式來(lái)精學(xué)某一類(lèi)
來(lái)保持它能繼續(xù)保持能騙過(guò)生成器。所以如何達(dá)到平衡
是一個(gè)值得深入研究的方向。
參考資料
- [1] GOODFELLOW I, POUGET-ABADIE J, MIRZA M, et al. Generative Adversarial Nets[J/OL]. Journal of Japan Society for Fuzzy Theory and Intelligent Informatics, 2017: 177-177. http://dx.doi.org/10.3156/jsoft.29.5_177_2. DOI:10.3156/jsoft.29.5_177_2.
- GAN圖像生成-mindspore