Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tune] Save and restore stateful callbacks as part of experiment checkpoint #31957

Merged
merged 15 commits into from
Feb 1, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/ray/tune/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ py_test(
tags = ["team:ml", "exclusive"],
)

py_test(
name = "test_callbacks",
size = "small",
srcs = ["tests/test_callbacks.py"],
deps = [":tune_lib"],
tags = ["team:ml", "exclusive"],
)

py_test(
name = "test_checkpoint_manager",
size = "small",
Expand Down
27 changes: 26 additions & 1 deletion python/ray/tune/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings

from ray.util.annotations import PublicAPI, DeveloperAPI
from ray.tune.utils.restorable import Restorable

if TYPE_CHECKING:
from ray.air._internal.checkpoint_manager import _TrackedCheckpoint
Expand Down Expand Up @@ -66,7 +67,7 @@ def need_override_by_subclass(mcs, attr_name: str, attr: Any) -> bool:


@PublicAPI(stability="beta")
class Callback(metaclass=_CallbackMeta):
class Callback(Restorable, metaclass=_CallbackMeta):
"""Tune base callback that can be extended and passed to a ``TrialRunner``

Tune callbacks are called from within the ``TrialRunner`` class. There are
Expand Down Expand Up @@ -104,6 +105,8 @@ def train(config):

"""

CKPT_FILE_TMPL = "callback-state-{}.json"

# arguments here match Experiment.public_spec
def setup(
self,
Expand Down Expand Up @@ -284,6 +287,7 @@ class CallbackList(Callback):
"""Call multiple callbacks at once."""

IS_CALLBACK_CONTAINER = True
CKPT_FILE_TMPL = "callback-list-state-{}.json"

def __init__(self, callbacks: List[Callback]):
self._callbacks = callbacks
Expand Down Expand Up @@ -343,3 +347,24 @@ def on_checkpoint(self, **info):
def on_experiment_end(self, **info):
for callback in self._callbacks:
callback.on_experiment_end(**info)

def get_state(self) -> Optional[Dict]:
state = {}
any_stateful_callbacks = False
for i, callback in enumerate(self._callbacks):
callback_state = None
try:
callback_state = callback.get_state()
any_stateful_callbacks = True
except NotImplementedError:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also check for AttributeError?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have replaced the get_state with a default of returning None (stateless). No need to check for this error anymore. I think an attribute error should get raised immediately instead of caught.

pass
state[i] = callback_state
if not any_stateful_callbacks:
return None
return state

def set_state(self, state: Dict):
for i, callback in enumerate(self._callbacks):
callback_state = state.get(i, None)
if callback_state:
callback.set_state(callback_state)
13 changes: 10 additions & 3 deletions python/ray/tune/execution/trial_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def checkpoint(
trial_runner: "TrialRunner",
trial_executor: RayTrialExecutor,
search_alg: SearchAlgorithm,
callbacks: CallbackList,
force: bool = False,
):
"""Saves execution state to `self._local_checkpoint_dir`.
Expand Down Expand Up @@ -190,6 +191,9 @@ def _serialize_and_write():
search_alg.save_to_dir(
self._local_checkpoint_dir, session_str=self._session_str
)
callbacks.save_to_dir(
self._local_checkpoint_dir, session_str=self._session_str
)

checkpoint_time_start = time.monotonic()
with out_of_band_serialize_dataset():
Expand Down Expand Up @@ -347,6 +351,7 @@ def __init__(
self._search_alg = search_alg or BasicVariantGenerator()
self._scheduler_alg = scheduler or FIFOScheduler()
self.trial_executor = trial_executor or RayTrialExecutor()
self._callbacks = CallbackList(callbacks or [])
self._insufficient_resources_manager = _InsufficientResourcesManager()
self._pending_trial_queue_times = {}

Expand Down Expand Up @@ -469,8 +474,6 @@ def __init__(
TrialRunner.CKPT_FILE_TMPL.format(self._session_str),
)

self._callbacks = CallbackList(callbacks or [])

if checkpoint_period is None:
checkpoint_period = os.getenv("TUNE_GLOBAL_CHECKPOINT_S", "auto")

Expand Down Expand Up @@ -751,6 +754,7 @@ def checkpoint(self, force: bool = False):
trial_runner=self,
trial_executor=self.trial_executor,
search_alg=self._search_alg,
callbacks=self._callbacks,
force=force,
)

Expand Down Expand Up @@ -795,10 +799,13 @@ def resume(
# 1. Restore trial runner state
self.__setstate__(runner_state["runner_data"])

# 2. Restore search algorithm state
# 2. Restore search algorithm and callback state
if self._search_alg.has_checkpoint(self._local_checkpoint_dir):
self._search_alg.restore_from_dir(self._local_checkpoint_dir)

if self._callbacks.can_restore(self._local_checkpoint_dir):
self._callbacks.restore_from_dir(self._local_checkpoint_dir)

# 3. Load trial table from experiment checkpoint
trials = []
for trial_json_state in runner_state["checkpoints"]:
Expand Down
92 changes: 92 additions & 0 deletions python/ray/tune/tests/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import pytest
from typing import Dict, Optional

from ray.tune.callback import Callback, CallbackList


class StatefulCallback(Callback):
CKPT_FILE_TMPL = "test-callback-state-{}.json"

def __init__(self):
self.counter = 0

def on_trial_result(self, iteration, trials, trial, result, **info):
self.counter += 1

def get_state(self) -> Optional[Dict]:
return {"counter": self.counter}

def set_state(self, state: Dict):
self.counter = state["counter"]


def test_stateful_callback_save_and_restore(tmp_path):
"""Checks that a stateful callback can be saved to a directory and restored with
the right state."""

callback = StatefulCallback()
for i in range(3):
callback.on_trial_result(i, None, None, None)
callback.save_to_dir(str(tmp_path))
assert list(tmp_path.glob(StatefulCallback.CKPT_FILE_TMPL.format("*")))
assert callback.can_restore(str(tmp_path))

restored_callback = StatefulCallback()
restored_callback.restore_from_dir(str(tmp_path))
assert restored_callback.counter == 3


def test_stateless_callback_save_and_restore(tmp_path):
"""Checks that proper errors are raised/handled when saving/restoring a
stateless callback (i.e. one that doesn't implement get/set_state)."""

class StatelessCallback(Callback):
def handle_save_error(self, error: Exception):
assert isinstance(error, NotImplementedError)

callback = StatelessCallback()
callback.save_to_dir(str(tmp_path))

assert not list(tmp_path.glob(StatelessCallback.CKPT_FILE_TMPL.format("*")))
assert not callback.can_restore(str(tmp_path))
with pytest.raises(RuntimeError):
callback.restore_from_dir(str(tmp_path))


def test_callback_list_with_stateful_callback(tmp_path):
"""Checks that a callback list saves and restores all callbacks contained
inside it."""

callbacks = CallbackList([Callback(), StatefulCallback()])
for i in range(3):
callbacks.on_trial_result(iteration=i, trials=None, trial=None, result=None)

callbacks.save_to_dir(str(tmp_path))

assert list(tmp_path.glob(CallbackList.CKPT_FILE_TMPL.format("*")))
assert callbacks.can_restore(str(tmp_path))

restored_callbacks = CallbackList([Callback(), StatefulCallback()])
restored_callbacks.restore_from_dir(str(tmp_path))

assert restored_callbacks._callbacks[1].counter == 3


def test_callback_list_without_stateful_callback(tmp_path):
"""If no callbacks within a CallbackList are stateful, then nothing
should be saved."""

callbacks = CallbackList([Callback(), Callback()])
callbacks.save_to_dir(str(tmp_path))

assert not list(tmp_path.glob(CallbackList.CKPT_FILE_TMPL.format("*")))
assert not callbacks.can_restore(str(tmp_path))

with pytest.raises(RuntimeError):
callbacks.restore_from_dir(str(tmp_path))


if __name__ == "__main__":
import sys

sys.exit(pytest.main(["-v", __file__]))
24 changes: 24 additions & 0 deletions python/ray/tune/tests/test_trial_runner_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ray.tune.search.search_generator import SearchGenerator
from ray.tune.syncer import SyncConfig, Syncer
from ray.tune.tests.tune_test_util import TrialResultObserver
from ray.tune.tests.test_callbacks import StatefulCallback


class MyCallbacks(DefaultCallbacks):
Expand Down Expand Up @@ -402,6 +403,29 @@ def num_running_trials():
count = Counter(evaluated)
assert all(v <= 3 for v in count.values())

def testCallbackSaveRestore(self):
"""Check that experiment state save + restore handles stateful callbacks."""
ray.init(num_cpus=2)
runner = TrialRunner(
local_checkpoint_dir=self.tmpdir,
callbacks=[StatefulCallback()],
trial_executor=RayTrialExecutor(resource_manager=self._resourceManager()),
)
runner.add_trial(Trial("__fake", stub=True))
for i in range(3):
runner._callbacks.on_trial_result(
iteration=i, trials=None, trial=None, result=None
)
runner.checkpoint(force=True)
callback = StatefulCallback()
runner2 = TrialRunner(
local_checkpoint_dir=self.tmpdir,
callbacks=[callback],
)
assert callback.counter == 0
runner2.resume()
assert callback.counter == 3

def testTrialErrorResumeFalse(self):
ray.init(num_cpus=3, local_mode=True, include_dashboard=False)
runner = TrialRunner(
Expand Down
125 changes: 125 additions & 0 deletions python/ray/tune/utils/restorable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import abc
import glob
import os
from typing import Dict, Optional

from ray.tune.utils.util import _atomic_save, _load_newest_checkpoint


class Restorable(abc.ABC):
"""Interface for an object that can save and restore state to a directory.

The object's state will be saved as a file of the form `CKPT_FILE_TMPL`.
When implementing this interface, be sure to change `CKPT_FILE_TMPL` to
something unique to the object being stored.
"""

CKPT_FILE_TMPL = "restorable-state-{}.pkl"

def get_state(self) -> Optional[Dict]:
"""Get the state of the object.

This method should be implemented by subclasses to return a dictionary
representation of the object's current state.

Returns:
state: State of the object. Should be `None` if the object does not
have any state to save.
"""
raise NotImplementedError

def set_state(self, state: Dict):
"""Get the state of the object.

This method should be implemented by subclasses to restore the object's
state based on the given dict state.

Args:
state: State of the object.
"""
raise NotImplementedError

def handle_save_error(self, error: Exception):
"""Handle error occurred during saving.

For example, this can be used to log a warning if `get_state` isn't implemented.

Args:
error: The exception that occurred during saving.
"""
pass

def handle_restore_error(self, error: Exception):
"""Handle error occurred during restoring.

For example, this can be used to log a warning if `set_state` isn't implemented.

Args:
error: The exception that occurred during restoring.
"""
pass

def can_restore(self, checkpoint_dir: str) -> bool:
"""Check if the checkpoint_dir contains the saved state for this object.

Returns:
can_restore: True if the checkpoint_dir contains a file of the
format `CKPT_FILE_TMPL`. False otherwise.
"""
return bool(
glob.glob(os.path.join(checkpoint_dir, self.CKPT_FILE_TMPL.format("*")))
)

def save_to_dir(self, checkpoint_dir: str, session_str: str = "default"):
"""Save the state of the object to the checkpoint_dir.

Args:
checkpoint_dir: directory where the checkpoint is stored.
session_str: Unique identifier of the current run session (ex: timestamp).

Raises:
NotImplementedError: if the `get_state` method is not implemented.
"""
try:
state_dict = self.get_state()
except NotImplementedError as e:
self.handle_save_error(e)
state_dict = None

if state_dict:
file_name = self.CKPT_FILE_TMPL.format(session_str)
tmp_file_name = f".tmp-{file_name}"
try:
_atomic_save(
state=state_dict,
checkpoint_dir=checkpoint_dir,
file_name=file_name,
tmp_file_name=tmp_file_name,
)
except Exception as e:
self.handle_save_error(e)

def restore_from_dir(self, checkpoint_dir: str):
"""Restore the state of the object from the checkpoint_dir.

You should check if it's possible to restore with `can_restore`
before calling this method.

Args:
checkpoint_dir: directory where the checkpoint is stored.

Raises:
RuntimeError: if unable to find checkpoint.
NotImplementedError: if the `set_state` method is not implemented.
"""
state_dict = _load_newest_checkpoint(
checkpoint_dir, self.CKPT_FILE_TMPL.format("*")
)
if not state_dict:
raise RuntimeError(
"Unable to find checkpoint in {}.".format(checkpoint_dir)
)
try:
self.set_state(state_dict)
except Exception as e:
self.handle_restore_error(e)