Skip to content

Commit d2b8274

Browse files
Stub-related improvements (#386)
TypeAlias, TypedDict, eval functions. And runtime TypeAlias.
1 parent fbca993 commit d2b8274

21 files changed

+526
-187
lines changed

docs/changelog.md

+4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
## Unreleased
44

5+
- Support some imports from stub-only modules (#386)
6+
- Support type evaluation functions in stubs (#386)
7+
- Support `TypedDict` in stubs (#386)
8+
- Support `TypeAlias` (PEP 612) (#386)
59
- Small improvements to `ParamSpec` support (#385)
610
- Allow `CustomCheck` to customize what values
711
a value can be assigned to (#383)

pyanalyze/annotations.py

+87-28
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
import qcore
3131
import ast
3232
import builtins
33+
import inspect
3334
from collections.abc import Callable, Iterable
35+
import textwrap
3436
from typing import (
3537
Any,
3638
Container,
@@ -46,7 +48,7 @@
4648
Union,
4749
TYPE_CHECKING,
4850
)
49-
from typing_extensions import ParamSpec
51+
from typing_extensions import ParamSpec, TypedDict
5052

5153
from .error_code import ErrorCode
5254
from .extensions import (
@@ -57,8 +59,11 @@
5759
NoReturnGuard,
5860
ParameterTypeGuard,
5961
TypeGuard,
62+
get_type_evaluation,
6063
)
6164
from .find_unused import used
65+
from .functions import FunctionDefNode
66+
from .node_visitor import ErrorContext
6267
from .signature import ELLIPSIS_PARAM, SigParameter, Signature, ParameterKind
6368
from .safe import is_typing_name, is_instance_of_typing_name
6469
from . import type_evaluation
@@ -67,7 +72,6 @@
6772
AnySource,
6873
AnyValue,
6974
CallableValue,
70-
CanAssignContext,
7175
CustomCheckExtension,
7276
Extension,
7377
HasAttrGuardExtension,
@@ -152,11 +156,9 @@ def get_name_from_globals(self, name: str, globals: Mapping[str, Any]) -> Value:
152156

153157

154158
@dataclass
155-
class TypeEvaluationContext(Context, type_evaluation.Context):
156-
variables: type_evaluation.VarMap
157-
positions: Mapping[str, type_evaluation.Position]
158-
can_assign_context: CanAssignContext
159+
class RuntimeEvaluator(type_evaluation.Evaluator, Context):
159160
globals: Mapping[str, object] = field(repr=False)
161+
func: typing.Callable[..., Any]
160162

161163
def evaluate_type(self, node: ast.AST) -> Value:
162164
return type_from_ast(node, ctx=self)
@@ -168,36 +170,58 @@ def get_name(self, node: ast.Name) -> Value:
168170
"""Return the :class:`Value <pyanalyze.value.Value>` corresponding to a name."""
169171
return self.get_name_from_globals(node.id, self.globals)
170172

173+
@classmethod
174+
def get_for(cls, func: typing.Callable[..., Any]) -> Optional["RuntimeEvaluator"]:
175+
try:
176+
key = f"{func.__module__}.{func.__qualname__}"
177+
except AttributeError:
178+
return None
179+
evaluation_func = get_type_evaluation(key)
180+
if evaluation_func is None or not hasattr(evaluation_func, "__globals__"):
181+
return None
182+
lines, _ = inspect.getsourcelines(evaluation_func)
183+
code = textwrap.dedent("".join(lines))
184+
body = ast.parse(code)
185+
if not body.body:
186+
return None
187+
evaluator = body.body[0]
188+
if not isinstance(evaluator, ast.FunctionDef):
189+
return None
190+
return RuntimeEvaluator(evaluator, evaluation_func.__globals__, evaluation_func)
191+
171192

172193
@dataclass
173-
class EvaluatorValidationContext(Context, type_evaluation.Context):
174-
variables: type_evaluation.VarMap
175-
positions: Mapping[str, type_evaluation.Position]
176-
can_assign_context: "NameCheckVisitor"
177-
node: ast.AST
194+
class SyntheticEvaluator(type_evaluation.Evaluator):
195+
error_ctx: ErrorContext
196+
annotations_context: Context
178197

179198
def show_error(
180199
self,
181200
message: str,
182201
error_code: ErrorCode = ErrorCode.invalid_annotation,
183202
node: Optional[ast.AST] = None,
184203
) -> None:
185-
self.can_assign_context.show_error(
186-
node or self.node, message, error_code=error_code
187-
)
204+
self.error_ctx.show_error(node or self.node, message, error_code=error_code)
188205

189206
def evaluate_type(self, node: ast.AST) -> Value:
190-
return type_from_ast(node, ctx=self)
207+
return type_from_ast(node, ctx=self.annotations_context)
191208

192209
def evaluate_value(self, node: ast.AST) -> Value:
193-
return value_from_ast(node, ctx=self, error_on_unrecognized=False)
210+
return value_from_ast(
211+
node, ctx=self.annotations_context, error_on_unrecognized=False
212+
)
194213

195214
def get_name(self, node: ast.Name) -> Value:
196215
"""Return the :class:`Value <pyanalyze.value.Value>` corresponding to a name."""
197-
val, _ = self.can_assign_context.resolve_name(
198-
node, suppress_errors=self.should_suppress_undefined_names
216+
return self.annotations_context.get_name(node)
217+
218+
@classmethod
219+
def from_visitor(
220+
cls, node: FunctionDefNode, visitor: "NameCheckVisitor"
221+
) -> "SyntheticEvaluator":
222+
return cls(
223+
node, visitor, _DefaultContext(visitor, node, use_name_node_for_error=True)
199224
)
200-
return val
201225

202226

203227
@used # part of an API
@@ -223,6 +247,25 @@ def type_from_ast(
223247
return _type_from_ast(ast_node, ctx)
224248

225249

250+
def type_from_annotations(
251+
annotations: Mapping[str, object],
252+
key: str,
253+
*,
254+
globals: Optional[Mapping[str, object]] = None,
255+
ctx: Optional[Context] = None,
256+
) -> Optional[Value]:
257+
try:
258+
annotation = annotations[key]
259+
except Exception:
260+
# Malformed __annotations__
261+
return None
262+
else:
263+
maybe_val = type_from_runtime(annotation, globals=globals, ctx=ctx)
264+
if maybe_val != AnyValue(AnySource.incomplete_annotation):
265+
return maybe_val
266+
return None
267+
268+
226269
def type_from_runtime(
227270
val: object,
228271
visitor: Optional["NameCheckVisitor"] = None,
@@ -490,11 +533,15 @@ def _type_from_runtime(val: Any, ctx: Context, is_typeddict: bool = False) -> Va
490533
elif is_instance_of_typing_name(val, "_MaybeRequired"):
491534
required = is_instance_of_typing_name(val, "_Required")
492535
if is_typeddict:
493-
return _Pep655Value(required, _type_from_runtime(val.__type__, ctx))
536+
return Pep655Value(required, _type_from_runtime(val.__type__, ctx))
494537
else:
495538
cls = "Required" if required else "NotRequired"
496539
ctx.show_error(f"{cls}[] used in unsupported context")
497540
return AnyValue(AnySource.error)
541+
elif is_typing_name(val, "TypeAlias"):
542+
return AnyValue(AnySource.incomplete_annotation)
543+
elif is_typing_name(val, "TypedDict"):
544+
return KnownValue(TypedDict)
498545
else:
499546
origin = get_origin(val)
500547
if isinstance(origin, type):
@@ -571,7 +618,7 @@ def _get_typeddict_value(
571618
total: bool,
572619
) -> Tuple[bool, Value]:
573620
val = _type_from_runtime(value, ctx, is_typeddict=True)
574-
if isinstance(val, _Pep655Value):
621+
if isinstance(val, Pep655Value):
575622
return (val.required, val.value)
576623
if required_keys is None:
577624
required = total
@@ -605,6 +652,8 @@ def _type_from_value(value: Value, ctx: Context, is_typeddict: bool = False) ->
605652
)
606653
elif isinstance(value, AnyValue):
607654
return value
655+
elif isinstance(value, SubclassValue) and value.exactly:
656+
return value.typ
608657
elif isinstance(value, TypedValue) and isinstance(value.typ, str):
609658
# Synthetic type
610659
return value
@@ -634,6 +683,15 @@ def _type_from_subscripted_value(
634683
for subval in root.vals
635684
]
636685
)
686+
if (
687+
isinstance(root, SubclassValue)
688+
and root.exactly
689+
and isinstance(root.typ, TypedValue)
690+
):
691+
return GenericValue(
692+
root.typ.typ, [_type_from_value(elt, ctx) for elt in members]
693+
)
694+
637695
if isinstance(root, TypedValue) and isinstance(root.typ, str):
638696
return GenericValue(root.typ, [_type_from_value(elt, ctx) for elt in members])
639697

@@ -691,15 +749,15 @@ def _type_from_subscripted_value(
691749
if len(members) != 1:
692750
ctx.show_error("Required[] requires a single argument")
693751
return AnyValue(AnySource.error)
694-
return _Pep655Value(True, _type_from_value(members[0], ctx))
752+
return Pep655Value(True, _type_from_value(members[0], ctx))
695753
elif is_typing_name(root, "NotRequired"):
696754
if not is_typeddict:
697755
ctx.show_error("NotRequired[] used in unsupported context")
698756
return AnyValue(AnySource.error)
699757
if len(members) != 1:
700758
ctx.show_error("NotRequired[] requires a single argument")
701759
return AnyValue(AnySource.error)
702-
return _Pep655Value(False, _type_from_value(members[0], ctx))
760+
return Pep655Value(False, _type_from_value(members[0], ctx))
703761
elif root is Callable or root is typing.Callable:
704762
if len(members) == 2:
705763
args, return_value = members
@@ -739,11 +797,13 @@ def __init__(
739797
visitor: "NameCheckVisitor",
740798
node: Optional[ast.AST],
741799
globals: Optional[Mapping[str, object]] = None,
800+
use_name_node_for_error: bool = False,
742801
) -> None:
743802
super().__init__()
744803
self.visitor = visitor
745804
self.node = node
746805
self.globals = globals
806+
self.use_name_node_for_error = use_name_node_for_error
747807

748808
def show_error(
749809
self,
@@ -760,7 +820,7 @@ def get_name(self, node: ast.Name) -> Value:
760820
if self.visitor is not None:
761821
val, _ = self.visitor.resolve_name(
762822
node,
763-
error_node=self.node,
823+
error_node=node if self.use_name_node_for_error else self.node,
764824
suppress_errors=self.should_suppress_undefined_names,
765825
)
766826
return val
@@ -786,7 +846,7 @@ class _SubscriptedValue(Value):
786846

787847

788848
@dataclass
789-
class _Pep655Value(Value):
849+
class Pep655Value(Value):
790850
required: bool
791851
value: Value
792852

@@ -901,7 +961,6 @@ def visit_Call(self, node: ast.Call) -> Optional[Value]:
901961
if isinstance(kwarg_value, KnownValue):
902962
kwargs[name] = kwarg_value.val
903963
else:
904-
print(kwarg_value)
905964
return None
906965
return KnownValue(func.val(*args, **kwargs))
907966
elif func.val == TypeVar:
@@ -1038,15 +1097,15 @@ def _value_of_origin_args(
10381097
if len(args) != 1:
10391098
ctx.show_error("Required[] requires a single argument")
10401099
return AnyValue(AnySource.error)
1041-
return _Pep655Value(True, _type_from_runtime(args[0], ctx))
1100+
return Pep655Value(True, _type_from_runtime(args[0], ctx))
10421101
elif is_typing_name(origin, "NotRequired"):
10431102
if not is_typeddict:
10441103
ctx.show_error("NotRequired[] used in unsupported context")
10451104
return AnyValue(AnySource.error)
10461105
if len(args) != 1:
10471106
ctx.show_error("NotRequired[] requires a single argument")
10481107
return AnyValue(AnySource.error)
1049-
return _Pep655Value(False, _type_from_runtime(args[0], ctx))
1108+
return Pep655Value(False, _type_from_runtime(args[0], ctx))
10501109
elif origin is None and isinstance(val, type):
10511110
# This happens for SupportsInt in 3.7.
10521111
return _maybe_typed_value(val)

pyanalyze/arg_spec.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .options import Options, PyObjectSequenceOption
88
from .analysis_lib import is_positional_only_arg_name
99
from .extensions import CustomCheck, get_overloads
10-
from .annotations import Context, type_from_runtime
10+
from .annotations import Context, RuntimeEvaluator, type_from_runtime
1111
from .config import Config
1212
from .find_unused import used
1313
from . import implementation
@@ -33,7 +33,6 @@
3333
Signature,
3434
ParameterKind,
3535
)
36-
from .type_evaluation import get_evaluator
3736
from .typeshed import TypeshedFinder
3837
from .value import (
3938
AnySource,
@@ -94,9 +93,9 @@ def _scribe_log_impl(variables, visitor, node):
9493
):
9594
yield
9695
else:
97-
argspec = ArgSpecCache(Options.from_option_list([], Config())).get_argspec(
98-
fn, impl=implementation_fn
99-
)
96+
options = Options.from_option_list([], Config())
97+
tsf = TypeshedFinder.make(options)
98+
argspec = ArgSpecCache(options, tsf).get_argspec(fn, impl=implementation_fn)
10099
if argspec is None:
101100
# builtin or something, just use a generic argspec
102101
argspec = Signature.make(
@@ -288,13 +287,14 @@ class ArgSpecCache:
288287
def __init__(
289288
self,
290289
options: Options,
290+
ts_finder: TypeshedFinder,
291291
*,
292292
vnv_provider: Callable[[str], Optional[Value]] = lambda _: None,
293293
) -> None:
294294
self.vnv_provider = vnv_provider
295295
self.options = options
296296
self.config = options.fallback
297-
self.ts_finder = TypeshedFinder(verbose=False)
297+
self.ts_finder = ts_finder
298298
self.known_argspecs = {}
299299
self.generic_bases_cache = {}
300300
self.default_context = AnnotationsContext(self)
@@ -509,7 +509,7 @@ def _uncached_get_argspec(
509509
]
510510
if all_of_type(sigs, Signature):
511511
return OverloadedSignature(sigs)
512-
evaluator = get_evaluator(obj)
512+
evaluator = RuntimeEvaluator.get_for(obj)
513513
if evaluator is not None:
514514
sig = self._cached_get_argspec(
515515
evaluator.func, impl, is_asynq, in_overload_resolution=True

pyanalyze/attributes.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from typing import Any, Generic, Sequence, Tuple, Optional, Union
1515

1616

17-
from .annotations import type_from_runtime, Context
17+
from .annotations import type_from_annotations, type_from_runtime, Context
1818
from .safe import safe_isinstance, safe_issubclass
1919
from .signature import Signature, MaybeSignature
2020
from .stacked_scopes import Composite
@@ -376,13 +376,14 @@ def _get_attribute_from_mro(
376376
pass
377377
elif safe_isinstance(typ, types.ModuleType):
378378
try:
379-
annotation = typ.__annotations__[ctx.attr]
379+
annotations = typ.__annotations__
380380
except Exception:
381-
# Module doesn't have annotations or it's not in there
382381
pass
383382
else:
384-
attr_type = type_from_runtime(annotation, ctx=AnnotationsContext(ctx, typ))
385-
if attr_type != AnyValue(AnySource.incomplete_annotation):
383+
attr_type = type_from_annotations(
384+
annotations, ctx.attr, ctx=AnnotationsContext(ctx, typ)
385+
)
386+
if attr_type is not None:
386387
return (attr_type, typ, False)
387388

388389
try:
@@ -407,15 +408,15 @@ def _get_attribute_from_mro(
407408
try:
408409
# Make sure to use only __annotations__ that are actually on this
409410
# class, not ones inherited from a base class.
410-
annotation = base_cls.__dict__["__annotations__"][ctx.attr]
411+
annotations = base_cls.__dict__["__annotations__"]
411412
except Exception:
412413
# no __annotations__, or it's not a dict, or the attr isn't there
413414
pass
414415
else:
415-
attr_type = type_from_runtime(
416-
annotation, ctx=AnnotationsContext(ctx, base_cls)
416+
attr_type = type_from_annotations(
417+
annotations, ctx.attr, ctx=AnnotationsContext(ctx, base_cls)
417418
)
418-
if attr_type != AnyValue(AnySource.incomplete_annotation):
419+
if attr_type is not None:
419420
return (attr_type, base_cls, False)
420421
try:
421422
# Make sure we use only the object from this class, but do invoke

0 commit comments

Comments
 (0)