Skip to content

🐛 [Bug] aten.mean.dim converter throws error despite being supported #1742

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
gs-olive opened this issue Mar 16, 2023 · 1 comment · Fixed by #1810
Closed

🐛 [Bug] aten.mean.dim converter throws error despite being supported #1742

gs-olive opened this issue Mar 16, 2023 · 1 comment · Fixed by #1810
Assignees
Labels
bug Something isn't working component: fx

Comments

@gs-olive
Copy link
Collaborator

gs-olive commented Mar 16, 2023

Bug Description

The aten.mean.dim converter throws the following error when compiling the displayed model:

##### MODEL:
class Sample(torch.nn.Module):
    def __init__(self):
        super(Sample, self).__init__()

    def forward(self, x):
        return torch.mean(x, dim=1)

##### ERROR:
  File "~/TensorRT/py/torch_tensorrt/fx/fx2trt.py", line 328, in call_function
    return converter(self.network, target, args, kwargs, self._cur_node_name)
  File "~/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py", line 57, in aten_ops_adaptive_avg_poolnd
    raise RuntimeError(f"We do not support {target} has dim={args[1]}")
RuntimeError: We do not support aten.mean.dim has dim=[1]

To Reproduce

Steps to reproduce the behavior:

  1. Initialize model as above: Sample().eval().cuda()
  2. Initialize oneinput tensors, for example: torch.zeros((5, 5), dtype=torch.float, device="cuda:0")
  3. Compile the model using FX: torch_tensorrt.fx.compile(model, [input_], min_acc_module_size=1, is_aten=True)

Expected behavior

Model should compile via the FX path or list the operator as unsupported.

Environment

  • Transformers: 4.26.1
  • Torch-TensorRT Version (e.g. 1.0.0): fce0a01
  • PyTorch Version (e.g. 1.0): 2.1.0.dev20230313+cu117
  • CPU Architecture: Intel Xeon CPU
  • OS: Ubuntu 20.04
  • How you installed PyTorch: pip
  • Build command you used: python setup.py develop
  • Are you using local sources or building from archives: local
  • Python version: 3.8.13
  • CUDA version: 11.7

Additional Context

Solving this issue will also resolve the error encountered in #1740

@gs-olive gs-olive added bug Something isn't working component: fx labels Mar 16, 2023
@gs-olive
Copy link
Collaborator Author

Addressed by #1657, but seeing the following error message when using that PR:

Error Message
Got 1 acc subgraphs and 0 non-acc subgraphs
Traceback (most recent call last):
  File "case_test.py", line 22, in <module>
    main()
  File "case_test.py", line 19, in main
    trt_model = torch_tensorrt.fx.compile(model, [input_], min_acc_module_size=1, is_aten=True)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 86, in compile
    return lowerer(module, input)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 316, in __call__
    return do_lower(module, inputs)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py", line 118, in pass_with_validation
    processed_module = pass_(module, input, *args, **kwargs)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 313, in do_lower
    lower_result = pm(module)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/passes/pass_manager.py", line 246, in __call__
    out = _pass(out)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/passes/pass_manager.py", line 246, in __call__
    out = _pass(out)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py", line 202, in lower_func
    lowered_module = self._lower_func(
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 178, in lower_pass
    interp_res: TRTInterpreterResult = interpreter(mod, input, module_name)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 130, in __call__
    interp_result: TRTInterpreterResult = interpreter.run(
  File "~/TensorRT/py/torch_tensorrt/fx/fx2trt.py", line 204, in run
    super().run()
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/interpreter.py", line 137, in run
    self.env[node] = self.run_node(node)
  File "~/TensorRT/py/torch_tensorrt/fx/fx2trt.py", line 275, in run_node
    trt_node = super().run_node(n)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/interpreter.py", line 179, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "~/TensorRT/py/torch_tensorrt/fx/fx2trt.py", line 328, in call_function
    return converter(self.network, target, args, kwargs, self._cur_node_name)
  File "~/TensorRT/py/torch_tensorrt/fx/converters/fx2trt_ops_converter.py", line 81, in convert_avg_pool
    input_trt = add_missing_trt_tensors(network, [input])[0]
  File "~/TensorRT/py/torch_tensorrt/fx/converters/fx2trt_ops_converter_utils.py", line 19, in add_missing_trt_tensors
    dtype = check_torch_dtype(*tensors)
  File "~/TensorRT/py/torch_tensorrt/fx/converters/fx2trt_ops_converter_utils.py", line 88, in check_torch_dtype
    assert (
AssertionError: While executing %mean_dim : [#users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%arg0, [1]), kwargs = {_itensor_to_tensor_meta: {<tensorrt.tensorrt.ITensor object at 0x7f48305059f0>: None}})
Original traceback:
  File "case_test.py", line 13, in forward
    return torch.mean(x, dim=1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working component: fx
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants