Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Revert unhandled exception PR for 1.3.0 temporarily #15287

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -706,16 +706,6 @@ cc_test(
],
)

cc_test(
name = "memory_store_test",
srcs = ["src/ray/core_worker/test/memory_store_test.cc"],
copts = COPTS,
deps = [
":core_worker_lib",
"@com_google_googletest//:gtest_main",
],
)

cc_test(
name = "direct_actor_transport_test",
srcs = ["src/ray/core_worker/test/direct_actor_transport_test.cc"],
Expand Down
4 changes: 2 additions & 2 deletions ci/asan_tests/ray-project/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ blist
boto3
cython==0.29.0
dataclasses; python_version < '3.7'
dm-tree
dm-tree==0.1.5
feather-format
flask
grpcio
Expand Down Expand Up @@ -34,4 +34,4 @@ torch
torchvision
uvicorn
werkzeug
xlrd
xlrd
2 changes: 1 addition & 1 deletion ci/travis/install-dependencies.sh
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ install_dependencies() {
fi

# Install modules needed in all jobs.
pip install --no-clean dm-tree # --no-clean is due to: https://github.com/deepmind/tree/issues/5
pip install --no-clean dm-tree==0.1.5 # --no-clean is due to: https://github.com/deepmind/tree/issues/5

if [ -n "${PYTHON-}" ]; then
# Remove this entire section once RLlib and Serve dependencies are fixed.
Expand Down
25 changes: 3 additions & 22 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -771,20 +771,6 @@ cdef void delete_spilled_objects_handler(
job_id=None)


cdef void unhandled_exception_handler(const CRayObject& error) nogil:
with gil:
worker = ray.worker.global_worker
data = None
metadata = None
if error.HasData():
data = Buffer.make(error.GetData())
if error.HasMetadata():
metadata = Buffer.make(error.GetMetadata()).to_pybytes()
# TODO(ekl) why does passing a ObjectRef.nil() lead to shutdown errors?
object_ids = [None]
worker.raise_errors([(data, metadata)], object_ids)


# This function introduces ~2-7us of overhead per call (i.e., it can be called
# up to hundreds of thousands of times per second).
cdef void get_py_stack(c_string* stack_out) nogil:
Expand Down Expand Up @@ -906,7 +892,6 @@ cdef class CoreWorker:
options.restore_spilled_objects = restore_spilled_objects_handler
options.delete_spilled_objects = delete_spilled_objects_handler
options.run_on_util_worker_handler = run_on_util_worker_handler
options.unhandled_exception_handler = unhandled_exception_handler
options.get_lang_stack = get_py_stack
options.ref_counting_enabled = True
options.is_local_mode = local_mode
Expand Down Expand Up @@ -1526,13 +1511,9 @@ cdef class CoreWorker:
object_ref.native())

def remove_object_ref_reference(self, ObjectRef object_ref):
cdef:
CObjectID c_object_id = object_ref.native()
# We need to release the gil since object destruction may call the
# unhandled exception handler.
with nogil:
CCoreWorkerProcess.GetCoreWorker().RemoveLocalReference(
c_object_id)
# Note: faster to not release GIL for short-running op.
CCoreWorkerProcess.GetCoreWorker().RemoveLocalReference(
object_ref.native())

def serialize_and_promote_object_ref(self, ObjectRef object_ref):
cdef:
Expand Down
1 change: 0 additions & 1 deletion python/ray/includes/libcoreworker.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,6 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
(void(
const c_string&,
const c_vector[c_string]&) nogil) run_on_util_worker_handler
(void(const CRayObject&) nogil) unhandled_exception_handler
(void(c_string *stack_out) nogil) get_lang_stack
c_bool ref_counting_enabled
c_bool is_local_mode
Expand Down
91 changes: 53 additions & 38 deletions python/ray/tests/test_failure.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,50 +23,65 @@
get_error_message, Semaphore)


def test_unhandled_errors(ray_start_regular):
def test_failed_task(ray_start_regular, error_pubsub):
@ray.remote
def f():
raise ValueError()
def throw_exception_fct1():
raise Exception("Test function 1 intentionally failed.")

@ray.remote
class Actor:
def f(self):
raise ValueError()
def throw_exception_fct2():
raise Exception("Test function 2 intentionally failed.")

a = Actor.remote()
num_exceptions = 0

def interceptor(e):
nonlocal num_exceptions
num_exceptions += 1

# Test we report unhandled exceptions.
ray.worker._unhandled_error_handler = interceptor
x1 = f.remote()
x2 = a.f.remote()
del x1
del x2
wait_for_condition(lambda: num_exceptions == 2)

# Test we don't report handled exceptions.
x1 = f.remote()
x2 = a.f.remote()
with pytest.raises(ray.exceptions.RayError) as err: # noqa
ray.get([x1, x2])
del x1
del x2
time.sleep(1)
assert num_exceptions == 2, num_exceptions
@ray.remote(num_returns=3)
def throw_exception_fct3(x):
raise Exception("Test function 3 intentionally failed.")

p = error_pubsub

throw_exception_fct1.remote()
throw_exception_fct1.remote()

msgs = get_error_message(p, 2, ray_constants.TASK_PUSH_ERROR)
assert len(msgs) == 2
for msg in msgs:
assert "Test function 1 intentionally failed." in msg.error_message

# Test suppression with env var works.
x = throw_exception_fct2.remote()
try:
os.environ["RAY_IGNORE_UNHANDLED_ERRORS"] = "1"
x1 = f.remote()
del x1
time.sleep(1)
assert num_exceptions == 2, num_exceptions
finally:
del os.environ["RAY_IGNORE_UNHANDLED_ERRORS"]
ray.get(x)
except Exception as e:
assert "Test function 2 intentionally failed." in str(e)
else:
# ray.get should throw an exception.
assert False

x, y, z = throw_exception_fct3.remote(1.0)
for ref in [x, y, z]:
try:
ray.get(ref)
except Exception as e:
assert "Test function 3 intentionally failed." in str(e)
else:
# ray.get should throw an exception.
assert False

class CustomException(ValueError):
pass

@ray.remote
def f():
raise CustomException("This function failed.")

try:
ray.get(f.remote())
except Exception as e:
assert "This function failed." in str(e)
assert isinstance(e, CustomException)
assert isinstance(e, ray.exceptions.RayTaskError)
assert "RayTaskError(CustomException)" in repr(e)
else:
# ray.get should throw an exception.
assert False


def test_push_error_to_driver_through_redis(ray_start_regular, error_pubsub):
Expand Down
79 changes: 60 additions & 19 deletions python/ray/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
import os
import redis
from six.moves import queue
import sys
import threading
import time
Expand Down Expand Up @@ -72,12 +73,6 @@
logger = logging.getLogger(__name__)


# Visible for testing.
def _unhandled_error_handler(e: Exception):
logger.error("Unhandled error (suppress with "
"RAY_IGNORE_UNHANDLED_ERRORS=1): {}".format(e))


class Worker:
"""A class used to define the control flow of a worker process.

Expand Down Expand Up @@ -274,14 +269,6 @@ def put_object(self, value, object_ref=None):
self.core_worker.put_serialized_object(
serialized_value, object_ref=object_ref))

def raise_errors(self, data_metadata_pairs, object_refs):
context = self.get_serialization_context()
out = context.deserialize_objects(data_metadata_pairs, object_refs)
if "RAY_IGNORE_UNHANDLED_ERRORS" in os.environ:
return
for e in out:
_unhandled_error_handler(e)

def deserialize_objects(self, data_metadata_pairs, object_refs):
context = self.get_serialization_context()
return context.deserialize_objects(data_metadata_pairs, object_refs)
Expand Down Expand Up @@ -877,6 +864,13 @@ def custom_excepthook(type, value, tb):

sys.excepthook = custom_excepthook

# The last time we raised a TaskError in this process. We use this value to
# suppress redundant error messages pushed from the workers.
last_task_error_raise_time = 0

# The max amount of seconds to wait before printing out an uncaught error.
UNCAUGHT_ERROR_GRACE_PERIOD = 5


def print_logs(redis_client, threads_stopped, job_id):
"""Prints log messages from workers on all of the nodes.
Expand Down Expand Up @@ -1027,14 +1021,51 @@ def color_for(data: Dict[str, str]) -> str:
file=print_file)


def listen_error_messages_raylet(worker, threads_stopped):
def print_error_messages_raylet(task_error_queue, threads_stopped):
"""Prints message received in the given output queue.

This checks periodically if any un-raised errors occurred in the
background.

Args:
task_error_queue (queue.Queue): A queue used to receive errors from the
thread that listens to Redis.
threads_stopped (threading.Event): A threading event used to signal to
the thread that it should exit.
"""

while True:
# Exit if we received a signal that we should stop.
if threads_stopped.is_set():
return

try:
error, t = task_error_queue.get(block=False)
except queue.Empty:
threads_stopped.wait(timeout=0.01)
continue
# Delay errors a little bit of time to attempt to suppress redundant
# messages originating from the worker.
while t + UNCAUGHT_ERROR_GRACE_PERIOD > time.time():
threads_stopped.wait(timeout=1)
if threads_stopped.is_set():
break
if t < last_task_error_raise_time + UNCAUGHT_ERROR_GRACE_PERIOD:
logger.debug(f"Suppressing error from worker: {error}")
else:
logger.error(f"Possible unhandled error from worker: {error}")


def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
"""Listen to error messages in the background on the driver.

This runs in a separate thread on the driver and pushes (error, time)
tuples to the output queue.

Args:
worker: The worker class that this thread belongs to.
task_error_queue (queue.Queue): A queue used to communicate with the
thread that prints the errors found by this thread.
threads_stopped (threading.Event): A threading event used to signal to
the thread that it should exit.
"""
Expand Down Expand Up @@ -1078,9 +1109,8 @@ def listen_error_messages_raylet(worker, threads_stopped):

error_message = error_data.error_message
if (error_data.type == ray_constants.TASK_PUSH_ERROR):
# TODO(ekl) remove task push errors entirely now that we have
# the separate unhandled exception handler.
pass
# Delay it a bit to see if we can suppress it
task_error_queue.put((error_message, time.time()))
else:
logger.warning(error_message)
except (OSError, redis.exceptions.ConnectionError) as e:
Expand Down Expand Up @@ -1260,12 +1290,19 @@ def connect(node,
# temporarily using this implementation which constantly queries the
# scheduler for new error messages.
if mode == SCRIPT_MODE:
q = queue.Queue()
worker.listener_thread = threading.Thread(
target=listen_error_messages_raylet,
name="ray_listen_error_messages",
args=(worker, worker.threads_stopped))
args=(worker, q, worker.threads_stopped))
worker.printer_thread = threading.Thread(
target=print_error_messages_raylet,
name="ray_print_error_messages",
args=(q, worker.threads_stopped))
worker.listener_thread.daemon = True
worker.listener_thread.start()
worker.printer_thread.daemon = True
worker.printer_thread.start()
if log_to_driver:
global_worker_stdstream_dispatcher.add_handler(
"ray_print_logs", print_to_stdstream)
Expand Down Expand Up @@ -1318,6 +1355,8 @@ def disconnect(exiting_interpreter=False):
worker.import_thread.join_import_thread()
if hasattr(worker, "listener_thread"):
worker.listener_thread.join()
if hasattr(worker, "printer_thread"):
worker.printer_thread.join()
if hasattr(worker, "logger_thread"):
worker.logger_thread.join()
worker.threads_stopped.clear()
Expand Down Expand Up @@ -1429,11 +1468,13 @@ def get(object_refs, *, timeout=None):
raise ValueError("'object_refs' must either be an object ref "
"or a list of object refs.")

global last_task_error_raise_time
# TODO(ujvl): Consider how to allow user to retrieve the ready objects.
values, debugger_breakpoint = worker.get_objects(
object_refs, timeout=timeout)
for i, value in enumerate(values):
if isinstance(value, RayError):
last_task_error_raise_time = time.time()
if isinstance(value, ray.exceptions.ObjectLostError):
worker.core_worker.dump_object_store_memory_usage()
if isinstance(value, RayTaskError):
Expand Down
8 changes: 0 additions & 8 deletions src/ray/common/ray_object.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,20 +97,12 @@ class RayObject {
/// large to return directly as part of a gRPC response).
bool IsInPlasmaError() const;

/// Mark this object as accessed before.
void SetAccessed() { accessed_ = true; };

/// Check if this object was accessed before.
bool WasAccessed() const { return accessed_; }

private:
std::shared_ptr<Buffer> data_;
std::shared_ptr<Buffer> metadata_;
const std::vector<ObjectID> nested_ids_;
/// Whether this class holds a data copy.
bool has_data_copy_;
/// Whether this object was accessed.
bool accessed_ = false;
};

} // namespace ray
13 changes: 1 addition & 12 deletions src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -476,18 +476,7 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
return Status::OK();
},
options_.ref_counting_enabled ? reference_counter_ : nullptr, local_raylet_client_,
options_.check_signals,
[this](const RayObject &obj) {
// Run this on the event loop to avoid calling back into the language runtime
// from the middle of user operations.
io_service_.post(
[this, obj]() {
if (options_.unhandled_exception_handler != nullptr) {
options_.unhandled_exception_handler(obj);
}
},
"CoreWorker.HandleException");
}));
options_.check_signals));

auto check_node_alive_fn = [this](const NodeID &node_id) {
auto node = gcs_client_->Nodes().Get(node_id);
Expand Down
Loading