From dfbfd75f3b1ce1d4c53a639eaeda2254d75078d8 Mon Sep 17 00:00:00 2001 From: Andreas Stenius Date: Tue, 3 Jan 2023 23:11:36 -0500 Subject: [PATCH] Support catching `@rule` errors (#17911) * reviving #10954, superseding #17745, closing #17910 Co-authored-by: Danny McClanahan <1305167+cosmicexplorer@users.noreply.github.com> --- .../backend/python/goals/setup_py_test.py | 3 +- src/python/pants/base/exceptions.py | 13 ++ .../pants/engine/internals/engine_test.py | 46 ++++- .../pants/engine/internals/native_engine.pyi | 14 +- .../pants/engine/internals/scheduler_test.py | 171 +++++++++++++++++- .../pants/engine/internals/selectors.py | 62 ++++++- src/python/pants/testutil/rule_runner.py | 35 +++- src/rust/engine/src/externs/mod.rs | 43 ++++- src/rust/engine/src/intrinsics.rs | 2 +- src/rust/engine/src/nodes.rs | 53 ++++-- src/rust/engine/src/python.rs | 86 +++++++-- 11 files changed, 462 insertions(+), 66 deletions(-) diff --git a/src/python/pants/backend/python/goals/setup_py_test.py b/src/python/pants/backend/python/goals/setup_py_test.py index 5e5b59cda9b..beeb5859feb 100644 --- a/src/python/pants/backend/python/goals/setup_py_test.py +++ b/src/python/pants/backend/python/goals/setup_py_test.py @@ -55,6 +55,7 @@ ) from pants.backend.python.util_rules import dists, python_sources from pants.backend.python.util_rules.interpreter_constraints import InterpreterConstraints +from pants.base.exceptions import IntrinsicError from pants.core.goals.package import BuiltPackage from pants.core.target_types import FileTarget, ResourcesGeneratorTarget, ResourceTarget from pants.core.target_types import rules as core_target_types_rules @@ -615,7 +616,7 @@ def test_generate_long_description_field_from_non_existing_file( assert_chroot_error( chroot_rule_runner, Address("src/python/foo", target_name="foo-dist"), - Exception, + IntrinsicError, ) diff --git a/src/python/pants/base/exceptions.py b/src/python/pants/base/exceptions.py index 949ef207881..00c95128ecc 100644 --- a/src/python/pants/base/exceptions.py +++ b/src/python/pants/base/exceptions.py @@ -5,6 +5,12 @@ from typing import TYPE_CHECKING +from pants.engine.internals.native_engine import EngineError as EngineError # noqa: F401 +from pants.engine.internals.native_engine import ( # noqa: F401 + IncorrectProductError as IncorrectProductError, +) +from pants.engine.internals.native_engine import IntrinsicError as IntrinsicError # noqa: F401 + if TYPE_CHECKING: from pants.engine.internals.native_engine import PyFailure @@ -38,6 +44,13 @@ class MappingError(Exception): class NativeEngineFailure(Exception): """A wrapper around a `Failure` instance. + The failure instance being wrapped can come from an exception raised in a rule. When this + failure is returned to a requesting rule it is first unwrapped so the original exception will be + presented in the rule, thus the `NativeEngineFailure` exception will not be seen in rule code. + + This is different from the other `EngineError` based exceptions which doesn't originate from + rule code. + TODO: This type is defined in Python because pyo3 doesn't support declaring Exceptions with additional fields. See https://github.com/PyO3/pyo3/issues/295 """ diff --git a/src/python/pants/engine/internals/engine_test.py b/src/python/pants/engine/internals/engine_test.py index a6113ab7e75..d9c5713eaf3 100644 --- a/src/python/pants/engine/internals/engine_test.py +++ b/src/python/pants/engine/internals/engine_test.py @@ -11,6 +11,7 @@ import pytest from pants.backend.python.target_types import PythonSourcesGeneratorTarget +from pants.base.exceptions import IntrinsicError from pants.base.specs import Specs from pants.base.specs_parser import SpecsParser from pants.engine.engine_aware import EngineAwareParameter, EngineAwareReturnType @@ -21,6 +22,7 @@ Digest, DigestContents, FileContent, + MergeDigests, Snapshot, ) from pants.engine.internals.engine_testutil import ( @@ -40,8 +42,9 @@ from pants.engine.unions import UnionRule, union from pants.goal.run_tracker import RunTracker from pants.testutil.option_util import create_options_bootstrapper -from pants.testutil.rule_runner import QueryRule, RuleRunner +from pants.testutil.rule_runner import QueryRule, RuleRunner, engine_error from pants.util.logging import LogLevel +from pants.util.strutil import softwrap class A: @@ -1006,3 +1009,44 @@ async def for_member() -> str: ) assert "yep" == rule_runner.request(str, []) + + +@dataclass(frozen=True) +class FileInput: + filename: str + + +@dataclass(frozen=True) +class MergedOutput: + digest: Digest + + +class MergeErr(Exception): + pass + + +@rule +async def catch_merge_digests_error(file_input: FileInput) -> MergedOutput: + # Create two separate digests writing different contents to the same file path. + input_1 = CreateDigest((FileContent(path=file_input.filename, content=b"yes"),)) + input_2 = CreateDigest((FileContent(path=file_input.filename, content=b"no"),)) + digests = await MultiGet(Get(Digest, CreateDigest, input_1), Get(Digest, CreateDigest, input_2)) + try: + merged = await Get(Digest, MergeDigests(digests)) + except IntrinsicError as e: + raise MergeErr(f"error merging digests for input {file_input}: {e}") + return MergedOutput(merged) + + +def test_catch_intrinsic_error() -> None: + rule_runner = RuleRunner( + rules=[catch_merge_digests_error, QueryRule(MergedOutput, (FileInput,))] + ) + msg = softwrap( + """\ + error merging digests for input FileInput(filename='some-file.txt'): Can only merge + Directories with no duplicates, but found 2 duplicate entries in : + """ + ) + with engine_error(MergeErr, contains=msg): + rule_runner.request(MergedOutput, (FileInput("some-file.txt"),)) diff --git a/src/python/pants/engine/internals/native_engine.pyi b/src/python/pants/engine/internals/native_engine.pyi index d8db2811271..c42ec24400a 100644 --- a/src/python/pants/engine/internals/native_engine.pyi +++ b/src/python/pants/engine/internals/native_engine.pyi @@ -20,7 +20,8 @@ from pants.engine.process import InteractiveProcess, InteractiveProcessResult # (core) # ------------------------------------------------------------------------------ -class PyFailure: ... +class PyFailure: + def get_error(self) -> Exception | None: ... # ------------------------------------------------------------------------------ # Address (parsing) @@ -478,3 +479,14 @@ class PyThreadLocals: class PollTimeout(Exception): pass + +# Prefer to import these exception types from `pants.base.exceptions` + +class EngineError(Exception): + """Base exception used for errors originating from the native engine.""" + +class IntrinsicError(EngineError): + """Exceptions raised for failures within intrinsic methods implemented in Rust.""" + +class IncorrectProductError(EngineError): + """Exceptions raised when a rule's return value doesn't match its declared type.""" diff --git a/src/python/pants/engine/internals/scheduler_test.py b/src/python/pants/engine/internals/scheduler_test.py index c82e5a6dd6b..0c0bd2ba34d 100644 --- a/src/python/pants/engine/internals/scheduler_test.py +++ b/src/python/pants/engine/internals/scheduler_test.py @@ -9,11 +9,12 @@ import pytest +from pants.base.exceptions import IncorrectProductError from pants.engine.internals.engine_testutil import remove_locations_from_traceback from pants.engine.internals.scheduler import ExecutionError from pants.engine.rules import Get, rule from pants.engine.unions import UnionRule, union -from pants.testutil.rule_runner import QueryRule, RuleRunner +from pants.testutil.rule_runner import QueryRule, RuleRunner, engine_error # ----------------------------------------------------------------------------------------------- # Test params @@ -248,6 +249,115 @@ def test_outlined_get() -> None: ) in str(exc.value.args[0]) +@dataclass(frozen=True) +class SomeInput: + s: str + + +@dataclass(frozen=True) +class SomeOutput: + s: str + + +@rule +def raise_an_exception(some_input: SomeInput) -> SomeOutput: + raise Exception(some_input.s) + + +@dataclass(frozen=True) +class OuterInput: + s: str + + +@rule +async def catch_an_exception(outer_input: OuterInput) -> SomeOutput: + try: + return await Get(SomeOutput, SomeInput(outer_input.s)) + except Exception as e: + return SomeOutput(str(e)) + + +@rule(desc="error chain") +async def catch_and_reraise(outer_input: OuterInput) -> SomeOutput: + """This rule is used in a dedicated test only, so does not conflict with + `catch_an_exception`.""" + try: + return await Get(SomeOutput, SomeInput(outer_input.s)) + except Exception as e: + raise Exception("nested exception!") from e + + +class InputWithNothing: + pass + + +GLOBAL_FLAG: bool = True + + +@rule +def raise_an_exception_upon_global_state(input_with_nothing: InputWithNothing) -> SomeOutput: + if GLOBAL_FLAG: + raise Exception("global flag is set!") + return SomeOutput("asdf") + + +@rule +def return_a_wrong_product_type(input_with_nothing: InputWithNothing) -> A: + return B() # type: ignore[return-value] + + +@rule +async def catch_a_wrong_product_type(input_with_nothing: InputWithNothing) -> B: + try: + _ = await Get(A, InputWithNothing, input_with_nothing) + except IncorrectProductError as e: + raise Exception(f"caught product type error: {e}") + return B() + + +@pytest.fixture +def rule_error_runner() -> RuleRunner: + return RuleRunner( + rules=( + consumes_a_and_b, + QueryRule(str, (A, B)), + transitive_b_c, + QueryRule(str, (A, C)), + transitive_coroutine_rule, + QueryRule(D, (C,)), + boolean_and_int, + QueryRule(A, (int, bool)), + raise_an_exception, + QueryRule(SomeOutput, (SomeInput,)), + catch_an_exception, + QueryRule(SomeOutput, (OuterInput,)), + raise_an_exception_upon_global_state, + QueryRule(SomeOutput, (InputWithNothing,)), + return_a_wrong_product_type, + QueryRule(A, (InputWithNothing,)), + catch_a_wrong_product_type, + QueryRule(B, (InputWithNothing,)), + ) + ) + + +def test_catch_inner_exception(rule_error_runner: RuleRunner) -> None: + assert rule_error_runner.request(SomeOutput, [OuterInput("asdf")]) == SomeOutput("asdf") + + +def test_exceptions_uncached(rule_error_runner: RuleRunner) -> None: + global GLOBAL_FLAG + with engine_error(Exception, contains="global flag is set!"): + rule_error_runner.request(SomeOutput, [InputWithNothing()]) + GLOBAL_FLAG = False + assert rule_error_runner.request(SomeOutput, [InputWithNothing()]) == SomeOutput("asdf") + + +def test_incorrect_product_type(rule_error_runner: RuleRunner) -> None: + with engine_error(Exception, contains="caught product type error"): + rule_error_runner.request(B, [InputWithNothing()]) + + # ----------------------------------------------------------------------------------------------- # Test tracebacks # ----------------------------------------------------------------------------------------------- @@ -264,11 +374,7 @@ def nested_raise() -> A: def test_trace_includes_rule_exception_traceback() -> None: - rule_runner = RuleRunner(rules=[nested_raise, QueryRule(A, [])]) - with pytest.raises(ExecutionError) as exc: - rule_runner.request(A, []) - normalized_traceback = remove_locations_from_traceback(str(exc.value)) - assert normalized_traceback == dedent( + normalized_traceback = dedent( f"""\ 1 Exception encountered: @@ -288,6 +394,59 @@ def test_trace_includes_rule_exception_traceback() -> None: """ ) + rule_runner = RuleRunner(rules=[nested_raise, QueryRule(A, [])]) + with engine_error(Exception, contains=normalized_traceback, normalize_tracebacks=True): + rule_runner.request(A, []) + + +def test_trace_includes_nested_exception_traceback() -> None: + normalized_traceback = dedent( + f"""\ + 1 Exception encountered: + + Engine traceback: + in select + .. + in {__name__}.{catch_and_reraise.__name__} + error chain + in {__name__}.{raise_an_exception.__name__} + .. + + Traceback (most recent call last): + File LOCATION-INFO, in raise_an_exception + raise Exception(some_input.s) + Exception: asdf + + During handling of the above exception, another exception occurred: + + Traceback (most recent call last): + File LOCATION-INFO, in catch_and_reraise + return await Get(SomeOutput, SomeInput(outer_input.s)) + File LOCATION-INFO, in __await__ + result = yield self + Exception: asdf + + The above exception was the direct cause of the following exception: + + Traceback (most recent call last): + File LOCATION-INFO, in native_engine_generator_send + res = rule.send(arg) if err is None else rule.throw(throw or err) + File LOCATION-INFO, in catch_and_reraise + raise Exception("nested exception!") from e + Exception: nested exception! + """ + ) + + rule_runner = RuleRunner( + rules=[ + raise_an_exception, + catch_and_reraise, + QueryRule(SomeOutput, (OuterInput,)), + ] + ) + with engine_error(Exception, contains=normalized_traceback, normalize_tracebacks=True): + rule_runner.request(SomeOutput, [OuterInput("asdf")]) + # ----------------------------------------------------------------------------------------------- # Test unhashable types diff --git a/src/python/pants/engine/internals/selectors.py b/src/python/pants/engine/internals/selectors.py index e71da986ffd..3cb34653d1f 100644 --- a/src/python/pants/engine/internals/selectors.py +++ b/src/python/pants/engine/internals/selectors.py @@ -9,16 +9,19 @@ from typing import ( TYPE_CHECKING, Any, + Coroutine, Generator, Generic, Iterable, Sequence, Tuple, TypeVar, + Union, cast, overload, ) +from pants.base.exceptions import NativeEngineFailure from pants.engine.internals.native_engine import ( PyGeneratorResponseBreak, PyGeneratorResponseGet, @@ -590,11 +593,48 @@ def __init__(self, *args: Any) -> None: self.params = tuple(args) +# A specification for how the native engine interacts with @rule coroutines: +# - coroutines may await on any of `Get`, `MultiGet`, `Effect` or other coroutines decorated with `@rule_helper`. +# - we will send back a single `Any` or a tuple of `Any` to the coroutine, depending upon the variant of `Get`. +# - a coroutine will eventually return a single `Any`. +RuleInput = Union[ + # The value used to "start" a Generator. + None, + # A single value requested by a Get. + Any, + # Multiple values requested by a MultiGet. + Tuple[Any, ...], + # An exception to be raised in the Generator. + NativeEngineFailure, +] +RuleOutput = Union[Get, Tuple[Get, ...]] +RuleResult = Any +RuleCoroutine = Coroutine[RuleOutput, RuleInput, RuleResult] +NativeEngineGeneratorResponse = Union[ + PyGeneratorResponseGet, + PyGeneratorResponseGetMulti, + PyGeneratorResponseBreak, +] + + def native_engine_generator_send( - func, arg -) -> PyGeneratorResponseGet | PyGeneratorResponseGetMulti | PyGeneratorResponseBreak: + rule: RuleCoroutine, arg: RuleInput +) -> NativeEngineGeneratorResponse: + err = arg if isinstance(arg, NativeEngineFailure) else None + throw = err and err.failure.get_error() try: - res = func.send(arg) + res = rule.send(arg) if err is None else rule.throw(throw or err) + except StopIteration as e: + return PyGeneratorResponseBreak(e.value) + except Exception as e: + if throw and e.__cause__ is throw: + # Preserve the engine traceback by using the wrapped failure error as cause. The cause + # will be swapped back again in + # `src/rust/engine/src/python.rs:Failure::from_py_err_with_gil()` to preserve the python + # traceback. + e.__cause__ = err + raise + else: # It isn't necessary to differentiate between `Get` and `Effect` here, as the static # analysis of `@rule`s has already validated usage. if isinstance(res, (Get, Effect)): @@ -602,10 +642,12 @@ def native_engine_generator_send( elif type(res) in (tuple, list): return PyGeneratorResponseGetMulti(res) else: - raise ValueError(f"internal engine error: unrecognized coroutine result {res}") - except StopIteration as e: - if not e.args: - raise - # This was a `return` from a coroutine, as opposed to a `StopIteration` raised - # by calling `next()` on an empty iterator. - return PyGeneratorResponseBreak(e.value) + raise ValueError( + softwrap( + f""" + Async @rule error: unrecognized await object + + Expected a rule query such as `Get(..)` or similar, but got: {res!r} + """ + ) + ) diff --git a/src/python/pants/testutil/rule_runner.py b/src/python/pants/testutil/rule_runner.py index a2291414748..0e88a5a4200 100644 --- a/src/python/pants/testutil/rule_runner.py +++ b/src/python/pants/testutil/rule_runner.py @@ -6,6 +6,7 @@ import dataclasses import functools import os +import re import sys from contextlib import contextmanager from dataclasses import dataclass @@ -113,7 +114,10 @@ def wrapper(*args, **kwargs): @contextmanager def engine_error( - expected_underlying_exception: type[Exception] = Exception, *, contains: str | None = None + expected_underlying_exception: type[Exception] = Exception, + *, + contains: str | None = None, + normalize_tracebacks: bool = False, ) -> Iterator[None]: """A context manager to catch `ExecutionError`s in tests and check that the underlying exception is expected. @@ -124,6 +128,10 @@ def engine_error( rule_runner.request(OutputType, [input]) Will raise AssertionError if no ExecutionError occurred. + + Set `normalize_tracebacks=True` to replace file locations and addresses in the error message + with fixed values for testability, and check `contains` against the `ExecutionError` message + instead of the underlying error only. """ try: yield @@ -143,12 +151,17 @@ def engine_error( f"{type(underlying)} rather than the expected type " f"{expected_underlying_exception}:\n\n{underlying}" ) - if contains is not None and contains not in str(underlying): - raise AssertionError( - "Expected value not found in exception.\n" - f"expected: {contains}\n\n" - f"exception: {underlying}" - ) + if contains is not None: + if normalize_tracebacks: + errmsg = remove_locations_from_traceback(str(exec_error)) + else: + errmsg = str(underlying) + if contains not in errmsg: + raise AssertionError( + "Expected value not found in exception.\n" + f"=> Expected: {contains}\n\n" + f"=> Actual: {errmsg}" + ) else: raise AssertionError( "DID NOT RAISE ExecutionError with underlying exception type " @@ -156,6 +169,14 @@ def engine_error( ) +def remove_locations_from_traceback(trace: str) -> str: + location_pattern = re.compile(r'"/.*", line \d+') + address_pattern = re.compile(r"0x[0-9a-f]+") + new_trace = location_pattern.sub("LOCATION-INFO", trace) + new_trace = address_pattern.sub("0xEEEEEEEEE", new_trace) + return new_trace + + # ----------------------------------------------------------------------------------------------- # `RuleRunner` # ----------------------------------------------------------------------------------------------- diff --git a/src/rust/engine/src/externs/mod.rs b/src/rust/engine/src/externs/mod.rs index 4fbb183a10c..e350d861091 100644 --- a/src/rust/engine/src/externs/mod.rs +++ b/src/rust/engine/src/externs/mod.rs @@ -13,7 +13,7 @@ use lazy_static::lazy_static; use pyo3::exceptions::{PyException, PyTypeError}; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyDict, PyTuple, PyType}; -use pyo3::{import_exception, intern}; +use pyo3::{create_exception, import_exception, intern}; use pyo3::{FromPyObject, ToPyObject}; use smallvec::{smallvec, SmallVec}; @@ -35,15 +35,38 @@ mod stdio; pub mod testutil; pub mod workunits; -pub fn register(_py: Python, m: &PyModule) -> PyResult<()> { +pub fn register(py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; + m.add("EngineError", py.get_type::())?; + m.add("IntrinsicError", py.get_type::())?; + m.add( + "IncorrectProductError", + py.get_type::(), + )?; + Ok(()) } +create_exception!(native_engine, EngineError, PyException); +create_exception!(native_engine, IntrinsicError, EngineError); +create_exception!(native_engine, IncorrectProductError, EngineError); + #[derive(Clone)] #[pyclass] pub struct PyFailure(pub Failure); +#[pymethods] +impl PyFailure { + fn get_error(&self, py: Python) -> PyErr { + match &self.0 { + Failure::Throw { val, .. } => val.into_py(py), + f @ (Failure::Invalidated | Failure::MissingDigest { .. }) => { + EngineError::new_err(format!("{}", f)) + } + } + } +} + // TODO: We import this exception type because `pyo3` doesn't support declaring exceptions with // additional fields. See https://github.com/PyO3/pyo3/issues/295 import_exception!(pants.base.exceptions, NativeEngineFailure); @@ -226,7 +249,7 @@ pub fn doc_url(py: Python, slug: &str) -> String { } pub fn create_exception(py: Python, msg: String) -> Value { - Value::new(PyException::new_err(msg).into_py(py)) + Value::new(IntrinsicError::new_err(msg).into_py(py)) } pub fn call_function<'py>(func: &'py PyAny, args: &[Value]) -> PyResult<&'py PyAny> { @@ -238,12 +261,22 @@ pub fn call_function<'py>(func: &'py PyAny, args: &[Value]) -> PyResult<&'py PyA pub fn generator_send( py: Python, generator: &Value, - arg: &Value, + arg: Option, + err: Option, ) -> Result { let selectors = py.import("pants.engine.internals.selectors").unwrap(); let native_engine_generator_send = selectors.getattr("native_engine_generator_send").unwrap(); + let py_arg = match (arg, err) { + (Some(arg), None) => arg.to_object(py), + (None, Some(err)) => err.into_py(py), + (None, None) => py.None(), + (Some(arg), Some(err)) => panic!( + "generator_send got both value and error: arg={:?}, err={:?}", + arg, err + ), + }; let response = native_engine_generator_send - .call1((generator.to_object(py), arg.to_object(py))) + .call1((generator.to_object(py), py_arg)) .map_err(|py_err| Failure::from_py_err_with_gil(py, py_err))?; if let Ok(b) = response.extract::>() { diff --git a/src/rust/engine/src/intrinsics.rs b/src/rust/engine/src/intrinsics.rs index 0b28051e713..02784fe4c15 100644 --- a/src/rust/engine/src/intrinsics.rs +++ b/src/rust/engine/src/intrinsics.rs @@ -350,7 +350,7 @@ fn download_file_to_digest( mut args: Vec, ) -> BoxFuture<'static, NodeResult> { async move { - let key = Key::from_value(args.pop().unwrap()).map_err(Failure::from_py_err)?; + let key = Key::from_value(args.pop().unwrap()).map_err(Failure::from)?; let snapshot = context.get(DownloadedFile(key)).await?; let gil = Python::acquire_gil(); let value = Snapshot::store_directory_digest(gil.python(), snapshot.into())?; diff --git a/src/rust/engine/src/nodes.rs b/src/rust/engine/src/nodes.rs index 280c1ca22d0..4414f8578ef 100644 --- a/src/rust/engine/src/nodes.rs +++ b/src/rust/engine/src/nodes.rs @@ -17,7 +17,7 @@ use futures::future::{self, BoxFuture, FutureExt, TryFutureExt}; use grpc_util::prost::MessageExt; use internment::Intern; use protos::gen::pants::cache::{CacheKey, CacheKeyType, ObservedUrl}; -use pyo3::prelude::{Py, PyAny, Python}; +use pyo3::prelude::{Py, PyAny, PyErr, Python}; use pyo3::IntoPy; use url::Url; @@ -1127,7 +1127,7 @@ impl Task { /// /// Given a python generator Value, loop to request the generator's dependencies until - /// it completes with a result Value. + /// it completes with a result Value or fails with an error. /// async fn generate( context: &Context, @@ -1137,22 +1137,40 @@ impl Task { generator: Value, ) -> NodeResult<(Value, TypeId)> { let mut input: Option = None; + let mut err: Option = None; loop { let context = context.clone(); let params = params.clone(); - let response = Python::with_gil(|py| { - let input = input.unwrap_or_else(|| Value::from(py.None())); - externs::generator_send(py, &generator, &input) - })?; + let response = Python::with_gil(|py| externs::generator_send(py, &generator, input, err))?; match response { externs::GeneratorResponse::Get(get) => { - let values = Self::gen_get(&context, workunit, ¶ms, entry, vec![get]).await?; - input = Some(values.into_iter().next().unwrap()); + let result = Self::gen_get(&context, workunit, ¶ms, entry, vec![get]).await; + match result { + Ok(values) => { + input = Some(values.into_iter().next().unwrap()); + err = None; + } + Err(throw @ Failure::Throw { .. }) => { + input = None; + err = Some(PyErr::from(throw)); + } + Err(failure) => break Err(failure), + } } externs::GeneratorResponse::GetMulti(gets) => { - let values = Self::gen_get(&context, workunit, ¶ms, entry, gets).await?; - let gil = Python::acquire_gil(); - input = Some(externs::store_tuple(gil.python(), values)); + let result = Self::gen_get(&context, workunit, ¶ms, entry, gets).await; + match result { + Ok(values) => { + let gil = Python::acquire_gil(); + input = Some(externs::store_tuple(gil.python(), values)); + err = None; + } + Err(throw @ Failure::Throw { .. }) => { + input = None; + err = Some(PyErr::from(throw)); + } + Err(failure) => break Err(failure), + } } externs::GeneratorResponse::Break(val, type_id) => { break Ok((val, type_id)); @@ -1197,7 +1215,7 @@ impl Task { let val = Value::new(res.into_py(py)); (val, type_id) }) - .map_err(Failure::from_py_err) + .map_err(Failure::from) }) }) .await?; @@ -1214,10 +1232,13 @@ impl Task { } if result_type != self.task.product { - return Err(throw(format!( - "{:?} returned a result value that did not satisfy its constraints: {:?}", - self.task.func, result_val - ))); + return Err( + externs::IncorrectProductError::new_err(format!( + "{:?} returned a result value that did not satisfy its constraints: {:?}", + self.task.func, result_val + )) + .into(), + ); } if self.task.engine_aware_return_type { diff --git a/src/rust/engine/src/python.rs b/src/rust/engine/src/python.rs index 74349ea29ef..c1889ded6d4 100644 --- a/src/rust/engine/src/python.rs +++ b/src/rust/engine/src/python.rs @@ -9,7 +9,7 @@ use std::{fmt, hash}; use deepsize::{known_deep_size, DeepSizeOf}; use pyo3::prelude::*; use pyo3::types::{PyDict, PyType}; -use pyo3::{FromPyObject, ToPyObject}; +use pyo3::{FromPyObject, IntoPy, ToPyObject}; use smallvec::SmallVec; use hashing::Digest; @@ -377,6 +377,20 @@ impl From for Value { } } +impl From for Value { + fn from(py_err: PyErr) -> Self { + let gil = Python::acquire_gil(); + let py = gil.python(); + Value::new(py_err.into_py(py)) + } +} + +impl IntoPy for &Value { + fn into_py(self, py: Python) -> PyErr { + PyErr::from_value((*self.0).as_ref(py)) + } +} + /// /// A short required name, and optional human readable description for a single frame of a Failure. /// @@ -436,26 +450,38 @@ impl Failure { } } } -} - -impl Failure { - pub fn from_py_err(py_err: PyErr) -> Failure { - let gil = Python::acquire_gil(); - let py = gil.python(); - Failure::from_py_err_with_gil(py, py_err) - } pub fn from_py_err_with_gil(py: Python, py_err: PyErr) -> Failure { // If this is a wrapped Failure, return it immediately. - if let Ok(n_e_failure) = py_err.value(py).downcast::() { - let failure = n_e_failure - .getattr("failure") - .unwrap() - .extract::() - .unwrap(); - return failure.0; + if let Some(failure) = Failure::from_wrapped_failure(py, &py_err) { + return failure; } + // Propagate the tracebacks from the causing error, if any. + let (previous_ptraceback, engine_traceback) = if let Some(cause) = py_err.cause(py) { + match Failure::from_wrapped_failure(py, &cause) { + Some(Failure::Throw { + val, + engine_traceback, + python_traceback, + }) => { + // Preserve tracebacks (both engine and python) from upstream error by using any existing + // engine traceback and restoring the original python exception cause. + py_err.set_cause(py, Some(PyErr::from_value((*val.0).as_ref(py)))); + ( + format!( + "{}\nDuring handling of the above exception, another exception occurred:\n\n", + python_traceback + ), + engine_traceback, + ) + } + _ => ("".to_string(), Vec::new()), + } + } else { + ("".to_string(), Vec::new()) + }; + let maybe_ptraceback = py_err .traceback(py) .map(|traceback| traceback.to_object(py)); @@ -480,8 +506,8 @@ impl Failure { }; Failure::Throw { val, - python_traceback, - engine_traceback: Vec::new(), + python_traceback: previous_ptraceback + &python_traceback, + engine_traceback, } } @@ -493,6 +519,22 @@ impl Failure { } } +impl Failure { + fn from_wrapped_failure(py: Python, py_err: &PyErr) -> Option { + match py_err.value(py).downcast::() { + Ok(n_e_failure) => { + let failure = n_e_failure + .getattr("failure") + .unwrap() + .extract::() + .unwrap(); + Some(failure.0) + } + _ => None, + } + } +} + impl fmt::Display for Failure { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -541,6 +583,14 @@ impl From for Failure { } } +impl From for Failure { + fn from(py_err: PyErr) -> Self { + let gil = Python::acquire_gil(); + let py = gil.python(); + Failure::from_py_err_with_gil(py, py_err) + } +} + pub fn throw(msg: String) -> Failure { let python_traceback = Failure::native_traceback(&msg); let gil = Python::acquire_gil();