@@ -52,6 +52,30 @@ def test_accuracy():
52
52
pred = torch .Tensor ([[0.2 , 0.3 , 0.6 , 0.5 ], [0.1 , 0.1 , 0.2 , 0.6 ],
53
53
[0.9 , 0.0 , 0.0 , 0.1 ], [0.4 , 0.7 , 0.1 , 0.1 ],
54
54
[0.0 , 0.0 , 0.99 , 0 ]])
55
+ # test for ignore_index
56
+ true_label = torch .Tensor ([2 , 3 , 0 , 1 , 2 ]).long ()
57
+ accuracy = Accuracy (topk = 1 , ignore_index = None )
58
+ acc = accuracy (pred , true_label )
59
+ assert acc .item () == 100
60
+
61
+ # test for ignore_index with a wrong prediction of that index
62
+ true_label = torch .Tensor ([2 , 3 , 1 , 1 , 2 ]).long ()
63
+ accuracy = Accuracy (topk = 1 , ignore_index = 1 )
64
+ acc = accuracy (pred , true_label )
65
+ assert acc .item () == 100
66
+
67
+ # test for ignore_index 1 with a wrong prediction of other index
68
+ true_label = torch .Tensor ([2 , 0 , 0 , 1 , 2 ]).long ()
69
+ accuracy = Accuracy (topk = 1 , ignore_index = 1 )
70
+ acc = accuracy (pred , true_label )
71
+ assert acc .item () == 75
72
+
73
+ # test for ignore_index 4 with a wrong prediction of other index
74
+ true_label = torch .Tensor ([2 , 0 , 0 , 1 , 2 ]).long ()
75
+ accuracy = Accuracy (topk = 1 , ignore_index = 4 )
76
+ acc = accuracy (pred , true_label )
77
+ assert acc .item () == 80
78
+
55
79
# test for top1
56
80
true_label = torch .Tensor ([2 , 3 , 0 , 1 , 2 ]).long ()
57
81
accuracy = Accuracy (topk = 1 )
0 commit comments