30
30
import qcore
31
31
import ast
32
32
import builtins
33
+ import inspect
33
34
from collections .abc import Callable , Iterable
35
+ import textwrap
34
36
from typing import (
35
37
Any ,
36
38
Container ,
46
48
Union ,
47
49
TYPE_CHECKING ,
48
50
)
49
- from typing_extensions import ParamSpec
51
+ from typing_extensions import ParamSpec , TypedDict
50
52
51
53
from .error_code import ErrorCode
52
54
from .extensions import (
57
59
NoReturnGuard ,
58
60
ParameterTypeGuard ,
59
61
TypeGuard ,
62
+ get_type_evaluation ,
60
63
)
61
64
from .find_unused import used
65
+ from .functions import FunctionDefNode
66
+ from .node_visitor import ErrorContext
62
67
from .signature import ELLIPSIS_PARAM , SigParameter , Signature , ParameterKind
63
68
from .safe import is_typing_name , is_instance_of_typing_name
64
69
from . import type_evaluation
67
72
AnySource ,
68
73
AnyValue ,
69
74
CallableValue ,
70
- CanAssignContext ,
71
75
CustomCheckExtension ,
72
76
Extension ,
73
77
HasAttrGuardExtension ,
@@ -152,11 +156,9 @@ def get_name_from_globals(self, name: str, globals: Mapping[str, Any]) -> Value:
152
156
153
157
154
158
@dataclass
155
- class TypeEvaluationContext (Context , type_evaluation .Context ):
156
- variables : type_evaluation .VarMap
157
- positions : Mapping [str , type_evaluation .Position ]
158
- can_assign_context : CanAssignContext
159
+ class RuntimeEvaluator (type_evaluation .Evaluator , Context ):
159
160
globals : Mapping [str , object ] = field (repr = False )
161
+ func : typing .Callable [..., Any ]
160
162
161
163
def evaluate_type (self , node : ast .AST ) -> Value :
162
164
return type_from_ast (node , ctx = self )
@@ -168,36 +170,58 @@ def get_name(self, node: ast.Name) -> Value:
168
170
"""Return the :class:`Value <pyanalyze.value.Value>` corresponding to a name."""
169
171
return self .get_name_from_globals (node .id , self .globals )
170
172
173
+ @classmethod
174
+ def get_for (cls , func : typing .Callable [..., Any ]) -> Optional ["RuntimeEvaluator" ]:
175
+ try :
176
+ key = f"{ func .__module__ } .{ func .__qualname__ } "
177
+ except AttributeError :
178
+ return None
179
+ evaluation_func = get_type_evaluation (key )
180
+ if evaluation_func is None or not hasattr (evaluation_func , "__globals__" ):
181
+ return None
182
+ lines , _ = inspect .getsourcelines (evaluation_func )
183
+ code = textwrap .dedent ("" .join (lines ))
184
+ body = ast .parse (code )
185
+ if not body .body :
186
+ return None
187
+ evaluator = body .body [0 ]
188
+ if not isinstance (evaluator , ast .FunctionDef ):
189
+ return None
190
+ return RuntimeEvaluator (evaluator , evaluation_func .__globals__ , evaluation_func )
191
+
171
192
172
193
@dataclass
173
- class EvaluatorValidationContext (Context , type_evaluation .Context ):
174
- variables : type_evaluation .VarMap
175
- positions : Mapping [str , type_evaluation .Position ]
176
- can_assign_context : "NameCheckVisitor"
177
- node : ast .AST
194
+ class SyntheticEvaluator (type_evaluation .Evaluator ):
195
+ error_ctx : ErrorContext
196
+ annotations_context : Context
178
197
179
198
def show_error (
180
199
self ,
181
200
message : str ,
182
201
error_code : ErrorCode = ErrorCode .invalid_annotation ,
183
202
node : Optional [ast .AST ] = None ,
184
203
) -> None :
185
- self .can_assign_context .show_error (
186
- node or self .node , message , error_code = error_code
187
- )
204
+ self .error_ctx .show_error (node or self .node , message , error_code = error_code )
188
205
189
206
def evaluate_type (self , node : ast .AST ) -> Value :
190
- return type_from_ast (node , ctx = self )
207
+ return type_from_ast (node , ctx = self . annotations_context )
191
208
192
209
def evaluate_value (self , node : ast .AST ) -> Value :
193
- return value_from_ast (node , ctx = self , error_on_unrecognized = False )
210
+ return value_from_ast (
211
+ node , ctx = self .annotations_context , error_on_unrecognized = False
212
+ )
194
213
195
214
def get_name (self , node : ast .Name ) -> Value :
196
215
"""Return the :class:`Value <pyanalyze.value.Value>` corresponding to a name."""
197
- val , _ = self .can_assign_context .resolve_name (
198
- node , suppress_errors = self .should_suppress_undefined_names
216
+ return self .annotations_context .get_name (node )
217
+
218
+ @classmethod
219
+ def from_visitor (
220
+ cls , node : FunctionDefNode , visitor : "NameCheckVisitor"
221
+ ) -> "SyntheticEvaluator" :
222
+ return cls (
223
+ node , visitor , _DefaultContext (visitor , node , use_name_node_for_error = True )
199
224
)
200
- return val
201
225
202
226
203
227
@used # part of an API
@@ -223,6 +247,25 @@ def type_from_ast(
223
247
return _type_from_ast (ast_node , ctx )
224
248
225
249
250
+ def type_from_annotations (
251
+ annotations : Mapping [str , object ],
252
+ key : str ,
253
+ * ,
254
+ globals : Optional [Mapping [str , object ]] = None ,
255
+ ctx : Optional [Context ] = None ,
256
+ ) -> Optional [Value ]:
257
+ try :
258
+ annotation = annotations [key ]
259
+ except Exception :
260
+ # Malformed __annotations__
261
+ return None
262
+ else :
263
+ maybe_val = type_from_runtime (annotation , globals = globals , ctx = ctx )
264
+ if maybe_val != AnyValue (AnySource .incomplete_annotation ):
265
+ return maybe_val
266
+ return None
267
+
268
+
226
269
def type_from_runtime (
227
270
val : object ,
228
271
visitor : Optional ["NameCheckVisitor" ] = None ,
@@ -490,11 +533,15 @@ def _type_from_runtime(val: Any, ctx: Context, is_typeddict: bool = False) -> Va
490
533
elif is_instance_of_typing_name (val , "_MaybeRequired" ):
491
534
required = is_instance_of_typing_name (val , "_Required" )
492
535
if is_typeddict :
493
- return _Pep655Value (required , _type_from_runtime (val .__type__ , ctx ))
536
+ return Pep655Value (required , _type_from_runtime (val .__type__ , ctx ))
494
537
else :
495
538
cls = "Required" if required else "NotRequired"
496
539
ctx .show_error (f"{ cls } [] used in unsupported context" )
497
540
return AnyValue (AnySource .error )
541
+ elif is_typing_name (val , "TypeAlias" ):
542
+ return AnyValue (AnySource .incomplete_annotation )
543
+ elif is_typing_name (val , "TypedDict" ):
544
+ return KnownValue (TypedDict )
498
545
else :
499
546
origin = get_origin (val )
500
547
if isinstance (origin , type ):
@@ -571,7 +618,7 @@ def _get_typeddict_value(
571
618
total : bool ,
572
619
) -> Tuple [bool , Value ]:
573
620
val = _type_from_runtime (value , ctx , is_typeddict = True )
574
- if isinstance (val , _Pep655Value ):
621
+ if isinstance (val , Pep655Value ):
575
622
return (val .required , val .value )
576
623
if required_keys is None :
577
624
required = total
@@ -605,6 +652,8 @@ def _type_from_value(value: Value, ctx: Context, is_typeddict: bool = False) ->
605
652
)
606
653
elif isinstance (value , AnyValue ):
607
654
return value
655
+ elif isinstance (value , SubclassValue ) and value .exactly :
656
+ return value .typ
608
657
elif isinstance (value , TypedValue ) and isinstance (value .typ , str ):
609
658
# Synthetic type
610
659
return value
@@ -634,6 +683,15 @@ def _type_from_subscripted_value(
634
683
for subval in root .vals
635
684
]
636
685
)
686
+ if (
687
+ isinstance (root , SubclassValue )
688
+ and root .exactly
689
+ and isinstance (root .typ , TypedValue )
690
+ ):
691
+ return GenericValue (
692
+ root .typ .typ , [_type_from_value (elt , ctx ) for elt in members ]
693
+ )
694
+
637
695
if isinstance (root , TypedValue ) and isinstance (root .typ , str ):
638
696
return GenericValue (root .typ , [_type_from_value (elt , ctx ) for elt in members ])
639
697
@@ -691,15 +749,15 @@ def _type_from_subscripted_value(
691
749
if len (members ) != 1 :
692
750
ctx .show_error ("Required[] requires a single argument" )
693
751
return AnyValue (AnySource .error )
694
- return _Pep655Value (True , _type_from_value (members [0 ], ctx ))
752
+ return Pep655Value (True , _type_from_value (members [0 ], ctx ))
695
753
elif is_typing_name (root , "NotRequired" ):
696
754
if not is_typeddict :
697
755
ctx .show_error ("NotRequired[] used in unsupported context" )
698
756
return AnyValue (AnySource .error )
699
757
if len (members ) != 1 :
700
758
ctx .show_error ("NotRequired[] requires a single argument" )
701
759
return AnyValue (AnySource .error )
702
- return _Pep655Value (False , _type_from_value (members [0 ], ctx ))
760
+ return Pep655Value (False , _type_from_value (members [0 ], ctx ))
703
761
elif root is Callable or root is typing .Callable :
704
762
if len (members ) == 2 :
705
763
args , return_value = members
@@ -739,11 +797,13 @@ def __init__(
739
797
visitor : "NameCheckVisitor" ,
740
798
node : Optional [ast .AST ],
741
799
globals : Optional [Mapping [str , object ]] = None ,
800
+ use_name_node_for_error : bool = False ,
742
801
) -> None :
743
802
super ().__init__ ()
744
803
self .visitor = visitor
745
804
self .node = node
746
805
self .globals = globals
806
+ self .use_name_node_for_error = use_name_node_for_error
747
807
748
808
def show_error (
749
809
self ,
@@ -760,7 +820,7 @@ def get_name(self, node: ast.Name) -> Value:
760
820
if self .visitor is not None :
761
821
val , _ = self .visitor .resolve_name (
762
822
node ,
763
- error_node = self .node ,
823
+ error_node = node if self . use_name_node_for_error else self .node ,
764
824
suppress_errors = self .should_suppress_undefined_names ,
765
825
)
766
826
return val
@@ -786,7 +846,7 @@ class _SubscriptedValue(Value):
786
846
787
847
788
848
@dataclass
789
- class _Pep655Value (Value ):
849
+ class Pep655Value (Value ):
790
850
required : bool
791
851
value : Value
792
852
@@ -901,7 +961,6 @@ def visit_Call(self, node: ast.Call) -> Optional[Value]:
901
961
if isinstance (kwarg_value , KnownValue ):
902
962
kwargs [name ] = kwarg_value .val
903
963
else :
904
- print (kwarg_value )
905
964
return None
906
965
return KnownValue (func .val (* args , ** kwargs ))
907
966
elif func .val == TypeVar :
@@ -1038,15 +1097,15 @@ def _value_of_origin_args(
1038
1097
if len (args ) != 1 :
1039
1098
ctx .show_error ("Required[] requires a single argument" )
1040
1099
return AnyValue (AnySource .error )
1041
- return _Pep655Value (True , _type_from_runtime (args [0 ], ctx ))
1100
+ return Pep655Value (True , _type_from_runtime (args [0 ], ctx ))
1042
1101
elif is_typing_name (origin , "NotRequired" ):
1043
1102
if not is_typeddict :
1044
1103
ctx .show_error ("NotRequired[] used in unsupported context" )
1045
1104
return AnyValue (AnySource .error )
1046
1105
if len (args ) != 1 :
1047
1106
ctx .show_error ("NotRequired[] requires a single argument" )
1048
1107
return AnyValue (AnySource .error )
1049
- return _Pep655Value (False , _type_from_runtime (args [0 ], ctx ))
1108
+ return Pep655Value (False , _type_from_runtime (args [0 ], ctx ))
1050
1109
elif origin is None and isinstance (val , type ):
1051
1110
# This happens for SupportsInt in 3.7.
1052
1111
return _maybe_typed_value (val )
0 commit comments