Skip to content

Commit c991db8

Browse files
authored
[OPS, TEST] Add onnx test for batched_nms (#3483)
1 parent 668927e commit c991db8

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

test/test_onnx.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,16 +82,30 @@ def to_numpy(tensor):
8282
raise
8383

8484
def test_nms(self):
85-
boxes = torch.rand(5, 4)
86-
boxes[:, 2:] += torch.rand(5, 2)
87-
scores = torch.randn(5)
85+
num_boxes = 100
86+
boxes = torch.rand(num_boxes, 4)
87+
boxes[:, 2:] += boxes[:, :2]
88+
scores = torch.randn(num_boxes)
8889

8990
class Module(torch.nn.Module):
9091
def forward(self, boxes, scores):
9192
return ops.nms(boxes, scores, 0.5)
9293

9394
self.run_model(Module(), [(boxes, scores)])
9495

96+
def test_batched_nms(self):
97+
num_boxes = 100
98+
boxes = torch.rand(num_boxes, 4)
99+
boxes[:, 2:] += boxes[:, :2]
100+
scores = torch.randn(num_boxes)
101+
idxs = torch.randint(0, 5, size=(num_boxes,))
102+
103+
class Module(torch.nn.Module):
104+
def forward(self, boxes, scores, idxs):
105+
return ops.batched_nms(boxes, scores, idxs, 0.5)
106+
107+
self.run_model(Module(), [(boxes, scores, idxs)])
108+
95109
def test_clip_boxes_to_image(self):
96110
boxes = torch.randn(5, 4) * 500
97111
boxes[:, 2:] += boxes[:, :2]

0 commit comments

Comments
 (0)