Skip to content

Commit

Permalink
Complete cache key for inference tip (#2158)
Browse files Browse the repository at this point in the history
The cache key was lacking the `context` arg.

Co-authored-by: Sylvain Ackermann <sylvain.ackermann@gmail.com>
  • Loading branch information
jacobtylerwalls and kriek authored May 6, 2023
1 parent 06fafc4 commit 0740a0d
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 19 deletions.
7 changes: 7 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ Release date: TBA

* Reduce file system access in ``ast_from_file()``.

* Fix incorrect cache keys for inference results, thereby correctly inferring types
for calls instantiating types dynamically.

Closes #1828
Closes pylint-dev/pylint#7464
Closes pylint-dev/pylint#8074

* ``nodes.FunctionDef`` no longer inherits from ``nodes.Lambda``.
This is a breaking change but considered a bug fix as the nodes did not share the same
API and were not interchangeable.
Expand Down
30 changes: 22 additions & 8 deletions astroid/inference_tip.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import sys
from collections.abc import Callable, Iterator

from astroid.context import InferenceContext
from astroid.exceptions import InferenceOverwriteError, UseInferenceDefault
from astroid.nodes import NodeNG
from astroid.typing import InferenceResult, InferFn
Expand All @@ -20,7 +21,11 @@

_P = ParamSpec("_P")

_cache: dict[tuple[InferFn, NodeNG], list[InferenceResult] | None] = {}
_cache: dict[
tuple[InferFn, NodeNG, InferenceContext | None], list[InferenceResult]
] = {}

_CURRENTLY_INFERRING: set[tuple[InferFn, NodeNG]] = set()


def clear_inference_tip_cache() -> None:
Expand All @@ -35,16 +40,25 @@ def _inference_tip_cached(

def inner(*args: _P.args, **kwargs: _P.kwargs) -> Iterator[InferenceResult]:
node = args[0]
try:
result = _cache[func, node]
context = args[1]
partial_cache_key = (func, node)
if partial_cache_key in _CURRENTLY_INFERRING:
# If through recursion we end up trying to infer the same
# func + node we raise here.
if result is None:
raise UseInferenceDefault()
raise UseInferenceDefault
try:
return _cache[func, node, context]
except KeyError:
_cache[func, node] = None
result = _cache[func, node] = list(func(*args, **kwargs))
assert result
# Recursion guard with a partial cache key.
# Using the full key causes a recursion error on PyPy.
# It's a pragmatic compromise to avoid so much recursive inference
# with slightly different contexts while still passing the simple
# test cases included with this commit.
_CURRENTLY_INFERRING.add(partial_cache_key)
result = _cache[func, node, context] = list(func(*args, **kwargs))
# Remove recursion guard.
_CURRENTLY_INFERRING.remove(partial_cache_key)

return iter(result)

return inner
Expand Down
10 changes: 2 additions & 8 deletions tests/brain/test_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,13 +930,7 @@ class A:
assert inferred.value == 42

def test_typing_cast_multiple_inference_calls(self) -> None:
"""Inference of an outer function should not store the result for cast.
https://github.com/pylint-dev/pylint/issues/8074
Possible solution caused RecursionErrors with Python 3.8 and CPython + PyPy.
https://github.com/pylint-dev/astroid/pull/1982
"""
"""Inference of an outer function should not store the result for cast."""
ast_nodes = builder.extract_node(
"""
from typing import TypeVar, cast
Expand All @@ -954,7 +948,7 @@ def ident(var: T) -> T:

i1 = next(ast_nodes[1].infer())
assert isinstance(i1, nodes.Const)
assert i1.value == 2 # should be "Hello"!
assert i1.value == "Hello"


class ReBrainTest(unittest.TestCase):
Expand Down
21 changes: 21 additions & 0 deletions tests/test_regrtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,27 @@ def d(self):
assert isinstance(inferred, Instance)
assert inferred.qname() == ".A"

def test_inference_context_consideration(self) -> None:
"""https://github.com/PyCQA/astroid/issues/1828"""
code = """
class Base:
def return_type(self):
return type(self)()
class A(Base):
def method(self):
return self.return_type()
class B(Base):
def method(self):
return self.return_type()
A().method() #@
B().method() #@
"""
node1, node2 = extract_node(code)
inferred1 = next(node1.infer())
assert inferred1.qname() == ".A"
inferred2 = next(node2.infer())
assert inferred2.qname() == ".B"


class Whatever:
a = property(lambda x: x, lambda x: x) # type: ignore[misc]
Expand Down
4 changes: 1 addition & 3 deletions tests/test_scoped_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1771,9 +1771,7 @@ def __init__(self):
"FinalClass",
"ClassB",
"MixinB",
# We don't recognize what 'cls' is at time of .format() call, only
# what it is at the end.
# "strMixin",
"strMixin",
"ClassA",
"MixinA",
"intMixin",
Expand Down

0 comments on commit 0740a0d

Please sign in to comment.