Skip to content

Commit dd02ad8

Browse files
committed
add UT for accuracy
1 parent c730122 commit dd02ad8

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

tests/test_models/test_losses/test_utils.py

+24
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,30 @@ def test_accuracy():
5252
pred = torch.Tensor([[0.2, 0.3, 0.6, 0.5], [0.1, 0.1, 0.2, 0.6],
5353
[0.9, 0.0, 0.0, 0.1], [0.4, 0.7, 0.1, 0.1],
5454
[0.0, 0.0, 0.99, 0]])
55+
# test for ignore_index
56+
true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
57+
accuracy = Accuracy(topk=1, ignore_index=None)
58+
acc = accuracy(pred, true_label)
59+
assert acc.item() == 100
60+
61+
# test for ignore_index with a wrong prediction of that index
62+
true_label = torch.Tensor([2, 3, 1, 1, 2]).long()
63+
accuracy = Accuracy(topk=1, ignore_index=1)
64+
acc = accuracy(pred, true_label)
65+
assert acc.item() == 100
66+
67+
# test for ignore_index 1 with a wrong prediction of other index
68+
true_label = torch.Tensor([2, 0, 0, 1, 2]).long()
69+
accuracy = Accuracy(topk=1, ignore_index=1)
70+
acc = accuracy(pred, true_label)
71+
assert acc.item() == 75
72+
73+
# test for ignore_index 4 with a wrong prediction of other index
74+
true_label = torch.Tensor([2, 0, 0, 1, 2]).long()
75+
accuracy = Accuracy(topk=1, ignore_index=4)
76+
acc = accuracy(pred, true_label)
77+
assert acc.item() == 80
78+
5579
# test for top1
5680
true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
5781
accuracy = Accuracy(topk=1)

0 commit comments

Comments
 (0)