無(wú)線網(wǎng)絡(luò)管理系統(tǒng)長(zhǎng)沙seo網(wǎng)絡(luò)推廣
提示:transformer結(jié)構(gòu)的目標(biāo)檢測(cè)解碼器,包含loss計(jì)算,附有源碼
文章目錄
- 前言
- 一、main函數(shù)代碼解讀
- 1、整體結(jié)構(gòu)認(rèn)識(shí)
- 2、main函數(shù)代碼解讀
- 3、源碼鏈接
- 二、decode模塊代碼解讀
- 1、decoded的TransformerDec模塊代碼解讀
- 2、decoded的TransformerDecoder模塊代碼解讀
- 3、decoded的DecoderLayer模塊代碼解讀
- 三、decode模塊訓(xùn)練demo代碼解讀
- 1、解碼數(shù)據(jù)輸入格式
- 2、解碼訓(xùn)練demo代碼解讀
- 四、decode模塊預(yù)測(cè)demo代碼解讀
- 1、預(yù)測(cè)數(shù)據(jù)輸入格式
- 2、解碼預(yù)測(cè)demo代碼解讀
- 五、losses模塊代碼解讀
- 1、matcher初始化
- 2、二分匹配matcher代碼解讀
- 3、num_classes參數(shù)解讀
- 4、losses的demo代碼解讀
前言
最近重溫DETR模型,越發(fā)感覺(jué)detr模型結(jié)構(gòu)精妙之處,不同于anchor base 與anchor free設(shè)計(jì),直接利用100框給出預(yù)測(cè)結(jié)果,使用可學(xué)習(xí)learn query深度查找,使用二分匹配方式訓(xùn)練模型。為此,我基于detr源碼提取解碼decode、loss計(jì)算等系列模塊,并重構(gòu)、修改、整合一套解碼與loss實(shí)現(xiàn)的框架,該框架可適用任何backbone特征提取接我框架,實(shí)現(xiàn)完整訓(xùn)練與預(yù)測(cè),我也有相應(yīng)demo指導(dǎo)使用我的框架。那么,接下來(lái),我將完整介紹該框架源碼。同時(shí),我將此源碼進(jìn)行開(kāi)源,并上傳github中,供讀者參考。
一、main函數(shù)代碼解讀
1、整體結(jié)構(gòu)認(rèn)識(shí)
在介紹main函數(shù)代碼前,我先說(shuō)下整體框架結(jié)構(gòu),該框架包含2個(gè)文件夾,一個(gè)losses文件夾,用于處理loss計(jì)算,一個(gè)是obj_det文件,用于transformer解碼模塊,該模塊源碼修改于detr模型,也包含main.py,該文件是整體解碼與loss計(jì)算demo示意代碼,如下圖。
2、main函數(shù)代碼解讀
該代碼實(shí)際是我隨機(jī)創(chuàng)造了標(biāo)簽target數(shù)據(jù)與backbone特征提取數(shù)據(jù)及位置編碼數(shù)據(jù),使其能正常運(yùn)行的demo,其代碼如下:
import torch
from obj_det.transformer_obj import TransformerDec
from losses.matcher import HungarianMatcher
from losses.loss import SetCriterionif __name__ == '__main__':Model = TransformerDec(d_model=256, output_intermediate_dec=True, num_classes=4)num_classes = 4 # 類(lèi)別+1matcher = HungarianMatcher(cost_class=1, cost_bbox=5, cost_giou=2) # 二分匹配不同任務(wù)分配的權(quán)重losses = ['labels', 'boxes', 'cardinality'] # 計(jì)算loss的任務(wù)weight_dict = {'loss_ce': 1, 'loss_bbox': 5, 'loss_giou': 2} # 為dert最后一個(gè)設(shè)置權(quán)重criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict, eos_coef=0.1, losses=losses)# 下面使用iter,我構(gòu)造了虛擬模型編碼數(shù)據(jù)與數(shù)據(jù)加載標(biāo)簽數(shù)據(jù)src = torch.rand((391, 2, 256))pos_embed = torch.ones((391, 1, 256))# 創(chuàng)造真實(shí)target數(shù)據(jù)target1 = {'boxes':torch.rand((5,4)),'labels':torch.tensor([1,3,2,1,2])}target2 = {'boxes': torch.rand((3, 4)), 'labels': torch.tensor([1, 1, 2])}target = [target1, target2]res = Model(src, pos_embed)losses = criterion(res, target)print(losses)
如下圖:
3、源碼鏈接
源碼鏈接:點(diǎn)擊這里
二、decode模塊代碼解讀
該模塊主要是使用transform方式對(duì)backbone提取特征的解碼,主要使用learn query等相關(guān)trike與transform解碼方式內(nèi)容。
我主要介紹TransformerDec、TransformerDecoder、DecoderLayer模塊,為依次被包含關(guān)系,或說(shuō)成后者是前者組成部分。
1、decoded的TransformerDec模塊代碼解讀
該類(lèi)大意是包含了learn query嵌入、解碼transform模塊調(diào)用、head頭預(yù)測(cè)logit與boxes等內(nèi)容,是實(shí)現(xiàn)解碼與預(yù)測(cè)內(nèi)容,該模塊參數(shù)或解釋已有注釋?zhuān)x者可自行查看,其代碼如下:
class TransformerDec(nn.Module):'''d_model=512, 使用多少維度表示,實(shí)際為編碼輸出表達(dá)維度nhead=8, 有多少個(gè)頭num_queries=100, 目標(biāo)查詢數(shù)量,可學(xué)習(xí)querynum_decoder_layers=6, 解碼循環(huán)層數(shù)dim_feedforward=2048, 類(lèi)似FFN的2個(gè)nn.Linear變化dropout=0.1,activation="relu",normalize_before=False,解碼結(jié)構(gòu)使用2種方式,默認(rèn)False使用post解碼結(jié)構(gòu)output_intermediate_dec=False, 若為T(mén)rue保存中間層解碼結(jié)果(即:每個(gè)解碼層結(jié)果保存),若False只保存最后一次結(jié)果,訓(xùn)練為T(mén)rue,推理為Falsenum_classes: num_classes數(shù)量與數(shù)據(jù)格式有關(guān),若類(lèi)別id=1表示第一類(lèi),則num_classes=實(shí)際類(lèi)別數(shù)+1,若id=0表示第一個(gè),則num_classes=實(shí)際類(lèi)別數(shù)額外說(shuō)明,coco類(lèi)別id是1開(kāi)始的,假如有三個(gè)類(lèi),名稱為[dog,cat,pig],batch=2,那么參數(shù)num_classes=4,表示3個(gè)類(lèi)+1個(gè)背景,模型輸出src_logits=[2,100,5]會(huì)多出一個(gè)預(yù)測(cè),target_classes設(shè)置為[2,100],其值為4(該值就是背景,而有類(lèi)別值為1、2、3),那么target_classes中沒(méi)有值為0,我理解模型不對(duì)0類(lèi)做任何操作,是個(gè)無(wú)效值,模型只對(duì)1、2、3、4進(jìn)行l(wèi)oss計(jì)算,然4為背景會(huì)比較多,作者使用權(quán)重0.1避免其背景過(guò)度影響。forward return: 返回字典,包含{'pred_logits':[], # 為列表,格式為[b,100,num_classes+2]'pred_boxes':[], # 為列表,格式為[b,100,4]'aux_outputs'[{},...] # 為列表,元素為字典,每個(gè)字典為{'pred_logits':[],'pred_boxes':[]},格式與上相同}'''def __init__(self, d_model=512, nhead=8, num_queries=100, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,activation="relu", normalize_before=False, output_intermediate_dec=False, num_classes=1):super().__init__()self.num_queries = num_queriesself.query_embed = nn.Embedding(num_queries, d_model) # 與編碼輸出表達(dá)維度一致self.output_intermediate_dec = output_intermediate_decdecoder_layer = DecoderLayer(d_model, nhead, dim_feedforward,dropout, activation, normalize_before)decoder_norm = nn.LayerNorm(d_model)self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers