Skip to content

Commit

Permalink
[RLlib] Fix WandB metric overlap after restore from checkpoint. (#46897)
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored Aug 2, 2024
1 parent 1df0c8e commit 603d552
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 4 deletions.
5 changes: 3 additions & 2 deletions python/ray/air/integrations/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ray._private.storage import _load_class
from ray.air import session
from ray.air._internal import usage as air_usage
from ray.air.constants import TRAINING_ITERATION
from ray.air.util.node import _force_on_current_node
from ray.train._internal.syncer import DEFAULT_SYNC_TIMEOUT
from ray.tune.experiment import Trial
Expand Down Expand Up @@ -166,7 +167,7 @@ def _setup_wandb(
project = _get_wandb_project(kwargs.pop("project", None))
group = kwargs.pop("group", os.environ.get(WANDB_GROUP_ENV_VAR))

# remove unpickleable items
# Remove unpickleable items.
_config = _clean_log(_config)

wandb_init_kwargs = dict(
Expand Down Expand Up @@ -415,7 +416,7 @@ def run(self):
log, config_update = self._handle_result(item_content)
try:
self._wandb.config.update(config_update, allow_val_change=True)
self._wandb.log(log)
self._wandb.log(log, step=log.get(TRAINING_ITERATION))
except urllib.error.HTTPError as e:
# Ignore HTTPError. Missing a few data points is not a
# big issue, as long as things eventually recover.
Expand Down
2 changes: 1 addition & 1 deletion python/ray/air/tests/mocked_wandb_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def init(self, *args, **kwargs):

return mock

def log(self, data):
def log(self, data, step=None):
try:
json_dumps_safer(data)
except Exception:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def on_train_result(self, *, algorithm, metrics_logger, result, **kwargs):
test_results = test_algo.train()
assert (
test_results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] >= args.stop_reward_crash
)
), test_results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN]
# Stop the test algorithm again.
test_algo.stop()

Expand Down

0 comments on commit 603d552

Please sign in to comment.