diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index a8d2790e9b00..a8556c54d544 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -470,7 +470,7 @@ def prepare_module(self, module_op: Operation): ... def resolve_literal( - self, gni: "GraphNodeImporter", literal: Any + self, gni: "GraphNodeImporter", literal: Any, info: Optional[InputInfo] ) -> Optional[Value]: """User overridable hook to resolve a literal value.""" return None @@ -1826,13 +1826,13 @@ def _convert_type( name=op_name, results=[result_type], operands=operands ).result - def _import_literal(self, py_value: Any) -> Value: + def _import_literal(self, py_value: Any, info: Optional[InputInfo] = None) -> Value: orig_value = None if isinstance(py_value, torch.Tensor) and py_value.dtype == torch.bool: orig_value = py_value py_value = py_value.to(torch.uint8) # Apply the conversion callback. - user_value = self.fx_importer._hooks.resolve_literal(self, py_value) + user_value = self.fx_importer._hooks.resolve_literal(self, py_value, info) if user_value is not None: assert isinstance(user_value, Value) if orig_value is not None: @@ -1866,7 +1866,7 @@ def _import_input(self, py_value: Any, info: InputInfo) -> Value: raise ValueError( f"Cannot import {info.input_spec} as a literal because it is mutable" ) - return self._import_literal(py_value) + return self._import_literal(py_value, info) def _import_scalar_as_tensor(self, loc: Location, arg: NodeArgument) -> Value: tensor_arg = torch.tensor(arg)