寶雞網(wǎng)站制作公司百度關(guān)鍵詞競價價格
前言 ??
?? 在低照度場景下進行目標檢測任務(wù),常存在圖像RGB特征信息少、提取特征困難、目標識別和定位精度低等問題,給檢測帶來一定的難度。
? ? ?🌻使用圖像增強模塊對原始圖像進行畫質(zhì)提升,恢復(fù)各類圖像信息,再使用目標檢測網(wǎng)絡(luò)對增強圖像進行特定目標檢測,有效提高檢測的精確度。
? ? ? ?本專欄會介紹傳統(tǒng)方法、Retinex、EnlightenGAN、SCI、Zero-DCE、IceNet、RRDNet、URetinex-Net等低照度圖像增強算法。
👑完整代碼已打包上傳至資源→低照度圖像增強代碼匯總
目錄
前言 ??
🚀一、Zero-DCE介紹?
??1.1 Zero-DCE簡介??
🚀二、Zero-DCE網(wǎng)絡(luò)結(jié)構(gòu)及核心代碼
??2.1 網(wǎng)絡(luò)結(jié)構(gòu)
??2.2 核心代碼
🚀三、Zero-DCE損失函數(shù)及核心代碼
??3.1 Lspa—Spatial Consistency Loss(空間一致性損失)
??3.2 Lexp—Exposure Control Loss(曝光控制損失)
??3.3 Lcol—Color Constancy Loss(顏色恒定損失)
??3.4 LtvA—Illumination Smoothness Loss(照明平滑度損失)
🚀四、Zero-DCE代碼復(fù)現(xiàn)
??4.1 環(huán)境配置
??4.2 運行過程
??4.3 運行效果
🚀一、Zero-DCE介紹?
相關(guān)資料:
- 論文題目:《Zero-Reference Deep Curve Estimation for Low-Light Image Enhancement》(用于低光圖像增強的零參考深度曲線估計)
- 原文地址:https://arxiv.org/abs/2001.06826
- 論文精讀:CVPR2020|ZeroDCE《Zero-Reference Deep Curve Estimation for Low-Light Image Enhancement》論文超詳細解讀(翻譯+精讀)
- 源碼地址:項目概覽 - Zero-DCE - GitCode
??1.1 Zero-DCE簡介??
本文發(fā)表在CVPR2020,主要提出了一個零參考深度曲線估計Zero-Reference Deep Curve Estimation(Zero-DCE),將光線增強轉(zhuǎn)換為了一個image-specific曲線估計問題(圖像作為輸入,曲線作為輸出),通過非參考損失函數(shù)實現(xiàn),從而獲得增強圖像。
另外,通過訓(xùn)練一個輕量級的網(wǎng)絡(luò)(DCE-NET),來預(yù)測一個像素級的,高階的曲線,并通過該曲線來調(diào)整圖像。
主要貢獻:
- 是第一個不依賴于成對和非成對訓(xùn)練數(shù)據(jù)的弱光增強網(wǎng)絡(luò),從而避免了過擬合的風(fēng)險。
- 設(shè)計一種特定的曲線,能夠迭代運用于自身來近似像素和高階曲線。這種曲線能夠在動態(tài)范圍內(nèi)有效的進行映射
- 提出了一種無參的損失函數(shù),來直接估計增強圖像的質(zhì)量。?
取得成效:?
- 整個方法在多個數(shù)據(jù)集上都取得了SOTA
-
在黑暗中的人臉檢測取得成效
🚀二、Zero-DCE網(wǎng)絡(luò)結(jié)構(gòu)及核心代碼
??2.1 網(wǎng)絡(luò)結(jié)構(gòu)
- (1)backbone:??DCE-Net包含七個具有對稱跳躍連接的卷積層:conv-ReLU 重復(fù) 6 次 + conv-Than。(注意:它具有對稱的級聯(lián),即第 1/2/3 層輸出和第 6/5/4 層輸出進行通道級聯(lián))
- (2)conv層: 由32個3x3的卷積核組成,stride=1
- (3)參數(shù):??整個網(wǎng)絡(luò)的參數(shù)量為79,416
- (4)Flops:?Flops為5.21G(input 為256x256x3)
??2.2 核心代碼
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
#import pytorch_colors as colors
import numpy as npclass enhance_net_nopool(nn.Module):def __init__(self):super(enhance_net_nopool, self).__init__()self.relu = nn.ReLU(inplace=True)# 一共有32個模塊number_f = 32# 7個3*3,padding=1,stride=1的卷積核self.e_conv1 = nn.Conv2d(3,number_f,3,1,1,bias=True) self.e_conv2 = nn.Conv2d(number_f,number_f,3,1,1,bias=True) self.e_conv3 = nn.Conv2d(number_f,number_f,3,1,1,bias=True) self.e_conv4 = nn.Conv2d(number_f,number_f,3,1,1,bias=True) self.e_conv5 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True) self.e_conv6 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True) self.e_conv7 = nn.Conv2d(number_f*2,24,3,1,1,bias=True) # 最大池化層self.maxpool = nn.MaxPool2d(2, stride=2, return_indices=False, ceil_mode=False) # 雙線性插值上采樣層 self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)def forward(self, x):x1 = self.relu(self.e_conv1(x))# p1 = self.maxpool(x1)x2 = self.relu(self.e_conv2(x1))# p2 = self.maxpool(x2)x3 = self.relu(self.e_conv3(x2))# p3 = self.maxpool(x3)x4 = self.relu(self.e_conv4(x3))x5 = self.relu(self.e_conv5(torch.cat([x3,x4],1)))# x5 = self.upsample(x5)x6 = self.relu(self.e_conv6(torch.cat([x2,x5],1)))# 通過tanh激活函數(shù)處理得到增強后的圖像enhance_imagex_r = F.tanh(self.e_conv7(torch.cat([x1,x6],1)))# 通過torch.split將enhance_image分割成8個通道,分別表示不同的增強效果r1,r2,r3,r4,r5,r6,r7,r8 = torch.split(x_r, 3, dim=1)x = x + r1*(torch.pow(x,2)-x)x = x + r2*(torch.pow(x,2)-x)x = x + r3*(torch.pow(x,2)-x)enhance_image_1 = x + r4*(torch.pow(x,2)-x) x = enhance_image_1 + r5*(torch.pow(enhance_image_1,2)-enhance_image_1) x = x + r6*(torch.pow(x,2)-x) x = x + r7*(torch.pow(x,2)-x)enhance_image = x + r8*(torch.pow(x,2)-x)r = torch.cat([r1,r2,r3,r4,r5,r6,r7,r8],1)return enhance_image_1,enhance_image,r
?這段代碼平平無奇,就是實現(xiàn)圖像增強操作。具體來說,主要通過多層卷積和連接操作,以及一些激活函數(shù),學(xué)習(xí)圖像的增強信息。
首先,初始化定義了32個模塊,每個模塊由7個3*3,padding=1,stride=1的卷積核組成。
然后,前6個卷積層使用ReLU激活函數(shù),第7層使用tanh激活函數(shù),得到增強后的圖像enhance_image
。
接著,?通過torch.split
將enhance_image
分割成8個通道,分別表示不同的增強效果。
?最后,將這些效果疊加到原始輸入圖像上,得到最終的增強圖像。
🚀三、Zero-DCE損失函數(shù)及核心代碼
其實這四個損失函數(shù),才是本文最大的亮點。
??3.1 Lspa—Spatial Consistency Loss(空間一致性損失)
目的
通過保持輸入圖像與增強圖像相鄰區(qū)域的梯度促進圖像的空間一致性。
方法
-
首先計算輸入圖像和增強圖像在通道維度的平均值(將R、G、B三通道加起來求平均),得到兩個灰度圖像
-
然后分解為若干個4×4patches(不重復(fù),覆蓋全圖)
-
最后計算patch內(nèi)中心i與相鄰j像素差值,求平均
公式
-
:局部區(qū)域的數(shù)量
-
:是以區(qū)域 i為中心的四個相鄰區(qū)域(頂部、下、左、右)
-
:增強版本的局部區(qū)域的平均強度值
-
:輸入版本的局部區(qū)域的平均強度值?
代碼
class L_spa(nn.Module):def __init__(self):super(L_spa, self).__init__()# print(1)kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)kernel_left = torch.FloatTensor( [[0,0,0],[-1,1,0],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)kernel_right = torch.FloatTensor( [[0,0,0],[0,1,-1],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)kernel_up = torch.FloatTensor( [[0,-1,0],[0,1, 0 ],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)kernel_down = torch.FloatTensor( [[0,0,0],[0,1, 0],[0,-1,0]]).cuda().unsqueeze(0).unsqueeze(0)self.weight_left = nn.Parameter(data=kernel_left, requires_grad=False)self.weight_right = nn.Parameter(data=kernel_right, requires_grad=False)self.weight_up = nn.Parameter(data=kernel_up, requires_grad=False)self.weight_down = nn.Parameter(data=kernel_down, requires_grad=False)self.pool = nn.AvgPool2d(4)def forward(self, org , enhance ):b,c,h,w = org.shapeorg_mean = torch.mean(org,1,keepdim=True)enhance_mean = torch.mean(enhance,1,keepdim=True)org_pool = self.pool(org_mean) enhance_pool = self.pool(enhance_mean) weight_diff =torch.max(torch.FloatTensor([1]).cuda() + 10000*torch.min(org_pool - torch.FloatTensor([0.3]).cuda(),torch.FloatTensor([0]).cuda()),torch.FloatTensor([0.5]).cuda())E_1 = torch.mul(torch.sign(enhance_pool - torch.FloatTensor([0.5]).cuda()) ,enhance_pool-org_pool)D_org_letf = F.conv2d(org_pool , self.weight_left, padding=1)D_org_right = F.conv2d(org_pool , self.weight_right, padding=1)D_org_up = F.conv2d(org_pool , self.weight_up, padding=1)D_org_down = F.conv2d(org_pool , self.weight_down, padding=1)D_enhance_letf = F.conv2d(enhance_pool , self.weight_left, padding=1)D_enhance_right = F.conv2d(enhance_pool , self.weight_right, padding=1)D_enhance_up = F.conv2d(enhance_pool , self.weight_up, padding=1)D_enhance_down = F.conv2d(enhance_pool , self.weight_down, padding=1)D_left = torch.pow(D_org_letf - D_enhance_letf,2)D_right = torch.pow(D_org_right - D_enhance_right,2)D_up = torch.pow(D_org_up - D_enhance_up,2)D_down = torch.pow(D_org_down - D_enhance_down,2)E = (D_left + D_right + D_up +D_down)# E = 25*(D_left + D_right + D_up +D_down)return E
首先,定義了四個卷積核分別用于計算圖像在左、右、上和下方向上的差異。
然后,在向前傳播過程中進行如下計算:?
- ?計算權(quán)重差異
weight_diff
。 - 計算增強圖像的差異
E_1
,該差異受到閾值0.5
的控制。 - 利用卷積運算分別計算原始圖像和增強圖像在四個方向上的梯度差異。
- 計算每個方向上的梯度差異的平方,并將它們相加,得到
E
。?
?最后,返回計算得到的空間損失 E
。
??3.2 Lexp—Exposure Control Loss(曝光控制損失)
目的
抑制曝光不足/過度區(qū)域,控制曝光水平。
方法
測量的是局部區(qū)域的平均強度值與良好曝光水平(E=0.6 ,經(jīng)驗設(shè)置)之間的距離。
-
首先將增強圖像轉(zhuǎn)為灰度圖
-
然后分解為若干 16×16 patches(不重復(fù),覆蓋全圖)
-
最后計算 patch 內(nèi)的平均值
公式
-
:大小為16×16的不重疊局部區(qū)域個數(shù)
-
:增強圖像中某個局部區(qū)域的平均強度值
?代碼
class L_exp(nn.Module):def __init__(self,patch_size,mean_val):super(L_exp, self).__init__()# print(1)self.pool = nn.AvgPool2d(patch_size)self.mean_val = mean_valdef forward(self, x ):b,c,h,w = x.shapex = torch.mean(x,1,keepdim=True)mean = self.pool(x)d = torch.mean(torch.pow(mean- torch.FloatTensor([self.mean_val] ).cuda(),2))return d
這段代碼比較簡單,就是通過初始化平均池化層和均值函數(shù),比較輸入圖像的全局均值與指定均值之間的差異。
最后,返回計算得到的亮度損失 d。
??3.3 Lcol—Color Constancy Loss(顏色恒定損失)
目的
用于糾正增強圖像中的潛在色偏,同時也建立了三個調(diào)整通道之間的關(guān)系。
方法
-
首先將提亮圖像分成RGB三通道,計算每個通道的平均亮度
-
然后將不同通道的平均亮度兩兩相減,求平均和
Color Constancy Loss值越小,說明提亮圖像顏色越平衡,損失越大則說明提亮圖像可能有色偏的問題
公式
-
:增強后圖像中p通道的平均強度值
-
:一對顏色通道
?代碼?
class L_color(nn.Module):def __init__(self):super(L_color, self).__init__()def forward(self, x ):b,c,h,w = x.shapemean_rgb = torch.mean(x,[2,3],keepdim=True)mr,mg, mb = torch.split(mean_rgb, 1, dim=1)Drg = torch.pow(mr-mg,2)Drb = torch.pow(mr-mb,2)Dgb = torch.pow(mb-mg,2)k = torch.pow(torch.pow(Drg,2) + torch.pow(Drb,2) + torch.pow(Dgb,2),0.5)return k
?這段代碼也比較簡單,主要進行以下的計算:
- 計算圖像在每個像素位置的RGB均值,這是通過對每個通道在高度和寬度上進行平均計算得到的。
- 將RGB均值分割成單獨的通道(mr、mg、mb)。
- 計算顏色差異,分別為紅綠差異
Drg
、紅藍差異Drb
和綠藍差異Dgb
。
?最后,返回計算得到的最終的顏色損失 k。
??3.4 LtvA—Illumination Smoothness Loss(照明平滑度損失)
目的
保持相鄰像素之間的單調(diào)關(guān)系。
啟發(fā)
將所有通道、所有迭代次數(shù)的 A (也就是網(wǎng)絡(luò)的輸出),其橫豎的梯度平均值應(yīng)該很小。
公式
-
:迭代次數(shù)
-
:水平梯度
-
? :垂直梯度
??代碼?
class L_TV(nn.Module):def __init__(self,TVLoss_weight=1):super(L_TV,self).__init__()self.TVLoss_weight = TVLoss_weightdef forward(self,x):batch_size = x.size()[0]h_x = x.size()[2]w_x = x.size()[3]count_h = (x.size()[2]-1) * x.size()[3]count_w = x.size()[2] * (x.size()[3] - 1)h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size
首先,定義了一個 TVLoss_weight
屬性,表示總變差損失的權(quán)重,默認為1。
然后,在向前傳播過程中進行如下計算:?
- 計算圖像在水平方向上的總變差
h_tv
和在垂直方向上的總變差w_tv
。 - 計算總變差損失(包括水平和垂直方向上的總變差),以及權(quán)重調(diào)整。
最后,返回計算得到的總變差損失。
🚀四、Zero-DCE代碼復(fù)現(xiàn)
??4.1 環(huán)境配置
- Python 3.7
- Pytorch 1.0.0
- opencv
- torchvision 0.2.1
- cuda 10.0
??4.2 運行過程
這個運行比較簡單,配好環(huán)境就行。如果有報錯可以參考以下博文:?
【代碼復(fù)現(xiàn)Zero-DCE詳解:Zero-Reference Deep Curve Estimation for Low-Light Image Enhancement】_zerodce代碼解讀-CSDN博客?跑微光圖像增強程序遇到的問題匯總_userwarning: nn.functional.tanh is deprecated. use-CSDN博客
暗光增強——Zero-DCE網(wǎng)絡(luò)推理測試(詳細圖文教程)-CSDN博客
??4.3 運行效果
????
?