Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Coverage] AttributeError: 'Infinity' object has no attribute '_mpf_' #3306

Closed
Tracked by #3179
chohk88 opened this issue Nov 28, 2024 · 1 comment
Closed
Tracked by #3179
Assignees

Comments

@chohk88
Copy link
Collaborator

chohk88 commented Nov 28, 2024

Description:

An error occurs when attempting to optimize the BasicUNet model using torch.compile with the Torch-TensorRT backend for JIT optimization. Below is the detailed error log captured during the process:

2024-11-25 12:06:49.595 | INFO     | MainProcess | /usr/local/lib/python3.12/dist-packages/model_navigator/commands/execution_context.py:218 - Command: /usr/bin/python torch/reproduce_correctness-torchtensorrtcompilerunner.py --batch_dim '0' --results_path '/tmp/tmpebl_vh8z' --runner_name 'TorchTensorRTCompile' --input_metadata '{"metadata": [{"name": "input__0", "shape": (-1, 3, 128, 128, 128), "dtype": "float32"}], "pytree_metadata": {"metadata": ("input__0", {}), "tensor_type": "torch"}, "is_legacy": False}' --output_metadata '{"metadata": [{"name": "output__0", "shape": (-1, 1, 128, 128, 128), "dtype": "float16"}], "pytree_metadata": {"metadata": "output__0", "tensor_type": "torch"}, "is_legacy": False}' --runner_config '{"autocast": True, "inference_mode": True, "device": None, "autocast_dtype": None, "custom_args": None}'
2024-11-25 12:06:49.881 | INFO     | MainProcess | /usr/local/lib/python3.12/dist-packages/model_navigator/runners/torch.py:363 - Using torch.compile with config: fullgraph=False, dynamic=None, backend=torch_tensorrt, mode=None, options={'truncate_long_and_double': True, 'enabled_precisions': {torch.float16, torch.float32}, 'timing_cache_path': PosixPath('/root/.cache/model_navigator/global_nvidia_h100_pcie_cuda_12_6_trt_10_6_0.cache')}
2024-11-25 12:06:50.984 | WARNING  | MainProcess | /usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/_compiler.py:354 - Node sum_dim_int_list of op type call_function does not have metadata. This could sometimes lead to undefined behavior.
2024-11-25 12:06:50.985 | WARNING  | MainProcess | /usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/_compiler.py:363 - Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments.
2024-11-25 12:12:28.029 | WARNING  | MainProcess | /usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/_compiler.py:341 - 3 supported operations detected in subgraph containing 3 computational nodes. Skipping this subgraph, since min_block_size was detected to be 5
2024-11-25 12:12:28.263 | WARNING  | MainProcess | /usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/_compiler.py:354 - Node sum_dim_int_list of op type call_function does not have metadata. This could sometimes lead to undefined behavior.
2024-11-25 12:12:28.263 | WARNING  | MainProcess | /usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/_compiler.py:363 - Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments.
2024-11-25 12:12:35.858 | WARNING  | MainProcess | /usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/_compiler.py:341 - 3 supported operations detected in subgraph containing 3 computational nodes. Skipping this subgraph, since min_block_size was detected to be 5
2024-11-25 12:12:36.395 | WARNING  | MainProcess | /usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/_compiler.py:354 - Node sum_dim_int_list of op type call_function does not have metadata. This could sometimes lead to undefined behavior.
2024-11-25 12:12:36.395 | WARNING  | MainProcess | /usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/_compiler.py:363 - Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments.
2024-11-25 12:12:36.537 | INFO     | MainProcess | /usr/local/lib/python3.12/dist-packages/model_navigator/pipelines/pipeline.py:128 - backend='torch_tensorrt' raised:
AttributeError: 'Infinity' object has no attribute '_mpf_'
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
2024-11-25 12:12:36.541 | WARNING  | MainProcess | /usr/local/lib/python3.12/dist-packages/model_navigator/pipelines/pipeline.py:131 - Command finished with ModelNavigatorUserInputError. The error is considered as external error. Usually caused by incompatibilities between the model and the target formats and/or runtimes. Please review the command output.
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/model_navigator/commands/execution_context.py", line 156, in _execute_function
    fire.Fire(func, unwrapped_args)
  File "/usr/local/lib/python3.12/dist-packages/fire/core.py", line 143, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/fire/core.py", line 477, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
                                ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/fire/core.py", line 693, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/model_navigator/commands/correctness/correctness_script.py", line 91, in correctness
    with runner:
  File "/usr/local/lib/python3.12/dist-packages/model_navigator/runners/base.py", line 225, in __enter__
    self.activate()
  File "/usr/local/lib/python3.12/dist-packages/model_navigator/runners/base.py", line 261, in activate
    self.activate_impl()
  File "/usr/local/lib/python3.12/dist-packages/model_navigator_custom_runners/torch_trt_compile/runner.py", line 49, in activate_impl
    self._compile_dynamic_shapes(feed_dict)
  File "/usr/local/lib/python3.12/dist-packages/model_navigator_custom_runners/torch_trt_compile/runner.py", line 93, in _compile_dynamic_shapes
    self.infer_impl(feed_dict=dummy_input)
  File "/usr/local/lib/python3.12/dist-packages/model_navigator/runners/torch.py", line 104, in infer_impl
    outputs = self._infer(feed_dict=feed_dict)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/model_navigator/runners/torch.py", line 144, in _infer_v1
    outputs = self._loaded_model(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/monai/networks/nets/basic_unet.py", line 254, in forward
    def forward(self, x: torch.Tensor):
  File "/usr/local/lib/python3.12/dist-packages/monai/networks/nets/basic_unet.py", line 273, in torch_dynamo_resume_in_forward_at_273
    u4 = self.upcat_4(x4, x3)
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/monai/networks/nets/basic_unet.py", line 153, in forward
    def forward(self, x: torch.Tensor, x_e: Optional[torch.Tensor]):
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 1333, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 1124, in __call__
    result = self._inner_convert(
             ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 528, in __call__
    return _compile(
           ^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 948, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 679, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_utils_internal.py", line 87, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 712, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1337, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 221, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 641, in transform
    tracer.run()
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 2766, in run
    super().run()
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 973, in run
    while self.step():
          ^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 885, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 2957, in RETURN_VALUE
    self._return(inst)
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 2942, in _return
    self.output.compile_subgraph(
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/output_graph.py", line 1117, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/output_graph.py", line 1369, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/output_graph.py", line 1416, in call_user_compiler
    return self._call_user_compiler(gm)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/output_graph.py", line 1465, in _call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/output_graph.py", line 1446, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/__init__.py", line 2280, in __call__
    return self.compiler_fn(model_, inputs_, **self.kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/backend/backends.py", line 44, in torch_tensorrt_backend
    return DEFAULT_BACKEND(gm, sample_inputs, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/backend/backends.py", line 52, in aot_torch_tensorrt_aten_backend
    return _pretraced_backend(gm, sample_inputs, settings, engine_cache)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/backend/backends.py", line 110, in _pretraced_backend
    trt_compiled = compile_module(
                   ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/_compiler.py", line 434, in compile_module
    submodule_inputs = partitioning.construct_submodule_inputs(submodule)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/partitioning/common.py", line 91, in construct_submodule_inputs
    get_input(input_shape, input_meta.dtype, name=input.name)
  File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/partitioning/common.py", line 61, in get_input
    return construct_dynamic_input(
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/partitioning/common.py", line 32, in construct_dynamic_input
    min_max_opt = extract_var_range_info(dim)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/utils.py", line 345, in extract_var_range_info
    min_val, max_val, opt_val = int(var_range.lower), int(var_range.upper), int(var_val)
                                                      ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/sympy/core/expr.py", line 308, in __int__
    r = self.round(2)
        ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/sympy/core/expr.py", line 3856, in round
    digits_to_decimal = _mag(x)  # _mag(12) = 2, _mag(.012) = -1
                        ^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/sympy/core/expr.py", line 4037, in _mag
    mag_first_dig = int(ceil(Float(mpf_log(xpos._mpf_, 53))/log(10)))
                                           ^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='torch_tensorrt' raised:
AttributeError: 'Infinity' object has no attribute '_mpf_'
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/model_navigator/pipelines/pipeline.py", line 121, in _execute_unit
    command_output = execution_unit.command().run(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/model_navigator/commands/base.py", line 127, in run
    output = self._run(*args, **_filter_dict_for_func(kwargs, self._run))
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/model_navigator/commands/correctness/correctness.py", line 150, in _run
    context.execute_python_script(
  File "/usr/local/lib/python3.12/dist-packages/model_navigator/commands/execution_context.py", line 142, in execute_python_script
    self._execute_function(func, unwrapped_args, allow_failure, cmd)
  File "/usr/local/lib/python3.12/dist-packages/model_navigator/commands/execution_context.py", line 168, in _execute_function
    raise ModelNavigatorUserInputError(cmd_to_reproduce_error) from e
model_navigator.exceptions.ModelNavigatorUserInputError: Command to reproduce error: /bin/bash torch/reproduce_correctness-torchtensorrtcompilerunner.sh
monai.networks.nets.basic_unet.BasicUNet: Validating model torch on TorchTensorRTCompile backend FAIL
2024-11-25 12:12:36.542 | INFO     | MainProcess | /usr/local/lib/python3.12/dist-packages/model_navigator/pipelines/pipeline.py:148 - Execution time: 346.95[s]

Steps to Reproduce:

Below is the code used to reproduce the issue

import torch
from monai.networks.nets import BasicUNet
import torch_tensorrt

device = "cuda:0"

# Create BasicUNet model
model = BasicUNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    features=(32, 64, 128, 256, 512, 32)
)

# Move model to GPU\model = model.to(device)

# Generate random input tensor
input_tensor = torch.randn(1, 1, 128, 128, 128, device=device).half()

# Compile with Torch-TensorRT backend
backend = "torch_tensorrt"

model = torch.compile(
    model.half().eval(),
    backend=backend,
    options={
        "truncate_long_and_double": False,
        "enabled_precisions": {torch.float16, torch.float32},
    },
    dynamic=False,
)

# Run inference
with torch.no_grad():
    output = model(input_tensor)

print(output)
@chohk88
Copy link
Collaborator Author

chohk88 commented Nov 28, 2024

The AttributeError: 'Infinity' object has no attribute '_mpf_' was traced to the extract_var_range_info(symbolic_integer: torch.SymInt) function

def extract_var_range_info(symbolic_integer: torch.SymInt) -> Dict[str, int]:
when retrieving opt_val. This has been fixed in #3279, which adds proper exception handling. After applying the PR, the error no longer occurs.

With this fix, I am closing the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant