@@ -573,6 +573,62 @@ def perplexity_loss(
573
573
return np .mean (perp_losses )
574
574
575
575
576
+ def smooth_l1_loss (y_true : np .ndarray , y_pred : np .ndarray , beta : float = 1.0 ) -> float :
577
+ """
578
+ Calculate the Smooth L1 Loss between y_true and y_pred.
579
+
580
+ The Smooth L1 Loss is less sensitive to outliers than the L2 Loss and is often used
581
+ in regression problems, such as object detection.
582
+
583
+ Smooth L1 Loss =
584
+ 0.5 * (x - y)^2 / beta, if |x - y| < beta
585
+ |x - y| - 0.5 * beta, otherwise
586
+
587
+ Reference:
588
+ https://pytorch.org/docs/stable/generated/torch.nn.SmoothL1Loss.html
589
+
590
+ Args:
591
+ y_true: Array of true values.
592
+ y_pred: Array of predicted values.
593
+ beta: Specifies the threshold at which to change between L1 and L2 loss.
594
+
595
+ Returns:
596
+ The calculated Smooth L1 Loss between y_true and y_pred.
597
+
598
+ Raises:
599
+ ValueError: If the length of the two arrays is not the same.
600
+
601
+ >>> y_true = np.array([3, 5, 2, 7])
602
+ >>> y_pred = np.array([2.9, 4.8, 2.1, 7.2])
603
+ >>> smooth_l1_loss(y_true, y_pred, 1.0)
604
+ 0.012500000000000022
605
+
606
+ >>> y_true = np.array([2, 4, 6])
607
+ >>> y_pred = np.array([1, 5, 7])
608
+ >>> smooth_l1_loss(y_true, y_pred, 1.0)
609
+ 0.5
610
+
611
+ >>> y_true = np.array([1, 3, 5, 7])
612
+ >>> y_pred = np.array([1, 3, 5, 7])
613
+ >>> smooth_l1_loss(y_true, y_pred, 1.0)
614
+ 0.0
615
+
616
+ >>> y_true = np.array([1, 3, 5])
617
+ >>> y_pred = np.array([1, 3, 5, 7])
618
+ >>> smooth_l1_loss(y_true, y_pred, 1.0)
619
+ Traceback (most recent call last):
620
+ ...
621
+ ValueError: The length of the two arrays should be the same.
622
+ """
623
+
624
+ if len (y_true ) != len (y_pred ):
625
+ raise ValueError ("The length of the two arrays should be the same." )
626
+
627
+ diff = np .abs (y_true - y_pred )
628
+ loss = np .where (diff < beta , 0.5 * diff ** 2 / beta , diff - 0.5 * beta )
629
+ return np .mean (loss )
630
+
631
+
576
632
if __name__ == "__main__" :
577
633
import doctest
578
634
0 commit comments