diff --git a/src/super_gradients/training/losses/yolox_loss.py b/src/super_gradients/training/losses/yolox_loss.py index 460d957544..f5d6696f1f 100644 --- a/src/super_gradients/training/losses/yolox_loss.py +++ b/src/super_gradients/training/losses/yolox_loss.py @@ -4,7 +4,7 @@ """ import logging -from typing import List, Tuple, Union +from typing import List, Tuple, Union, Optional import torch import torch.distributed as dist @@ -112,9 +112,24 @@ class YoloXDetectionLoss(_Loss): :param use_l1: Controls the L_l1 Coef as discussed above (default=False). :param center_sampling_radius: Sampling radius used for center sampling when creating the fg mask (default=2.5). :param iou_type: Iou loss type, one of ["iou","giou"] (deafult="iou"). + :param iou_weight: Weight to apply to the iou loss term. + :param obj_weight: Weight to apply to the obj loss term. + :param cls_weight: Weight to apply to the cls loss term. + :param cls_pos_weight: Class weights for the cls loss. Passed on to torch.nn.BCEWithLogitsLoss """ - def __init__(self, strides: list, num_classes: int, use_l1: bool = False, center_sampling_radius: float = 2.5, iou_type: str = "iou"): + def __init__( + self, + strides: List[int], + num_classes: int, + use_l1: bool = False, + center_sampling_radius: float = 2.5, + iou_type: str = "iou", + iou_weight: float = 5.0, + obj_weight: float = 1.0, + cls_weight: float = 1.0, + cls_pos_weight: Optional[torch.Tensor] = None, + ): super().__init__() self.grids = [torch.zeros(1)] * len(strides) self.strides = strides @@ -123,9 +138,14 @@ def __init__(self, strides: list, num_classes: int, use_l1: bool = False, center self.center_sampling_radius = center_sampling_radius self.use_l1 = use_l1 self.l1_loss = nn.L1Loss(reduction="none") - self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none") + self.obj_bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none") + self.cls_bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none", pos_weight=cls_pos_weight) self.iou_loss = IOUloss(reduction="none", loss_type=iou_type) + self.iou_weight = 5.0 if iou_weight is None else iou_weight + self.obj_weight = 1.0 if obj_weight is None else obj_weight + self.cls_weight = 1.0 if cls_weight is None else cls_weight + @property def component_names(self) -> List[str]: """ @@ -283,15 +303,14 @@ def _compute_loss(self, predictions: List[torch.Tensor], targets: torch.Tensor) num_fg = max(num_fg, 1) # loss terms divided by the total number of foregrounds loss_iou = self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets).sum() / num_fg - loss_obj = self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets).sum() / num_fg - loss_cls = self.bcewithlog_loss(cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets).sum() / num_fg + loss_obj = self.obj_bcewithlog_loss(obj_preds.view(-1, 1), obj_targets).sum() / num_fg + loss_cls = self.cls_bcewithlog_loss(cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets).sum() / num_fg if self.use_l1: loss_l1 = self.l1_loss(raw_outputs.view(-1, 4)[fg_masks], l1_targets).sum() / num_fg else: loss_l1 = 0.0 - reg_weight = 5.0 - loss = reg_weight * loss_iou + loss_obj + loss_cls + loss_l1 + loss = self.iou_weight * loss_iou + self.obj_weight * loss_obj + self.cls_weight * loss_cls + loss_l1 return ( loss, @@ -636,9 +655,31 @@ class YoloXFastDetectionLoss(YoloXDetectionLoss): """ def __init__( - self, strides, num_classes, use_l1=False, center_sampling_radius=2.5, iou_type="iou", dynamic_ks_bias=1.1, sync_num_fgs=False, obj_loss_fix=False + self, + strides, + num_classes, + use_l1=False, + center_sampling_radius=2.5, + iou_type="iou", + iou_weight: float = 5.0, + obj_weight: float = 1.0, + cls_weight: float = 1.0, + cls_pos_weight: Optional[torch.Tensor] = None, + dynamic_ks_bias=1.1, + sync_num_fgs=False, + obj_loss_fix=False, ): - super().__init__(strides=strides, num_classes=num_classes, use_l1=use_l1, center_sampling_radius=center_sampling_radius, iou_type=iou_type) + super().__init__( + strides=strides, + num_classes=num_classes, + use_l1=use_l1, + center_sampling_radius=center_sampling_radius, + iou_type=iou_type, + iou_weight=iou_weight, + obj_weight=obj_weight, + cls_weight=cls_weight, + cls_pos_weight=cls_pos_weight, + ) self.dynamic_ks_bias = dynamic_ks_bias self.sync_num_fgs = sync_num_fgs @@ -706,16 +747,15 @@ def _compute_loss(self, predictions: List[torch.Tensor], targets: torch.Tensor) dist.all_reduce(num_fg, op=torch._C._distributed_c10d.ReduceOp.AVG) loss_iou = self.iou_loss(bbox_preds[matched_img_ids, matched_fg_ids], reg_targets).sum() / num_fg - loss_obj = self.bcewithlog_loss(obj_preds.squeeze(-1), obj_targets).sum() / (total_num_anchors if self.obj_loss_fix else num_fg) - loss_cls = self.bcewithlog_loss(cls_preds[matched_img_ids, matched_fg_ids], cls_targets).sum() / num_fg + loss_obj = self.obj_bcewithlog_loss(obj_preds.squeeze(-1), obj_targets).sum() / (total_num_anchors if self.obj_loss_fix else num_fg) + loss_cls = self.cls_bcewithlog_loss(cls_preds[matched_img_ids, matched_fg_ids], cls_targets).sum() / num_fg if self.use_l1 and num_gts > 0: loss_l1 = self.l1_loss(raw_outputs[matched_img_ids, matched_fg_ids], l1_targets).sum() / num_fg else: loss_l1 = 0.0 - reg_weight = 5.0 - loss = reg_weight * loss_iou + loss_obj + loss_cls + loss_l1 + loss = self.iou_weight * loss_iou + self.obj_weight * loss_obj + self.cls_weight * loss_cls + loss_l1 return ( loss,