Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential bugs in RecordEpisodeStatistics #454

Open
3 tasks done
williamd4112 opened this issue Apr 2, 2024 · 0 comments
Open
3 tasks done

Potential bugs in RecordEpisodeStatistics #454

williamd4112 opened this issue Apr 2, 2024 · 0 comments

Comments

@williamd4112
Copy link

williamd4112 commented Apr 2, 2024

Problem Description

Checklist

Current Behavior

In ppo_rnd_envpool.py (also ppo_atari_envpoo.py), the implementation of RecordEpisodeStatistics will accumulate the rewards after time-limit truncation since the self.episode_returns is only masked by info["terminated"]. This means that in Atari, the returns of two independent rounds (i.e., one round ends when the agent loses all of its lives) will be accumulated if the previous round gets resets due to time-limit truncation.

The following is what I observe when training using envpool with max_episode_steps=27000 (default value in envpool).
Here is how I log (adapted from this line

for idx, d in enumerate(done):
     log_rewards[idx].append(reward[idx])
     if info["terminated"][idx]:
          avg_returns.append(info["r"][idx])
          print(f`Env {idx} finishes a round with length {info['l'][idx]} and score {info['r'][idx]})
          log_rewards[idx] = []

Then there are the logs I got

Env 0 finishes a round with length 54012 and score 1900
...
Env 0 finishes a round with length 81016 and score 4900

It's problematic since info["l"][idx] should not exceed 27000. I checked that when the timestep hits 27000, the environment will be reset. This means the scores across two rounds are summed up.

Expected Behavior

Expect the game scores is the sum of rewards over all the lives in one round.

Possible Solution

Should we change this line) to:

self.episode_returns *= 1 - (infos["terminated"] | infos["TimeLimit.Truncated"])

Steps to Reproduce

Run the following script:

import gym
import numpy as np
import envpool

is_legacy_gym = True

# From: https://github.com/sail-sg/envpool/blob/main/examples/cleanrl_examples/ppo_atari_envpool.py
class RecordEpisodeStatistics(gym.Wrapper):

  def __init__(self, env, deque_size=100):
    super(RecordEpisodeStatistics, self).__init__(env)
    self.num_envs = getattr(env, "num_envs", 1)
    self.episode_returns = None
    self.episode_lengths = None
    # get if the env has lives
    self.has_lives = False
    env.reset()
    info = env.step(np.zeros(self.num_envs, dtype=int))[-1]
    if info["lives"].sum() > 0:
      self.has_lives = True
      print("env has lives")

  def reset(self, **kwargs):
    if is_legacy_gym:
      observations = super(RecordEpisodeStatistics, self).reset(**kwargs)
    else:
      observations, _ = super(RecordEpisodeStatistics, self).reset(**kwargs)
    self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
    self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
    self.lives = np.zeros(self.num_envs, dtype=np.int32)
    self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)
    self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
    return observations

  def step(self, action):
    if is_legacy_gym:
      observations, rewards, dones, infos = super(
        RecordEpisodeStatistics, self
      ).step(action)
    else:
      observations, rewards, term, trunc, infos = super(
        RecordEpisodeStatistics, self
      ).step(action)
      dones = term + trunc
    self.episode_returns += infos["reward"]
    self.episode_lengths += 1
    self.returned_episode_returns[:] = self.episode_returns
    self.returned_episode_lengths[:] = self.episode_lengths
    all_lives_exhausted = infos["lives"] == 0
    if self.has_lives:
      self.episode_returns *= 1 - all_lives_exhausted
      self.episode_lengths *= 1 - all_lives_exhausted
    else:
      self.episode_returns *= 1 - dones
      self.episode_lengths *= 1 - dones
    infos["r"] = self.returned_episode_returns
    infos["l"] = self.returned_episode_lengths
    return (
      observations,
      rewards,
      dones,
      infos,
    )

# From: https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_atari_envpool.py
# class RecordEpisodeStatistics(gym.Wrapper):
#     def __init__(self, env, deque_size=100):
#         super().__init__(env)
#         self.num_envs = getattr(env, "num_envs", 1)
#         self.episode_returns = None
#         self.episode_lengths = None

#     def reset(self, **kwargs):
#         observations = super().reset(**kwargs)
#         self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
#         self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
#         self.lives = np.zeros(self.num_envs, dtype=np.int32)
#         self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)
#         self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
#         return observations

#     def step(self, action):
#         observations, rewards, dones, infos = super().step(action)
#         self.episode_returns += infos["reward"]
#         self.episode_lengths += 1
#         self.returned_episode_returns[:] = self.episode_returns
#         self.returned_episode_lengths[:] = self.episode_lengths       
#         self.episode_returns *= 1 - infos["terminated"]
#         self.episode_lengths *= 1 - infos["terminated"]
#         infos["r"] = self.returned_episode_returns
#         infos["l"] = self.returned_episode_lengths
#         return (
#             observations,
#             rewards,
#             dones,
#             infos,
#         )


if __name__ == "__main__":
  
  np.random.seed(1)
  
  envs = envpool.make(
    "UpNDown-v5",
    env_type="gym",
    num_envs=1,
    episodic_life=True,  # Espeholt et al., 2018, Tab. G.1
    repeat_action_probability=0,  # Hessel et al., 2022 (Muesli) Tab. 10
    full_action_space=False,  # Espeholt et al., 2018, Appendix G., "Following related work, experts use game-specific action sets."
    max_episode_steps=30, # Set as 50 to hit timelimit faster
    reward_clip=True,
    seed=1,
  )
  envs = RecordEpisodeStatistics(envs)
 
  num_episodes = 2

  episode_count = 0
  cur_episode_len = 0
  cur_episode_return = 0

  my_episode_returns = []
  my_episode_lens = []

  # Track episode returns here to compare with the ones recorded with `RecordEpisodeStatistics`
  recorded_episode_returns = []
  recorded_episode_lens = []
  
  obs = envs.reset()
  while episode_count < num_episodes:   
      action = np.random.randint(0, envs.action_space.n, 1)
      obs, reward, done, info = envs.step(action)
      cur_episode_return += info["reward"][0]
      cur_episode_len += 1
      print(f"Ep={episode_count}, EpStep={cur_episode_len}, Return={info['r']}, MyReturn={cur_episode_return}, Terminated={info['terminated']}, Timeout={info['TimeLimit.truncated']}, Lives={info['lives']}")
      
      # info["terminated"] = True: Game over.
      # info["TimeLimit.truncated"] = True: Timeout, the environment will be reset (so the episode return should be reset too)
      if info["terminated"][0] or info["TimeLimit.truncated"][0]:
        recorded_episode_returns.append(info["r"][0]) # Append the episode return recorded in `RecordEpisodeStatistics`
        recorded_episode_lens.append(info["l"][0]) # Append the episode length recorded in `RecordEpisodeStatistics`
        my_episode_returns.append(cur_episode_return)
        my_episode_lens.append(cur_episode_len)
        print(f"Episode {episode_count}'s length is {cur_episode_len} (terminated={info['terminated']}, timeout={info['TimeLimit.truncated']})")
        
        episode_count += 1
        cur_episode_return *= 1 - (info["terminated"][0] | info["TimeLimit.truncated"][0])
        cur_episode_len *= 1 - (info["terminated"][0] | info["TimeLimit.truncated"][0])

  for episode_idx in range(num_episodes):
    print(f"Episode {episode_idx}'s return is supposed to be {my_episode_returns[episode_idx]}, but the wrapper `RecordEpisodeStatistics` gives {recorded_episode_returns[episode_idx]}")
    print(f"Episode {episode_idx}'s len is supposed to be {my_episode_lens[episode_idx]}, but the wrapper `RecordEpisodeStatistics` gives {recorded_episode_lens[episode_idx]}")

You should see the output:

env has lives
Ep=0, EpStep=1, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=2, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=3, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=4, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=5, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=6, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=7, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=8, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=9, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=10, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=11, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=12, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=13, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=14, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=15, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=16, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=17, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=18, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=19, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=20, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=21, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=22, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=23, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=24, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=25, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=26, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=27, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=28, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=29, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[ True], Lives=[5]
Episode 0's length is 29 (terminated=[0], timeout=[ True])
Ep=1, EpStep=1, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=2, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=3, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=4, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=5, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=6, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=7, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=8, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=9, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=10, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=11, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=12, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=13, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=14, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=15, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=16, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=17, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=18, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=19, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=20, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=21, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=22, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=23, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=24, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=25, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=26, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=27, Return=[20.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=28, Return=[20.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=29, Return=[20.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=30, Return=[20.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=31, Return=[20.], MyReturn=10.0, Terminated=[0], Timeout=[ True], Lives=[5]
Episode 1's length is 31 (terminated=[0], timeout=[ True])
Episode 0's return is supposed to be 10.0, but the wrapper `RecordEpisodeStatistics` gives 10.0
Episode 0's len is supposed to be 29, but the wrapper `RecordEpisodeStatistics` gives 29
Episode 1's return is supposed to be 10.0, but the wrapper `RecordEpisodeStatistics` gives 20.0
Episode 1's len is supposed to be 31, but the wrapper `RecordEpisodeStatistics` gives 60

See the above example's output:

Ep=0, EpStep=29, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[ True], Lives=[5]
Episode 0's length is 29 (terminated=[0], timeout=[ True])
Ep=1, EpStep=1, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]

The return in the new episode (Ep=1) is not reset to zero but is carried from the return in the old episode. The expected behavior is to reset the return counter to zero upon timeout.

@vwxyzjn

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant