做網(wǎng)站都要學(xué)什么互換鏈接的方法
v10出了就想看看它的loss設(shè)計有什么不同,看下來由于v8和v10的loss部分基本一致就放一起了。
v10的論文筆記,還沒看的可以看看,初步嘗試耗時確實有提升
好記性不如爛筆頭,還是得記錄一下,以免忘了,廢話結(jié)束!!!
代碼地址:GitHub - THU-MIG/yolov10: YOLOv10: Real-Time End-to-End Object Detection
論文地址:https://arxiv.org/pdf/2405.14458
????????YOLOv10/8從Anchor-Based(box anchor)換成了Anchor-Free(point anchor),檢測頭也換成了Decoupled Head,這一結(jié)構(gòu)具有提高收斂速度的好處,(在box anchor 方案中試過精度也有提升,但耗時增加了一些)但另一方面講,也會遇到分類與回歸不對齊的問題。在一些網(wǎng)絡(luò)中,會通過將feature map中的cell(point anchor中心點所編碼的box)與ground truth進行IOU計算以分配預(yù)測所用cell,但用來分類和回歸的最佳cell通常不一致。為了解決這一問題,引入了TAL(Task Alignment Learning)來負責(zé)正負樣本分配,使得分類和回歸任務(wù)之間具有較高的對齊一致性。
yolov10/v8中的loss主要分為2部分3個loss:
一、回歸分支的損失函數(shù):
1、DFL(Distribution Focal Loss),計算anchor point的中心點到左上角和右下角的偏移量
2、IoU Loss,定位損失,采用CIoU loss,只計算正樣本的定位損失
二、分類損失:
1、分類損失,采用BCE?loss,只計算正樣本的分類損失。
v8DetectionLoss
v8和v10的loss最大的不同在于,v10有兩個解耦頭,一個計算one2one head,一個計算one2many head,但是兩個head的loss函數(shù)一樣,就是超參數(shù)有一些不同
class v10DetectLoss:def __init__(self, model):self.one2many = v8DetectionLoss(model, tal_topk=10)self.one2one = v8DetectionLoss(model, tal_topk=1)def __call__(self, preds, batch):one2many = preds["one2many"]loss_one2many = self.one2many(one2many, batch)one2one = preds["one2one"]loss_one2one = self.one2one(one2one, batch)return loss_one2many[0] + loss_one2one[0], torch.cat((loss_one2many[1], loss_one2one[1]))
one2many的topk為10,one2one的topk為1。(這部分代碼和我寫輔助監(jiān)督的方式一樣)
class v8DetectionLoss:"""Criterion class for computing training losses."""def __init__(self, model, tal_topk=10): # model must be de-paralleled"""Initializes v8DetectionLoss with the model, defining model-related properties and BCE loss function."""device = next(model.parameters()).device # get model deviceh = model.args # hyperparametersm = model.model[-1] # Detect() moduleself.bce = nn.BCEWithLogitsLoss(reduction="none")self.hyp = hself.stride = m.stride # model stridesself.nc = m.nc # number of classesself.no = m.noself.reg_max = m.reg_maxself.device = deviceself.use_dfl = m.reg_max > 1self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=self.use_dfl).to(device)self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)def preprocess(self, targets, batch_size, scale_tensor):"""Preprocesses the target counts and matches with the input batch size to output a tensor."""if targets.shape[0] == 0:out = torch.zeros(batch_size, 0, 5, device=self.device)else:i = targets[:, 0] # image index_, counts = i.unique(return_counts=True)counts = counts.to(dtype=torch.int32)out = torch.zeros(batch_size, counts.max(), 5, device=self.device)for j in range(batch_size):matches = i == jn = matches.sum()if n:out[j, :n] = targets[matches, 1:]out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))return outdef bbox_decode(self, anchor_points, pred_dist):"""Decode predicted object bounding box coordinates from anchor points and distribution."""if self.use_dfl:b, a, c = pred_dist.shape # batch, anchors, channelspred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))# pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))# pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)return dist2bbox(pred_dist, anchor_points, xywh=False)def __call__(self, preds, batch):"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""loss = torch.zeros(3, device=self.device) # box, cls, dflfeats = preds[1] if isinstance(preds, tuple) else predspred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split((self.reg_max * 4, self.nc), 1)pred_scores = pred_scores.permute(0, 2, 1).contiguous()pred_distri = pred_distri.permute(0, 2, 1).contiguous()dtype = pred_scores.dtypebatch_size = pred_scores.shape[0]imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)# Targetstargets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxymask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)# Pboxespred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)_, target_bboxes, target_scores, fg_mask, _ = self.assigner(pred_scores.detach().sigmoid(),(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),anchor_points * stride_tensor,gt_labels,gt_bboxes,mask_gt,)target_scores_sum = max(target_scores.sum(), 1)# Cls loss# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL wayloss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE# Bbox lossif fg_mask.sum():target_bboxes /= stride_tensorloss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask)loss[0] *= self.hyp.box # box gainloss[1] *= self.hyp.cls # cls gainloss[2] *= self.hyp.dfl # dfl gainreturn loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
v8DetectionLoss中preprocess
該函數(shù)主要是用來處理gt,將同一batch中不同長度的gt(cls + boxes)做對齊,短的gt用全0補齊。假設(shè)一個batch為2,其中image1的gt是[4,5],image2的gt是[7,5],那么取該batch中最長的7創(chuàng)建一個batch為2的張量[2,7,5],batch1的前四維為gt信息,為全0。下面用一組實際數(shù)據(jù)為例:
?對應(yīng)的gt_labels,gt_bboxes,mask_gt(之后會提到)
v8DetectionLoss中bbox_decode
該函數(shù)主要是將每一個anchor point和預(yù)測的回歸參數(shù)通過dist2bbox做解碼,生成anchor box與gt計算iou
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):"""Transform distance(ltrb) to box(xywh or xyxy)."""assert(distance.shape[dim] == 4)lt, rb = distance.split([2, 2], dim)x1y1 = anchor_points - ltx2y2 = anchor_points + rbif xywh:c_xy = (x1y1 + x2y2) / 2wh = x2y2 - x1y1return torch.cat((c_xy, wh), dim) # xywh bboxreturn torch.cat((x1y1, x2y2), dim) # xyxy bbox
loss[1] bce loss對應(yīng)類別損失
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCEloss[0] 對應(yīng)iou loss
loss[2] 對應(yīng)dfl loss
loss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask)
bbox loss的實現(xiàn)如下:
class BboxLoss(nn.Module):"""Criterion class for computing training losses during training."""def __init__(self, reg_max, use_dfl=False):"""Initialize the BboxLoss module with regularization maximum and DFL settings."""super().__init__()self.reg_max = reg_maxself.use_dfl = use_dfldef forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):"""IoU loss."""weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum# DFL lossif self.use_dfl:target_ltrb = bbox2dist(anchor_points, target_bboxes, self.reg_max)loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weightloss_dfl = loss_dfl.sum() / target_scores_sumelse:loss_dfl = torch.tensor(0.0).to(pred_dist.device)return loss_iou, loss_dfl@staticmethoddef _df_loss(pred_dist, target):"""Return sum of left and right DFL losses.Distribution Focal Loss (DFL) proposed in Generalized Focal Losshttps://ieeexplore.ieee.org/document/9792391"""tl = target.long() # target lefttr = tl + 1 # target rightwl = tr - target # weight leftwr = 1 - wl # weight rightreturn (F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl+ F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr).mean(-1, keepdim=True)
TaskAlignedAssigner
這個我認為是整個loss設(shè)計中的重頭戲
因為整個loss中不像anchor base算法中需要計算前背景的obj loss,所以在TaskAlignedAssigner中需要確定哪些anchor屬于前景哪些anchor屬于背景,所以TaskAlignedAssigner得到target_labels, target_bboxes, target_scores的同時還需要得到前景的mask--fg_mask.bool()
class TaskAlignedAssigner(nn.Module):"""A task-aligned assigner for object detection.This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines bothclassification and localization information.Attributes:topk (int): The number of top candidates to consider.num_classes (int): The number of object classes.alpha (float): The alpha parameter for the classification component of the task-aligned metric.beta (float): The beta parameter for the localization component of the task-aligned metric.eps (float): A small value to prevent division by zero."""def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):"""Initialize a TaskAlignedAssigner object with customizable hyperparameters."""super().__init__()self.topk = topkself.num_classes = num_classesself.bg_idx = num_classesself.alpha = alphaself.beta = betaself.eps = eps@torch.no_grad()def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):"""Compute the task-aligned assignment. Reference code is available athttps://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py.Args:pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)anc_points (Tensor): shape(num_total_anchors, 2)gt_labels (Tensor): shape(bs, n_max_boxes, 1)gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)mask_gt (Tensor): shape(bs, n_max_boxes, 1)Returns:target_labels (Tensor): shape(bs, num_total_anchors)target_bboxes (Tensor): shape(bs, num_total_anchors, 4)target_scores (Tensor): shape(bs, num_total_anchors, num_classes)fg_mask (Tensor): shape(bs, num_total_anchors)target_gt_idx (Tensor): shape(bs, num_total_anchors)"""self.bs = pd_scores.shape[0]self.n_max_boxes = gt_bboxes.shape[1]if self.n_max_boxes == 0:device = gt_bboxes.devicereturn (torch.full_like(pd_scores[..., 0], self.bg_idx).to(device),torch.zeros_like(pd_bboxes).to(device),torch.zeros_like(pd_scores).to(device),torch.zeros_like(pd_scores[..., 0]).to(device),torch.zeros_like(pd_scores[..., 0]).to(device),)mask_pos, align_metric, overlaps = self.get_pos_mask(pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt)target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)# Assigned targettarget_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)# Normalizealign_metric *= mask_pospos_align_metrics = align_metric.amax(dim=-1, keepdim=True) # b, max_num_objpos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True) # b, max_num_objnorm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)target_scores = target_scores * norm_align_metricreturn target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
get_pos_mask
def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):"""Get in_gts mask, (b, max_num_obj, h*w)."""mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes) # 表示anchor中心是否位于對應(yīng)的ground truth bounding box內(nèi)# Get anchor_align metric, (b, max_num_obj, h*w)align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)# Get topk_metric mask, (b, max_num_obj, h*w)mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.expand(-1, -1, self.topk).bool())# Merge all mask to a final mask, (b, max_num_obj, h*w)mask_pos = mask_topk * mask_in_gts * mask_gt # 一個anchor point 負責(zé)一個gt object的預(yù)測return mask_pos, align_metric, overlaps
其中包含select_candidates_in_gts,get_box_metrics,select_topk_candidates,由這三個函數(shù)共同選擇正樣本anchor point的位置
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):"""Select the positive anchor center in gt.Args:xy_centers (Tensor): shape(h*w, 2)gt_bboxes (Tensor): shape(b, n_boxes, 4)Returns:(Tensor): shape(b, n_boxes, h*w)"""n_anchors = xy_centers.shape[0] # 表示anchor中心的數(shù)量bs, n_boxes, _ = gt_bboxes.shapelt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom# 通過計算每個anchor中心與每個gt_bboxes的左上角和右下角之間的差值,以及右下角和左上角之間的差值,并將結(jié)果拼接為形狀為 (bs, n_boxes, n_anchors, -1) 的張量。bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1) # return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype)# 計算 bbox_deltas 張量沿著第3個維度的最小值,形狀為 (b, n_boxes, h*w) 的布爾型張量,表示anchor中心是否位于對應(yīng)的ground truth bounding box內(nèi)(最小值都為正數(shù))return bbox_deltas.amin(3).gt_(eps)
實現(xiàn)思想很簡單就是,將anchor point的坐標與gt box的左上角坐標相減,得到一個差值,同時gt box右下角的坐標與anchor point的坐標相減,同樣得到一個差值,如果anchor point位于gt box內(nèi),那么這兩組差值的數(shù)值都應(yīng)該是大于0的數(shù)。
select_candidates_in_gts用于初步篩選位于gt box中的anchor points
如上圖,假設(shè)綠色的為gt box,紅色的anchor points就是通過?select_candidates_in_gts篩選出來用于預(yù)測該gt box表示的object的可能的anchor point,最后返回的是關(guān)于這些anchor point的位置mask
get_box_metrics
它具有如下參數(shù):
pd_scores:就是分類head輸出的結(jié)果,shape一般為[bs, 8400, 80](以coco數(shù)據(jù)集,輸入640*640為例)
pd_bboxes:回歸head輸出的結(jié)果,shape一般為[bs, 8400, 4]
gt_labels,gt_bboxes,mask_gt為gt所包含的信息,由于gt有做過數(shù)據(jù)用0補齊,mask_gt表示實際上非零的數(shù)據(jù)
mask_in_gts * mask_gt:表示實際上有g(shù)t標簽位置上的候選anchor的位置的mask
def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):"""Compute alignment metric given predicted and ground truth bounding boxes."""na = pd_bboxes.shape[-2]mask_gt = mask_gt.bool() # b, max_num_obj, h*woverlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device) # 存儲ioubbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device) # 存儲邊界框的分數(shù)ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) # torch.Size([2, 2, 7]) * 0 # 2, b, max_num_objind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes) # b, max_num_obj # 批次信息 為從0到 self.bs-1 的序列,將其展開為形狀為 (self.bs, self.n_max_boxes)ind[1] = gt_labels.squeeze(-1) # b, max_num_obj # 類別信息 為 gt_labels 的擠壓操作(squeeze(-1)),將其形狀變?yōu)?(self.bs, self.n_max_boxes)# Get the scores of each grid for each gt clsbbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt] # b, max_num_obj, h*w 根據(jù)實際邊界框的掩碼來獲取每個網(wǎng)格單元的預(yù)測分數(shù),并存儲在 bbox_scores 中# (b, max_num_obj, 1, 4), (b, 1, h*w, 4)pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]overlaps[mask_gt] = self.iou_calculation(gt_boxes, pd_boxes)# 對于滿足實際邊界框掩碼的每個位置,從 pd_bboxes 中獲取預(yù)測邊界框(pd_boxes)和實際邊界框(gt_boxes)計算iou,并將結(jié)果存儲在 overlaps 中align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta) # align_metric = bbox_scores^alpha * overlaps^beta 計算對齊度量,其中 alpha 和 beta 是超參數(shù)return align_metric, overlaps
通過iou計算預(yù)測框(解碼后的)與gt box之間的iou得到overlap;由于每個anchor point都有80個類別的預(yù)測得分,通過該處gt box對應(yīng)的類別標簽得到預(yù)測得分,得到bbox_scores,通過align_metric = bbox_scores^alpha * overlaps^beta 計算對齊度量。該度量同時考慮得分和框的重疊度。
select_topk_candidates
就是通過get_box_metrics中得到的align_metric來確定所有與gt有重疊的anchor中align_metric最高的前十(或前一)
def select_topk_candidates(self, metrics, largest=True, topk_mask=None):"""Select the top-k candidates based on the given metrics.Args:metrics (Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size,max_num_obj is the maximum number of objects, and h*w represents thetotal number of anchor points.largest (bool): If True, select the largest values; otherwise, select the smallest values.topk_mask (Tensor): An optional boolean tensor of shape (b, max_num_obj, topk), wheretopk is the number of top candidates to consider. If not provided,the top-k values are automatically computed based on the given metrics.Returns:(Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates."""# (b, max_num_obj, topk)# 使用 torch.topk 函數(shù)在給定的度量指標張量 metrics 的最后一個維度上選擇前 k 個最大。# 這將返回兩個張量:topk_metrics (形狀為 (b, max_num_obj, topk)) 包含了選定的度量指標,以及 topk_idxs (形狀為 (b, max_num_obj, topk)) 包含了相應(yīng)的索引topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)if topk_mask is None:topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)# (b, max_num_obj, topk)topk_idxs.masked_fill_(~topk_mask, 0) # 使用 topk_mask 將 topk_idxs 張量中未選中的索引位置(~topk_mask)用零進行填充# (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)for k in range(self.topk):# Expand topk_idxs for each value of k and add 1 at the specified positionscount_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones) # 使用 scatter_add_ 函數(shù)根據(jù)索引 topk_idxs[:, :, k : k + 1],將 ones 張量的值相加到 count_tensor 張量的相應(yīng)位置上# count_tensor.scatter_add_(-1, topk_idxs, torch.ones_like(topk_idxs, dtype=torch.int8, device=topk_idxs.device))# Filter invalid bboxescount_tensor.masked_fill_(count_tensor > 1, 0) # 將 count_tensor 中大于 1 的值用零進行填充,以過濾掉超過一個的邊界框return count_tensor.to(metrics.dtype)
比如上圖,由于這里只是作為示例,只表示其中一個特征圖上gt樣例,其他層的gt位置可能有更多的anchor point滿足?align_metric的條件被保留下來(不必太糾結(jié)這里是不是有10個),因為PAN輸出了三層特征圖,anchor對應(yīng)每層特征圖的中心,而實踐中將每層的anchor展平之后合并在一起得到8400的長度,而最終是在這8400中取前十的anchor,所以每層特征圖上保留的anchor可能數(shù)量不等。
此時被保留下來的anchor point的位置用1表示,其余位置為0,僅保留了指標前十的樣本作為正樣本
select_highest_overlaps
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):"""If an anchor box is assigned to multiple gts, the one with the highest IoU will be selected.Args:mask_pos (Tensor): shape(b, n_max_boxes, h*w)overlaps (Tensor): shape(b, n_max_boxes, h*w)Returns:target_gt_idx (Tensor): shape(b, h*w)fg_mask (Tensor): shape(b, h*w)mask_pos (Tensor): shape(b, n_max_boxes, h*w)"""# (b, n_max_boxes, h*w) -> (b, h*w)fg_mask = mask_pos.sum(-2) # 對 mask_pos 沿著倒數(shù)第二個維度求和,得到形狀為 (b, h*w) 的張量 fg_mask,表示每個網(wǎng)格單元上非背景anchor box的數(shù)量if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes# 創(chuàng)建一個布爾型張量 mask_multi_gts,形狀為 (b, n_max_boxes, h*w),用于指示哪些網(wǎng)格單元擁有多個ground truth bounding boxesmask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)# 獲取每個網(wǎng)格單元上具有最高IoU的ground truth bounding box的索引,并創(chuàng)建一個張量 is_max_overlaps,形狀與 mask_pos 相同,# 其中最高IoU的ground truth bounding box對應(yīng)的位置上為1,其余位置為0。max_overlaps_idx = overlaps.argmax(1) # (b, h*w)is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1) # max_overlaps_idx表示具有最大iou的索引,將具有最大iou的位置設(shè)置為1# 根據(jù) mask_multi_gts 來更新 mask_pos。對于存在多個ground truth bounding box的網(wǎng)格單元,將 is_max_overlaps 中# 對應(yīng)位置的值賦給 mask_pos,以保留具有最高IoU的ground truth bounding box的匹配情況mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w)fg_mask = mask_pos.sum(-2)# Find each grid serve which gt(index)target_gt_idx = mask_pos.argmax(-2) # (b, h*w) # 得到每個網(wǎng)格單元上具有最高IoU的ground truth bounding box的索引 target_gt_idxreturn target_gt_idx, fg_mask, mask_pos
對被分配了多個gt的anchor去重,得到前景的mask以及anchor point上具有最高IoU的ground truth bounding box的索引。假設(shè)上圖中紅色的anchor被分配給了兩個gt,通select_highest_overlaps后會保留gt與該anchor的iou最大的那個,并用該anchor來預(yù)測該gt,另一個gt則可能會被周圍的其他anchor所負責(zé)。此時也要更新mask_pos,畢竟重新對anchor做了處理。
因為每個anchor負責(zé)一個類別的檢測,mask_pos表示最終確定的anchor的mask,如下圖所示為其中一個batch中數(shù)據(jù)形式
該batch中8240,8241,8242為最終確定的anchor,其在紅色箭頭所示維度上對應(yīng)的索引為2,target_gt_idx在該batch上的最終表示為:
get_targets
有了以上的信息之后就獲取gt了
def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):"""Compute target labels, target bounding boxes, and target scores for the positive anchor points.Args:gt_labels (Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is thebatch size and max_num_obj is the maximum number of objects.gt_bboxes (Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).target_gt_idx (Tensor): Indices of the assigned ground truth objects for positiveanchor points, with shape (b, h*w), where h*w is the totalnumber of anchor points.fg_mask (Tensor): A boolean tensor of shape (b, h*w) indicating the positive(foreground) anchor points.Returns:(Tuple[Tensor, Tensor, Tensor]): A tuple containing the following tensors:- target_labels (Tensor): Shape (b, h*w), containing the target labels forpositive anchor points.- target_bboxes (Tensor): Shape (b, h*w, 4), containing the target bounding boxesfor positive anchor points.- target_scores (Tensor): Shape (b, h*w, num_classes), containing the target scoresfor positive anchor points, where num_classes is the numberof object classes."""# Assigned target labels, (b, 1)batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]# 使用 target_gt_idx 加上偏移量,得到形狀為 (b, h*w) 的 target_gt_idx 張量,表示正樣本anchor point的真實類別索引target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes # (b, h*w)# 使用 flatten 函數(shù)將 gt_labels 張量展平為形狀為 (b * max_num_obj) 的張量,然后使用 target_gt_idx 進行索引,# 得到形狀為 (b, h*w) 的 target_labels 張量,表示正樣本anchor point的目標標簽target_labels = gt_labels.long().flatten()[target_gt_idx] # (b, h*w)# Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4)target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx] # 表示正樣本anchor point的目標邊界框# Assigned target scorestarget_labels.clamp_(0)# 10x faster than F.one_hot()target_scores = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.num_classes),dtype=torch.int64,device=target_labels.device,) # (b, h*w, 80)target_scores.scatter_(2, target_labels.unsqueeze(-1), 1) # 使用 scatter_ 函數(shù)將 target_labels 的值進行 one-hot 編碼,將張量中每個位置上的目標類別置為 1fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80)target_scores = torch.where(fg_scores_mask > 0, target_scores, 0) # 根據(jù) fg_scores_mask 的值,將 target_scores 張量中的非正樣本位置(值小于等于 0)即背景類置為零return target_labels, target_bboxes, target_scores
該函數(shù)的要點基本都在代碼里注釋了
得到target后還要對target_scores做一些歸一化操作