From 43eee77fdd98a378fb5398dd3e719e8c32a4a8a9 Mon Sep 17 00:00:00 2001 From: VVsssssk <88368822+VVsssssk@users.noreply.github.com> Date: Sat, 9 Oct 2021 10:45:03 +0800 Subject: [PATCH] [Unittest]Add ops test (#108) * add ncnn test exporter in test_ops.py * add ncnn test exporter in utils.py * add onnxruntime and tensorrt ops test * fix blank line * fix comment add nms ops test * remove nms test * add test sample add dockerstring * remove nms test * fix grid_sample add type hind * fix problem * fix dockerstring * add nms batch_nms multi_level_roi_align * add test data * fix problem * rm pkl file dependent * rm file * add docstring * remove multi_level_dependce * Update test_ops.py --- mmdeploy/utils/test.py | 4 +- tests/test_ops/test_ops.py | 230 +++++++++++++++++++++++++++++++++++++ tests/test_ops/utils.py | 52 +++++---- 3 files changed, 262 insertions(+), 24 deletions(-) diff --git a/mmdeploy/utils/test.py b/mmdeploy/utils/test.py index 2a8b753724..4d1edc1ddf 100644 --- a/mmdeploy/utils/test.py +++ b/mmdeploy/utils/test.py @@ -57,8 +57,8 @@ def forward(self, *args, **kwargs): return func(*args, **kwargs) -def assert_allclose(actual: List[Union[torch.Tensor, np.ndarray]], - expected: List[Union[torch.Tensor, np.ndarray]], +def assert_allclose(expected: List[Union[torch.Tensor, np.ndarray]], + actual: List[Union[torch.Tensor, np.ndarray]], tolerate_small_mismatch: bool = False): """Determine whether all actual values are closed with the expected values. diff --git a/tests/test_ops/test_ops.py b/tests/test_ops/test_ops.py index 838dc0f0de..9b8f3b232a 100644 --- a/tests/test_ops/test_ops.py +++ b/tests/test_ops/test_ops.py @@ -195,6 +195,236 @@ def test_instance_norm(backend, save_dir=save_dir) +@pytest.mark.parametrize('backend', [TEST_TENSORRT]) +@pytest.mark.parametrize( + 'iou_threshold, score_threshold,max_output_boxes_per_class', + [(0.6, 0.2, 3), (0.4, 0, 4)]) +def test_nms(backend, + iou_threshold, + score_threshold, + max_output_boxes_per_class, + input_list=None, + save_dir=None): + backend.check_env() + + if input_list is None: + boxes = torch.tensor([[[291.1746, 316.2263, 343.5029, 347.7312], + [288.4846, 315.0447, 343.7267, 346.5630], + [288.5307, 318.1989, 341.6425, 349.7222], + [918.9102, 83.7463, 933.3920, 164.9041], + [895.5786, 78.2361, 907.8049, 172.0883], + [292.5816, 316.5563, 340.3462, 352.9989], + [609.4592, 83.5447, 631.2532, 144.0749], + [917.7308, 85.5870, 933.2839, 168.4530], + [895.5138, 79.3596, 908.2865, 171.0418], + [291.4747, 318.6987, 347.1208, 349.5754]]]) + scores = torch.tensor([[[ + 0.1790, 0.6798, 0.1875, 0.2358, 0.7146, 0.8325, 0.0760, 0.5339, + 0.0892, 0.9320 + ], + [ + 0.0823, 0.1769, 0.4965, 0.8658, 0.0522, + 0.5388, 0.0811, 0.2998, 0.6442, 0.4870 + ], + [ + 0.4649, 0.2912, 0.1123, 0.9007, 0.1675, + 0.7509, 0.7241, 0.8785, 0.7636, 0.3442 + ], + [ + 0.7217, 0.0010, 0.0481, 0.9762, 0.7573, + 0.2049, 0.4464, 0.3340, 0.2695, 0.6959 + ], + [ + 0.5959, 0.8608, 0.0664, 0.0709, 0.0453, + 0.8023, 0.5779, 0.0068, 0.8733, 0.7946 + ]]]) + else: + boxes = torch.tensor(input_list[0], dtype=torch.float32) + scores = torch.tensor(input_list[1], dtype=torch.float32) + + cfg = dict() + register_extra_symbolics(cfg=cfg, backend='tensorrt', opset=11) + from mmdeploy.mmcv.ops import DummyONNXNMSop + + def wrapped_function(torch_bboxes, torch_scores): + return DummyONNXNMSop.apply(torch_bboxes, torch_scores, + max_output_boxes_per_class, iou_threshold, + score_threshold) + + wrapped_model = WrapFunction(wrapped_function).eval() + + backend.run_and_validate( + wrapped_model, [boxes, scores], + 'nms', + input_names=['boxes', 'scores'], + output_names=['nms_bboxes'], + save_dir=save_dir) + + +@pytest.mark.parametrize('backend', [TEST_TENSORRT]) +@pytest.mark.parametrize('num_classes,pre_topk,after_topk,iou_threshold,' + 'score_threshold,background_label_id', + [(5, 6, 3, 0.7, 0.1, -1)]) +def test_batched_nms(backend, + num_classes, + pre_topk, + after_topk, + iou_threshold, + score_threshold, + background_label_id, + input_list=None, + save_dir=None): + backend.check_env() + + if input_list is None: + nms_boxes = torch.tensor([[[291.1746, 316.2263, 343.5029, 347.7312], + [288.4846, 315.0447, 343.7267, 346.5630], + [288.5307, 318.1989, 341.6425, 349.7222], + [918.9102, 83.7463, 933.3920, 164.9041], + [895.5786, 78.2361, 907.8049, 172.0883], + [292.5816, 316.5563, 340.3462, 352.9989], + [609.4592, 83.5447, 631.2532, 144.0749], + [917.7308, 85.5870, 933.2839, 168.4530], + [895.5138, 79.3596, 908.2865, 171.0418], + [291.4747, 318.6987, 347.1208, 349.5754]]]) + scores = torch.tensor([[[0.9577, 0.9745, 0.3030, 0.6589, 0.2742], + [0.1618, 0.7963, 0.5124, 0.6964, 0.6850], + [0.8425, 0.4843, 0.9489, 0.8068, 0.7340], + [0.7337, 0.4340, 0.9923, 0.0704, 0.4506], + [0.3090, 0.5606, 0.6939, 0.3764, 0.6920], + [0.0044, 0.7986, 0.2221, 0.2782, 0.4378], + [0.7293, 0.2735, 0.8381, 0.0264, 0.6278], + [0.7144, 0.1066, 0.4125, 0.4041, 0.8819], + [0.4963, 0.7891, 0.6908, 0.1499, 0.5584], + [0.4385, 0.6035, 0.0508, 0.0662, 0.5938]]]) + else: + nms_boxes = torch.tensor(input_list[0], dtype=torch.float32) + scores = torch.tensor(input_list[1], dtype=torch.float32) + + from mmdeploy.mmdet.core.post_processing.bbox_nms import _multiclass_nms + expected_result = _multiclass_nms( + nms_boxes, + scores, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pre_top_k=pre_topk + 1, + keep_top_k=after_topk + 1) + + boxes = nms_boxes.unsqueeze(2).tile(num_classes, 1) + + from mmdeploy.mmcv.ops.nms import TRTBatchedNMSop + batched_nms = TRTBatchedNMSop.apply + + def wrapped_function(boxes, scores): + return batched_nms(boxes, scores, num_classes, pre_topk, after_topk, + iou_threshold, score_threshold, background_label_id) + + wrapped_model = WrapFunction(wrapped_function) + + backend.run_and_validate( + wrapped_model, [boxes, scores], + 'batched_nms', + input_names=['boxes', 'scores'], + output_names=['batched_nms_bboxes', 'inds'], + expected_result=expected_result, + save_dir=save_dir) + + +@pytest.mark.parametrize('backend', [TEST_TENSORRT]) +@pytest.mark.parametrize('out_size, sampling_ratio,roi_scale_factor,' + ' finest_scale,featmap_strides, aligned', + [(tuple([2, 2]), 2, 1.0, 2, list([2.0, 4.0]), 1)]) +def test_multi_level_roi_align(backend, + out_size, + sampling_ratio, + roi_scale_factor, + finest_scale, + featmap_strides, + aligned, + input_list=None, + save_dir=None): + backend.check_env() + + if input_list is None: + input = [ + torch.tensor([[[[0.3014, 0.7334, 0.6502, 0.1689], + [0.3031, 0.3735, 0.6032, 0.1644], + [0.0393, 0.4415, 0.3858, 0.2657], + [0.5766, 0.0211, 0.6384, 0.0016]], + [[0.0811, 0.6255, 0.0247, 0.3471], + [0.1390, 0.9298, 0.6178, 0.6636], + [0.2243, 0.2024, 0.2366, 0.3660], + [0.1050, 0.2301, 0.7489, 0.7506]], + [[0.3868, 0.1706, 0.2390, 0.8494], + [0.2643, 0.9347, 0.0412, 0.5790], + [0.6202, 0.0682, 0.0390, 0.5296], + [0.5383, 0.1221, 0.6344, 0.1514]]]]), + torch.tensor([[[[0.1939, 0.9983, 0.4031, 0.2712], + [0.7929, 0.1504, 0.0946, 0.5030], + [0.1421, 0.7908, 0.9595, 0.4198], + [0.6880, 0.4722, 0.9896, 0.2266]], + [[0.0778, 0.4232, 0.0736, 0.0168], + [0.2887, 0.8461, 0.1140, 0.9582], + [0.5169, 0.4924, 0.8275, 0.5530], + [0.8961, 0.7466, 0.5976, 0.3760]], + [[0.1542, 0.5028, 0.8412, 0.6617], + [0.3751, 0.2798, 0.3835, 0.8640], + [0.5821, 0.6588, 0.1324, 0.7619], + [0.9178, 0.7282, 0.0291, 0.3028]]]]) + ] + rois = torch.tensor([[0., 0., 0., 4., 4.]]) + expected_result = torch.tensor([[[[0.1939, 0.3950], [0.3437, 0.4543]], + [[0.0778, 0.1641], [0.1305, 0.2301]], + [[0.1542, 0.2413], [0.2094, + 0.2688]]]]) + else: + input = input_list[0] + rois = input_list[1] + expected_result = input_list[2] + input_name = [('input_' + str(i)) for i in range(len(featmap_strides))] + input_name.insert(0, 'rois') + + inputs = [ + onnx.helper.make_tensor_value_info( + input_name[i + 1], onnx.TensorProto.FLOAT, shape=input[i].shape) + for i in range(len(input_name) - 1) + ] + inputs.append( + onnx.helper.make_tensor_value_info( + 'rois', onnx.TensorProto.FLOAT, shape=rois.shape)) + outputs = [ + onnx.helper.make_tensor_value_info( + 'bbox_feats', onnx.TensorProto.FLOAT, shape=expected_result.shape) + ] + node = onnx.helper.make_node( + 'MMCVMultiLevelRoiAlign', + input_name, ['bbox_feats'], + 'MMCVMultiLevelRoiAlign_0', + None, + 'mmlab', + aligned=aligned, + featmap_strides=featmap_strides, + finest_scale=finest_scale, + output_height=out_size[0], + output_width=out_size[1], + roi_scale_factor=roi_scale_factor, + sampling_ratio=sampling_ratio) + graph = onnx.helper.make_graph([node], 'torch-jit-export', inputs, outputs) + onnx_model = onnx.helper.make_model( + graph, producer_name='pytorch', producer_version='1.8') + onnx_model.opset_import[0].version = 11 + onnx_model.opset_import.append( + onnx.onnx_ml_pb2.OperatorSetIdProto(domain='mmlab', version=1)) + + backend.run_and_validate( + onnx_model, [rois, *input], + 'multi_level_roi_align', + input_names=input_name, + output_names=['bbox_feats'], + expected_result=expected_result, + save_dir=save_dir) + + @pytest.mark.parametrize('backend', [TEST_NCNN]) @pytest.mark.parametrize('k', [1, 3, 5]) @pytest.mark.parametrize('dim', [1, 2, 3]) diff --git a/tests/test_ops/utils.py b/tests/test_ops/utils.py index cf73ac185d..d9b50f82fb 100644 --- a/tests/test_ops/utils.py +++ b/tests/test_ops/utils.py @@ -32,6 +32,7 @@ def run_and_validate(self, dynamic_axes=None, output_names=None, input_names=None, + expected_result=None, save_dir=None): if save_dir is None: @@ -51,10 +52,11 @@ def run_and_validate(self, do_constant_folding=do_constant_folding, dynamic_axes=dynamic_axes, opset_version=11) - - with torch.no_grad(): - model_outputs = model(*input_list) - + if expected_result is None: + with torch.no_grad(): + model_outputs = model(*input_list) + else: + model_outputs = expected_result if isinstance(model_outputs, torch.Tensor): model_outputs = [model_outputs] else: @@ -90,6 +92,7 @@ def run_and_validate(self, dynamic_axes=None, output_names=None, input_names=None, + expected_result=None, save_dir=None): if save_dir is None: onnx_file_path = tempfile.NamedTemporaryFile().name @@ -97,18 +100,21 @@ def run_and_validate(self, else: onnx_file_path = os.path.join(save_dir, model_name + '.onnx') trt_file_path = os.path.join(save_dir, model_name + '.trt') - with torch.no_grad(): - torch.onnx.export( - model, - tuple(input_list), - onnx_file_path, - export_params=True, - keep_initializers_as_inputs=True, - input_names=input_names, - output_names=output_names, - do_constant_folding=do_constant_folding, - dynamic_axes=dynamic_axes, - opset_version=11) + if isinstance(model, onnx.onnx_ml_pb2.ModelProto): + onnx.save(model, onnx_file_path) + else: + with torch.no_grad(): + torch.onnx.export( + model, + tuple(input_list), + onnx_file_path, + export_params=True, + keep_initializers_as_inputs=True, + input_names=input_names, + output_names=output_names, + do_constant_folding=do_constant_folding, + dynamic_axes=dynamic_axes, + opset_version=11) deploy_cfg = mmcv.Config( dict( @@ -135,18 +141,20 @@ def run_and_validate(self, 0, deploy_cfg=deploy_cfg, onnx_model=onnx_model) - - with torch.no_grad(): - model_outputs = model(*input_list) - + if expected_result is None and not isinstance( + model, onnx.onnx_ml_pb2.ModelProto): + with torch.no_grad(): + model_outputs = model(*input_list) + else: + model_outputs = expected_result if isinstance(model_outputs, torch.Tensor): model_outputs = [model_outputs.cpu().float()] else: model_outputs = [data.cpu().float() for data in model_outputs] trt_model = trt_apis.TRTWrapper(trt_file_path) - inputs_list = [data.cuda() for data in input_list] - trt_outputs = trt_model(dict(zip(input_names, inputs_list))) + input_list = [data.cuda() for data in input_list] + trt_outputs = trt_model(dict(zip(input_names, input_list))) trt_outputs = [ trt_outputs[name].cpu().float() for name in output_names ]