深圳網(wǎng)站托管公司谷歌seo新規(guī)則
代碼來(lái)源:GitHub - ultralytics/yolov5: YOLOv5 🚀 in PyTorch > ONNX > CoreML > TFLite
使用的代碼是YOLOv5 6.1版本
參考筆記:YOLOv5改進(jìn)系列(八) 更換NMS非極大抑制DIoU-NMS、CIoU-NMS、EIoU-NMS、GIoU-NMS 、SIoU-NMS、Soft-NMS_diou nms-CSDN博客
yolov5 極大值抑制 nms 代碼詳解 - 金色旭光 - 博客園
https://zhuanlan.zhihu.com/p/511151467
目錄
1.NMS源碼理解
2.更換DIou-NMS
1.NMS源碼理解
YOLOv5中NMS的實(shí)現(xiàn)代碼在utils/general.py的non_max_suppression
#對(duì)推理結(jié)果執(zhí)行NMS
def non_max_suppression(prediction,#模型的預(yù)測(cè)結(jié)果,shape=[batch_size,預(yù)測(cè)框數(shù)量,5+類別數(shù)量=中心x+中心y+w+h+conf+類別數(shù)量]conf_thres=0.25,#置信度閾值,用于NMS,置信度低于此閾值的預(yù)測(cè)框會(huì)被去除iou_thres=0.45,#IoU閾值,用于NMS,去除冗余的預(yù)測(cè)框classes=None,#只對(duì)某些類別作NMS,None則表示所有類別都作NMSagnostic=False,#是否作類別無(wú)關(guān)的NMS,即所有預(yù)測(cè)框不分類別一起作NMS處理,通常不開啟,都是各類別各自作NMSmulti_label=False,labels=(),max_det=300#每張圖片作NMS之后剩余的最多預(yù)測(cè)框數(shù)):'''函數(shù)返回值:返回值output是一個(gè)列表,存放每張圖片的檢測(cè)結(jié)果eg:output[0]即第一張圖片的檢測(cè)結(jié)果,outout[0] shape=[預(yù)測(cè)框數(shù)量,6=xyxy+conf+cls]'''#類別數(shù)量ncnc = prediction.shape[2] - 5#符合置信度閾值的預(yù)測(cè)框bool數(shù)組,xc shape=[batch_size,預(yù)測(cè)框數(shù)量]xc = prediction[..., 4] > conf_thres#檢查置信度、IoU閾值的有效性assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'#設(shè)置參數(shù)min_wh, max_wh = 2, 4096 #框的最小和最大寬高(像素)max_nms = 30000 #每張圖片作NMS之前的最多預(yù)測(cè)框數(shù)time_limit = 10.0 #處理圖片超過(guò)此時(shí)間則退出multi_label &= nc > 1 #沒(méi)啥用t = time.time() #記錄開始時(shí)間output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0] #初始化返回值output#遍歷每張圖像的預(yù)測(cè)結(jié)果for xi, x in enumerate(prediction):'''xi:當(dāng)前圖片在batch中的idx:存放當(dāng)前圖片的預(yù)測(cè)框信息,shape=[預(yù)測(cè)框數(shù)量,5+類別數(shù)量]'''#僅保留大于置信度閾值的預(yù)測(cè)框,x shape=[預(yù)測(cè)框數(shù)量,5+類別數(shù)量]x = x[xc[xi]]#如果存在真實(shí)標(biāo)簽,則將其合并到預(yù)測(cè)結(jié)果中(這段代碼不知道有什么用)if labels and len(labels[xi]):l = labels[xi] #真實(shí)標(biāo)簽v = torch.zeros((len(l), nc + 5), device=x.device) # 初始化與真實(shí)標(biāo)簽相同形狀的張量v[:, :4] = l[:, 1:5] # 提取真實(shí)框的坐標(biāo)v[:, 4] = 1.0 # 置信度設(shè)為1.0v[range(len(l)), l[:, 0].long() + 5] = 1.0 # 設(shè)置類別x = torch.cat((x, v), 0) # 合并預(yù)測(cè)框和真實(shí)框#如果預(yù)測(cè)框數(shù)量為0,則處理下一張圖片if not x.shape[0]:continue#重置類別概率=conf置信度*原始類別概率x[:, 5:] *= x[:, 4:5]#將坐標(biāo)值從(中心x, 中心y, w, h)轉(zhuǎn)換為(x1, y1, x2, y2),box shape=[預(yù)測(cè)框數(shù)量,4=xyxy]box = xywh2xyxy(x[:, :4])#通常multi_label為False,執(zhí)行else部分if multi_label:i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T # 確定哪些框符合多標(biāo)簽條件x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1) # 合并框信息else:#將最大類別概率作為檢測(cè)框的置信度存放于conf中,并將類別索引存放于j中conf, j = x[:, 5:].max(1, keepdim=True)x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]#合并xyxy+置信度+類別索引'''conf: shape=[預(yù)測(cè)框數(shù),1=置信度]j: shape=[預(yù)測(cè)框數(shù),1=類別索引]x: shape=[預(yù)測(cè)框數(shù),6=xyxy+置信度+類別索引]'''#利用class進(jìn)行過(guò)濾,篩選出指定的class,nms僅僅對(duì)指定的class進(jìn)行nms;#若classes為None,則所有類別都需要作nmsif classes is not None:x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]#僅保留指定類別的預(yù)測(cè)框#預(yù)測(cè)框數(shù)量nn = x.shape[0]#如果沒(méi)有預(yù)測(cè)框,則處理下一張圖片if not n:continueelif n > max_nms: #如果作NMS之前預(yù)測(cè)框的數(shù)量大于max_nms,則按置信度排序并保留前max_nms個(gè)框x = x[x[:, 4].argsort(descending=True)[:max_nms]]#Batches NMS#這行代碼是在多類別中應(yīng)用NMS#多類別NMS的處理策略是為了讓每個(gè)類都能獨(dú)立執(zhí)行NMS,所以給所有預(yù)測(cè)框的坐標(biāo)值添加一個(gè)偏移量#偏移量?jī)H取決于了類別的Id(也就是x[:, 5:6]),并且足夠大,使得不同類的預(yù)測(cè)框不會(huì)重疊c = x[:, 5:6] * (0 if agnostic else max_wh)#創(chuàng)建類別偏移c,即c=原類別索引*max_wh#給每個(gè)預(yù)測(cè)框的坐標(biāo)值加上類別偏移c,boxes shape=[預(yù)測(cè)框數(shù)量,4]boxes = x[:, :4] + c#獲取所有預(yù)測(cè)框的置信度,scores shape=[預(yù)測(cè)框數(shù)量,]scores = x[:, 4]#執(zhí)行NMS,i存放NMS之后的預(yù)測(cè)框id,shape=[NMS后的預(yù)測(cè)框數(shù),]i = torchvision.ops.nms(boxes, scores, iou_thres)#每張圖片NMS之后最多剩余max_det個(gè)預(yù)測(cè)框if i.shape[0] > max_det:i = i[:max_det]#將該圖片的檢測(cè)結(jié)果存儲(chǔ)到輸出output中output[xi] = x[i]#如果處理此圖片超出時(shí)間限制if (time.time() - t) > time_limit:#提示超時(shí)print(f'WARNING: NMS time limit {time_limit}s exceeded')break #超時(shí)退出#返回值output是一個(gè)列表,存放每張圖片的檢測(cè)結(jié)果#eg:output[0]即第一張圖片的檢測(cè)結(jié)果,outout[0] shape=[預(yù)測(cè)框數(shù)量,6=xyxy+conf+cls]return output #返回每張圖片的檢測(cè)結(jié)果
真正作NMS過(guò)濾的代碼是如下幾行代碼:
#Batches NMS
#這行代碼是在多類別中應(yīng)用NMS
#多類別NMS的處理策略是為了讓每個(gè)類都能獨(dú)立執(zhí)行NMS,所以給所有預(yù)測(cè)框的坐標(biāo)值添加一個(gè)偏移量
#偏移量?jī)H取決于了類別的Id(也就是x[:, 5:6]),并且足夠大,使得不同類的預(yù)測(cè)框不會(huì)重疊
c = x[:, 5:6] * (0 if agnostic else max_wh)#創(chuàng)建類別偏移c,即c=類別索引*max_whboxes = x[:, :4] + c#給每個(gè)預(yù)測(cè)框的坐標(biāo)值加上類別偏移c,boxes shape=[預(yù)測(cè)框數(shù)量,4]
scores = x[:, 4]#獲取所有預(yù)測(cè)框的置信度,scores shape=[預(yù)測(cè)框數(shù)量,]#執(zhí)行NMS,i存放NMS之后的預(yù)測(cè)框id,shape=[NMS后的預(yù)測(cè)框數(shù),]
i = torchvision.ops.nms(boxes, scores, iou_thres)
代碼重點(diǎn)是在 '+c’這里,c是偏移量
(1)agnostic參數(shù)為True,表示所有類別一起作NMS處理,偏移量c為0;
(2)agnostic參數(shù)為False,表示按照不同類別分別作NMS處理,c=類別索引*max_wh,對(duì)不同類別的預(yù)測(cè)框做一個(gè)偏移操作,防止不同類別的預(yù)測(cè)框互相影響
注意:源碼中是傳入?yún)?shù)boxes、scores、iou_thres調(diào)用torchvision.ops.nms實(shí)現(xiàn)NMS處理,下面是NMS的代碼實(shí)現(xiàn)??戳讼旅娴腘MS代碼可以發(fā)現(xiàn)上面說(shuō)agnostic為False時(shí)表示按照不同類別分別作NMS處理,但源碼這里應(yīng)該不是特別嚴(yán)格按不同類別作NMS(因?yàn)檫B類別的索引都沒(méi)有用到),添加偏移量c只是算是一種trick把(我個(gè)人的理解,如有錯(cuò)誤請(qǐng)指出)
代碼流程:?
- 將所有預(yù)測(cè)框按置信度從高到低排序,確保置信度高的預(yù)測(cè)框排在前面。order存放排序后的預(yù)測(cè)框索引
- 從置信度最高的框開始(即order[0]),計(jì)算它和剩下所有預(yù)測(cè)框的IoU。剩下的預(yù)測(cè)框中IoU低于設(shè)定的IoU閾值則保留下來(lái),高于IoU閾值的預(yù)測(cè)框則去除(即在order中刪除當(dāng)前預(yù)測(cè)框和IoU大于閾值的預(yù)測(cè)框索引)
- 重復(fù)步驟2,直到遍歷完order中的預(yù)測(cè)框,得到最終篩選出來(lái)的預(yù)測(cè)框
import torch
def NMS(boxes,scores, iou_thres):'''boxes:shape=[預(yù)測(cè)框數(shù)量,4=xyxy],存放預(yù)測(cè)框坐標(biāo)值scores:shape=[預(yù)測(cè)框數(shù)量,],存放預(yù)測(cè)框的置信度iou_thres: IoU閾值'''x1 = boxes[:,0]y1 = boxes[:,1]x2 = boxes[:,2]y2 = boxes[:,3]#計(jì)算所有預(yù)測(cè)框的面積areas = (x2-x1)*(y2-y1)#將預(yù)測(cè)框按置信度從高到低排序,order存放預(yù)測(cè)框的索引值_,order = scores.sort(0,descending=True)#keep保存NMS之后剩余的預(yù)測(cè)框索引keep = []while order.numel() > 0:#循環(huán)條件'''注意:當(dāng)order=tensor([2,0,1,3])時(shí),用order[0]可以正常取出第1個(gè)值2當(dāng)order=tensor([3])時(shí),用order[0]取出第1個(gè)值3會(huì)報(bào)錯(cuò),需要用order.item()取出'''i = order[0] if order.numel()>1 else order.item()#取出置信度最大的預(yù)測(cè)框索引keep.append(i)#將預(yù)測(cè)框索引加入keep中#如果只剩余1個(gè)預(yù)測(cè)框,則NMS執(zhí)行結(jié)束if order.numel() == 1:break#計(jì)算當(dāng)前預(yù)測(cè)框與剩下所有預(yù)測(cè)框的IoUxx1 = x1[order[1:]].clamp(min=x1[i])yy1 = y1[order[1:]].clamp(min=y1[i])xx2 = x2[order[1:]].clamp(max=x2[i])yy2 = y2[order[1:]].clamp(max=y2[i])w = (xx2-xx1).clamp(min=0)h = (yy2-yy1).clamp(min=0)inter = w*hovr = inter/(areas[i] + areas[order[1:]] - inter)#當(dāng)前預(yù)測(cè)框與剩下所有預(yù)測(cè)框的IoU值#篩選出IOU小于閾值的預(yù)測(cè)框索引, 過(guò)濾掉所有IOU大于閾值的預(yù)測(cè)框ids = (ovr<=iou_thres).nonzero().squeeze()#重置order數(shù)組,丟棄和當(dāng)前bbox的IOU大于閾值的預(yù)測(cè)框order = order[ids+1]#這里看代碼會(huì)有點(diǎn)懵,可以debug一下#torch.LongTensor(keep)將keep列表轉(zhuǎn)換為tensor,shape:[NMS后預(yù)測(cè)框數(shù)量,]return torch.LongTensor(keep)#實(shí)例
box = torch.tensor([[2, 3.1, 7, 5], [3, 4, 8, 4.8], [4, 4, 5.6, 7], [0.1, 0, 8, 1]])
score = torch.tensor([0.5, 0.3, 0.2, 0.4])
output =NMS(boxes=box, scores=score, iou_thres=0.3)
print(output)
2.更換DIou-NMS
YOLOv5源碼中使用的是IoU-NMS,這里可以作一下改進(jìn),將其替換為DIoU-NMS,因?yàn)?span style="color:#be191c;">DIoU考慮到的要素比IoU更多,應(yīng)用于NMS中,可以使得NMS后得到的結(jié)果更加合理
第1步:編寫DIoU_NMS函數(shù)
def DIoU_NMS(boxes,scores, iou_thres):'''boxes:shape=[預(yù)測(cè)框數(shù)量,4=xyxy],存放預(yù)測(cè)框坐標(biāo)值scores:shape=[預(yù)測(cè)框數(shù)量,],存放預(yù)測(cè)框的置信度iou_thres: DIoU閾值'''#將預(yù)測(cè)框按置信度從高到低排序,order存放預(yù)測(cè)框的索引值_,order = scores.sort(0,descending=True)#keep保存NMS之后剩余的預(yù)測(cè)框索引keep = []while order.numel() > 0:#循環(huán)條件'''注意:當(dāng)order=tensor([2,0,1,3])時(shí),用order[0]可以正常取出第1個(gè)值2當(dāng)order=tensor([3])時(shí),用order[0]取出第1個(gè)值3會(huì)報(bào)錯(cuò),需要用order.item()取出'''i = order[0] if order.numel()>1 else order.item()#取出置信度最大的預(yù)測(cè)框索引keep.append(i)#將預(yù)測(cè)框索引加入keep中#如果只剩余1個(gè)預(yù)測(cè)框,則NMS執(zhí)行結(jié)束if order.numel() == 1:break#計(jì)算當(dāng)前預(yù)測(cè)框與剩下所有預(yù)測(cè)框的DIoU#boxes[i,:]為當(dāng)前預(yù)測(cè)框的坐標(biāo)值,shape=[4,]#boxes[order[1:],:]為其他預(yù)測(cè)框的坐標(biāo)值,shape=[n,4]ovr = bbox_iou(boxes[i, :], boxes[order[1:], :], DIoU=True)#篩選出DIoU小于閾值的預(yù)測(cè)框索引, 過(guò)濾掉所有DIoU大于閾值的預(yù)測(cè)框ids = (ovr<=iou_thres).nonzero().squeeze()#重置order數(shù)組,丟棄和當(dāng)前bbox的DIoU大于閾值的預(yù)測(cè)框order = order[ids+1]#這里看代碼會(huì)有點(diǎn)懵,可以debug一下#torch.LongTensor(keep)將keep列表轉(zhuǎn)換為tensor,shape:[NMS后預(yù)測(cè)框數(shù)量,]return torch.LongTensor(keep)
這里計(jì)算DIoU的函數(shù)bbox_iou是直接引用了YOLOv5中的代碼,該函數(shù)的實(shí)現(xiàn)在utils/metrics.py中,此函數(shù)集成了IoU、GIoU、DIoU、CIoU的計(jì)算,其他XIoU_NMS的實(shí)現(xiàn)方法類似。PS:GIoU、DIoU、CIoU用于損失函數(shù)的情況比較多
最后將DIoU_NMS函數(shù)復(fù)制到utils/general.py
第2步:將IoU-NMS更換為DIoU-NMS
將utils/general.py下non_max_suppression函數(shù)的
i = torchvision.ops.nms(boxes, scores, iou_thres)
替換為
i = DIoU_NMS(boxes, scores, iou_thres)
這樣就將IoU-NMS更換為DIoU-NMS了,但是我用幾張圖片作測(cè)試,發(fā)現(xiàn)大多數(shù)時(shí)候使用IoU-NMS和DIoU-NMS的處理結(jié)果是完全一致的。如下:
處理結(jié)果
所以這種改進(jìn)可能實(shí)際意義不大
更換其他XIoU-NMS的方法是一樣的,這里不再贅述