Skip to content

Commit

Permalink
Update yolox_loss.py (#1265)
Browse files Browse the repository at this point in the history
* Update yolox_loss.py

* Update yolox_loss.py

---------

Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com>
  • Loading branch information
eran-deci and BloodAxe authored Jul 20, 2023
1 parent c1587c5 commit f58fb95
Showing 1 changed file with 53 additions and 13 deletions.
66 changes: 53 additions & 13 deletions src/super_gradients/training/losses/yolox_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

import logging
from typing import List, Tuple, Union
from typing import List, Tuple, Union, Optional

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -112,9 +112,24 @@ class YoloXDetectionLoss(_Loss):
:param use_l1: Controls the L_l1 Coef as discussed above (default=False).
:param center_sampling_radius: Sampling radius used for center sampling when creating the fg mask (default=2.5).
:param iou_type: Iou loss type, one of ["iou","giou"] (deafult="iou").
:param iou_weight: Weight to apply to the iou loss term.
:param obj_weight: Weight to apply to the obj loss term.
:param cls_weight: Weight to apply to the cls loss term.
:param cls_pos_weight: Class weights for the cls loss. Passed on to torch.nn.BCEWithLogitsLoss
"""

def __init__(self, strides: list, num_classes: int, use_l1: bool = False, center_sampling_radius: float = 2.5, iou_type: str = "iou"):
def __init__(
self,
strides: List[int],
num_classes: int,
use_l1: bool = False,
center_sampling_radius: float = 2.5,
iou_type: str = "iou",
iou_weight: float = 5.0,
obj_weight: float = 1.0,
cls_weight: float = 1.0,
cls_pos_weight: Optional[torch.Tensor] = None,
):
super().__init__()
self.grids = [torch.zeros(1)] * len(strides)
self.strides = strides
Expand All @@ -123,9 +138,14 @@ def __init__(self, strides: list, num_classes: int, use_l1: bool = False, center
self.center_sampling_radius = center_sampling_radius
self.use_l1 = use_l1
self.l1_loss = nn.L1Loss(reduction="none")
self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none")
self.obj_bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none")
self.cls_bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none", pos_weight=cls_pos_weight)
self.iou_loss = IOUloss(reduction="none", loss_type=iou_type)

self.iou_weight = 5.0 if iou_weight is None else iou_weight
self.obj_weight = 1.0 if obj_weight is None else obj_weight
self.cls_weight = 1.0 if cls_weight is None else cls_weight

@property
def component_names(self) -> List[str]:
"""
Expand Down Expand Up @@ -283,15 +303,14 @@ def _compute_loss(self, predictions: List[torch.Tensor], targets: torch.Tensor)
num_fg = max(num_fg, 1)
# loss terms divided by the total number of foregrounds
loss_iou = self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets).sum() / num_fg
loss_obj = self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets).sum() / num_fg
loss_cls = self.bcewithlog_loss(cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets).sum() / num_fg
loss_obj = self.obj_bcewithlog_loss(obj_preds.view(-1, 1), obj_targets).sum() / num_fg
loss_cls = self.cls_bcewithlog_loss(cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets).sum() / num_fg
if self.use_l1:
loss_l1 = self.l1_loss(raw_outputs.view(-1, 4)[fg_masks], l1_targets).sum() / num_fg
else:
loss_l1 = 0.0

reg_weight = 5.0
loss = reg_weight * loss_iou + loss_obj + loss_cls + loss_l1
loss = self.iou_weight * loss_iou + self.obj_weight * loss_obj + self.cls_weight * loss_cls + loss_l1

return (
loss,
Expand Down Expand Up @@ -636,9 +655,31 @@ class YoloXFastDetectionLoss(YoloXDetectionLoss):
"""

def __init__(
self, strides, num_classes, use_l1=False, center_sampling_radius=2.5, iou_type="iou", dynamic_ks_bias=1.1, sync_num_fgs=False, obj_loss_fix=False
self,
strides,
num_classes,
use_l1=False,
center_sampling_radius=2.5,
iou_type="iou",
iou_weight: float = 5.0,
obj_weight: float = 1.0,
cls_weight: float = 1.0,
cls_pos_weight: Optional[torch.Tensor] = None,
dynamic_ks_bias=1.1,
sync_num_fgs=False,
obj_loss_fix=False,
):
super().__init__(strides=strides, num_classes=num_classes, use_l1=use_l1, center_sampling_radius=center_sampling_radius, iou_type=iou_type)
super().__init__(
strides=strides,
num_classes=num_classes,
use_l1=use_l1,
center_sampling_radius=center_sampling_radius,
iou_type=iou_type,
iou_weight=iou_weight,
obj_weight=obj_weight,
cls_weight=cls_weight,
cls_pos_weight=cls_pos_weight,
)

self.dynamic_ks_bias = dynamic_ks_bias
self.sync_num_fgs = sync_num_fgs
Expand Down Expand Up @@ -706,16 +747,15 @@ def _compute_loss(self, predictions: List[torch.Tensor], targets: torch.Tensor)
dist.all_reduce(num_fg, op=torch._C._distributed_c10d.ReduceOp.AVG)

loss_iou = self.iou_loss(bbox_preds[matched_img_ids, matched_fg_ids], reg_targets).sum() / num_fg
loss_obj = self.bcewithlog_loss(obj_preds.squeeze(-1), obj_targets).sum() / (total_num_anchors if self.obj_loss_fix else num_fg)
loss_cls = self.bcewithlog_loss(cls_preds[matched_img_ids, matched_fg_ids], cls_targets).sum() / num_fg
loss_obj = self.obj_bcewithlog_loss(obj_preds.squeeze(-1), obj_targets).sum() / (total_num_anchors if self.obj_loss_fix else num_fg)
loss_cls = self.cls_bcewithlog_loss(cls_preds[matched_img_ids, matched_fg_ids], cls_targets).sum() / num_fg

if self.use_l1 and num_gts > 0:
loss_l1 = self.l1_loss(raw_outputs[matched_img_ids, matched_fg_ids], l1_targets).sum() / num_fg
else:
loss_l1 = 0.0

reg_weight = 5.0
loss = reg_weight * loss_iou + loss_obj + loss_cls + loss_l1
loss = self.iou_weight * loss_iou + self.obj_weight * loss_obj + self.cls_weight * loss_cls + loss_l1

return (
loss,
Expand Down

0 comments on commit f58fb95

Please sign in to comment.