Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qdagger: Reincarnate RL #344

Merged
merged 33 commits into from
Jun 9, 2023
Merged

Qdagger: Reincarnate RL #344

merged 33 commits into from
Jun 9, 2023

Conversation

vwxyzjn
Copy link
Owner

@vwxyzjn vwxyzjn commented Jan 9, 2023

Description

https://github.com/google-research/reincarnating_rl

Preliminary result
image

Need more contributors on this.

Types of changes

  • Bug fix
  • New feature
  • New algorithm
  • Documentation

Checklist:

  • I've read the CONTRIBUTION guide (required).
  • I have ensured pre-commit run --all-files passes (required).
  • I have updated the documentation and previewed the changes via mkdocs serve.
  • I have updated the tests accordingly (if applicable).

If you are adding new algorithm variants or your change could result in performance difference, you may need to (re-)run tracked experiments. See #137 as an example PR.

  • I have contacted vwxyzjn to obtain access to the openrlbenchmark W&B team (required).
  • I have tracked applicable experiments in openrlbenchmark/cleanrl with --capture-video flag toggled on (required).
  • I have added additional documentation and previewed the changes via mkdocs serve.
    • I have explained note-worthy implementation details.
    • I have explained the logged metrics.
    • I have added links to the original paper and related papers (if applicable).
    • I have added links to the PR related to the algorithm variant.
    • I have created a table comparing my results against those from reputable sources (i.e., the original paper or other reference implementation).
    • I have added the learning curves (in PNG format).
    • I have added links to the tracked experiments.
    • I have updated the overview sections at the docs and the repo
  • I have updated the tests accordingly (if applicable).

@vwxyzjn vwxyzjn added the help wanted Extra attention is needed label Jan 9, 2023
@vercel
Copy link

vercel bot commented Jan 9, 2023

The latest updates on your projects. Learn more about Vercel for Git ↗︎

Name Status Preview Comments Updated (UTC)
cleanrl ✅ Ready (Inspect) Visit Preview 💬 Add feedback Jun 9, 2023 0:53am

@vwxyzjn
Copy link
Owner Author

vwxyzjn commented Jan 12, 2023

Starting a thread on this: @richa-verma has expressed interest in helping out with this PR. Welcome, Richa! I will try to put some information below to help you get started, and happy to help further with anything you need.

The main things we are looking for are 1) single file implementations (minimal lines of code), 2) documentation explaining notable implementation details, 3) benchmarking and matching the performance of reference implementations. Please check out our contribution guide for the usual process, and #331 is a good example of how new algorithms are contributed end-to-end.

With that said, let me share more detail on the current status of this PR.

model loading: As you know, Reincarnate RL relies on prior models for training. Luckily, we already have pre-trained models on huggingface with #292. See the docs for more detail, and the colab notebook has a good demo on how to load the models.

jax vs pytorch we have both jax and PyTorch-trained models on github for DQN and atari. Feel free to work with what you prefer more.

qdagger: I have implemented qdagger_dqn_atari_jax_impalacnn.py (which uses JAX) as a proof-of-concept. Its rough flow looks as follows:

  1. Load the teacher model from huggingface
    # QDAGGER LOGIC:
    teacher_model_path = hf_hub_download(repo_id=args.teacher_policy_hf_repo, filename="dqn_atari_jax.cleanrl_model")
    teacher_model = TeacherModel(action_dim=envs.single_action_space.n)
    teacher_model_key = jax.random.PRNGKey(args.seed)
    teacher_params = teacher_model.init(teacher_model_key, obs)
    with open(teacher_model_path, "rb") as f:
    teacher_params = flax.serialization.from_bytes(teacher_params, f.read())
    teacher_model.apply = jax.jit(teacher_model.apply)
  2. Evaluate the teacher model to get its average episodic return G_T, which will be useful later
    # evaluate the teacher model
    teacher_episodic_returns = evaluate(
    teacher_model_path,
    make_env,
    args.env_id,
    eval_episodes=10,
    run_name=f"{run_name}-teacher-eval",
    Model=TeacherModel,
    epsilon=0.05,
    capture_video=False,
    )
    writer.add_scalar("charts/teacher/avg_episodic_return", np.mean(teacher_episodic_returns), 0)
  3. Then, since the pre-trained models do not contain the replay buffer data, we need to populate the replay buffer for the teacher. See A.5 Additional ablations for QDagger in the original paper for more detail
    # collect teacher data for args.teacher_steps
    # we assume we don't have access to the teacher's replay buffer
    # see Fig. A.19 in Agarwal et al. 2022 for more detail
    teacher_rb = ReplayBuffer(
    args.buffer_size,
    envs.single_observation_space,
    envs.single_action_space,
    "cpu",
    optimize_memory_usage=True,
    handle_timeout_termination=True,
    )
    obs = envs.reset()
    for global_step in track(range(args.teacher_steps), description="filling teacher's replay buffer"):
    epsilon = linear_schedule(args.start_e, args.end_e, args.teacher_steps, global_step)
    if random.random() < epsilon:
    actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
    else:
    q_values = teacher_model.apply(teacher_params, obs)
    actions = q_values.argmax(axis=-1)
    actions = jax.device_get(actions)
    next_obs, rewards, dones, infos = envs.step(actions)
    real_next_obs = next_obs.copy()
    for idx, d in enumerate(dones):
    if d:
    real_next_obs[idx] = infos[idx]["terminal_observation"]
    teacher_rb.add(obs, real_next_obs, actions, rewards, dones, infos)
    obs = next_obs
  4. The rest is to perform the offline phase (e.g., "reincarnate steps")
    # offline training phase: train the student model using the qdagger loss
    for global_step in track(range(args.offline_steps), description="offline student training"):
    data = teacher_rb.sample(args.batch_size)
    # perform a gradient-descent step
    loss, q_loss, old_val, distill_loss, q_state = update(
    q_state,
    data.observations.numpy(),
    data.actions.numpy(),
    data.next_observations.numpy(),
    data.rewards.flatten().numpy(),
    data.dones.flatten().numpy(),
    1.0,
    )
    if global_step % 100 == 0:
    writer.add_scalar("charts/offline/loss", jax.device_get(loss), global_step)
    writer.add_scalar("charts/offline/q_loss", jax.device_get(q_loss), global_step)
    writer.add_scalar("charts/offline/distill_loss", jax.device_get(distill_loss), global_step)
    if global_step % 100000 == 0:
    # evaluate the student model
    model_path = f"runs/{run_name}/{args.exp_name}-offline-{global_step}.cleanrl_model"
    with open(model_path, "wb") as f:
    f.write(flax.serialization.to_bytes(q_state.params))
    print(f"model saved to {model_path}")
    episodic_returns = evaluate(
    model_path,
    make_env,
    args.env_id,
    eval_episodes=10,
    run_name=f"{run_name}-eval",
    Model=QNetwork,
    epsilon=0.05,
    )
    print(episodic_returns)
    writer.add_scalar("charts/offline/avg_episodic_return", np.mean(episodic_returns), global_step)
  5. and online phase
    # TODO: question: do we need to start the student rb from scratch?
    rb = teacher_rb
    start_time = time.time()
    # TRY NOT TO MODIFY: start the game
    envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, False, run_name)])
    obs = envs.reset()
    episodic_returns = deque(maxlen=10)
    for global_step in track(range(args.total_timesteps), description="online student training"):
    global_step += args.offline_steps
    # ALGO LOGIC: put action logic here
    # # TODO: question: do we need to use epsilon greedy here?
    # epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
    if random.random() < epsilon:
    actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
    else:
    q_values = q_network.apply(q_state.params, obs)
    actions = q_values.argmax(axis=-1)
    actions = jax.device_get(actions)
    # TRY NOT TO MODIFY: execute the game and log data.
    next_obs, rewards, dones, infos = envs.step(actions)
    # TRY NOT TO MODIFY: record rewards for plotting purposes
    for info in infos:
    if "episode" in info.keys():
    print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
    writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
    episodic_returns.append(info["episode"]["r"])
    writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
    writer.add_scalar("charts/epsilon", epsilon, global_step)
    break
    # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
    real_next_obs = next_obs.copy()
    for idx, d in enumerate(dones):
    if d:
    real_next_obs[idx] = infos[idx]["terminal_observation"]
    rb.add(obs, real_next_obs, actions, rewards, dones, infos)
    # TRY NOT TO MODIFY: CRUCIAL step easy to overlook
    obs = next_obs
    # ALGO LOGIC: training.
    if global_step > args.learning_starts:
    if global_step % args.train_frequency == 0:
    data = rb.sample(args.batch_size)
    # perform a gradient-descent step
    if len(episodic_returns) < 10:
    distill_coeff = 1.0
    else:
    distill_coeff = max(1 - np.mean(episodic_returns) / np.mean(teacher_episodic_returns), 0)
    loss, q_loss, old_val, distill_loss, q_state = update(
    q_state,
    data.observations.numpy(),
    data.actions.numpy(),
    data.next_observations.numpy(),
    data.rewards.flatten().numpy(),
    data.dones.flatten().numpy(),
    distill_coeff
    )
    if global_step % 100 == 0:
    writer.add_scalar("losses/loss", jax.device_get(loss), global_step)
    writer.add_scalar("losses/td_loss", jax.device_get(q_loss), global_step)
    writer.add_scalar("losses/distill_loss", jax.device_get(distill_loss), global_step)
    writer.add_scalar("losses/q_values", jax.device_get(old_val).mean(), global_step)
    writer.add_scalar("charts/distill_coeff", distill_coeff, global_step)
    print(distill_coeff)
    print("SPS:", int(global_step / (time.time() - start_time)))
    writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
    # update the target network
    if global_step % args.target_network_frequency == 0:
    q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, 1))

Some further considerations & optimizations:

  1. Atari preprocessing: we have used an old set of preprocessing techniques that doesn't use sticky action, but the original paper does. The exact difference I think is highlighted here. We have some possible options to reproduce this work: 1) we can use the current set of Atari preprocessing and just reproduce the algorithm, and 2) we can run another set of benchmarks with preprocessing techniques that are aligned with the original paper, save the models to huggingface, then load these models for our reproductions, 3) possibly we can somehow take the trained checkpoints from the original paper and figure out a way to load them, but this is likely extremely ad-hoc, and I would not recommend it.
  2. Step number 3 could be sped up by leveraging multiple simulation environments.
  3. In step 5, my implementation is to directly substitute the student's replay buffer with the teacher's. Not exactly sure if this is correct... Not sure if we should build the student's replay buffer from scratch.
  4. in step 5, we could optionally add a threshold at which we no longer take any distillation from the teacher policy.

I know this is throwing a lot at you. Please let me know if you need further clarifications or pair programming :) Thanks for your interest in working on this again.

@sdpkjc
Copy link
Collaborator

sdpkjc commented May 6, 2023

3. In step 5, my implementation is to directly substitute the student's replay buffer with the teacher's. Not exactly sure if this is correct... Not sure if we should build the student's replay buffer from scratch.

In this part, my understanding is that the teacher buffer and the student buffer should be distinguished. I see that Section 4.1 of the original paper mentions the symbols of the two buffers D_T and D_S. The implementation of the original paper code does the same thing, It can be obtained from https://github.com/google-research/reincarnating_rl/blob/a1d402f48a9f8658ca6aa0ddf416ab391745ff2c/reincarnating_rl/reincarnation_dqn_agent.py#LL147C1-L186C35 is proved.(QDaggerDQNAgent inherits from ReincarnationDQNAgent)。

ReincarnationDQNAgent inherits from dopamine.jax.agents.dqn.dqn_agent.JaxDQNAgent. As you can see from the dopamine repository code, the agent creates a single buffer. https://github.com/google/dopamine/blob/81f695c1525f2774fbaa205cf19d60946b543bc9/dopamine/jax/agents/dqn/dqn_agent.py#L334

@vwxyzjn
Copy link
Owner Author

vwxyzjn commented May 6, 2023

In this part, my understanding is that the teacher buffer and the student buffer should be distinguished.

This is correct. I used the same buffer because the teacher's buffer was not saved in the hugging face's model. Then we can populate the teacher's buffer, according to "A.5 Additional ablations for QDagger".

Would you be interested in taking on this PR?

@sdpkjc
Copy link
Collaborator

sdpkjc commented May 6, 2023

I would be glad to take on this PR. 😄
Then, I plan to perfect step 5 first, by implementing the buffer of the independent student agent, and by comparing the code of the original paper to perfect the weaning of the student agent.

@sdpkjc
Copy link
Collaborator

sdpkjc commented May 8, 2023

I observed some strange bugs in the latest version of the original code. When I looked at the init commit from git, it seemed a bit more correct. I suggest using files distillation_dqn_agent.py and persistent_dqn_agent.py as a reference to implement our code.

@sdpkjc
Copy link
Collaborator

sdpkjc commented May 8, 2023

TODO: question: do we need to use epsilon greedy here?

Yes, we need to use epsilon-greedy here.
The reason is that ReincarnationDQNAgent sets its epsilon_fn to the reincarnation_linearly_decaying_epsilon function.

@sdpkjc
Copy link
Collaborator

sdpkjc commented May 25, 2023

jax_step
torch_step

@vwxyzjn
Copy link
Owner Author

vwxyzjn commented May 25, 2023

The results look really good! Great job @sdpkjc. I noticed the learning curves looked slightly different... Any ideas? Maybe it could be explained by that the teacher model in dqn_atari has 333.60 +/- 120.61 score whereas dqn_atari_jax has
291.10 +/- 116.43? Also, feel free to test out Pong and BeamRider.

@sdpkjc sdpkjc marked this pull request as ready for review June 9, 2023 11:56
Copy link
Owner Author

@vwxyzjn vwxyzjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks really amazing!!!! Feel free to merge :) Thanks so much for the PR!

@vwxyzjn
Copy link
Owner Author

vwxyzjn commented Jun 9, 2023

jax_step
torch_step

Oh I guess maybe this is one last thing. Could you add this to the docs as well? This way the user can see the comparison.

@sdpkjc
Copy link
Collaborator

sdpkjc commented Jun 9, 2023

👌

@vwxyzjn
Copy link
Owner Author

vwxyzjn commented Jun 9, 2023

Btw you don't have to generate plots like these (but now that you have them already, it's perfectly fine to leave them there). We used to generate these plots because we had to do it manually, but now we can just use the openrlbenchmark utility to generate them :)

image

@sdpkjc
Copy link
Collaborator

sdpkjc commented Jun 9, 2023

Thanks, I have generated the plots of qdagger vs dqn using openrlbenchmark and added the comparison in our wandb report.

image image

@sdpkjc sdpkjc merged commit 0b976ac into master Jun 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants