Skip to content

Commit

Permalink
Support "is None" constraints from if statements during inference (#1189
Browse files Browse the repository at this point in the history
)


Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
Co-authored-by: Pierre Sassoulas <pierre.sassoulas@gmail.com>
Co-authored-by: Daniël van Noord <13665637+DanielNoord@users.noreply.github.com>
  • Loading branch information
4 people authored Jan 6, 2023
1 parent f476ebc commit 21880dd
Show file tree
Hide file tree
Showing 6 changed files with 770 additions and 5 deletions.
9 changes: 9 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,15 @@ Release date: 2022-07-09

Refs PyCQA/pylint#7109

* Support "is None" constraints from if statements during inference.

Ref #791
Ref PyCQA/pylint#157
Ref PyCQA/pylint#1472
Ref PyCQA/pylint#2016
Ref PyCQA/pylint#2631
Ref PyCQA/pylint#2880

What's New in astroid 2.11.7?
=============================
Release date: 2022-07-09
Expand Down
24 changes: 20 additions & 4 deletions astroid/bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import collections.abc
import sys
from collections.abc import Sequence
from typing import Any, ClassVar
from typing import TYPE_CHECKING, Any, ClassVar

from astroid import decorators, nodes
from astroid.const import PY310_PLUS
Expand All @@ -35,6 +35,9 @@
else:
from typing_extensions import Literal

if TYPE_CHECKING:
from astroid.constraint import Constraint

objectmodel = lazy_import("interpreter.objectmodel")
helpers = lazy_import("helpers")
manager = lazy_import("manager")
Expand Down Expand Up @@ -146,11 +149,14 @@ def _infer_stmts(
) -> collections.abc.Generator[InferenceResult, None, None]:
"""Return an iterator on statements inferred by each statement in *stmts*."""
inferred = False
constraint_failed = False
if context is not None:
name = context.lookupname
context = context.clone()
constraints = context.constraints.get(name, {})
else:
name = None
constraints = {}
context = InferenceContext()

for stmt in stmts:
Expand All @@ -161,16 +167,26 @@ def _infer_stmts(
# 'context' is always InferenceContext and Instances get '_infer_name' from ClassDef
context.lookupname = stmt._infer_name(frame, name) # type: ignore[union-attr]
try:
stmt_constraints: set[Constraint] = set()
for constraint_stmt, potential_constraints in constraints.items():
if not constraint_stmt.parent_of(stmt):
stmt_constraints.update(potential_constraints)
# Mypy doesn't recognize that 'stmt' can't be Uninferable
for inf in stmt.infer(context=context): # type: ignore[union-attr]
yield inf
inferred = True
if all(constraint.satisfied_by(inf) for constraint in stmt_constraints):
yield inf
inferred = True
else:
constraint_failed = True
except NameInferenceError:
continue
except InferenceError:
yield Uninferable
inferred = True
if not inferred:

if not inferred and constraint_failed:
yield Uninferable
elif not inferred:
raise InferenceError(
"Inference failed for all members of {stmts!r}.",
stmts=stmts,
Expand Down
137 changes: 137 additions & 0 deletions astroid/constraint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/PyCQA/astroid/blob/main/LICENSE
# Copyright (c) https://github.com/PyCQA/astroid/blob/main/CONTRIBUTORS.txt

"""Classes representing different types of constraints on inference values."""
from __future__ import annotations

import sys
from abc import ABC, abstractmethod
from collections.abc import Iterator
from typing import Union

from astroid import bases, nodes, util
from astroid.typing import InferenceResult

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self

_NameNodes = Union[nodes.AssignAttr, nodes.Attribute, nodes.AssignName, nodes.Name]


class Constraint(ABC):
"""Represents a single constraint on a variable."""

def __init__(self, node: nodes.NodeNG, negate: bool) -> None:
self.node = node
"""The node that this constraint applies to."""
self.negate = negate
"""True if this constraint is negated. E.g., "is not" instead of "is"."""

@classmethod
@abstractmethod
def match(
cls: type[Self], node: _NameNodes, expr: nodes.NodeNG, negate: bool = False
) -> Self | None:
"""Return a new constraint for node matched from expr, if expr matches
the constraint pattern.
If negate is True, negate the constraint.
"""

@abstractmethod
def satisfied_by(self, inferred: InferenceResult) -> bool:
"""Return True if this constraint is satisfied by the given inferred value."""


class NoneConstraint(Constraint):
"""Represents an "is None" or "is not None" constraint."""

CONST_NONE: nodes.Const = nodes.Const(None)

@classmethod
def match(
cls: type[Self], node: _NameNodes, expr: nodes.NodeNG, negate: bool = False
) -> Self | None:
"""Return a new constraint for node matched from expr, if expr matches
the constraint pattern.
Negate the constraint based on the value of negate.
"""
if isinstance(expr, nodes.Compare) and len(expr.ops) == 1:
left = expr.left
op, right = expr.ops[0]
if op in {"is", "is not"} and (
_matches(left, node) and _matches(right, cls.CONST_NONE)
):
negate = (op == "is" and negate) or (op == "is not" and not negate)
return cls(node=node, negate=negate)

return None

def satisfied_by(self, inferred: InferenceResult) -> bool:
"""Return True if this constraint is satisfied by the given inferred value."""
# Assume true if uninferable
if inferred is util.Uninferable:
return True

# Return the XOR of self.negate and matches(inferred, self.CONST_NONE)
return self.negate ^ _matches(inferred, self.CONST_NONE)


def get_constraints(
expr: _NameNodes, frame: nodes.LocalsDictNodeNG
) -> dict[nodes.If, set[Constraint]]:
"""Returns the constraints for the given expression.
The returned dictionary maps the node where the constraint was generated to the
corresponding constraint(s).
Constraints are computed statically by analysing the code surrounding expr.
Currently this only supports constraints generated from if conditions.
"""
current_node: nodes.NodeNG | None = expr
constraints_mapping: dict[nodes.If, set[Constraint]] = {}
while current_node is not None and current_node is not frame:
parent = current_node.parent
if isinstance(parent, nodes.If):
branch, _ = parent.locate_child(current_node)
constraints: set[Constraint] | None = None
if branch == "body":
constraints = set(_match_constraint(expr, parent.test))
elif branch == "orelse":
constraints = set(_match_constraint(expr, parent.test, invert=True))

if constraints:
constraints_mapping[parent] = constraints
current_node = parent

return constraints_mapping


ALL_CONSTRAINT_CLASSES = frozenset((NoneConstraint,))
"""All supported constraint types."""


def _matches(node1: nodes.NodeNG | bases.Proxy, node2: nodes.NodeNG) -> bool:
"""Returns True if the two nodes match."""
if isinstance(node1, nodes.Name) and isinstance(node2, nodes.Name):
return node1.name == node2.name
if isinstance(node1, nodes.Attribute) and isinstance(node2, nodes.Attribute):
return node1.attrname == node2.attrname and _matches(node1.expr, node2.expr)
if isinstance(node1, nodes.Const) and isinstance(node2, nodes.Const):
return node1.value == node2.value

return False


def _match_constraint(
node: _NameNodes, expr: nodes.NodeNG, invert: bool = False
) -> Iterator[Constraint]:
"""Yields all constraint patterns for node that match."""
for constraint_cls in ALL_CONSTRAINT_CLASSES:
constraint = constraint_cls.match(node, expr, invert)
if constraint:
yield constraint
6 changes: 6 additions & 0 deletions astroid/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import TYPE_CHECKING, Dict, Optional, Sequence, Tuple

if TYPE_CHECKING:
from astroid import constraint, nodes
from astroid.nodes.node_classes import Keyword, NodeNG

_InferenceCache = Dict[
Expand All @@ -37,6 +38,7 @@ class InferenceContext:
"callcontext",
"boundnode",
"extra_context",
"constraints",
"_nodes_inferred",
)

Expand Down Expand Up @@ -85,6 +87,9 @@ def __init__(
for call arguments
"""

self.constraints: dict[str, dict[nodes.If, set[constraint.Constraint]]] = {}
"""The constraints on nodes."""

@property
def nodes_inferred(self) -> int:
"""
Expand Down Expand Up @@ -134,6 +139,7 @@ def clone(self) -> InferenceContext:
clone.callcontext = self.callcontext
clone.boundnode = self.boundnode
clone.extra_context = self.extra_context
clone.constraints = self.constraints.copy()
return clone

@contextlib.contextmanager
Expand Down
9 changes: 8 additions & 1 deletion astroid/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from collections.abc import Callable, Generator, Iterable, Iterator
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union

from astroid import bases, decorators, helpers, nodes, protocols, util
from astroid import bases, constraint, decorators, helpers, nodes, protocols, util
from astroid.context import (
CallContext,
InferenceContext,
Expand Down Expand Up @@ -242,6 +242,8 @@ def infer_name(
)
context = copy_context(context)
context.lookupname = self.name
context.constraints[self.name] = constraint.get_constraints(self, frame)

return bases._infer_stmts(stmts, context, frame)


Expand Down Expand Up @@ -362,6 +364,11 @@ def infer_attribute(
old_boundnode = context.boundnode
try:
context.boundnode = owner
if isinstance(owner, (nodes.ClassDef, bases.Instance)):
frame = owner if isinstance(owner, nodes.ClassDef) else owner._proxied
context.constraints[self.attrname] = constraint.get_constraints(
self, frame=frame
)
yield from owner.igetattr(self.attrname, context)
except (
AttributeInferenceError,
Expand Down
Loading

0 comments on commit 21880dd

Please sign in to comment.