Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tune] Save and restore stateful callbacks as part of experiment checkpoint #31957

Merged
merged 15 commits into from
Feb 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/ray/tune/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ py_test(
tags = ["team:ml", "exclusive"],
)

py_test(
name = "test_callbacks",
size = "small",
srcs = ["tests/test_callbacks.py"],
deps = [":tune_lib"],
tags = ["team:ml", "exclusive"],
)

py_test(
name = "test_checkpoint_manager",
size = "small",
Expand Down
105 changes: 104 additions & 1 deletion python/ray/tune/callback.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from abc import ABCMeta
import glob
import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import warnings

from ray.util.annotations import PublicAPI, DeveloperAPI
from ray.tune.utils.util import _atomic_save, _load_newest_checkpoint

if TYPE_CHECKING:
from ray.air._internal.checkpoint_manager import _TrackedCheckpoint
Expand Down Expand Up @@ -278,12 +281,36 @@ def on_experiment_end(self, trials: List["Trial"], **info):
"""
pass

def get_state(self) -> Optional[Dict]:
"""Get the state of the callback.

This method should be implemented by subclasses to return a dictionary
representation of the object's current state.

Returns:
state: State of the callback. Should be `None` if the callback does not
have any state to save (this is the default).
"""
return None

def set_state(self, state: Dict):
"""Get the state of the callback.

This method should be implemented by subclasses to restore the callback's
state based on the given dict state.

Args:
state: State of the callback.
"""
pass
justinvyu marked this conversation as resolved.
Show resolved Hide resolved


@DeveloperAPI
class CallbackList(Callback):
"""Call multiple callbacks at once."""

IS_CALLBACK_CONTAINER = True
CKPT_FILE_TMPL = "callback-states-{}.pkl"

def __init__(self, callbacks: List[Callback]):
self._callbacks = callbacks
Expand Down Expand Up @@ -343,3 +370,79 @@ def on_checkpoint(self, **info):
def on_experiment_end(self, **info):
for callback in self._callbacks:
callback.on_experiment_end(**info)

def get_state(self) -> Optional[Dict]:
"""Gets the state of all callbacks contained within this list.
If there are no stateful callbacks, then None will be returned in order
to avoid saving an unnecessary callback checkpoint file."""
state = {}
any_stateful_callbacks = False
for i, callback in enumerate(self._callbacks):
callback_state = callback.get_state()
if callback_state:
any_stateful_callbacks = True
state[i] = callback_state
if not any_stateful_callbacks:
return None
return state

def set_state(self, state: Dict):
"""Sets the state for all callbacks contained within this list.
Skipps setting state for all stateless callbacks where `get_state`
returned None."""
for i, callback in enumerate(self._callbacks):
callback_state = state.get(i, None)
if callback_state:
callback.set_state(callback_state)

def save_to_dir(self, checkpoint_dir: str, session_str: str = "default"):
"""Save the state of the callback list to the checkpoint_dir.

Args:
checkpoint_dir: directory where the checkpoint is stored.
session_str: Unique identifier of the current run session (ex: timestamp).
"""
state_dict = self.get_state()

if state_dict:
file_name = self.CKPT_FILE_TMPL.format(session_str)
tmp_file_name = f".tmp-{file_name}"
_atomic_save(
state=state_dict,
checkpoint_dir=checkpoint_dir,
file_name=file_name,
tmp_file_name=tmp_file_name,
)

def restore_from_dir(self, checkpoint_dir: str):
"""Restore the state of the list of callbacks from the checkpoint_dir.

You should check if it's possible to restore with `can_restore`
before calling this method.

Args:
checkpoint_dir: directory where the checkpoint is stored.

Raises:
RuntimeError: if unable to find checkpoint.
NotImplementedError: if the `set_state` method is not implemented.
"""
state_dict = _load_newest_checkpoint(
checkpoint_dir, self.CKPT_FILE_TMPL.format("*")
)
if not state_dict:
raise RuntimeError(
"Unable to find checkpoint in {}.".format(checkpoint_dir)
)
self.set_state(state_dict)

def can_restore(self, checkpoint_dir: str) -> bool:
"""Check if the checkpoint_dir contains the saved state for this callback list.

Returns:
can_restore: True if the checkpoint_dir contains a file of the
format `CKPT_FILE_TMPL`. False otherwise.
"""
return bool(
glob.glob(os.path.join(checkpoint_dir, self.CKPT_FILE_TMPL.format("*")))
)
13 changes: 10 additions & 3 deletions python/ray/tune/execution/trial_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def checkpoint(
trial_runner: "TrialRunner",
trial_executor: RayTrialExecutor,
search_alg: SearchAlgorithm,
callbacks: CallbackList,
force: bool = False,
):
"""Saves execution state to `self._local_checkpoint_dir`.
Expand Down Expand Up @@ -190,6 +191,9 @@ def _serialize_and_write():
search_alg.save_to_dir(
self._local_checkpoint_dir, session_str=self._session_str
)
callbacks.save_to_dir(
self._local_checkpoint_dir, session_str=self._session_str
)

checkpoint_time_start = time.monotonic()
with out_of_band_serialize_dataset():
Expand Down Expand Up @@ -347,6 +351,7 @@ def __init__(
self._search_alg = search_alg or BasicVariantGenerator()
self._scheduler_alg = scheduler or FIFOScheduler()
self.trial_executor = trial_executor or RayTrialExecutor()
self._callbacks = CallbackList(callbacks or [])
self._insufficient_resources_manager = _InsufficientResourcesManager()
self._pending_trial_queue_times = {}

Expand Down Expand Up @@ -469,8 +474,6 @@ def __init__(
TrialRunner.CKPT_FILE_TMPL.format(self._session_str),
)

self._callbacks = CallbackList(callbacks or [])

if checkpoint_period is None:
checkpoint_period = os.getenv("TUNE_GLOBAL_CHECKPOINT_S", "auto")

Expand Down Expand Up @@ -751,6 +754,7 @@ def checkpoint(self, force: bool = False):
trial_runner=self,
trial_executor=self.trial_executor,
search_alg=self._search_alg,
callbacks=self._callbacks,
force=force,
)

Expand Down Expand Up @@ -795,10 +799,13 @@ def resume(
# 1. Restore trial runner state
self.__setstate__(runner_state["runner_data"])

# 2. Restore search algorithm state
# 2. Restore search algorithm and callback state
if self._search_alg.has_checkpoint(self._local_checkpoint_dir):
self._search_alg.restore_from_dir(self._local_checkpoint_dir)

if self._callbacks.can_restore(self._local_checkpoint_dir):
self._callbacks.restore_from_dir(self._local_checkpoint_dir)

# 3. Load trial table from experiment checkpoint
trials = []
for trial_json_state in runner_state["checkpoints"]:
Expand Down
59 changes: 59 additions & 0 deletions python/ray/tune/tests/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import pytest
from typing import Dict, Optional

from ray.tune.callback import Callback, CallbackList


class StatefulCallback(Callback):
CKPT_FILE_TMPL = "test-callback-state-{}.json"

def __init__(self):
self.counter = 0

def on_trial_result(self, iteration, trials, trial, result, **info):
self.counter += 1

def get_state(self) -> Optional[Dict]:
return {"counter": self.counter}

def set_state(self, state: Dict):
self.counter = state["counter"]


def test_callback_list_with_stateful_callback(tmp_path):
"""Checks that a callback list saves and restores all callbacks contained
inside it."""

callbacks = CallbackList([Callback(), StatefulCallback()])
for i in range(3):
callbacks.on_trial_result(iteration=i, trials=None, trial=None, result=None)

callbacks.save_to_dir(str(tmp_path))

assert list(tmp_path.glob(CallbackList.CKPT_FILE_TMPL.format("*")))
assert callbacks.can_restore(str(tmp_path))

restored_callbacks = CallbackList([Callback(), StatefulCallback()])
restored_callbacks.restore_from_dir(str(tmp_path))

assert restored_callbacks._callbacks[1].counter == 3


def test_callback_list_without_stateful_callback(tmp_path):
"""If no callbacks within a CallbackList are stateful, then nothing
should be saved."""

callbacks = CallbackList([Callback(), Callback()])
callbacks.save_to_dir(str(tmp_path))

assert not list(tmp_path.glob(CallbackList.CKPT_FILE_TMPL.format("*")))
assert not callbacks.can_restore(str(tmp_path))

with pytest.raises(RuntimeError):
callbacks.restore_from_dir(str(tmp_path))


if __name__ == "__main__":
import sys

sys.exit(pytest.main(["-v", __file__]))
24 changes: 24 additions & 0 deletions python/ray/tune/tests/test_trial_runner_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ray.tune.search.search_generator import SearchGenerator
from ray.tune.syncer import SyncConfig, Syncer
from ray.tune.tests.tune_test_util import TrialResultObserver
from ray.tune.tests.test_callbacks import StatefulCallback


class MyCallbacks(DefaultCallbacks):
Expand Down Expand Up @@ -402,6 +403,29 @@ def num_running_trials():
count = Counter(evaluated)
assert all(v <= 3 for v in count.values())

def testCallbackSaveRestore(self):
"""Check that experiment state save + restore handles stateful callbacks."""
ray.init(num_cpus=2)
runner = TrialRunner(
local_checkpoint_dir=self.tmpdir,
callbacks=[StatefulCallback()],
trial_executor=RayTrialExecutor(resource_manager=self._resourceManager()),
)
runner.add_trial(Trial("__fake", stub=True))
for i in range(3):
runner._callbacks.on_trial_result(
iteration=i, trials=None, trial=None, result=None
)
runner.checkpoint(force=True)
callback = StatefulCallback()
runner2 = TrialRunner(
local_checkpoint_dir=self.tmpdir,
callbacks=[callback],
)
assert callback.counter == 0
runner2.resume()
assert callback.counter == 3

def testTrialErrorResumeFalse(self):
ray.init(num_cpus=3, local_mode=True, include_dashboard=False)
runner = TrialRunner(
Expand Down