-
Notifications
You must be signed in to change notification settings - Fork 101
/
test_asl.py
90 lines (77 loc) · 4.57 KB
/
test_asl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# test_asl.py
# tests auto-generated by https://www.codium.ai/
# testing https://github.com/Alibaba-MIIL/ASL/blob/b9d01aff9f66ccddab6112e47f3ed0ceb59ad7f5/tests/test_asl.py#L6 class
import unittest
import torch
from src.loss_functions.losses import AsymmetricLoss
"""
Code Analysis for AsymmetricLoss() class:
- This class is a custom loss function called AsymmetricLoss, which is a subclass of the nn.Module class.
- It is used to calculate the loss between the input logits and the targets (multi-label binarized vector).
- The __init__ method initializes the parameters of the class, such as gamma_neg, gamma_pos, clip, eps, and disable_torch_grad_focal_loss.
- The forward method is used to calculate the loss between the input logits and the targets.
- The forward method first calculates the sigmoid of the input logits and then calculates the positive and negative logits.
- If the clip parameter is not None and greater than 0, the negative logits are clipped to a maximum of 1.
- The loss is then calculated using the positive and negative logits.
- If the gamma_neg and gamma_pos parameters are greater than 0, the loss is weighted using the one_sided_gamma and one_sided_w parameters.
- Finally, the loss is returned as a negative sum.
"""
"""
Test strategies:
- test_init(): tests that the parameters of the class are initialized correctly
- test_forward_positive_logits(): tests that the forward method correctly calculates the loss for positive logits
- test_forward_negative_logits(): tests that the forward method correctly calculates the loss for negative logits
- test_forward_clipped_logits(): tests that the forward method correctly calculates the loss for clipped logits
- test_forward_gamma_pos(): tests that the forward method correctly calculates the loss when gamma_pos is greater than 0
- test_forward_gamma_neg(): tests that the forward method correctly calculates the loss when gamma_neg is greater than 0
- test_forward_eps(): tests that the forward method correctly calculates the loss when eps is greater than 0
- test_forward_disable_torch_grad_focal_loss(): tests that the forward method correctly calculates the loss when disable_torch_grad_focal_loss is set to True
"""
class TestAsymmetricLoss(unittest.TestCase):
def setUp(self):
self.loss = AsymmetricLoss()
def test_init(self):
self.assertEqual(self.loss.gamma_neg, 4)
self.assertEqual(self.loss.gamma_pos, 1)
self.assertEqual(self.loss.clip, 0.05)
self.assertEqual(self.loss.eps, 1e-08)
self.assertTrue(self.loss.disable_torch_grad_focal_loss)
def test_forward_positive_logits(self):
x = torch.tensor([1., 2., 3.])
y = torch.tensor([1., 0., 1.])
expected_loss = -torch.log(torch.sigmoid(x)).sum()
self.assertEqual(self.loss(x, y), expected_loss)
def test_forward_negative_logits(self):
x = torch.tensor([-1., -2., -3.])
y = torch.tensor([1., 0., 1.])
expected_loss = -torch.log(1 - torch.sigmoid(x)).sum()
self.assertEqual(self.loss(x, y), expected_loss)
def test_forward_clipped_logits(self):
x = torch.tensor([-1., -2., -3.])
y = torch.tensor([1., 0., 1.])
expected_loss = -torch.log((1 - torch.sigmoid(x) + self.loss.clip).clamp(max=1)).sum()
self.assertEqual(self.loss(x, y), expected_loss)
def test_forward_gamma_pos(self):
x = torch.tensor([1., 2., 3.])
y = torch.tensor([1., 0., 1.])
self.loss.gamma_pos = 2
expected_loss = -torch.log(torch.sigmoid(x)) * torch.pow(1 - torch.sigmoid(x), 2).sum()
self.assertEqual(self.loss(x, y), expected_loss)
def test_forward_gamma_neg(self):
x = torch.tensor([-1., -2., -3.])
y = torch.tensor([1., 0., 1.])
self.loss.gamma_neg = 3
expected_loss = -torch.log((1 - torch.sigmoid(x) + self.loss.clip).clamp(max=1)) * torch.pow(1 - (torch.sigmoid(x) + self.loss.clip).clamp(max=1), 3).sum()
self.assertEqual(self.loss(x, y), expected_loss)
def test_forward_eps(self):
x = torch.tensor([-1., -2., -3.])
y = torch.tensor([1., 0., 1.])
self.loss.eps = 0.5
expected_loss = -torch.log((1 - torch.sigmoid(x) + self.loss.clip).clamp(min=0.5)).sum()
self.assertEqual(self.loss(x, y), expected_loss)
def test_forward_disable_torch_grad_focal_loss(self):
x = torch.tensor([-1., -2., -3.])
y = torch.tensor([1., 0., 1.])
self.loss.disable_torch_grad_focal_loss = False
expected_loss = -torch.log((1 - torch.sigmoid(x) + self.loss.clip).clamp(max=1)).sum()
self.assertEqual(self.loss(x, y), expected_loss)