diff --git a/fcos_core/layers/sigmoid_focal_loss.py b/fcos_core/layers/sigmoid_focal_loss.py index 3fe9bc2b..099cf9ff 100644 --- a/fcos_core/layers/sigmoid_focal_loss.py +++ b/fcos_core/layers/sigmoid_focal_loss.py @@ -49,7 +49,8 @@ def sigmoid_focal_loss_cpu(logits, targets, gamma, alpha): p = torch.sigmoid(logits) term1 = (1 - p) ** gamma * torch.log(p) term2 = p ** gamma * torch.log(1 - p) - return -(t == class_range).float() * term1 * alpha - ((t != class_range) * (t >= 0)).float() * term2 * (1 - alpha) +# return -(t == class_range).float() * term1 * alpha - ((t != class_range) * (t >= 0)).float() * term2 * (1 - alpha) + return -(t == class_range).float() * term1 * alpha - ((t >= 0) & (t != class_range)).float() * term2 * (1 - alpha) class SigmoidFocalLoss(nn.Module):