-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Tune] Save and restore stateful callbacks as part of experiment checkpoint #31957
Merged
gjoliver
merged 15 commits into
ray-project:master
from
justinvyu:tune/stateful_callbacks
Feb 1, 2023
Merged
Changes from 11 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
884df77
Add `Restorable` interface
justinvyu 5a54410
Make callback restorable
justinvyu 5cd27e3
Add unit tests for callback save/restore
justinvyu 4f72660
Save/restore callback state automatically in Tune loop
justinvyu 3f018d0
Fix lint
justinvyu 15d2f8b
Add unit test for trial runner callback save/resume
justinvyu 91a1872
Add docstrings (with help from chatgpt)
justinvyu b751897
Add to bazel build file
justinvyu 7fc6db9
Fix tests failing
justinvyu 563cfc9
Fix entrypoint for test_callbacks
justinvyu 3ac7c9b
Merge branch 'master' of https://github.com/ray-project/ray into tune…
justinvyu c1a07d2
Remove restorable interface
justinvyu 52ac0db
Merge branch 'master' of https://github.com/ray-project/ray into tune…
justinvyu 97197e2
Don't allow users to save individual callbacks, only the full list
justinvyu 1dd2d52
Remove irrelevant tests
justinvyu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import pytest | ||
from typing import Dict, Optional | ||
|
||
from ray.tune.callback import Callback, CallbackList | ||
|
||
|
||
class StatefulCallback(Callback): | ||
CKPT_FILE_TMPL = "test-callback-state-{}.json" | ||
|
||
def __init__(self): | ||
self.counter = 0 | ||
|
||
def on_trial_result(self, iteration, trials, trial, result, **info): | ||
self.counter += 1 | ||
|
||
def get_state(self) -> Optional[Dict]: | ||
return {"counter": self.counter} | ||
|
||
def set_state(self, state: Dict): | ||
self.counter = state["counter"] | ||
|
||
|
||
def test_stateful_callback_save_and_restore(tmp_path): | ||
"""Checks that a stateful callback can be saved to a directory and restored with | ||
the right state.""" | ||
|
||
callback = StatefulCallback() | ||
for i in range(3): | ||
callback.on_trial_result(i, None, None, None) | ||
callback.save_to_dir(str(tmp_path)) | ||
assert list(tmp_path.glob(StatefulCallback.CKPT_FILE_TMPL.format("*"))) | ||
assert callback.can_restore(str(tmp_path)) | ||
|
||
restored_callback = StatefulCallback() | ||
restored_callback.restore_from_dir(str(tmp_path)) | ||
assert restored_callback.counter == 3 | ||
|
||
|
||
def test_stateless_callback_save_and_restore(tmp_path): | ||
"""Checks that proper errors are raised/handled when saving/restoring a | ||
stateless callback (i.e. one that doesn't implement get/set_state).""" | ||
|
||
class StatelessCallback(Callback): | ||
def handle_save_error(self, error: Exception): | ||
assert isinstance(error, NotImplementedError) | ||
|
||
callback = StatelessCallback() | ||
callback.save_to_dir(str(tmp_path)) | ||
|
||
assert not list(tmp_path.glob(StatelessCallback.CKPT_FILE_TMPL.format("*"))) | ||
assert not callback.can_restore(str(tmp_path)) | ||
with pytest.raises(RuntimeError): | ||
callback.restore_from_dir(str(tmp_path)) | ||
|
||
|
||
def test_callback_list_with_stateful_callback(tmp_path): | ||
"""Checks that a callback list saves and restores all callbacks contained | ||
inside it.""" | ||
|
||
callbacks = CallbackList([Callback(), StatefulCallback()]) | ||
for i in range(3): | ||
callbacks.on_trial_result(iteration=i, trials=None, trial=None, result=None) | ||
|
||
callbacks.save_to_dir(str(tmp_path)) | ||
|
||
assert list(tmp_path.glob(CallbackList.CKPT_FILE_TMPL.format("*"))) | ||
assert callbacks.can_restore(str(tmp_path)) | ||
|
||
restored_callbacks = CallbackList([Callback(), StatefulCallback()]) | ||
restored_callbacks.restore_from_dir(str(tmp_path)) | ||
|
||
assert restored_callbacks._callbacks[1].counter == 3 | ||
|
||
|
||
def test_callback_list_without_stateful_callback(tmp_path): | ||
"""If no callbacks within a CallbackList are stateful, then nothing | ||
should be saved.""" | ||
|
||
callbacks = CallbackList([Callback(), Callback()]) | ||
callbacks.save_to_dir(str(tmp_path)) | ||
|
||
assert not list(tmp_path.glob(CallbackList.CKPT_FILE_TMPL.format("*"))) | ||
assert not callbacks.can_restore(str(tmp_path)) | ||
|
||
with pytest.raises(RuntimeError): | ||
callbacks.restore_from_dir(str(tmp_path)) | ||
|
||
|
||
if __name__ == "__main__": | ||
import sys | ||
|
||
sys.exit(pytest.main(["-v", __file__])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import abc | ||
import glob | ||
import os | ||
from typing import Dict, Optional | ||
|
||
from ray.tune.utils.util import _atomic_save, _load_newest_checkpoint | ||
|
||
|
||
class Restorable(abc.ABC): | ||
"""Interface for an object that can save and restore state to a directory. | ||
|
||
The object's state will be saved as a file of the form `CKPT_FILE_TMPL`. | ||
When implementing this interface, be sure to change `CKPT_FILE_TMPL` to | ||
something unique to the object being stored. | ||
""" | ||
|
||
CKPT_FILE_TMPL = "restorable-state-{}.pkl" | ||
|
||
def get_state(self) -> Optional[Dict]: | ||
"""Get the state of the object. | ||
|
||
This method should be implemented by subclasses to return a dictionary | ||
representation of the object's current state. | ||
|
||
Returns: | ||
state: State of the object. Should be `None` if the object does not | ||
have any state to save. | ||
""" | ||
raise NotImplementedError | ||
|
||
def set_state(self, state: Dict): | ||
"""Get the state of the object. | ||
|
||
This method should be implemented by subclasses to restore the object's | ||
state based on the given dict state. | ||
|
||
Args: | ||
state: State of the object. | ||
""" | ||
raise NotImplementedError | ||
|
||
def handle_save_error(self, error: Exception): | ||
"""Handle error occurred during saving. | ||
|
||
For example, this can be used to log a warning if `get_state` isn't implemented. | ||
|
||
Args: | ||
error: The exception that occurred during saving. | ||
""" | ||
pass | ||
|
||
def handle_restore_error(self, error: Exception): | ||
"""Handle error occurred during restoring. | ||
|
||
For example, this can be used to log a warning if `set_state` isn't implemented. | ||
|
||
Args: | ||
error: The exception that occurred during restoring. | ||
""" | ||
pass | ||
|
||
def can_restore(self, checkpoint_dir: str) -> bool: | ||
"""Check if the checkpoint_dir contains the saved state for this object. | ||
|
||
Returns: | ||
can_restore: True if the checkpoint_dir contains a file of the | ||
format `CKPT_FILE_TMPL`. False otherwise. | ||
""" | ||
return bool( | ||
glob.glob(os.path.join(checkpoint_dir, self.CKPT_FILE_TMPL.format("*"))) | ||
) | ||
|
||
def save_to_dir(self, checkpoint_dir: str, session_str: str = "default"): | ||
"""Save the state of the object to the checkpoint_dir. | ||
|
||
Args: | ||
checkpoint_dir: directory where the checkpoint is stored. | ||
session_str: Unique identifier of the current run session (ex: timestamp). | ||
|
||
Raises: | ||
NotImplementedError: if the `get_state` method is not implemented. | ||
""" | ||
try: | ||
state_dict = self.get_state() | ||
except NotImplementedError as e: | ||
self.handle_save_error(e) | ||
state_dict = None | ||
|
||
if state_dict: | ||
file_name = self.CKPT_FILE_TMPL.format(session_str) | ||
tmp_file_name = f".tmp-{file_name}" | ||
try: | ||
_atomic_save( | ||
state=state_dict, | ||
checkpoint_dir=checkpoint_dir, | ||
file_name=file_name, | ||
tmp_file_name=tmp_file_name, | ||
) | ||
except Exception as e: | ||
self.handle_save_error(e) | ||
|
||
def restore_from_dir(self, checkpoint_dir: str): | ||
"""Restore the state of the object from the checkpoint_dir. | ||
|
||
You should check if it's possible to restore with `can_restore` | ||
before calling this method. | ||
|
||
Args: | ||
checkpoint_dir: directory where the checkpoint is stored. | ||
|
||
Raises: | ||
RuntimeError: if unable to find checkpoint. | ||
NotImplementedError: if the `set_state` method is not implemented. | ||
""" | ||
state_dict = _load_newest_checkpoint( | ||
checkpoint_dir, self.CKPT_FILE_TMPL.format("*") | ||
) | ||
if not state_dict: | ||
raise RuntimeError( | ||
"Unable to find checkpoint in {}.".format(checkpoint_dir) | ||
) | ||
try: | ||
self.set_state(state_dict) | ||
except Exception as e: | ||
self.handle_restore_error(e) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also check for
AttributeError
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have replaced the
get_state
with a default of returning None (stateless). No need to check for this error anymore. I think an attribute error should get raised immediately instead of caught.