Skip to content

Commit

Permalink
[Feat][Core] Support ExceptionGroup (#47887)
Browse files Browse the repository at this point in the history
See the description in the corresponding issue for details.

The issue states that we should support `BaseExceptionGroup`, but I
found that we don't support `BaseException` either, so it doesn't make
sense to support `BaseExceptionGroup`.

For `ExceptionGroup`, we need to override its `__new__` method as well.
See https://docs.python.org/3/library/exceptions.html#ExceptionGroup for
details.

Signed-off-by: Chi-Sheng Liu <chishengliu@chishengliu.com>
Co-authored-by: Ruiyang Wang <56065503+rynewang@users.noreply.github.com>
  • Loading branch information
MortalHappiness and rynewang authored Oct 16, 2024
1 parent 4dfa033 commit 0056097
Show file tree
Hide file tree
Showing 4 changed files with 263 additions and 6 deletions.
52 changes: 46 additions & 6 deletions python/ray/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging
import os
import sys
from traceback import format_exception
from typing import Optional, Type, Union
from typing import Optional, Union

import colorama

Expand Down Expand Up @@ -147,9 +148,21 @@ def __init__(

assert traceback_str is not None

def make_dual_exception_type(self) -> Type:
"""Makes a Type that inherits from both RayTaskError and the type of
def make_dual_exception_instance(self) -> "RayTaskError":
"""Makes a object instance that inherits from both RayTaskError and the type of
`self.cause`. Raises TypeError if the cause class can't be subclassed"""
# For normal user Exceptions, we subclass from both
# RayTaskError and the user exception. For ExceptionGroup,
# we special handle it because it has a different __new__()
# signature from Exception.
# Ref: https://docs.python.org/3/library/exceptions.html#exception-groups
if sys.version_info >= (3, 11) and isinstance(
self.cause, ExceptionGroup # noqa: F821
):
return self._make_exceptiongroup_dual_exception_instance()
return self._make_normal_dual_exception_instance()

def _make_normal_dual_exception_instance(self) -> "RayTaskError":
cause_cls = self.cause.__class__
error_msg = str(self)

Expand All @@ -171,7 +184,35 @@ def __str__(self):
cls.__name__ = name
cls.__qualname__ = name

return cls
return cls(self.cause)

def _make_exceptiongroup_dual_exception_instance(self) -> "RayTaskError":
cause_cls = self.cause.__class__
error_msg = str(self)

class cls(RayTaskError, cause_cls):
def __new__(cls, cause):
self = super().__new__(cls, cause.message, cause.exceptions)
return self

def __init__(self, cause):
self.cause = cause
# BaseException implements a __reduce__ method that returns
# a tuple with the type and the value of self.args.
# https://stackoverflow.com/a/49715949/2213289
self.args = (cause,)

def __getattr__(self, name):
return getattr(self.cause, name)

def __str__(self):
return error_msg

name = f"RayTaskError({cause_cls.__name__})"
cls.__name__ = name
cls.__qualname__ = name

return cls(self.cause)

def as_instanceof_cause(self):
"""Returns an exception that's an instance of the cause's class.
Expand All @@ -187,8 +228,7 @@ def as_instanceof_cause(self):
return self # already satisfied

try:
dual_cls = self.make_dual_exception_type()
return dual_cls(self.cause)
return self.make_dual_exception_instance()
except TypeError as e:
logger.warning(
f"User exception type {type(self.cause)} in RayTaskError can't"
Expand Down
1 change: 1 addition & 0 deletions python/ray/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ py_test_module_list(
"test_debug_tools.py",
"test_distributed_sort.py",
"test_environ.py",
"test_exceptiongroup.py",
"test_get_or_create_actor.py",
"test_ids.py",
"test_list_actors.py",
Expand Down
196 changes: 196 additions & 0 deletions python/ray/tests/test_exceptiongroup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import os
import sys
from textwrap import dedent

import pytest

import ray
from ray.exceptions import RayTaskError

pytestmark = pytest.mark.skipif(
sys.version_info < (3, 11),
reason="ExceptionGroup is only available in Python 3.11+",
)


def test_baseexceptiongroup_task(ray_start_regular):
baseexceptiongroup = BaseExceptionGroup( # noqa: F821
"test baseexceptiongroup", [BaseException("abc")]
)

@ray.remote
def task():
raise baseexceptiongroup

with pytest.raises(ray.exceptions.WorkerCrashedError):
ray.get(task.remote())


def test_baseexceptiongroup_actor(ray_start_regular):
baseexceptiongroup = BaseExceptionGroup( # noqa: F821
"test baseexceptiongroup", [BaseException("abc")]
)

@ray.remote
class Actor:
def f(self):
raise baseexceptiongroup

with pytest.raises(ray.exceptions.ActorDiedError):
a = Actor.remote()
ray.get(a.f.remote())


def test_except_exceptiongroup(ray_start_regular):
exceptiongroup = ExceptionGroup( # noqa: F821
"test exceptiongroup", [ValueError(), TypeError()]
)

@ray.remote
def task():
raise exceptiongroup

@ray.remote
class Actor:
def f(self):
raise exceptiongroup

try:
ray.get(task.remote())
except Exception as ex:
assert isinstance(ex, RayTaskError)
assert isinstance(ex, ExceptionGroup) # noqa: F821
assert len(ex.exceptions) == 2
assert isinstance(ex.exceptions[0], ValueError)
assert isinstance(ex.exceptions[1], TypeError)

try:
a = Actor.remote()
ray.get(a.f.remote())
except Exception as ex:
assert isinstance(ex, RayTaskError)
assert isinstance(ex, ExceptionGroup) # noqa: F821
assert len(ex.exceptions) == 2
assert isinstance(ex.exceptions[0], ValueError)
assert isinstance(ex.exceptions[1], TypeError)


def test_except_star_exception(ray_start_regular):
@ray.remote
def task():
raise ValueError

@ray.remote
class Actor:
def f(self):
raise ValueError

# TODO: Don't use exec() when we only support Python 3.11+
# Here the exec() is used to avoid SyntaxError for except* for Python < 3.11
python_code = dedent(
"""\
try:
ray.get(task.remote())
except* RayTaskError as ex:
assert isinstance(ex, ExceptionGroup)
assert len(ex.exceptions) == 1
assert isinstance(ex.exceptions[0], RayTaskError)
assert isinstance(ex.exceptions[0], ValueError)
try:
ray.get(task.remote())
except* ValueError as ex:
assert isinstance(ex, ExceptionGroup)
assert len(ex.exceptions) == 1
assert isinstance(ex.exceptions[0], RayTaskError)
assert isinstance(ex.exceptions[0], ValueError)
try:
a = Actor.remote()
ray.get(a.f.remote())
except* RayTaskError as ex:
assert isinstance(ex, ExceptionGroup)
assert len(ex.exceptions) == 1
assert isinstance(ex.exceptions[0], RayTaskError)
assert isinstance(ex.exceptions[0], ValueError)
try:
a = Actor.remote()
ray.get(a.f.remote())
except* ValueError as ex:
assert isinstance(ex, ExceptionGroup)
assert len(ex.exceptions) == 1
assert isinstance(ex.exceptions[0], RayTaskError)
assert isinstance(ex.exceptions[0], ValueError)
"""
)
exec(python_code)


def test_except_star_exceptiongroup(ray_start_regular):
exceptiongroup = ExceptionGroup( # noqa: F821
"test exceptiongroup", [ValueError(), TypeError()]
)

@ray.remote
def task():
raise exceptiongroup

@ray.remote
class Actor:
def f(self):
raise exceptiongroup

# TODO: Don't use exec() when we only support Python 3.11+
# Here the exec() is used to avoid SyntaxError for except* for Python < 3.11
python_code = dedent(
"""\
try:
ray.get(task.remote())
except* RayTaskError as ex:
assert isinstance(ex, ExceptionGroup)
assert len(ex.exceptions) == 2
assert isinstance(ex.exceptions[0], ValueError)
assert isinstance(ex.exceptions[1], TypeError)
try:
ray.get(task.remote())
except* ValueError as ex:
assert isinstance(ex, ExceptionGroup)
assert len(ex.exceptions) == 1
assert isinstance(ex.exceptions[0], ValueError)
except* TypeError as ex:
assert isinstance(ex, ExceptionGroup)
assert len(ex.exceptions) == 1
assert isinstance(ex.exceptions[0], TypeError)
try:
a = Actor.remote()
ray.get(a.f.remote())
except* RayTaskError as ex:
assert isinstance(ex, ExceptionGroup)
assert len(ex.exceptions) == 2
assert isinstance(ex.exceptions[0], ValueError)
assert isinstance(ex.exceptions[1], TypeError)
try:
a = Actor.remote()
ray.get(a.f.remote())
except* ValueError as ex:
assert isinstance(ex, ExceptionGroup)
assert len(ex.exceptions) == 1
assert isinstance(ex.exceptions[0], ValueError)
except* TypeError as ex:
assert isinstance(ex, ExceptionGroup)
assert len(ex.exceptions) == 1
assert isinstance(ex.exceptions[0], TypeError)
"""
)
exec(python_code)


if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
else:
sys.exit(pytest.main(["-sv", __file__]))
20 changes: 20 additions & 0 deletions python/ray/tests/test_failure.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,26 @@ def foo():
assert isinstance(ex, RayTaskError)


def test_baseexception_task(ray_start_regular):
@ray.remote
def task():
raise BaseException("abc")

with pytest.raises(ray.exceptions.WorkerCrashedError):
ray.get(task.remote())


def test_baseexception_actor(ray_start_regular):
@ray.remote
class Actor:
def f(self):
raise BaseException("abc")

with pytest.raises(ray.exceptions.ActorDiedError):
a = Actor.remote()
ray.get(a.f.remote())


@pytest.mark.skip("This test does not work yet.")
@pytest.mark.parametrize("ray_start_object_store_memory", [10**6], indirect=True)
def test_put_error1(ray_start_object_store_memory, error_pubsub):
Expand Down

0 comments on commit 0056097

Please sign in to comment.