diff --git a/Lib/annotationlib.py b/Lib/annotationlib.py index 141e31bbf910e3..8f2a93be915832 100644 --- a/Lib/annotationlib.py +++ b/Lib/annotationlib.py @@ -74,7 +74,7 @@ def __init_subclass__(cls, /, *args, **kwds): def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None): """Evaluate the forward reference and return the value. - If the forward reference is not evaluatable, raise an exception. + If the forward reference cannot be evaluated, raise an exception. """ if self.__forward_evaluated__: return self.__forward_value__ @@ -89,12 +89,10 @@ def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None): return value if owner is None: owner = self.__owner__ - if type_params is None and owner is None: - raise TypeError("Either 'type_params' or 'owner' must be provided") - if self.__forward_module__ is not None: + if globals is None and self.__forward_module__ is not None: globals = getattr( - sys.modules.get(self.__forward_module__, None), "__dict__", globals + sys.modules.get(self.__forward_module__, None), "__dict__", None ) if globals is None: globals = self.__globals__ @@ -112,14 +110,14 @@ def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None): if locals is None: locals = {} - if isinstance(self.__owner__, type): - locals.update(vars(self.__owner__)) + if isinstance(owner, type): + locals.update(vars(owner)) - if type_params is None and self.__owner__ is not None: + if type_params is None and owner is not None: # "Inject" type parameters into the local namespace # (unless they are shadowed by assignments *in* the local namespace), # as a way of emulating annotation scopes when calling `eval()` - type_params = getattr(self.__owner__, "__type_params__", None) + type_params = getattr(owner, "__type_params__", None) # type parameters require some special handling, # as they exist in their own scope @@ -129,7 +127,14 @@ def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None): # but should in turn be overridden by names in the class scope # (which here are called `globalns`!) if type_params is not None: - globals, locals = dict(globals), dict(locals) + if globals is None: + globals = {} + else: + globals = dict(globals) + if locals is None: + locals = {} + else: + locals = dict(locals) for param in type_params: param_name = param.__name__ if not self.__forward_is_class__ or param_name not in globals: diff --git a/Lib/test/test_annotationlib.py b/Lib/test/test_annotationlib.py index e4dcdb6b58d009..db8350c2746983 100644 --- a/Lib/test/test_annotationlib.py +++ b/Lib/test/test_annotationlib.py @@ -5,7 +5,7 @@ import itertools import pickle import unittest -from annotationlib import Format, get_annotations, get_annotate_function +from annotationlib import Format, ForwardRef, get_annotations, get_annotate_function from typing import Unpack from test.test_inspect import inspect_stock_annotations @@ -250,6 +250,46 @@ def test_special_attrs(self): with self.assertRaises(TypeError): pickle.dumps(fr, proto) + def test_evaluate_with_type_params(self): + class Gen[T]: + alias = int + + with self.assertRaises(NameError): + ForwardRef("T").evaluate() + with self.assertRaises(NameError): + ForwardRef("T").evaluate(type_params=()) + with self.assertRaises(NameError): + ForwardRef("T").evaluate(owner=int) + + T, = Gen.__type_params__ + self.assertIs(ForwardRef("T").evaluate(type_params=Gen.__type_params__), T) + self.assertIs(ForwardRef("T").evaluate(owner=Gen), T) + + with self.assertRaises(NameError): + ForwardRef("alias").evaluate(type_params=Gen.__type_params__) + self.assertIs(ForwardRef("alias").evaluate(owner=Gen), int) + # If you pass custom locals, we don't look at the owner's locals + with self.assertRaises(NameError): + ForwardRef("alias").evaluate(owner=Gen, locals={}) + # But if the name exists in the locals, it works + self.assertIs( + ForwardRef("alias").evaluate(owner=Gen, locals={"alias": str}), str + ) + + def test_fwdref_with_module(self): + self.assertIs(ForwardRef("Format", module=annotationlib).evaluate(), Format) + + with self.assertRaises(NameError): + # If globals are passed explicitly, we don't look at the module dict + ForwardRef("Format", module=annotationlib).evaluate(globals={}) + + def test_fwdref_value_is_cached(self): + fr = ForwardRef("hello") + with self.assertRaises(NameError): + fr.evaluate() + self.assertIs(fr.evaluate(globals={"hello": str}), str) + self.assertIs(fr.evaluate(), str) + class TestGetAnnotations(unittest.TestCase): def test_builtin_type(self): diff --git a/Lib/typing.py b/Lib/typing.py index 39a14ae6f83c28..bcb7bec23a9aa1 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -474,6 +474,10 @@ def _eval_type(t, globalns, localns, type_params=_sentinel, *, recursive_guard=f _deprecation_warning_for_no_type_params_passed("typing._eval_type") type_params = () if isinstance(t, ForwardRef): + # If the forward_ref has __forward_module__ set, evaluate() infers the globals + # from the module, and it will probably pick better than the globals we have here. + if t.__forward_module__ is not None: + globalns = None return evaluate_forward_ref(t, globals=globalns, locals=localns, type_params=type_params, owner=owner, _recursive_guard=recursive_guard, format=format)