Skip to content

Commit

Permalink
[data] Add ExecutionCallback interface (ray-project#49205)
Browse files Browse the repository at this point in the history
## Why are these changes needed?

Add an ExecutionCallback interface to allow hooking custom callback
logic into certain execution events. This can be useful for optimization
rules.

---------

Signed-off-by: Hao Chen <chenh1024@gmail.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
  • Loading branch information
raulchen authored and ujjawal-khare committed Dec 17, 2024
1 parent 366e0ec commit 87e6665
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 0 deletions.
40 changes: 40 additions & 0 deletions python/ray/data/_internal/execution/execution_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import List

from ray.data.context import DataContext

EXECUTION_CALLBACKS_CONFIG_KEY = "execution_callbacks"


class ExecutionCallback:
"""Callback interface for execution events."""

def before_execution_starts(self):
"""Called before the Dataset execution starts."""
...

def after_execution_succeeds(self):
"""Called after the Dataset execution succeeds."""
...

def after_execution_fails(self, error: Exception):
"""Called after the Dataset execution fails."""
...


def get_execution_callbacks(context: DataContext) -> List[ExecutionCallback]:
"""Get all ExecutionCallbacks from the DataContext."""
return context.get_config(EXECUTION_CALLBACKS_CONFIG_KEY, [])


def add_execution_callback(callback: ExecutionCallback, context: DataContext):
"""Add an ExecutionCallback to the DataContext."""
execution_callbacks = context.get_config(EXECUTION_CALLBACKS_CONFIG_KEY, [])
execution_callbacks.append(callback)
context.set_config(EXECUTION_CALLBACKS_CONFIG_KEY, execution_callbacks)


def remove_execution_callback(callback: ExecutionCallback, context: DataContext):
"""Remove an ExecutionCallback from the DataContext."""
execution_callbacks = context.get_config(EXECUTION_CALLBACKS_CONFIG_KEY, [])
execution_callbacks.remove(callback)
context.set_config(EXECUTION_CALLBACKS_CONFIG_KEY, execution_callbacks)
8 changes: 8 additions & 0 deletions python/ray/data/_internal/execution/streaming_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
BackpressurePolicy,
get_backpressure_policies,
)
from ray.data._internal.execution.execution_callback import get_execution_callbacks
from ray.data._internal.execution.interfaces import (
ExecutionResources,
Executor,
Expand Down Expand Up @@ -141,6 +142,9 @@ def execute(
self._dataset_tag,
self._get_operator_tags(),
)
for callback in get_execution_callbacks(self._data_context):
callback.before_execution_starts()

self.start()
self._execution_started = True

Expand Down Expand Up @@ -233,9 +237,13 @@ def run(self):
)
if not continue_sched or self._shutdown:
break
for callback in get_execution_callbacks(self._data_context):
callback.after_execution_succeeds()
except Exception as e:
# Propagate it to the result iterator.
self._output_node.mark_finished(e)
for callback in get_execution_callbacks(self._data_context):
callback.after_execution_fails(e)
finally:
# Signal end of results.
self._output_node.mark_finished()
Expand Down
59 changes: 59 additions & 0 deletions python/ray/data/tests/test_streaming_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@

import ray
from ray._private.test_utils import run_string_as_driver_nonblocking
from ray.data._internal.execution.execution_callback import (
ExecutionCallback,
add_execution_callback,
get_execution_callbacks,
remove_execution_callback,
)
from ray.data._internal.execution.interfaces import (
ExecutionOptions,
ExecutionResources,
Expand Down Expand Up @@ -640,6 +646,59 @@ def test_time_scheduling():
assert 0 < ds_stats.streaming_exec_schedule_s.get() < 1


def test_executor_callbacks():
"""Test ExecutionCallback."""

class CustomExecutionCallback(ExecutionCallback):
def __init__(self):
self._before_execution_starts_called = False
self._after_execution_succeeds_called = False
self._execution_error = None

def before_execution_starts(self):
self._before_execution_starts_called = True

def after_execution_succeeds(self):
self._after_execution_succeeds_called = True

def after_execution_fails(self, error: Exception):
self._execution_error = error

# Test the success case.
ds = ray.data.range(10)
ctx = ds.context
callback = CustomExecutionCallback()
add_execution_callback(callback, ctx)
assert get_execution_callbacks(ctx) == [callback]

ds.take_all()

assert callback._before_execution_starts_called
assert callback._after_execution_succeeds_called
assert callback._execution_error is None

remove_execution_callback(callback, ctx)
assert get_execution_callbacks(ctx) == []

# Test the failure case.
ds = ray.data.range(10)
ctx = ds.context
ctx.raise_original_map_exception = True
callback = CustomExecutionCallback()
add_execution_callback(callback, ctx)

def map_fn(_):
raise ValueError("")

with pytest.raises(ValueError):
ds.map(map_fn).take_all()

assert callback._before_execution_starts_called
assert not callback._after_execution_succeeds_called
error = callback._execution_error
assert isinstance(error, ValueError), error


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit 87e6665

Please sign in to comment.