Skip to content

Commit

Permalink
Add type annotations to mixin.py and fix up all places that their typ…
Browse files Browse the repository at this point in the history
…e annotation needed fixes after the change.

PiperOrigin-RevId: 690586638
  • Loading branch information
h-joo authored and copybara-github committed Nov 8, 2024
1 parent 62a983d commit 4977c36
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 43 deletions.
2 changes: 1 addition & 1 deletion pytype/abstract/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def get_special_attribute(
self,
unused_node: "cfg.CFGNode",
name: str,
unused_valself: "cfg.Variable",
unused_valself: "cfg.Variable | None",
) -> "cfg.Variable | None":
"""Fetch a special attribute (e.g., __get__, __iter__)."""
if name == "__class__":
Expand Down
14 changes: 8 additions & 6 deletions pytype/abstract/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ def bases(self) -> list[cfg.Variable]: # pytype: disable=signature-mismatch

def load_lazy_attribute(
self, name: str, subst: str | None = None, store: bool = True
) -> cfg.Variable:
) -> cfg.Variable | None:
try:
return super().load_lazy_attribute(name, subst, store)
except self.ctx.convert.TypeParameterError as e:
Expand Down Expand Up @@ -737,7 +737,8 @@ def __init__(
self,
base_cls: PyTDClass | InterpreterClass,
formal_type_parameters: (
abstract_utils.LazyFormalTypeParameters | dict[str, _base.BaseValue]
abstract_utils.LazyFormalTypeParameters
| dict[str | int, _base.BaseValue]
),
ctx: "context.Context",
template: tuple["_typing.TypeParameter", ...] | None = None,
Expand Down Expand Up @@ -954,15 +955,16 @@ def get_formal_type_parameter(self, t):
def get_inner_types(self) -> ItemsView[int | str, _base.BaseValue]:
return self.formal_type_parameters.items()

def update_inner_type(self, key: str, typ: _base.BaseValue) -> None:
def update_inner_type(self, key: str | int, typ: _base.BaseValue) -> None:
self.formal_type_parameters[key] = typ

def replace(
self,
inner_types: (
abstract_utils.LazyFormalTypeParameters | dict[str, _base.BaseValue]
abstract_utils.LazyFormalTypeParameters
| Sequence[tuple[int, _base.BaseValue]]
),
) -> "ParameterizedClass":
) -> "ParameterizedClass | LiteralClass":
inner_types = dict(inner_types)
if isinstance(self, LiteralClass):
if inner_types == self.formal_type_parameters:
Expand Down Expand Up @@ -1270,7 +1272,7 @@ def getitem_slot(
)

def get_special_attribute(
self, node: cfg.CFGNode, name: str, valself: cfg.Variable
self, node: cfg.CFGNode, name: str, valself: cfg.Variable | None
) -> cfg.Variable | None:
if (
valself
Expand Down
6 changes: 3 additions & 3 deletions pytype/abstract/_typing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Constructs related to type annotations."""

from collections.abc import Iterable, Iterator, Mapping, Sequence, Set
from collections.abc import Iterable, Mapping, Sequence, Set
import dataclasses
import logging
from typing import Any, Literal, TYPE_CHECKING, cast
Expand Down Expand Up @@ -792,7 +792,7 @@ def get_formal_type_parameter(self, t):
]
return Union(new_options, self.ctx)

def get_inner_types(self) -> Iterator[tuple[int, _base.BaseValue]]:
def get_inner_types(self) -> Iterable[tuple[int, _base.BaseValue]]:
return enumerate(self.options)

def update_inner_type(self, key: int, typ: _base.BaseValue) -> None:
Expand Down Expand Up @@ -984,7 +984,7 @@ def instantiate(
return instance.to_variable(node)

def get_special_attribute(
self, node: "cfg.CFGNode", name: str, valself: "cfg.Variable"
self, node: "cfg.CFGNode", name: str, valself: "cfg.Variable | None"
) -> "cfg.Variable | None":
if name == "__getitem__" and not self.resolved:
container = _base.BaseValue.to_annotation_container(self) # pytype: disable=wrong-arg-types
Expand Down
4 changes: 3 additions & 1 deletion pytype/abstract/class_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
_abstract = abstract_utils._abstract # pylint: disable=protected-access


FunctionMapType = Mapping[str, Sequence["_interpreter_function.InterpreterFunction"]]
FunctionMapType = Mapping[
str, Sequence["_interpreter_function.InterpreterFunction"]
]
log: logging.Logger = logging.getLogger(__name__)


Expand Down
78 changes: 48 additions & 30 deletions pytype/abstract/mixin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Mixins for abstract.py."""

from collections.abc import Callable, Iterable, Sequence
import logging
from typing import Any, TYPE_CHECKING
from typing import Any, Self, TYPE_CHECKING

from pytype.abstract import abstract_utils
from pytype.abstract import function
Expand All @@ -11,11 +12,14 @@
from pytype.types import types

if TYPE_CHECKING:
from pytype import datatypes # pylint: disable=g-import-not-at-top, g-bad-import-order
from pytype.abstract import abstract as _abstract # pylint: disable=g-import-not-at-top, g-bad-import-order
from pytype.abstract import _base # pylint: disable=g-import-not-at-top, g-bad-import-order
from pytype.abstract import _function_base # pylint: disable=g-import-not-at-top, g-bad-import-order
else:
_abstract = abstract_utils._abstract # pylint: disable=protected-access

log = logging.getLogger(__name__)
log: logging.Logger = logging.getLogger(__name__)


class MixinMeta(type):
Expand All @@ -24,7 +28,7 @@ class MixinMeta(type):
__mixin_overloads__: dict[str, type[Any]]
_HAS_DYNAMIC_ATTRIBUTES = True

def __init__(cls, name, superclasses, *args, **kwargs):
def __init__(cls, name: str, superclasses, *args, **kwargs) -> None:
super().__init__(name, superclasses, *args, **kwargs)
for sup in superclasses:
if "overloads" in sup.__dict__:
Expand All @@ -38,7 +42,7 @@ def __init__(cls, name, superclasses, *args, **kwargs):
else:
setattr(cls, "__mixin_overloads__", {method: sup})

def super(cls, method):
def super(cls: Self, method):
"""Imitate super() in a mix-in.
This method is a substitute for
Expand Down Expand Up @@ -81,15 +85,22 @@ class PythonConstant(types.PythonConstant, metaclass=MixinMeta):
"r" etc.).
"""

overloads = ("__repr__",)
overloads: tuple[str, ...] = ("__repr__",)

def init_mixin(self, pyval):
def init_mixin(
self,
# TODO: b/350643999 - the type here is too complex and non-sensical
# probably this indicates that this codes need refactoring or either
# the type here is truly intended to be "Any" which also is bad.
# Fix the type.
pyval: "_base.BaseValue | datatypes.MonitorDict[Any, cfg.Variable] | dict[str, cfg.Variable] | Sequence[cfg.Variable] | None",
) -> None:
"""Mix-in equivalent of __init__."""
self.pyval = pyval
self.is_concrete = True
self._printing = False

def str_of_constant(self, printer):
def str_of_constant(self, printer: "Callable[[_base.BaseValue], str]") -> str:
"""Get a string representation of this constant.
Args:
Expand All @@ -102,7 +113,7 @@ def str_of_constant(self, printer):
del printer
return repr(self.pyval)

def __repr__(self):
def __repr__(self) -> str:
if self._printing: # recursion detected
const = "[...]"
else:
Expand All @@ -119,13 +130,13 @@ class HasSlots(metaclass=MixinMeta):
handling of some magic methods (__setitem__ etc.)
"""

overloads = ("get_special_attribute",)
overloads: tuple[str, ...] = ("get_special_attribute",)

def init_mixin(self):
def init_mixin(self) -> None:
self._slots = {}
self._super = {}

def set_slot(self, name, slot):
def set_slot(self, name: str, slot: "_function_base.Function") -> None:
"""Add a new slot to this value."""
assert name not in self._slots, f"slot {name} already occupied"
# For getting a slot value, we don't need a ParameterizedClass's type
Expand All @@ -142,11 +153,13 @@ def set_slot(self, name, slot):
self._super[name] = attr
self._slots[name] = slot

def set_native_slot(self, name, method):
def set_native_slot(self, name, method) -> None:
"""Add a new NativeFunction slot to this value."""
self.set_slot(name, _abstract.NativeFunction(name, method, self.ctx))

def call_pytd(self, node, name, *args):
def call_pytd(
self, node: cfg.CFGNode, name: str, *args
) -> tuple[cfg.CFGNode, cfg.Variable]:
"""Call the (original) pytd version of a method we overwrote."""
return function.call_function(
self.ctx,
Expand All @@ -156,11 +169,16 @@ def call_pytd(self, node, name, *args):
fallback_to_unsolvable=False,
)

def get_special_attribute(self, node, name, valself):
def get_special_attribute(
self, node: cfg.CFGNode, name: str, valself: cfg.Variable | None
) -> cfg.Variable | None:
if name not in self._slots:
return HasSlots.super(self.get_special_attribute)(node, name, valself)
if valself:
slot = self._slots[name].property_get(valself.variable)
# TODO: b/350643999 - Type here seems to be correct on all callsites
# but the type checker rejects this attribute access. Figure out what this
# code is truely doing
slot = self._slots[name].property_get(valself.variable) # pytype: disable=attribute-error
attr = self.ctx.program.NewVariable([slot], [valself], node)
else:
attr = self.ctx.program.NewVariable([self._slots[name]], [], node)
Expand All @@ -182,15 +200,15 @@ class NestedAnnotation(metaclass=MixinMeta):
one but with the given inner types, again as a (key, typ) sequence.
"""

overloads = ("formal",)
overloads: tuple[str, ...] = ("formal",)

def init_mixin(self):
def init_mixin(self) -> None:
self.processed = False
self._seen_for_formal = False # for calculating the 'formal' property
self._formal = None

@property
def formal(self):
def formal(self) -> bool:
"""See BaseValue.formal."""
# We can't compute self.formal in __init__ because doing so would force
# evaluation of our type parameters during initialization, possibly
Expand All @@ -207,13 +225,13 @@ def formal(self):
self._formal = formal
return formal

def get_inner_types(self):
def get_inner_types(self) -> "Iterable[tuple[int | str, _base.BaseValue]]":
raise NotImplementedError()

def update_inner_type(self, key, typ):
def update_inner_type(self, key: int, typ: "_base.BaseValue"):
raise NotImplementedError()

def replace(self, inner_types):
def replace(self, inner_types: "Sequence[tuple[int, _base.BaseValue]]"):
raise NotImplementedError()


Expand All @@ -235,10 +253,10 @@ class LazyMembers(metaclass=MixinMeta):

members: dict[str, cfg.Variable]

def init_mixin(self, member_map):
def init_mixin(self, member_map: dict[str, cfg.Variable]) -> None:
self._member_map = member_map

def _convert_member(self, name, member, subst=None):
def _convert_member(self, name: str, member, subst=None) -> cfg.Variable:
raise NotImplementedError()

def load_lazy_attribute(self, name, subst=None, store=True):
Expand Down Expand Up @@ -273,7 +291,7 @@ class PythonDict(PythonConstant):
# More methods can be implemented by adding the name to `overloads` and
# defining the delegating method.

overloads = PythonConstant.overloads + (
overloads: Sequence[str] = PythonConstant.overloads + (
"__getitem__",
"get",
"__contains__",
Expand All @@ -290,20 +308,20 @@ def __getitem__(self, key):
def get(self, key, default=None):
return self.pyval.get(key, default)

def __contains__(self, key):
def __contains__(self, key) -> bool:
return key in self.pyval

def copy(self):
return self.pyval.copy()
def copy(self) -> "_base.BaseValue | None":
return self.pyval.copy() # pytype: disable=attribute-error

def __iter__(self):
return iter(self.pyval)

def items(self):
return self.pyval.items()
return self.pyval.items() # pytype: disable=attribute-error

def keys(self):
return self.pyval.keys()
return self.pyval.keys() # pytype: disable=attribute-error

def values(self):
return self.pyval.values()
return self.pyval.values() # pytype: disable=attribute-error
7 changes: 5 additions & 2 deletions pytype/types/instances.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
"""Basic datatypes for instances."""

from collections.abc import Callable
from typing import Any
from typing import TYPE_CHECKING

from pytype.types import base

if TYPE_CHECKING:
from pytype.abstract import _base # pylint: disable=g-import-not-at-top, g-bad-import-order


class Module:
name: str


class PythonConstant:
pyval: Any
pyval: "_base.BaseValue | None"
is_concrete: bool

def str_of_constant(self, printer: Callable[[base.BaseValue], str]) -> str:
Expand Down

0 comments on commit 4977c36

Please sign in to comment.