diff --git a/pytensor/compile/function/__init__.py b/pytensor/compile/function/__init__.py index ffce6db4fb..a0351f5545 100644 --- a/pytensor/compile/function/__init__.py +++ b/pytensor/compile/function/__init__.py @@ -275,23 +275,6 @@ def opt_log1p(node): else: output_keys = None - if name is None: - # Determine possible file names - source_file = re.sub(r"\.pyc?", ".py", __file__) - compiled_file = source_file + "c" - - stack = tb.extract_stack() - idx = len(stack) - 1 - - last_frame = stack[idx] - if last_frame[0] == source_file or last_frame[0] == compiled_file: - func_frame = stack[idx - 1] - while "pytensor/graph" in func_frame[0] and idx > 0: - idx -= 1 - # This can happen if we call var.eval() - func_frame = stack[idx - 1] - name = func_frame[0] + ":" + str(func_frame[1]) - if updates is None: updates = [] diff --git a/pytensor/compile/function/types.py b/pytensor/compile/function/types.py index d77f11d84d..aa6eb765c0 100644 --- a/pytensor/compile/function/types.py +++ b/pytensor/compile/function/types.py @@ -880,6 +880,21 @@ def _restore_defaults(self): value = value.storage[0] self[i] = value + def add_note_to_invalid_argument_exception(self, e, arg_container, arg): + i = self.input_storage.index(arg_container) + function_name = ( + f"PyTensor function '{self.name}'" if self.name else "PyTensor function" + ) + argument_name = ( + f"argument '{arg.name}'" if getattr(arg, "name", None) else "argument" + ) + where = ( + "" + if config.exception_verbosity == "low" + else get_variable_trace_string(self.maker.inputs[i].variable) + ) + e.add_note(f"\nInvalid {argument_name} to {function_name} at index {i}.{where}") + def __call__(self, *args, output_subset=None, **kwargs): """ Evaluates value of a function on given arguments. @@ -947,34 +962,10 @@ def __call__(self, *args, output_subset=None, **kwargs): strict=arg_container.strict, allow_downcast=arg_container.allow_downcast, ) - except Exception as e: - i = input_storage.index(arg_container) - function_name = "pytensor function" - argument_name = "argument" - if self.name: - function_name += ' with name "' + self.name + '"' - if hasattr(arg, "name") and arg.name: - argument_name += ' with name "' + arg.name + '"' - where = get_variable_trace_string(self.maker.inputs[i].variable) - if len(e.args) == 1: - e.args = ( - "Bad input " - + argument_name - + " to " - + function_name - + f" at index {int(i)} (0-based). {where}" - + e.args[0], - ) - else: - e.args = ( - "Bad input " - + argument_name - + " to " - + function_name - + f" at index {int(i)} (0-based). {where}" - ) + e.args - self._restore_defaults() + self.add_note_to_invalid_argument_exception( + e, arg_container, arg + ) raise arg_container.provided += 1 diff --git a/pytensor/configdefaults.py b/pytensor/configdefaults.py index f99b8240ca..89620bed0c 100644 --- a/pytensor/configdefaults.py +++ b/pytensor/configdefaults.py @@ -644,15 +644,8 @@ def add_error_and_warning_configvars(): # on all important apply nodes. config.add( "exception_verbosity", - "If 'low', the text of exceptions will generally refer " - "to apply nodes with short names such as " - "Elemwise{add_no_inplace}. If 'high', some exceptions " - "will also refer to apply nodes with long descriptions " - """ like: - A. Elemwise{add_no_inplace} - B. log_likelihood_v_given_h - C. log_likelihood_h""", - EnumStr("low", ["high"]), + "Verbosity of exceptions generated by PyTensor functions.", + EnumStr("low", ["medium", "high"]), in_c_key=False, ) diff --git a/pytensor/link/utils.py b/pytensor/link/utils.py index 03c4f4eddc..00c1157245 100644 --- a/pytensor/link/utils.py +++ b/pytensor/link/utils.py @@ -313,6 +313,12 @@ def raise_with_op( # print a simple traceback from KeyboardInterrupt raise exc_value.with_traceback(exc_trace) + if verbosity == "low": + exc_value.add_note( + "\nHINT: Set PyTensor `config.exception_verbosity` to `medium` or `high` for more information about the source of the error." + ) + raise exc_value.with_traceback(exc_trace) + trace = getattr(node.outputs[0].tag, "trace", ()) exc_value.__thunk_trace__ = trace diff --git a/tests/compile/function/test_function.py b/tests/compile/function/test_function.py index d36743192f..2839c6b9b0 100644 --- a/tests/compile/function/test_function.py +++ b/tests/compile/function/test_function.py @@ -47,8 +47,9 @@ def test_function_dump(): def test_function_name(): x = vector("x") func = function([x], x + 1.0) - - assert __file__ in func.name + assert func.name is None + func = function([x], x + 1.0, name="my_func") + assert func.name == "my_func" def test_trust_input(): diff --git a/tests/link/test_vm.py b/tests/link/test_vm.py index 26058ce429..371a8b3806 100644 --- a/tests/link/test_vm.py +++ b/tests/link/test_vm.py @@ -421,9 +421,13 @@ def make_thunk(self, *args, **kwargs): z = BadOp()(a) - with pytest.raises(Exception, match=r".*Apply node that caused the error.*"): + with pytest.raises(Exception, match=r"bad Op"): function([a], z, mode=Mode(optimizer=None, linker=linker)) + with config.change_flags(exception_verbosity="high"): + with pytest.raises(Exception, match=r".*Apply node that caused the error.*"): + function([a], z, mode=Mode(optimizer=None, linker=linker)) + def test_VM_exception(): class SomeVM(VM):