From 68e7e41dc8fd707c8299e2dac55d399d19e19e07 Mon Sep 17 00:00:00 2001 From: Yonghye Kwon Date: Sun, 26 Jan 2025 10:55:48 +0900 Subject: [PATCH 1/2] feat: Raise ValueError if alpha > 1 in sigmoid_focal_loss - Restrict alpha to (0, 1) or -1 for ignore, as per the original focal loss settings. - Add a check to raise ValueError when alpha exceeds 1. - This helps prevent invalid alpha values from silently causing unexpected behaviors. --- torchvision/ops/focal_loss.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torchvision/ops/focal_loss.py b/torchvision/ops/focal_loss.py index 08c282555fc..e1bd05d2a9b 100644 --- a/torchvision/ops/focal_loss.py +++ b/torchvision/ops/focal_loss.py @@ -33,6 +33,11 @@ def sigmoid_focal_loss( """ # Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py + if alpha > 1: + raise ValueError( + f"Invalid alpha value: {alpha}. alpha must be in the range (0,1) or -1 for ignore." + ) + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(sigmoid_focal_loss) p = torch.sigmoid(inputs) From e0e11c5975fec9c99fb49e3a4b48c5540fffa120 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 19 Feb 2025 15:04:10 +0000 Subject: [PATCH 2/2] Use brackets instead of parenthesis --- torchvision/ops/focal_loss.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/torchvision/ops/focal_loss.py b/torchvision/ops/focal_loss.py index e1bd05d2a9b..abe861e8b03 100644 --- a/torchvision/ops/focal_loss.py +++ b/torchvision/ops/focal_loss.py @@ -20,7 +20,7 @@ def sigmoid_focal_loss( targets (Tensor): A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). - alpha (float): Weighting factor in range (0,1) to balance + alpha (float): Weighting factor in range [0, 1] to balance positive vs negative examples or -1 for ignore. Default: ``0.25``. gamma (float): Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. Default: ``2``. @@ -33,10 +33,8 @@ def sigmoid_focal_loss( """ # Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py - if alpha > 1: - raise ValueError( - f"Invalid alpha value: {alpha}. alpha must be in the range (0,1) or -1 for ignore." - ) + if not (0 <= alpha <= 1) or alpha != -1: + raise ValueError(f"Invalid alpha value: {alpha}. alpha must be in the range [0,1] or -1 for ignore.") if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(sigmoid_focal_loss)