Skip to content

🐛 [Bug] Regression : Torch-TensorRT now fail to convert due to unsupported negative pad for torch.nn.ConstantPad2d #2079

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
fabricecarles opened this issue Jul 6, 2023 · 5 comments
Assignees
Labels
bug Something isn't working component: converters Issues re: Specific op converters No Activity

Comments

@fabricecarles
Copy link

Bug Description

This is a regression due to :

  • torch 2.0.1
  • torch-tensorrt 1.4.0
  • torchvision 0.15.2
  • tensorrt 8.6.1

Now we fail to convert every network with torch_tensorrt.compile() if negative pad is used inside ConstantPad layer while conversion work fine in earlier version

example: torch.nn.ConstantPad2d(-6, float(0.0)) will raised RuntimeError: [Error thrown at core/conversion/converters/impl/constant_pad.cpp:35] Expected left >= 0 to be true but got false Unsupported negative pad at index 0

To Reproduce

Steps to reproduce the behavior:

  1. instantiate a net with a layer having negative pad
  2. convert the net with trt_ts_module = torch_tensorrt.compile( model.to(device), inputs=[torch_tensorrt.Input((1, 1, args.image_size, args.image_size))], enabled_precisions={torch.half}, # Run with FP16 enabled_precisions, workspace_size = 1024, )
  3. get message RuntimeError: [Error thrown at core/conversion/converters/impl/constant_pad.cpp:35] Expected left >= 0 to be true but got false Unsupported negative pad at index 0

full stack trace

Traceback (most recent call last):
  File "/home/username/src/dev/deepWork/pytorch/main_test.py", line 446, in <module>
    trt_ts_module = torch_tensorrt.compile(
  File "/home/username/bin/anaconda3/envs/django/lib/python3.10/site-packages/torch_tensorrt/_compile.py", line 133, in compile
    return torch_tensorrt.ts.compile(
  File "/home/username/bin/anaconda3/envs/django/lib/python3.10/site-packages/torch_tensorrt/ts/_compiler.py", line 139, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: [Error thrown at core/conversion/converters/impl/constant_pad.cpp:35] Expected left >= 0 to be true but got false
Unsupported negative pad at index 0

Expected behavior

Conversion work fine in earlier version:
-torch 1.12.1
-torch-tensorrt 1.2.0
-torchvision 0.13.1
-tensorrt 8.0.3.4

Environment

  • Torch-TensorRT Version 1.4.0:
  • PyTorch Version 2.0.1:
  • CPU Architecture: x64_64 Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz
  • OS (e.g., Linux): Ubuntu 20.04
  • How you installed PyTorch : pip inside conda env):
  • Python version: 3.10.6
  • CUDA version: 11.8
  • GPU models and configuration: models on cuda:0 GeForce RTX 2070 Mobile
  • Any other relevant information: Regression bug, no bug in earlier version (torch 1.12.1 torch-tensorrt 1.2.0)

Additional context

Pure torchscript conversion with torch.jit.trace work fine for all versions !

@fabricecarles fabricecarles added the bug Something isn't working label Jul 6, 2023
@fabricecarles
Copy link
Author

By using torch_tensorrt.compile(..., ir="fx_ts_compat") as suggested in https://github.com/pytorch/TensorRT/releases compilation work fine but the compiled net is slower compared to earlier version (torch-tensorrt 1.2.0)

So I have make a benchmark of the same net compiled in the following environments:

environment 1:
-torch 1.12.1
-torch-tensorrt 1.2.0
-torchvision 0.13.1
-tensorrt 8.0.3.4
-cuda 11.7

environement 2:
-torch 2.0.1
-torch-tensorrt 1.4.0
-torchvision 0.15.2
-tensorrt 8.6.1
-cuda 11.8

Environement1

Inference 0 Load time 0.3170967102050781, inference time : 1.9924640655517578 total time : 2.3245811462402344[msec]
Inference 1 Load time 0.270843505859375, inference time : 3.055095672607422 total time : 3.338336944580078[msec]
Inference 2 Load time 0.22172927856445312, inference time : 1.9993782043457031 total time : 2.231121063232422[msec]
Inference 3 Load time 0.20837783813476562, inference time : 2.0284652709960938 total time : 2.2461414337158203[msec]
Inference 4 Load time 0.2071857452392578, inference time : 2.4094581604003906 total time : 2.6292800903320312[msec]
Inference 5 Load time 0.28252601623535156, inference time : 2.212047576904297 total time : 2.5081634521484375[msec]
Inference 6 Load time 0.5252361297607422, inference time : 1.6930103302001953 total time : 2.2292137145996094[msec]
Inference 7 Load time 0.2658367156982422, inference time : 2.4471282958984375 total time : 2.726316452026367[msec]
Inference 8 Load time 0.48828125, inference time : 1.6243457794189453 total time : 2.126455307006836[msec]
Inference 9 Load time 0.30994415283203125, inference time : 1.9650459289550781 total time : 2.284526824951172[msec]

Environement2

Inference 0 Load time 0.1780986785888672, inference time : 0.9043216705322266 total time : 1.0898113250732422[msec]
Inference 1 Load time 0.1976490020751953, inference time : 7.76219367980957 total time : 7.967948913574219[msec]
Inference 2 Load time 0.274658203125, inference time : 5.695819854736328 total time : 5.979061126708984[msec]
Inference 3 Load time 0.2827644348144531, inference time : 5.619049072265625 total time : 5.90968132019043[msec]
Inference 4 Load time 0.26869773864746094, inference time : 5.508184432983398 total time : 5.784511566162109[msec]
Inference 5 Load time 0.25081634521484375, inference time : 5.67936897277832 total time : 5.937337875366211[msec]
Inference 6 Load time 0.23293495178222656, inference time : 6.565570831298828 total time : 6.8073272705078125[msec]
Inference 7 Load time 0.339508056640625, inference time : 6.059169769287109 total time : 6.405830383300781[msec]
Inference 8 Load time 0.36025047302246094, inference time : 6.229639053344727 total time : 6.60395622253418[msec]
Inference 9 Load time 0.7829666137695312, inference time : 5.886077880859375 total time : 6.680011749267578[msec]

Please note that in environment1 I can't use ir="fx_ts_compat" (not yet available) and I set dtype=torch.float so that expected input for compiled trt_ts_module is float32 (i.e inputs=[torch_tensorrt.Input((1, 1, args.image_size, args.image_size), dtype=torch.float)])

But in case of environement2 we are forced to use dtype=torch.half (to prevent RuntimeError: Input type (float) and bias type (c10::Half) should be the same) and then expected input for compiled trt_ts_module is float16

In environment 1 I get many of this following warning (but inference result and inference time is good):

WARNING: [Torch-TensorRT TorchScript Conversion Context] -  - Subnormal FP16 values detected. 
WARNING: [Torch-TensorRT TorchScript Conversion Context] - If this is not the desired behavior, please modify the weights or retrain with regularization to reduce the magnitude of the weights.

In environment 2 I get also a warning :

07/06/2023-15:26:15] [TRT] [W] CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage and speed up TensorRT initialization. See "Lazy Loading" section of CUDA documentation https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#lazy-loading

and torch.jit.save(trt_ts_module, path) fail to save my compiled trt_ts_module:

Traceback (most recent call last):
  File "/home/fabrice/src/dev/deepWork/pytorch/main_test.py", line 496, in <module>
    torch.jit.save(trt_ts_module, path)
  File "/home/fabrice/bin/anaconda3/envs/django/lib/python3.10/site-packages/torch/jit/_serialization.py", line 80, in save
    m.save(f, _extra_files=_extra_files)
  File "/home/fabrice/bin/anaconda3/envs/django/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1614, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'GraphModule' object has no attribute 'save'

Do you have suggestion to improve inference time when upgrading to torch_tensort 1.4.0 ?
Do you think that issue come from the new implementation of #1970 ?

full stack trace for environment2 compilation :

convert net in TorchTrt ...
start compilation... 
use torch_tensorrt version  1.4.0
WARNING:torch_tensorrt.dynamo.fx_ts_compat.lower:For ir=fx_ts_compat backend only the following arguments are supported: {enabled_precisions, debug, workspace_size, device, disable_tf32, sparse_weights, min_block_size}
INFO:torch_tensorrt.fx.passes.pass_utils:== Log pass <function fuse_permute_matmul at 0x7f1e8ae85480> before/after graph to /tmp/tmpof4on06v, before/after are the same = True, time elapsed = 0:00:00.015313
INFO:torch_tensorrt.fx.passes.pass_utils:== Log pass <function fuse_permute_linear at 0x7f1e8ae85240> before/after graph to /tmp/tmphc_ycgxr, before/after are the same = True, time elapsed = 0:00:00.013979
INFO:torch_tensorrt.fx.passes.pass_utils:== Log pass <function fix_clamp_numerical_limits_to_fp16 at 0x7f1e8ae85a20> before/after graph to /tmp/tmp6p2a4ob9, before/after are the same = True, time elapsed = 0:00:00.013534

Supported node types in the model:
acc_ops.conv2d: ((), {'input': torch.float16, 'weight': torch.float16, 'bias': torch.float16})
acc_ops.relu: ((), {'input': torch.float16})
acc_ops.max_pool2d: ((), {'input': torch.float16})
acc_ops.conv_transpose2d: ((), {'input': torch.float16, 'weight': torch.float16})
acc_ops.pad: ((), {'input': torch.float16})

Unsupported node types in the model:
torch.argmax: ((torch.float16,), {})

Got 1 acc subgraphs and 1 non-acc subgraphs
INFO:torch_tensorrt.dynamo.fx_ts_compat.passes.lower_pass_manager_builder:Now lowering submodule _run_on_acc_0
INFO:torch_tensorrt.dynamo.fx_ts_compat.lower:split_name=_run_on_acc_0, input_specs=[InputTensorSpec(shape=(1, 1, 192, 192), dtype=torch.float16, device=device(type='cpu'), shape_ranges=[], has_batch_dim=True)]
INFO:torch_tensorrt.dynamo.fx_ts_compat.lower:Timing cache is used!
[07/06/2023-15:46:06] [TRT] [W] CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage and speed up TensorRT initialization. See "Lazy Loading" section of CUDA documentation https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#lazy-loading
INFO:torch_tensorrt.dynamo.fx_ts_compat.fx2trt:TRT INetwork construction elapsed time: 0:00:00.041412
INFO:torch_tensorrt.dynamo.fx_ts_compat.fx2trt:Build TRT engine elapsed time: 0:00:21.165574
INFO:torch_tensorrt.dynamo.fx_ts_compat.fx2trt:TRT Engine uses: 38939648 bytes of Memory
[07/06/2023-15:46:28] [TRT] [W] CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage and speed up TensorRT initialization. See "Lazy Loading" section of CUDA documentation https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#lazy-loading
INFO:torch_tensorrt.dynamo.fx_ts_compat.passes.lower_pass_manager_builder:Lowering submodule _run_on_acc_0 elapsed time 0:00:26.833156
compilation DONE !

@gs-olive gs-olive added the component: converters Issues re: Specific op converters label Jul 6, 2023
@gs-olive
Copy link
Collaborator

gs-olive commented Jul 6, 2023

Hi @fabricecarles - thanks for the detailed information. The error reported in TorchScript does seem to derive from these lines:

for (size_t i = 0UL; i < padding.size(); i += 2) {
auto left = padding[i];
TORCHTRT_CHECK(left >= 0, "Unsupported negative pad at index " << i);
auto right = padding[i + 1];
TORCHTRT_CHECK(right >= 0, "Unsupported negative pad at index " << i + 1);
auto idx = in_rank - ((i / 2) + 1);
start[idx] = -left;
total_padding[idx] = left + right;
}

Regarding the slow-down with torch_tensorrt.compile(..., ir="fx_ts_compat") - it may be due to the presence of that unsupported node, which can cause some segmentation in the graph. Could you also try torch_tensorrt.compile(..., ir="fx_ts_compat", is_aten=True)? On the point of saving these models, @peri044 may have some suggestions here.

@fabricecarles
Copy link
Author

using torch_tensorrt.compile(..., ir="fx_ts_compat", is_aten=True)
I get RuntimeError: Target aten.convolution.default does not support "transposed=True"

In the tail of my network definition self.upscore refer to a ConvTranspose2d layer which is follow by ConstantPad2d to crop result from (8, 16, 256, 256) to (8, 16, 192, 192) which is my target output size

self.upscore = nn.ConvTranspose2d(num_classes, num_classes,
                                  kernel_size=64, stride=32,
                                  output_padding=0, bias=False)
# 256 -32*2 = 192 (8, 16, 256, 256) -> (8, 16, 192, 192)
self.crop = torch.nn.ConstantPad2d(-32, float(0.0))

For debugging purpose I can test other options if you want
Do you think that torch_tensorrt 1.3.0 + pytorch 1.13 will work ?

full stack trace

convert net in TorchTrt ...
start compilation... 
use torch_tensorrt version  1.4.0
WARNING:torch_tensorrt.dynamo.fx_ts_compat.lower:For ir=fx_ts_compat backend only the following arguments are supported: {enabled_precisions, debug, workspace_size, device, disable_tf32, sparse_weights, min_block_size}
[2023-07-07 09:41:35,588] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing forward
[2023-07-07 09:41:36,450] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing forward (RETURN_VALUE)
[2023-07-07 09:41:36,453] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function dynamo_normalization_capturing_compiler
[2023-07-07 09:41:36,453] torch._dynamo.output_graph: [INFO] Step 2: done compiler function dynamo_normalization_capturing_compiler

graph():
    %arg0 : [#users=1] = placeholder[target=arg0]
    %_param_constant0 : [#users=1] = get_attr[target=_param_constant0]
    %_param_constant1 : [#users=1] = get_attr[target=_param_constant1]
    %convolution_default : [#users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%arg0, %_param_constant0, 
    ... # I cut the graph
    %constant_pad_nd_default : [#users=1] = call_function[target=torch.ops.aten.constant_pad_nd.default](args = (%convolution_default_12, [-32, -32, -32, -32], 0.0), kwargs = {})
    %argmax_default : [#users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%constant_pad_nd_default, 1), kwargs = {})
    return [argmax_default]
INFO:torch_tensorrt.fx.passes.pass_utils:== Log pass <function fuse_permute_matmul at 0x7f0b92d4d510> before/after graph to /tmp/tmp_uclsuqt, before/after are the same = True, time elapsed = 0:00:00.027320
INFO:torch_tensorrt.fx.passes.pass_utils:== Log pass <function fuse_permute_linear at 0x7f0b92d4d2d0> before/after graph to /tmp/tmp60ly0fv8, before/after are the same = True, time elapsed = 0:00:00.023211

Supported node types in the model:
torch.ops.aten.convolution.default: ((torch.float16, torch.float16, torch.float16), {})
torch.ops.aten.convolution.default: ((torch.float16, torch.float16), {})
torch.ops.aten.relu.default: ((torch.float16,), {})
torch.ops.aten.max_pool2d: ((torch.float16,), {})

Unsupported node types in the model:
torch.ops.aten.constant_pad_nd.default: ((torch.float16,), {})
torch.ops.aten.argmax.default: ((torch.float16,), {})

Got 1 acc subgraphs and 1 non-acc subgraphs
INFO:torch_tensorrt.dynamo.fx_ts_compat.passes.lower_pass_manager_builder:Now lowering submodule _run_on_acc_0
INFO:torch_tensorrt.dynamo.fx_ts_compat.lower:split_name=_run_on_acc_0, input_specs=[InputTensorSpec(shape=(1, 1, 192, 192), dtype=torch.float16, device=device(type='cpu'), shape_ranges=[], has_batch_dim=True)]
INFO:torch_tensorrt.dynamo.fx_ts_compat.lower:Timing cache is used!
[07/07/2023-09:41:43] [TRT] [W] CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage and speed up TensorRT initialization. See "Lazy Loading" section of CUDA documentation https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#lazy-loading
Traceback (most recent call last):
  File "/home/fabrice/src/dev/deepWork/pytorch/main_test.py", line 451, in <module>
    trt_ts_module = torch_tensorrt.compile(
  File "/home/fabrice/bin/anaconda3/envs/django/lib/python3.10/site-packages/torch_tensorrt/_compile.py", line 164, in compile
    return torch_tensorrt.dynamo.fx_ts_compat.compile(
  File "/home/fabrice/bin/anaconda3/envs/django/lib/python3.10/site-packages/torch_tensorrt/dynamo/fx_ts_compat/lower.py", line 128, in compile
    return lowerer(module, inputs)
  File "/home/fabrice/bin/anaconda3/envs/django/lib/python3.10/site-packages/torch_tensorrt/dynamo/fx_ts_compat/lower.py", line 364, in __call__
    return do_lower(module, inputs)
  File "/home/fabrice/bin/anaconda3/envs/django/lib/python3.10/site-packages/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py", line 148, in pass_with_validation
    return pass_(module, input, *args, **kwargs)
  File "/home/fabrice/bin/anaconda3/envs/django/lib/python3.10/site-packages/torch_tensorrt/dynamo/fx_ts_compat/lower.py", line 361, in do_lower
    lower_result = pm(module)
  File "/home/fabrice/bin/anaconda3/envs/django/lib/python3.10/site-packages/torch/fx/passes/pass_manager.py", line 238, in __call__
    out = _pass(out)
  File "/home/fabrice/bin/anaconda3/envs/django/lib/python3.10/site-packages/torch/fx/passes/pass_manager.py", line 238, in __call__
    out = _pass(out)
  File "/home/fabrice/bin/anaconda3/envs/django/lib/python3.10/site-packages/torch_tensorrt/dynamo/fx_ts_compat/passes/lower_pass_manager_builder.py", line 201, in lower_func
    lowered_module = self._lower_func(
  File "/home/fabrice/bin/anaconda3/envs/django/lib/python3.10/site-packages/torch_tensorrt/dynamo/fx_ts_compat/lower.py", line 224, in lower_pass
    interp_res: TRTInterpreterResult = interpreter(mod, input, module_name)
  File "/home/fabrice/bin/anaconda3/envs/django/lib/python3.10/site-packages/torch_tensorrt/dynamo/fx_ts_compat/lower.py", line 172, in __call__
    interp_result: TRTInterpreterResult = interpreter.run(
  File "/home/fabrice/bin/anaconda3/envs/django/lib/python3.10/site-packages/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py", line 206, in run
    super().run()
  File "/home/fabrice/bin/anaconda3/envs/django/lib/python3.10/site-packages/torch/fx/interpreter.py", line 136, in run
    self.env[node] = self.run_node(node)
  File "/home/fabrice/bin/anaconda3/envs/django/lib/python3.10/site-packages/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py", line 296, in run_node
    trt_node = super().run_node(n)
  File "/home/fabrice/bin/anaconda3/envs/django/lib/python3.10/site-packages/torch/fx/interpreter.py", line 177, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/home/fabrice/bin/anaconda3/envs/django/lib/python3.10/site-packages/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py", line 349, in call_function
    return converter(self.network, target, args, kwargs, self._cur_node_name)
  File "/home/fabrice/bin/anaconda3/envs/django/lib/python3.10/site-packages/torch_tensorrt/fx/converters/aten_ops_converters.py", line 128, in aten_ops_convolution
    raise RuntimeError(f"Target {target} does not support `transposed=True` ")
RuntimeError: Target aten.convolution.default does not support `transposed=True` 

Original traceback:
  File "/home/fabrice/src/dev/deepWork/pytorch/net.py", line 84, in forward
    x = self.upscore(x)

@fabricecarles
Copy link
Author

I can confirm that there is no issue with
environement3 :
-torch 1.13.1
-torch-tensorrt 1.3.0
-torchvision 0.13.1
-tensorrt 8.5.3.1
-cuda 11.7

and inference time is good

Inference 0 Load time 0.20074844360351562, inference time : 1.3387203216552734 total time : 1.54876708984375[msec]
Inference 1 Load time 0.2315044403076172, inference time : 2.9180049896240234 total time : 3.157377243041992[msec]
Inference 2 Load time 0.2434253692626953, inference time : 2.240419387817383 total time : 2.5107860565185547[msec]
Inference 3 Load time 0.33473968505859375, inference time : 1.8868446350097656 total time : 2.230405807495117[msec]
Inference 4 Load time 0.27489662170410156, inference time : 2.260923385620117 total time : 2.545595169067383[msec]
Inference 5 Load time 0.36787986755371094, inference time : 1.8558502197265625 total time : 2.234220504760742[msec]
Inference 6 Load time 0.35452842712402344, inference time : 1.9216537475585938 total time : 2.2852420806884766[msec]
Inference 7 Load time 0.24366378784179688, inference time : 2.1636486053466797 total time : 2.41851806640625[msec]
Inference 8 Load time 0.4229545593261719, inference time : 1.8286705017089844 total time : 2.2597312927246094[msec]
Inference 9 Load time 0.37169456481933594, inference time : 2.0971298217773438 total time : 2.475738525390625[msec]

@github-actions
Copy link

github-actions bot commented Oct 6, 2023

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working component: converters Issues re: Specific op converters No Activity
Projects
None yet
Development

No branches or pull requests

3 participants