From 603d5523ed56f3776a37eb7b27cac30fa4e9c08e Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Fri, 2 Aug 2024 10:42:26 +0200 Subject: [PATCH] [RLlib] Fix WandB metric overlap after restore from checkpoint. (#46897) --- python/ray/air/integrations/wandb.py | 5 +++-- python/ray/air/tests/mocked_wandb_integration.py | 2 +- .../checkpoints/continue_training_from_checkpoint.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/python/ray/air/integrations/wandb.py b/python/ray/air/integrations/wandb.py index 0511b4385f341..fcd683ec0e7f7 100644 --- a/python/ray/air/integrations/wandb.py +++ b/python/ray/air/integrations/wandb.py @@ -15,6 +15,7 @@ from ray._private.storage import _load_class from ray.air import session from ray.air._internal import usage as air_usage +from ray.air.constants import TRAINING_ITERATION from ray.air.util.node import _force_on_current_node from ray.train._internal.syncer import DEFAULT_SYNC_TIMEOUT from ray.tune.experiment import Trial @@ -166,7 +167,7 @@ def _setup_wandb( project = _get_wandb_project(kwargs.pop("project", None)) group = kwargs.pop("group", os.environ.get(WANDB_GROUP_ENV_VAR)) - # remove unpickleable items + # Remove unpickleable items. _config = _clean_log(_config) wandb_init_kwargs = dict( @@ -415,7 +416,7 @@ def run(self): log, config_update = self._handle_result(item_content) try: self._wandb.config.update(config_update, allow_val_change=True) - self._wandb.log(log) + self._wandb.log(log, step=log.get(TRAINING_ITERATION)) except urllib.error.HTTPError as e: # Ignore HTTPError. Missing a few data points is not a # big issue, as long as things eventually recover. diff --git a/python/ray/air/tests/mocked_wandb_integration.py b/python/ray/air/tests/mocked_wandb_integration.py index 323ee25153d44..6ed0983450332 100644 --- a/python/ray/air/tests/mocked_wandb_integration.py +++ b/python/ray/air/tests/mocked_wandb_integration.py @@ -65,7 +65,7 @@ def init(self, *args, **kwargs): return mock - def log(self, data): + def log(self, data, step=None): try: json_dumps_safer(data) except Exception: diff --git a/rllib/examples/checkpoints/continue_training_from_checkpoint.py b/rllib/examples/checkpoints/continue_training_from_checkpoint.py index 6eff1133e772c..6fff3d4328784 100644 --- a/rllib/examples/checkpoints/continue_training_from_checkpoint.py +++ b/rllib/examples/checkpoints/continue_training_from_checkpoint.py @@ -241,7 +241,7 @@ def on_train_result(self, *, algorithm, metrics_logger, result, **kwargs): test_results = test_algo.train() assert ( test_results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] >= args.stop_reward_crash - ) + ), test_results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] # Stop the test algorithm again. test_algo.stop()