Skip to content

Commit

Permalink
Modernize usages of collections.namedtuple in pytype.
Browse files Browse the repository at this point in the history
I mostly replaced collections.namedtuple with dataclasses.dataclass. In one
case, I replaced it with attrs.define in order to use a converter. I used
typing.NamedTuple for a couple classes that need namedtuple's iteration and/or
pickling behaviors.

PiperOrigin-RevId: 454776470
  • Loading branch information
rchen152 committed Jun 14, 2022
1 parent fe58b6c commit 351a2cd
Show file tree
Hide file tree
Showing 23 changed files with 257 additions and 149 deletions.
3 changes: 3 additions & 0 deletions pytype/abstract/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ py_library(
DEPS
pytype.utils
pytype.pyc.pyc
pytype.pytd.pytd
pytype.typegraph.cfg
pytype.typegraph.cfg_utils
)
Expand All @@ -216,9 +217,11 @@ py_library(
SRCS
function.py
DEPS
._base
.abstract_utils
pytype.utils
pytype.pytd.pytd
pytype.typegraph.cfg
pytype.typegraph.cfg_utils
)

Expand Down
4 changes: 2 additions & 2 deletions pytype/abstract/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,8 +649,8 @@ def __contains__(self, name):
def _raw_formal_type_parameters(self):
assert isinstance(self._formal_type_parameters,
abstract_utils.LazyFormalTypeParameters)
template, parameters, _ = self._formal_type_parameters
for i, name in enumerate(template):
parameters = self._formal_type_parameters.parameters
for i, name in enumerate(self._formal_type_parameters.template):
# TODO(rechen): A missing parameter should be an error.
yield name, parameters[i] if i < len(parameters) else None

Expand Down
10 changes: 7 additions & 3 deletions pytype/abstract/_pytd_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,10 @@ def compatible_with(new, existing, view):
filtered_mutations = []
errors = collections.defaultdict(dict)

for (obj, name, values), view in all_mutations.items():
for mutation, view in all_mutations.items():
obj = mutation.instance
name = mutation.name
values = mutation.value
if obj.from_annotation:
params = obj.get_instance_type_parameter(name)
ps = {v for v in params.data if should_check(v)}
Expand All @@ -239,12 +242,13 @@ def compatible_with(new, existing, view):
new.append(b.data)
# By updating filtered_mutations only when ps is non-empty, we
# filter out mutations to parameters with type Any.
filtered_mutations.append((obj, name, filtered_values))
filtered_mutations.append(
function.Mutation(obj, name, filtered_values))
if new:
formal = name.split(".")[-1]
errors[obj][formal] = (params, values, obj.from_annotation)
else:
filtered_mutations.append((obj, name, values))
filtered_mutations.append(function.Mutation(obj, name, values))

all_mutations = filtered_mutations

Expand Down
14 changes: 9 additions & 5 deletions pytype/abstract/abstract_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Utilities for abstract.py."""

import collections
import dataclasses
import logging
from typing import Any, Collection, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union

from pytype import datatypes
from pytype import utils
from pytype.pyc import opcodes
from pytype.pyc import pyc
from pytype.pytd import pytd
from pytype.typegraph import cfg
from pytype.typegraph import cfg_utils

Expand Down Expand Up @@ -102,8 +103,11 @@ class AsReturnValue(AsInstance):


# For lazy evaluation of ParameterizedClass.formal_type_parameters
LazyFormalTypeParameters = collections.namedtuple(
"LazyFormalTypeParameters", ("template", "parameters", "subst"))
@dataclasses.dataclass(eq=True, frozen=True)
class LazyFormalTypeParameters:
template: Sequence[Any]
parameters: Sequence[pytd.Node]
subst: Dict[str, cfg.Variable]


# Sentinel for get_atomic_value
Expand Down Expand Up @@ -241,12 +245,12 @@ def apply_mutations(node, get_mutations):
"""Apply mutations yielded from a get_mutations function."""
log.info("Applying mutations")
num_mutations = 0
for obj, name, value in get_mutations():
for mut in get_mutations():
if not num_mutations:
# mutations warrant creating a new CFG node
node = node.ConnectNew(node.name)
num_mutations += 1
obj.merge_instance_type_parameter(node, name, value)
mut.instance.merge_instance_type_parameter(node, mut.name, mut.value)
log.info("Applied %d mutations", num_mutations)
return node

Expand Down
82 changes: 50 additions & 32 deletions pytype/abstract/function.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
"""Representation of Python function headers and calls."""

import collections
import dataclasses
import itertools
import logging
from typing import Any, Dict, Optional, Sequence, Tuple

import attrs

from pytype import datatypes
from pytype.abstract import _base
from pytype.abstract import abstract_utils
from pytype.pytd import pytd
from pytype.pytd import pytd_utils
from pytype.typegraph import cfg
from pytype.typegraph import cfg_utils

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -373,30 +379,28 @@ def get_first_arg(self, callargs):
return callargs.get(self.param_names[0]) if self.param_names else None


class Args(collections.namedtuple(
"Args", ["posargs", "namedargs", "starargs", "starstarargs"])):
"""Represents the parameters of a function call."""
def _convert_namedargs(namedargs):
return {} if namedargs is None else namedargs

def __new__(cls, posargs, namedargs=None, starargs=None, starstarargs=None):
"""Create arguments for a function under analysis.

Args:
posargs: The positional arguments. A tuple of cfg.Variable.
namedargs: The keyword arguments. A dictionary, mapping strings to
cfg.Variable.
starargs: The *args parameter, or None.
starstarargs: The **kwargs parameter, or None.
Returns:
An Args instance.
"""
assert isinstance(posargs, tuple), posargs
cls.replace = cls._replace
return super().__new__(
cls,
posargs=posargs,
namedargs=namedargs or {},
starargs=starargs,
starstarargs=starstarargs)
@attrs.frozen(eq=True)
class Args:
"""Represents the parameters of a function call.
Attributes:
posargs: The positional arguments. A tuple of cfg.Variable.
namedargs: The keyword arguments. A dictionary, mapping strings to
cfg.Variable.
starargs: The *args parameter, or None.
starstarargs: The **kwargs parameter, or None.
"""

posargs: Tuple[cfg.Variable, ...]
namedargs: Dict[str, cfg.Variable] = attrs.field(converter=_convert_namedargs,
default=None)
starargs: Optional[cfg.Variable] = None
starstarargs: Optional[cfg.Variable] = None

def has_namedargs(self):
return bool(self.namedargs)
Expand Down Expand Up @@ -587,16 +591,19 @@ def get_variables(self):

def replace_posarg(self, pos, val):
new_posargs = self.posargs[:pos] + (val,) + self.posargs[pos + 1:]
return self._replace(posargs=new_posargs)
return self.replace(posargs=new_posargs)

def replace_namedarg(self, name, val):
new_namedargs = dict(self.namedargs)
new_namedargs[name] = val
return self._replace(namedargs=new_namedargs)
return self.replace(namedargs=new_namedargs)

def delete_namedarg(self, name):
new_namedargs = {k: v for k, v in self.namedargs.items() if k != name}
return self._replace(namedargs=new_namedargs)
return self.replace(namedargs=new_namedargs)

def replace(self, **kwargs):
return attrs.evolve(self, **kwargs)


class ReturnValueMixin:
Expand Down Expand Up @@ -657,14 +664,19 @@ def __le__(self, other):
return not self.__gt__(other)


BadCall = collections.namedtuple("_", ["sig", "passed_args", "bad_param"])

@dataclasses.dataclass(eq=True, frozen=True)
class BadParam:
name: str
expected: _base.BaseValue
# Should be matcher.ErrorDetails but can't use due to circular dep.
error_details: Optional[Any] = None

class BadParam(
collections.namedtuple("_", ["name", "expected", "error_details"])):

def __new__(cls, name, expected, error_details=None):
return super().__new__(cls, name, expected, error_details)
@dataclasses.dataclass(eq=True, frozen=True)
class BadCall:
sig: Signature
passed_args: Sequence[Tuple[str, _base.BaseValue]]
bad_param: Optional[BadParam]


class InvalidParameters(FailedFunctionCall):
Expand Down Expand Up @@ -728,7 +740,13 @@ def __init__(self, sig, passed_args, ctx, missing_parameter):
# pylint: enable=g-bad-exception-name


class Mutation(collections.namedtuple("_", ["instance", "name", "value"])):
@dataclasses.dataclass(frozen=True)
class Mutation:
"""A type mutation."""

instance: _base.BaseValue
name: str
value: cfg.Variable

def __eq__(self, other):
return (self.instance == other.instance and
Expand Down
4 changes: 3 additions & 1 deletion pytype/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,9 @@ def import_error(self, stack, module_name):

def _invalid_parameters(self, stack, message, bad_call):
"""Log an invalid parameters error."""
sig, passed_args, bad_param = bad_call
sig = bad_call.sig
passed_args = bad_call.passed_args
bad_param = bad_call.bad_param
expected = self._print_args(self._iter_expected(sig, bad_param), bad_param)
literal = "Literal[" in expected
actual = self._print_args(
Expand Down
12 changes: 6 additions & 6 deletions pytype/inspect/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,15 @@ def _is_constant(val):
color = "white" if val.origins else "red"
self.add_node(val, label=label, fillcolor=color)
self.add_edge(variable, val, arrowhead="none")
for loc, srcsets in val.origins:
if loc == program.entrypoint:
for origin in val.origins:
if origin.where == program.entrypoint:
continue
for srcs in srcsets:
for srcs in origin.source_sets:
self.add_node(srcs, label="")
self.add_edge(val, srcs, color="pink", arrowhead="none", weight=40)
if loc not in ignored:
self.add_edge(
loc, srcs, arrowhead="none", style="dotted", weight=5)
if origin.where not in ignored:
self.add_edge(origin.where, srcs, arrowhead="none",
style="dotted", weight=5)
for src in srcs:
self.add_edge(src, srcs, color="lightblue", weight=2)

Expand Down
9 changes: 6 additions & 3 deletions pytype/load_pytd.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Load and link .pyi files."""

import collections
import dataclasses
import logging
import os
import pickle
Expand Down Expand Up @@ -44,8 +44,11 @@ def create_loader(options):
return Loader(options)


ResolvedModule = collections.namedtuple(
"ResolvedModule", ("module_name", "filename", "ast"))
@dataclasses.dataclass(eq=True, frozen=True)
class ResolvedModule:
module_name: str
filename: str
ast: pytd.TypeDeclUnit


class Module:
Expand Down
7 changes: 5 additions & 2 deletions pytype/load_pytd_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Tests for load_pytd.py."""

import collections
import contextlib
import dataclasses
import io
import os
import textwrap
Expand Down Expand Up @@ -613,7 +613,10 @@ class Foo(Generic[AnyStr]): ...
self.assertEqual(pytd_utils.Print(ast.Lookup("b.x").type), "a.Foo[str]")


_Module = collections.namedtuple("_", ["module_name", "file_name"])
@dataclasses.dataclass(eq=True, frozen=True)
class _Module:
module_name: str
file_name: str


class PickledPyiLoaderTest(test_base.UnitTest):
Expand Down
11 changes: 7 additions & 4 deletions pytype/module_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Representation of modules."""

import collections
import dataclasses
import os
from typing import Sequence


class Module(collections.namedtuple("_", "path target name kind")):
@dataclasses.dataclass(eq=True, frozen=True)
class Module:
"""Inferred information about a module.
Attributes:
Expand All @@ -17,8 +18,10 @@ class Module(collections.namedtuple("_", "path target name kind")):
full_path: The full path to the module (path + target).
"""

def __new__(cls, path, target, name, kind=None):
return super(Module, cls).__new__(cls, path, target, name, kind or "Local")
path: str
target: str
name: str
kind: str = "Local"

@property
def full_path(self):
Expand Down
12 changes: 8 additions & 4 deletions pytype/pytd/serialize_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
disk, which is faster to digest than a pyi file.
"""

import collections
from typing import List, NamedTuple, Optional, Set, Tuple

from pytype import utils
from pytype.pyi import parser
Expand Down Expand Up @@ -58,9 +58,12 @@ def VisitLateType(self, node):
return node


SerializableTupleClass = collections.namedtuple(
"_", ["ast", "dependencies", "late_dependencies", "class_type_nodes",
"is_package"])
class SerializableTupleClass(NamedTuple):
ast: pytd.TypeDeclUnit
dependencies: List[Tuple[str, Set[str]]]
late_dependencies: List[Tuple[str, Set[str]]]
class_type_nodes: Optional[List[pytd.ClassType]]
is_package: bool


class SerializableAst(SerializableTupleClass):
Expand All @@ -74,6 +77,7 @@ class SerializableAst(SerializableTupleClass):
Therefore it might be different from the set found by
visitors.CollectDependencies in
load_pytd._load_and_resolve_ast_dependencies.
late_dependencies: This AST's late dependencies.
class_type_nodes: A list of all the ClassType instances in ast or None. If
this list is provided only the ClassType instances in the list will be
visited and have their .cls set. If this attribute is None the whole AST
Expand Down
6 changes: 5 additions & 1 deletion pytype/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import collections
import copy
import dataclasses
import io
import re
import sys
Expand All @@ -16,7 +17,10 @@
import unittest


FakeCode = collections.namedtuple("FakeCode", "co_filename co_name")
@dataclasses.dataclass(eq=True, frozen=True)
class FakeCode:
co_filename: str
co_name: str


class FakeOpcode:
Expand Down
Loading

0 comments on commit 351a2cd

Please sign in to comment.