Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Which models support batch_size>1? #880

Closed
avickars opened this issue Aug 8, 2022 · 12 comments
Closed

Which models support batch_size>1? #880

avickars opened this issue Aug 8, 2022 · 12 comments
Assignees
Labels
help wanted Extra attention is needed

Comments

@avickars
Copy link

avickars commented Aug 8, 2022

I can't seem to find a list of the models that support a batch_size>1. Just wondering if such a list exists somewhere?

@tpoisonooo
Copy link
Collaborator

tpoisonooo commented Aug 9, 2022

maybe you need #868

@avickars
Copy link
Author

avickars commented Aug 9, 2022

@tpoisonooo thanks but I just need to know which ones actually support it... Im currently using mask rcnn with swin backbone, and I can't tell if it is supposed to support it or not...(it doesn't appear to be as it always errors out when I test it with batch_size>1 with both onnx and tensorrt), but would like to confirm.

@tpoisonooo
Copy link
Collaborator

cc @RunningLeon

@RunningLeon
Copy link
Collaborator

RunningLeon commented Aug 10, 2022

@tpoisonooo thanks but I just need to know which ones actually support it... Im currently using mask rcnn with swin backbone, and I can't tell if it is supposed to support it or not...(it doesn't appear to be as it always errors out when I test it with batch_size>1 with both onnx and tensorrt), but would like to confirm.

@avickars Try to change deploy cfg to make it support multi-batch such as

_base_ = [
    '../_base_/base_instance-seg_static.py',
    '../../_base_/backends/tensorrt.py'
]

onnx_config = dict(input_shape = (512,
    512),
    dynamic_axes = {
        'input': {
            0: 'batch',
        },
        'dets': {
            0: 'batch',
        },
        'labels': {
            0: 'batch',
        },
        'masks': {
            0: 'batch',
            1: 'num_dets',
        },
    })

backend_config = dict(
    common_config = dict(max_workspace_size = 1 << 30),
    model_inputs = [
        dict(
            input_shapes = dict(
                input = dict(
                    min_shape = [
                        1,
                        3,
                        512,
                        512
                    ],
                    opt_shape = [
                        1,
                        3,
                        512,
                        512
                    ],
                    max_shape = [
                        2, # max 2 batch
                        3,
                        512,
                        512
                    ])))
    ])


@avickars
Copy link
Author

avickars commented Aug 10, 2022

@RunningLeon so for onnx...my understanding from the config is that it should support batch_size>1. However I am getting this error:

File "mmdeploy/tools/test.py", line 150, in <module> main() File "mmdeploy/tools/test.py", line 143, in main args.show_dir) File "/home/aidan/Programs/anaconda3/envs/mmdeploy2/lib/python3.7/site-packages/mmdeploy/codebase/base/task.py", line 139, in single_gpu_test out_dir, **kwargs) File "/home/aidan/Programs/anaconda3/envs/mmdeploy2/lib/python3.7/site-packages/mmdeploy/codebase/mmdet/deploy/mmdetection.py", line 142, in single_gpu_test outputs = single_gpu_test(model, data_loader, show, out_dir, **kwargs) File "/home/aidan/Programs/anaconda3/envs/mmdeploy2/lib/python3.7/site-packages/mmdet/apis/test.py", line 27, in single_gpu_test for i, data in enumerate(data_loader): File "/home/aidan/Programs/anaconda3/envs/mmdeploy2/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 530, in __next__ data = self._next_data() File "/home/aidan/Programs/anaconda3/envs/mmdeploy2/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1224, in _next_data return self._process_data(data) File "/home/aidan/Programs/anaconda3/envs/mmdeploy2/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1250, in _process_data data.reraise() File "/home/aidan/Programs/anaconda3/envs/mmdeploy2/lib/python3.7/site-packages/torch/_utils.py", line 457, in reraise raise exception RuntimeError: Caught RuntimeError in DataLoader worker process 0. Original Traceback (most recent call last): File "/home/aidan/Programs/anaconda3/envs/mmdeploy2/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop data = fetcher.fetch(index) File "/home/aidan/Programs/anaconda3/envs/mmdeploy2/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch return self.collate_fn(data) File "/home/aidan/Programs/anaconda3/envs/mmdeploy2/lib/python3.7/site-packages/mmcv/parallel/collate.py", line 81, in collate for key in batch[0] File "/home/aidan/Programs/anaconda3/envs/mmdeploy2/lib/python3.7/site-packages/mmcv/parallel/collate.py", line 81, in <dictcomp> for key in batch[0] File "/home/aidan/Programs/anaconda3/envs/mmdeploy2/lib/python3.7/site-packages/mmcv/parallel/collate.py", line 77, in collate return [collate(samples, samples_per_gpu) for samples in transposed] File "/home/aidan/Programs/anaconda3/envs/mmdeploy2/lib/python3.7/site-packages/mmcv/parallel/collate.py", line 77, in <listcomp> return [collate(samples, samples_per_gpu) for samples in transposed] File "/home/aidan/Programs/anaconda3/envs/mmdeploy2/lib/python3.7/site-packages/mmcv/parallel/collate.py", line 84, in collate return default_collate(batch) File "/home/aidan/Programs/anaconda3/envs/mmdeploy2/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 137, in default_collate out = elem.new(storage).resize_(len(batch), *list(elem.size())) RuntimeError: Trying to resize storage that is not resizable
When running:

python mmdeploy/tools/test.py \ mmdeploy/configs/mmdet/instance-seg/instance-seg_onnxruntime_dynamic.py \ mmdetection/configs/swin/mask_rcnn_swin-s-p4-w7_fpn_fp16_ms-crop-3x_coco.py \ --model mmdeploy/work_dir/end2end.onnx \ --metrics bbox segm \ --batch-size 2

@avickars
Copy link
Author

avickars commented Aug 10, 2022

FYI running this model: https://github.com/open-mmlab/mmdetection/blob/3b72b12fe9b14de906d1363982b9fba05e7d47c1/configs/swin/mask_rcnn_swin-s-p4-w7_fpn_fp16_ms-crop-3x_coco.py.

Also test.py works fine with "--batch-size 1". It errors out with any batch size > 1.

Will test your solution for tensorrt shortly.

@avickars
Copy link
Author

avickars commented Aug 10, 2022

Sadly it doesn't work...got this error:
(mmdeploy) aidan@aidan-ubuntu:~/Git-Repositories$ python mmdeploy/tools/test.py mmdeploy/configs/mmdet/instance-seg/instance-seg_tensorrt_static-800x1344.py mmdetection/configs/swin/mask_rcnn_swin-s-p4-w7_fpn_fp16_ms-crop-3x_coco.py --model mmdeploy/work_dir/end2end.engine --metrics bbox segm --device cuda:0 --batch-size 2 /home/aidan/Programs/anaconda3/envs/mmdeploy/lib/python3.8/site-packages/mmdet/datasets/utils.py:66: UserWarning: "ImageToTensor" pipeline is replaced by "DefaultFormatBundle" for batch inference. It is recommended to manually replace it in the test data pipeline in your config file. warnings.warn( loading annotations into memory... Done (t=0.02s) creating index... index created! 2022-08-10 16:14:17,496 - mmdeploy - INFO - Successfully loaded tensorrt plugins from /home/aidan/Programs/anaconda3/envs/mmdeploy/lib/python3.8/site-packages/mmdeploy/lib/libmmdeploy_tensorrt_ops.so 2022-08-10 16:14:17,496 - mmdeploy - INFO - Successfully loaded tensorrt plugins from /home/aidan/Programs/anaconda3/envs/mmdeploy/lib/python3.8/site-packages/mmdeploy/lib/libmmdeploy_tensorrt_ops.so [08/10/2022-16:14:18] [TRT] [W] TensorRT was linked against cuBLAS/cuBLAS LT 11.6.5 but loaded cuBLAS/cuBLAS LT 11.5.1 [08/10/2022-16:14:18] [TRT] [W] TensorRT was linked against cuBLAS/cuBLAS LT 11.6.5 but loaded cuBLAS/cuBLAS LT 11.5.1 [ ] 0/605, elapsed: 0s, ETA:[08/10/2022-16:14:21] [TRT] [E] 7: [shapeMachine.cpp::execute::565] Error Code 7: Internal Error (IShuffleLayer Reshape_197: reshaping failed for tensor: onnx::Reshape_658 reshape would change volume Instruction: RESHAPE_ZERO_IS_PLACEHOLDER{1 58 7 48 7 96} {1 203 336 96} ) Traceback (most recent call last): File "mmdeploy/tools/test.py", line 150, in <module> main() File "mmdeploy/tools/test.py", line 142, in main outputs = task_processor.single_gpu_test(model, data_loader, args.show, File "/home/aidan/Programs/anaconda3/envs/mmdeploy/lib/python3.8/site-packages/mmdeploy/codebase/base/task.py", line 138, in single_gpu_test return self.codebase_class.single_gpu_test(model, data_loader, show, File "/home/aidan/Programs/anaconda3/envs/mmdeploy/lib/python3.8/site-packages/mmdeploy/codebase/mmdet/deploy/mmdetection.py", line 132, in single_gpu_test outputs = single_gpu_test(model, data_loader, show, out_dir, **kwargs) File "/home/aidan/Programs/anaconda3/envs/mmdeploy/lib/python3.8/site-packages/mmdet/apis/test.py", line 29, in single_gpu_test result = model(return_loss=False, rescale=True, **data) File "/home/aidan/Programs/anaconda3/envs/mmdeploy/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, **kwargs) File "/home/aidan/Programs/anaconda3/envs/mmdeploy/lib/python3.8/site-packages/mmcv/parallel/data_parallel.py", line 50, in forward return super().forward(*inputs, **kwargs) File "/home/aidan/Programs/anaconda3/envs/mmdeploy/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 166, in forward return self.module(*inputs[0], **kwargs[0]) File "/home/aidan/Programs/anaconda3/envs/mmdeploy/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, **kwargs) File "/home/aidan/Programs/anaconda3/envs/mmdeploy/lib/python3.8/site-packages/mmdeploy/codebase/mmdet/deploy/object_detection_model.py", line 198, in forward outputs = self.forward_test(input_img, img_metas, *args, **kwargs) File "/home/aidan/Programs/anaconda3/envs/mmdeploy/lib/python3.8/site-packages/mmdeploy/codebase/mmdet/deploy/object_detection_model.py", line 274, in forward_test outputs = self.wrapper({self.input_name: imgs}) File "/home/aidan/Programs/anaconda3/envs/mmdeploy/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, **kwargs) File "/home/aidan/Programs/anaconda3/envs/mmdeploy/lib/python3.8/site-packages/mmdeploy/backend/tensorrt/wrapper.py", line 158, in forward shape = tuple(self.context.get_binding_shape(idx)) ValueError: __len__() should return >= 0

when running the engine file with batch_size 2

@RunningLeon
Copy link
Collaborator

@avickars It fails on my side with only batch dim is dynamic. Maybe Swin backbone does not support multi-batch, @AllentDan.

@AllentDan
Copy link
Member

Hi, @avickars. It seems mrcnn and swin models cannot be exported to TensorRT with batch size >1. Once we set the op_shape to [2, 3, 800, 1344], TensorRT will raise error as follows:

[E] 4: [shapeCompiler.cpp::evaluateShapeChecks::911] Error Code 4: Internal Error (kOPT values for profile 0 violate shape constraints: reshape would change volume. IShuffleLayer Reshape_806: reshaping failed for tensor: 1752)

@avickars
Copy link
Author

@AllentDan do you know if bachsize>1 should work for onnx? Just would like to confirm thats all.

@AllentDan
Copy link
Member

@AllentDan do you know if bachsize>1 should work for onnx? Just would like to confirm thats all.

In my tseting with ONNXRuntime backend, it failed with batch size >1.

@tpoisonooo
Copy link
Collaborator

If you have any question, please reopen this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

5 participants