diff --git a/src/python/pants/backend/awslambda/python/target_types_test.py b/src/python/pants/backend/awslambda/python/target_types_test.py index 5127c2dbf0e..4c84588b64d 100644 --- a/src/python/pants/backend/awslambda/python/target_types_test.py +++ b/src/python/pants/backend/awslambda/python/target_types_test.py @@ -19,9 +19,8 @@ from pants.backend.python.target_types import PythonLibrary, PythonRequirementLibrary from pants.backend.python.target_types_rules import rules as python_target_types_rules from pants.build_graph.address import Address -from pants.engine.internals.scheduler import ExecutionError from pants.engine.target import InjectedDependencies, InvalidFieldException -from pants.testutil.rule_runner import QueryRule, RuleRunner +from pants.testutil.rule_runner import QueryRule, RuleRunner, engine_error @pytest.fixture @@ -91,10 +90,10 @@ def assert_resolved(handler: str, *, expected: str, is_file: bool) -> None: assert_resolved("path.to.lambda:func", expected="path.to.lambda:func", is_file=False) assert_resolved("lambda.py:func", expected="project.lambda:func", is_file=True) - with pytest.raises(ExecutionError): + with engine_error(contains="Unmatched glob"): assert_resolved("doesnt_exist.py:func", expected="doesnt matter", is_file=True) # Resolving >1 file is an error. - with pytest.raises(ExecutionError): + with engine_error(InvalidFieldException): assert_resolved("*.py:func", expected="doesnt matter", is_file=True) diff --git a/src/python/pants/backend/go/util_rules/external_module_test.py b/src/python/pants/backend/go/util_rules/external_module_test.py index 0ff4d22e3ce..217d4ee07b0 100644 --- a/src/python/pants/backend/go/util_rules/external_module_test.py +++ b/src/python/pants/backend/go/util_rules/external_module_test.py @@ -22,10 +22,9 @@ from pants.backend.go.util_rules.go_pkg import ResolvedGoPackage from pants.engine.addresses import Address from pants.engine.fs import Digest, PathGlobs, Snapshot -from pants.engine.internals.scheduler import ExecutionError from pants.engine.process import ProcessExecutionFailure from pants.engine.rules import QueryRule -from pants.testutil.rule_runner import RuleRunner +from pants.testutil.rule_runner import RuleRunner, engine_error @pytest.fixture @@ -127,16 +126,13 @@ def assert_module(module: str, version: str, sample_file: str) -> None: def test_download_modules_missing_module(rule_runner: RuleRunner) -> None: input_digest = rule_runner.make_snapshot({"go.mod": GO_MOD, "go.sum": GO_SUM}).digest - with pytest.raises(ExecutionError) as exc: + with engine_error( + AssertionError, contains="The module some_project.org/project@v1.1 was not downloaded" + ): rule_runner.request( DownloadedModule, [DownloadedModuleRequest("some_project.org/project", "v1.1", input_digest)], ) - underlying_exception = exc.value.wrapped_exceptions[0] - assert isinstance(underlying_exception, AssertionError) - assert "The module some_project.org/project@v1.1 was not downloaded" in str( - underlying_exception - ) def test_download_modules_invalid_go_sum(rule_runner: RuleRunner) -> None: @@ -157,11 +153,8 @@ def test_download_modules_invalid_go_sum(rule_runner: RuleRunner) -> None: ), } ).digest - with pytest.raises(ExecutionError) as exc: + with engine_error(ProcessExecutionFailure, contains="SECURITY ERROR"): rule_runner.request(AllDownloadedModules, [AllDownloadedModulesRequest(input_digest)]) - underlying_exception = exc.value.wrapped_exceptions[0] - assert isinstance(underlying_exception, ProcessExecutionFailure) - assert "SECURITY ERROR" in str(underlying_exception) def test_download_modules_missing_go_sum(rule_runner: RuleRunner) -> None: @@ -183,10 +176,8 @@ def test_download_modules_missing_go_sum(rule_runner: RuleRunner) -> None: ), } ).digest - with pytest.raises(ExecutionError) as exc: + with engine_error(contains="`go.mod` and/or `go.sum` changed!"): rule_runner.request(AllDownloadedModules, [AllDownloadedModulesRequest(input_digest)]) - underlying_exception = exc.value.wrapped_exceptions[0] - assert "`go.mod` and/or `go.sum` changed!" in str(underlying_exception) def test_determine_external_package_info(rule_runner: RuleRunner) -> None: diff --git a/src/python/pants/testutil/rule_runner.py b/src/python/pants/testutil/rule_runner.py index e166c3a8eca..aceb6b8be79 100644 --- a/src/python/pants/testutil/rule_runner.py +++ b/src/python/pants/testutil/rule_runner.py @@ -27,7 +27,7 @@ from pants.engine.goal import Goal from pants.engine.internals import native_engine from pants.engine.internals.native_engine import PyExecutor -from pants.engine.internals.scheduler import SchedulerSession +from pants.engine.internals.scheduler import ExecutionError, SchedulerSession from pants.engine.internals.selectors import Get, Params from pants.engine.internals.session import SessionValues from pants.engine.process import InteractiveRunner @@ -75,6 +75,46 @@ def wrapper(*args, **kwargs): return wrapper +@contextmanager +def engine_error( + expected_underlying_exception: type[Exception] = Exception, *, contains: str | None = None +) -> Iterator[None]: + """A context manager to catch `ExecutionError`s in tests and check that the underlying exception + is expected. + + Use like this: + + with engine_error(ValueError, contains="foo"): + rule_runner.request(OutputType, [input]) + + Will raise AssertionError if no ExecutionError occurred. + """ + try: + yield + except ExecutionError as exec_error: + if not len(exec_error.wrapped_exceptions) == 1: + formatted_errors = "\n\n".join(repr(e) for e in exec_error.wrapped_exceptions) + raise ValueError( + "Multiple underlying exceptions, but this helper function expected only one. " + "Use `with pytest.raises(ExecutionError) as exc` directly and inspect " + "`exc.value.wrapped_exceptions`.\n\n" + f"Errors: {formatted_errors}" + ) + underlying = exec_error.wrapped_exceptions[0] + if not isinstance(underlying, expected_underlying_exception): + raise AssertionError( + "ExecutionError occurred as expected, but the underlying exception had type " + f"{type(underlying)} rather than the expected type " + f"{expected_underlying_exception}." + ) + if contains is not None and contains not in str(underlying): + raise AssertionError( + "Expected value not found in exception.\n" + f"expected: {contains}\n\n" + f"exception: {underlying}" + ) + + # ----------------------------------------------------------------------------------------------- # `RuleRunner` # -----------------------------------------------------------------------------------------------