Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve error message for invalid input in Get()s #11081

Merged
merged 5 commits into from
Oct 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 81 additions & 109 deletions src/python/pants/engine/internals/scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@
from textwrap import dedent
from typing import Any

from pants.engine.internals.engine_testutil import (
assert_equal_with_printing,
remove_locations_from_traceback,
)
import pytest

from pants.engine.internals.engine_testutil import remove_locations_from_traceback
from pants.engine.internals.scheduler import ExecutionError
from pants.engine.rules import Get, rule
from pants.engine.unions import UnionRule, union
from pants.testutil.rule_runner import QueryRule
from pants.testutil.rule_runner import QueryRule, RuleRunner
from pants.testutil.test_base import TestBase


Expand All @@ -28,16 +27,6 @@ class B:
pass


def fn_raises(x):
raise Exception(f"An exception for {type(x).__name__}")


@rule(desc="Nested raise")
def nested_raise(b: B) -> A:
fn_raises(b)
return A()


@rule
def consumes_a_and_b(a: A, b: B) -> str:
return str(f"{a} and {b}")
Expand Down Expand Up @@ -122,37 +111,6 @@ async def error_msg_test_rule(union_wrapper: UnionWrapper) -> UnionX:
raise AssertionError("The statement above this one should have failed!")


class TypeCheckFailWrapper:
"""This object wraps another object which will be used to demonstrate a type check failure when
the engine processes an `await Get(...)` statement."""

def __init__(self, inner):
self.inner = inner


@rule
async def a_typecheck_fail_test(wrapper: TypeCheckFailWrapper) -> A:
# This `await` would use the `nested_raise` rule, but it won't get to the point of raising since
# the type check will fail at the Get.
_ = await Get(A, B, wrapper.inner) # noqa: F841
return A()


@dataclass(frozen=True)
class CollectionType:
# NB: We pass an unhashable type when we want this to fail at the root, and a hashable type
# when we'd like it to succeed.
items: Any


@rule
async def c_unhashable(_: CollectionType) -> C:
# This `await` would use the `nested_raise` rule, but it won't get to the point of raising since
# the hashability check will fail.
_result = await Get(A, B, list()) # noqa: F841
return C()


@rule
def boolean_and_int(i: int, b: bool) -> A:
return A()
Expand Down Expand Up @@ -247,67 +205,81 @@ def test_union_rules_no_docstring(self):
self.request(UnionX, [UnionWrapper(UnionA())])


class SchedulerWithNestedRaiseTest(TestBase):
@classmethod
def rules(cls):
return (
*super().rules(),
a_typecheck_fail_test,
c_unhashable,
nested_raise,
QueryRule(A, (TypeCheckFailWrapper,)),
QueryRule(A, (B,)),
QueryRule(C, (CollectionType,)),
)
# -----------------------------------------------------------------------------------------------
# Test tracebacks.
# -----------------------------------------------------------------------------------------------

def test_get_type_match_failure(self):
"""Test that Get(...)s are now type-checked during rule execution, to allow for union
types."""

with self.assertRaises(ExecutionError) as cm:
# `a_typecheck_fail_test` above expects `wrapper.inner` to be a `B`.
self.request(A, [TypeCheckFailWrapper(A())])

expected_regex = "WithDeps.*did not declare a dependency on JustGet"
self.assertRegex(str(cm.exception), expected_regex)

def test_unhashable_root_params_failure(self):
"""Test that unhashable root params result in a structured error."""
# This will fail at the rust boundary, before even entering the engine.
with self.assertRaisesRegex(TypeError, "unhashable type: 'list'"):
self.request(C, [CollectionType([1, 2, 3])])

def test_unhashable_get_params_failure(self):
"""Test that unhashable Get(...) params result in a structured error."""
# This will fail inside of `c_unhashable_dataclass`.
with self.assertRaisesRegex(ExecutionError, "unhashable type: 'list'"):
self.request(C, [CollectionType(tuple())])

def test_trace_includes_rule_exception_traceback(self):
# Execute a request that will trigger the nested raise, and then directly inspect its trace.
request = self.scheduler.execution_request([A], [B()])
_, throws = self.scheduler.execute(request)

with self.assertRaises(ExecutionError) as cm:
self.scheduler._raise_on_error([t for _, t in throws])

trace = remove_locations_from_traceback(str(cm.exception))
assert_equal_with_printing(
self,
dedent(
f"""\
1 Exception encountered:

Engine traceback:
in select
in {self.__module__}.{nested_raise.__name__}
Traceback (most recent call last):
File LOCATION-INFO, in nested_raise
fn_raises(b)
File LOCATION-INFO, in fn_raises
raise Exception(f"An exception for {{type(x).__name__}}")
Exception: An exception for B
"""
),
trace,
)

def fn_raises():
raise Exception("An exception!")


@rule(desc="Nested raise")
def nested_raise() -> A:
fn_raises()
return A()


def test_trace_includes_rule_exception_traceback() -> None:
rule_runner = RuleRunner(rules=[nested_raise, QueryRule(A, [])])
with pytest.raises(ExecutionError) as exc:
rule_runner.request(A, [])
normalized_traceback = remove_locations_from_traceback(str(exc.value))
assert normalized_traceback == dedent(
f"""\
1 Exception encountered:

Engine traceback:
in select
in {__name__}.{nested_raise.__name__}
Traceback (most recent call last):
File LOCATION-INFO, in nested_raise
fn_raises()
File LOCATION-INFO, in fn_raises
raise Exception("An exception!")
Exception: An exception!
"""
)


# -----------------------------------------------------------------------------------------------
# Test unhashable types.
# -----------------------------------------------------------------------------------------------


@dataclass(frozen=True)
class MaybeHashableWrapper:
maybe_hashable: Any


@rule
async def unhashable(_: MaybeHashableWrapper) -> B:
return B()


@rule
async def call_unhashable_with_invalid_input() -> C:
_ = await Get(B, MaybeHashableWrapper([1, 2]))
return C()


def test_unhashable_types_failure() -> None:
rule_runner = RuleRunner(
rules=[
unhashable,
call_unhashable_with_invalid_input,
QueryRule(B, [MaybeHashableWrapper]),
QueryRule(C, []),
]
)

# Succeed if an argument to a rule is hashable.
assert rule_runner.request(B, [MaybeHashableWrapper((1, 2))]) == B()
# But fail if an argument to a rule is unhashable. This is a TypeError because it fails while
# hashing as part of FFI.
with pytest.raises(TypeError, match="unhashable type: 'list'"):
rule_runner.request(B, [MaybeHashableWrapper([1, 2])])

# Fail if the `input` in a `Get` is not hashable.
with pytest.raises(ExecutionError, match="unhashable type: 'list'"):
rule_runner.request(C, [])
57 changes: 41 additions & 16 deletions src/python/pants/engine/internals/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
overload,
)

from pants.engine.unions import union
from pants.util.meta import frozen_after_init

_Output = TypeVar("_Output")
Expand Down Expand Up @@ -149,33 +150,57 @@ def __init__(
input_arg0: Union[Type[_Input], _Input],
input_arg1: Optional[_Input] = None,
) -> None:
self.output_type = output_type
self.input_type = self._validate_input_type(
input_arg0 if input_arg1 is not None else type(input_arg0)
)
self.input = self._validate_input(input_arg1 if input_arg1 is not None else input_arg0)

self._validate_output_type()
self.output_type = self._validate_output_type(output_type)
if input_arg1 is None:
self.input_type = type(input_arg0)
self.input = self._validate_input(input_arg0, shorthand_form=True)
else:
self.input_type = self._validate_explicit_input_type(input_arg0)
self.input = self._validate_input(input_arg1, shorthand_form=False)

def _validate_output_type(self) -> None:
if not isinstance(self.output_type, type):
@staticmethod
def _validate_output_type(output_type: Any) -> Type[_Output]:
if not isinstance(output_type, type):
raise TypeError(
f"The output type must be a type, but given {self.output_type} of type "
f"{type(self.output_type)}."
"Invalid Get. The first argument (the output type) must be a type, but given "
f"`{output_type}` with type {type(output_type)}."
)
return cast(Type[_Output], output_type)

@staticmethod
def _validate_input_type(input_type: Any) -> Type[_Input]:
def _validate_explicit_input_type(input_type: Any) -> Type[_Input]:
if not isinstance(input_type, type):
raise TypeError(
f"The input type must be a type, but given {input_type} of type {type(input_type)}."
"Invalid Get. Because you are using the longhand form Get(OutputType, InputType, "
f"input), the second argument must be a type, but given `{input_type}` of type "
f"{type(input_type)}."
)
return cast(Type[_Input], input_type)

@staticmethod
def _validate_input(input_: Any) -> _Input:
def _validate_input(self, input_: Any, *, shorthand_form: bool) -> _Input:
if isinstance(input_, type):
raise TypeError(f"The input argument cannot be a type, but given {input_}.")
if shorthand_form:
raise TypeError(
"Invalid Get. Because you are using the shorthand form "
"Get(OutputType, InputType(constructor args)), the second argument should be "
f"a constructor call, rather than a type, but given {input_}."
)
else:
raise TypeError(
"Invalid Get. Because you are using the longhand form "
"Get(OutputType, InputType, input), the third argument should be "
f"an object, rather than a type, but given {input_}."
)
# If the input_type is not annotated with `@union`, then we validate that the input is
# exactly the same type as the input_type. (Why not check unions? We don't have access to
# `UnionMembership` to know if it's a valid union member. The engine will check that.)
if not union.is_instance(self.input_type) and type(input_) != self.input_type:
# We can assume we're using the longhand form because the shorthand form guarantees
# that the `input_type` is the same as `input`.
raise TypeError(
f"Invalid Get. The third argument `{input_}` must have the exact same type as the "
f"second argument, {self.input_type}, but had the type {type(input_)}."
)
return cast(_Input, input_)

def __await__(
Expand Down
Loading