網(wǎng)站開發(fā)使用的語言有哪些seo是什么意思中文翻譯
pytorch 實(shí)現(xiàn)git地址
論文地址:Neural Discrete Representation Learning
1 論文核心知識(shí)點(diǎn)
-
encoder
將圖片通過encoder得到圖片點(diǎn)表征
如輸入shape [32,3,32,32]
通過encoder后輸出 [32,64,8,8] (其中64位輸出維度) -
量化碼本
先隨機(jī)構(gòu)建一個(gè)碼本,維度與encoder保持一致
這里定義512個(gè)離散特征,碼本shape 為[512,64] -
encoder 碼本中向量最近查找
encoder輸出shape [32,64,8,8], 經(jīng)過維度變換 shape [32 * 8 * 8,64]
在碼本中找到最相近的向量,并替換為碼本中相似向量
輸出shape [3288,64],維度變換后,shape 為 [32,64,8,8] -
decoder
將上述數(shù)據(jù),喂給decoder,還原原始圖片 -
loss
loss 包含兩部分
a . encoder輸出和碼本向量接近
b. 重構(gòu)loss,重構(gòu)圖片與原圖片接近
2 論文實(shí)現(xiàn)
2.1 encoder
encoder是常用的圖片卷積神經(jīng)網(wǎng)絡(luò)
輸入x shape [32,3,32,32]
輸出 shape [32,128,8,8]
def __init__(self, in_dim, h_dim, n_res_layers, res_h_dim):super(Encoder, self).__init__()kernel = 4stride = 2self.conv_stack = nn.Sequential(nn.Conv2d(in_dim, h_dim // 2, kernel_size=kernel,stride=stride, padding=1),nn.ReLU(),nn.Conv2d(h_dim // 2, h_dim, kernel_size=kernel,stride=stride, padding=1),nn.ReLU(),nn.Conv2d(h_dim, h_dim, kernel_size=kernel-1,stride=stride-1, padding=1),ResidualStack(h_dim, h_dim, res_h_dim, n_res_layers))def forward(self, x):return self.conv_stack(x)
2.2 VectorQuantizer 向量量化層
- 輸入:
為encoder的輸出z,shape : [32,64,8,8] - 碼本維度:
encoder維度變換為[2024,64],和碼本embeddign shape [512,64]計(jì)算相似度 - 相似計(jì)算:使用 ( x ? y ) 2 = x 2 + y 2 ? 2 x y (x-y)^2=x^2+y^2-2xy (x?y)2=x2+y2?2xy計(jì)算和碼本的相似度
- z_q生成
然后取碼本中最相似的向量替換encoder中的向量 - z_1維度:
得到z_q shape [2024,64],經(jīng)維度變換 shape [32,64,8,8] ,維度與輸入z一致 - 損失函數(shù):
使 z_q和z接近,構(gòu)建損失函數(shù)
decoder 層
decoder層比較簡(jiǎn)單,與encoder層相反
輸入x shape 【32,64,8,8】
輸出shape [32,3,32,32]
class Decoder(nn.Module):"""This is the p_phi (x|z) network. Given a latent sample z p_phi maps back to the original space z -> x.Inputs:- in_dim : the input dimension- h_dim : the hidden layer dimension- res_h_dim : the hidden dimension of the residual block- n_res_layers : number of layers to stack"""def __init__(self, in_dim, h_dim, n_res_layers, res_h_dim):super(Decoder, self).__init__()kernel = 4stride = 2self.inverse_conv_stack = nn.Sequential(nn.ConvTranspose2d(in_dim, h_dim, kernel_size=kernel-1, stride=stride-1, padding=1),ResidualStack(h_dim, h_dim, res_h_dim, n_res_layers),nn.ConvTranspose2d(h_dim, h_dim // 2,kernel_size=kernel, stride=stride, padding=1),nn.ReLU(),nn.ConvTranspose2d(h_dim//2, 3, kernel_size=kernel,stride=stride, padding=1))def forward(self, x):return self.inverse_conv_stack(x)
2.3 損失函數(shù)
損失函數(shù)為重構(gòu)損失和embedding損失之和
- decoder 輸出為圖片重構(gòu)x_hat
- embedding損失,為encoder和碼本的embedding近似損失
- 重點(diǎn):(decoder計(jì)算損失時(shí),由于中間有取最小值,導(dǎo)致梯度不連續(xù),因此decoder loss 不能直接對(duì)encocer推薦進(jìn)行求導(dǎo),采用了復(fù)制梯度的方式: z_q = z + (z_q - z).detach(),及
for i in range(args.n_updates):(x, _) = next(iter(training_loader))x = x.to(device)optimizer.zero_grad()embedding_loss, x_hat, perplexity = model(x)recon_loss = torch.mean((x_hat - x)**2) / x_train_varloss = recon_loss + embedding_lossloss.backward()optimizer.step()