@@ -82,16 +82,30 @@ def to_numpy(tensor):
82
82
raise
83
83
84
84
def test_nms (self ):
85
- boxes = torch .rand (5 , 4 )
86
- boxes [:, 2 :] += torch .rand (5 , 2 )
87
- scores = torch .randn (5 )
85
+ num_boxes = 100
86
+ boxes = torch .rand (num_boxes , 4 )
87
+ boxes [:, 2 :] += boxes [:, :2 ]
88
+ scores = torch .randn (num_boxes )
88
89
89
90
class Module (torch .nn .Module ):
90
91
def forward (self , boxes , scores ):
91
92
return ops .nms (boxes , scores , 0.5 )
92
93
93
94
self .run_model (Module (), [(boxes , scores )])
94
95
96
+ def test_batched_nms (self ):
97
+ num_boxes = 100
98
+ boxes = torch .rand (num_boxes , 4 )
99
+ boxes [:, 2 :] += boxes [:, :2 ]
100
+ scores = torch .randn (num_boxes )
101
+ idxs = torch .randint (0 , 5 , size = (num_boxes ,))
102
+
103
+ class Module (torch .nn .Module ):
104
+ def forward (self , boxes , scores , idxs ):
105
+ return ops .batched_nms (boxes , scores , idxs , 0.5 )
106
+
107
+ self .run_model (Module (), [(boxes , scores , idxs )])
108
+
95
109
def test_clip_boxes_to_image (self ):
96
110
boxes = torch .randn (5 , 4 ) * 500
97
111
boxes [:, 2 :] += boxes [:, :2 ]
0 commit comments