Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use global_step as the x-axis for wandb #558

Merged
merged 15 commits into from
Mar 6, 2022

Conversation

vwxyzjn
Copy link
Contributor

@vwxyzjn vwxyzjn commented Mar 5, 2022

  • I have marked all applicable categories:
    • exception-raising fix
    • algorithm implementation fix
    • documentation modification
    • new feature
  • I have reformatted the code using make format (required)
  • I have checked the code using make commit-checks (required)
  • If applicable, I have mentioned the relevant/related issue(s)
  • If applicable, I have listed every items in this Pull Request below

Tianshou already supports W&B logging via #426. The current logging solution uses two custom x-axises train/env_step and test/env_step. Such usage might be less desirable because

  1. train/env_step and test/env_step share virtually the same values, so we should use the same key such as global_step; with global_step as the x-axis we can still see the train/reward and test_reward as the y-axis,
  2. it's hard to compare tianshou's experiments with those from SB3 and CleanRL, which have adopted global_step as the common x-axis (see Support experiment tracking with W&B DLR-RM/rl-baselines3-zoo#213).

To help address this issue, this PR uses global_step as the x-axis for wandb logging. Additionally, this PR allows the users to override the default wandb project via environment variables like:

WANDB_PROJECT=myproject python3 atari_dqn.py --task "BreakoutNoFrameskip-v4" --test-num 100 --logger wandb

Alternatives considered

An alternative plan is to remove the WandbLogger altogether and instead use wandb's tensorboard integration like

wandb.init(..., sync_tensorboard=True)

While this is possible, WandbLogger currently does more such as resume training, so removing it is a bit more complicated.

@Trinkle23897
Copy link
Collaborator

Trinkle23897 commented Mar 5, 2022

WandbLogger currently does more such as resume training

TensorboardLogger also has this function. So I think that's fine to create something like

class WandbLogger(TensorboardLogger):
  def __init__(self, *args, **kwargs):
    wandb.init(..., sync_tensorboard=True)
    super().__init__(*args, **kwargs)

@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented Mar 6, 2022

WandbLogger currently does more such as resume training

TensorboardLogger also has this function. So I think that's fine to create something like

class WandbLogger(TensorboardLogger):
  def __init__(self, *args, **kwargs):
    wandb.init(..., sync_tensorboard=True)
    super().__init__(*args, **kwargs)

There is an issue with this approach. The SummaryWriter needs to be initialized after the wandb.init(..., sync_tensorboard=True), which requires the refactoring from the TensorboardLogger. Maybe we should revert the changes back?

@Trinkle23897
Copy link
Collaborator

Trinkle23897 commented Mar 6, 2022

How about this:

# logger/wandb_init.py
import wandb
wandb.init(..., sync_tensorboard=True)
# logger/wandb.py
# from tianshou.utils.logger import wandb_init
from tianshou.utils.logger.tensorboard import TensorboardLogger

class WandbLogger(TensorboardLogger):
  pass
# utils/__init__.py
do not import wandb_init here

and in main.py:

...
from tianshou.utils.logger import wandb_init, WandbLogger
...
if __name__ == "__main__":
  ...

@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented Mar 6, 2022

@Trinkle23897
Copy link
Collaborator

Yeah, I mean we can replace this functionality with a simple import.

@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented Mar 6, 2022

Yeah, I mean we can replace this functionality with a simple import.

XD didn't complete my message. I was writing but it can't be easily applied here because the logger has other utilities like save and resume data.

How about something like

    if args.logger == "wandb":
        logger = WandbLogger(
            save_interval=1,
            name=log_name,
            run_id=args.resume_id,
            config=args,
        )
    writer = SummaryWriter(log_path)
    writer.add_text("args", str(args))
    if args.logger == "wandb":
        logger.load(writer)

and in the logger.load we basically load the TensoboardLogger

@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented Mar 6, 2022

Per conversation with @Trinkle23897, the latest code adopts the following style.

    if args.logger == "wandb":
        logger = WandbLogger(
            save_interval=1,
            name=f"{args.task}__{log_name}__{args.seed}__{int(time.time())}",
            run_id=args.resume_id,
            config=args,
        )
    writer = SummaryWriter(log_path)
    writer.add_text("args", str(args))
    if args.logger == "tensorboard":
        logger = TensorboardLogger(writer)
    if args.logger == "wandb":
        logger.load(writer)

https://wandb.ai/costa-huang/tianshou/runs/uktkei7h?workspace=user-costa-huang tracks this run.

@codecov-commenter
Copy link

codecov-commenter commented Mar 6, 2022

Codecov Report

Merging #558 (ac68423) into master (2377f2f) will decrease coverage by 0.03%.
The diff coverage is 81.81%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #558      +/-   ##
==========================================
- Coverage   93.88%   93.85%   -0.04%     
==========================================
  Files          64       64              
  Lines        4368     4376       +8     
==========================================
+ Hits         4101     4107       +6     
- Misses        267      269       +2     
Flag Coverage Δ
unittests 93.85% <81.81%> (-0.04%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
tianshou/utils/logger/wandb.py 51.92% <81.81%> (+4.19%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 2377f2f...ac68423. Read the comment docs.

Trinkle23897
Trinkle23897 previously approved these changes Mar 6, 2022
Copy link
Contributor Author

@vwxyzjn vwxyzjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything looks good except I think we should add algo_name to the args variable. Also, do you have a tracked run?

examples/atari/atari_ppo.py Outdated Show resolved Hide resolved
@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented Mar 6, 2022

image

It works now.

@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented Mar 6, 2022

All good on my end

@Trinkle23897 Trinkle23897 merged commit df3d7f5 into thu-ml:master Mar 6, 2022
@vwxyzjn vwxyzjn deleted the new-wandb branch March 6, 2022 23:56
BFAnas pushed a commit to BFAnas/tianshou that referenced this pull request May 5, 2024
* Use `global_step` as the x-axis for wandb
* Use Tensorboard SummaryWritter as core with `wandb.init(..., sync_tensorboard=True)`
* Update all atari examples with wandb

Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants