Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/dreamer multiple env #64

Merged
merged 17 commits into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions sheeprl/algos/dreamer_v1/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,13 +377,24 @@ def __init__(

self.init_states()

def init_states(self) -> None:
"""
Initialize the states and the actions for the ended environments.
def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None:
"""Initialize the states and the actions for the ended environments.

Args:
reset_envs (Optional[Sequence[int]], optional): which environments' states to reset.
If None, then all environments' states are reset.
Defaults to None.
"""
self.actions = torch.zeros(1, self.num_envs, np.sum(self.actions_dim), device=self.device)
self.stochastic_state = torch.zeros(1, self.num_envs, self.stochastic_size, device=self.device)
self.recurrent_state = torch.zeros(1, self.num_envs, self.recurrent_state_size, device=self.device)
if reset_envs is None or len(reset_envs) == 0:
self.actions = torch.zeros(1, self.num_envs, np.sum(self.actions_dim), device=self.device)
self.recurrent_state = torch.zeros(1, self.num_envs, self.recurrent_state_size, device=self.device)
self.stochastic_state = torch.zeros(
1, self.num_envs, self.stochastic_size, device=self.device
)
else:
self.actions[:, reset_envs] = torch.zeros_like(self.actions[:, reset_envs])
self.recurrent_state[:, reset_envs] = torch.zeros_like(self.recurrent_state[:, reset_envs])
self.stochastic_state[:, reset_envs] = torch.zeros_like(self.stochastic_state[:, reset_envs])

def get_exploration_action(self, obs: Tensor, is_continuous: bool) -> Tensor:
"""
Expand Down
122 changes: 77 additions & 45 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import os
import pathlib
import time
Expand All @@ -23,8 +24,9 @@
from sheeprl.algos.dreamer_v1.agent import Player, WorldModel, build_models
from sheeprl.algos.dreamer_v1.args import DreamerV1Args
from sheeprl.algos.dreamer_v1.loss import actor_loss, critic_loss, reconstruction_loss
from sheeprl.algos.dreamer_v1.utils import cnn_forward, make_env, test
from sheeprl.data.buffers import SequentialReplayBuffer
from sheeprl.algos.dreamer_v1.utils import cnn_forward, test
from sheeprl.algos.dreamer_v2.utils import make_env
from sheeprl.data.buffers import AsyncReplayBuffer
from sheeprl.utils.callback import CheckpointCallback
from sheeprl.utils.metric import MetricAggregator
from sheeprl.utils.parser import HfArgumentParser
Expand Down Expand Up @@ -340,7 +342,6 @@ def train(
def main():
parser = HfArgumentParser(DreamerV1Args)
args: DreamerV1Args = parser.parse_args_into_dataclasses()[0]
args.num_envs = 1
torch.set_num_threads(1)

# Initialize Fabric
Expand Down Expand Up @@ -396,23 +397,31 @@ def main():
log_dir = data[0]
os.makedirs(log_dir, exist_ok=True)

env: gym.Env = make_env(
args.env_id,
args.seed + rank * args.num_envs,
rank,
args,
logger.log_dir if rank == 0 else None,
"train",
# Environment setup
vectorized_env = gym.vector.SyncVectorEnv if args.sync_env else gym.vector.AsyncVectorEnv
envs = vectorized_env(
[
make_env(
args.env_id,
args.seed + rank * args.num_envs,
rank,
args,
logger.log_dir if rank == 0 else None,
"train",
)
for i in range(args.num_envs)
]
)

is_continuous = isinstance(env.action_space, gym.spaces.Box)
is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete)
action_space = envs.single_action_space
observation_space = envs.single_observation_space

is_continuous = isinstance(action_space, gym.spaces.Box)
is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete)
actions_dim = (
env.action_space.shape
if is_continuous
else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n])
action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
)
observation_shape = env.observation_space.shape
observation_shape = observation_space["rgb"].shape
clip_rewards_fn = lambda r: torch.tanh(r) if args.clip_rewards else r

world_model, actor, critic = build_models(
Expand Down Expand Up @@ -473,23 +482,24 @@ def main():
"Grads/critic": MeanMetric(sync_on_compute=False),
}
)
aggregator.to(fabric.device)
aggregator.to(fabric.device)

# Local data
buffer_size = (
args.buffer_size // int(args.num_envs * fabric.world_size * args.action_repeat) if not args.dry_run else 2
)
rb = SequentialReplayBuffer(
rb = AsyncReplayBuffer(
buffer_size,
args.num_envs,
device=fabric.device if args.memmap_buffer else "cpu",
memmap=args.memmap_buffer,
memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"),
sequential=True,
)
if args.checkpoint_path and args.checkpoint_buffer:
if isinstance(state["rb"], list) and fabric.world_size == len(state["rb"]):
rb = state["rb"][fabric.global_rank]
elif isinstance(state["rb"], SequentialReplayBuffer):
elif isinstance(state["rb"], AsyncReplayBuffer):
rb = state["rb"]
else:
raise RuntimeError(f"Given {len(state['rb'])}, but {fabric.world_size} processes are instantiated")
Expand All @@ -499,11 +509,12 @@ def main():
# Global variables
start_time = time.perf_counter()
start_step = state["global_step"] // fabric.world_size if args.checkpoint_path else 1
step_before_training = args.train_every // (fabric.world_size * args.action_repeat) if not args.dry_run else 0
num_updates = int(args.total_steps // (fabric.world_size * args.action_repeat)) if not args.dry_run else 1
learning_starts = (args.learning_starts // (fabric.world_size * args.action_repeat)) if not args.dry_run else 0
single_global_step = int(args.num_envs * fabric.world_size * args.action_repeat)
step_before_training = args.train_every // single_global_step if not args.dry_run else 0
num_updates = int(args.total_steps // single_global_step) if not args.dry_run else 1
learning_starts = (args.learning_starts // single_global_step) if not args.dry_run else 0
if args.checkpoint_path and not args.checkpoint_buffer:
learning_starts = start_step + args.learning_starts // int(fabric.world_size * args.action_repeat)
learning_starts = start_step + args.learning_starts // single_global_step
max_step_expl_decay = args.max_step_expl_decay // (args.gradient_steps * fabric.world_size)
if args.checkpoint_path:
player.expl_amount = polynomial_decay(
Expand All @@ -514,7 +525,9 @@ def main():
)

# Get the first environment observation and start the optimization
obs = torch.from_numpy(env.reset(seed=args.seed)[0]).view(args.num_envs, *observation_shape) # [N_envs, N_obs]
obs = torch.from_numpy(envs.reset(seed=args.seed)[0]["rgb"]).view(
args.num_envs, *observation_shape
) # [N_envs, N_obs]
step_data["dones"] = torch.zeros(args.num_envs, 1)
step_data["actions"] = torch.zeros(args.num_envs, np.sum(actions_dim))
step_data["rewards"] = torch.zeros(args.num_envs, 1)
Expand All @@ -524,13 +537,13 @@ def main():

for global_step in range(start_step, num_updates + 1):
# Sample an action given the observation received by the environment
if global_step < learning_starts and args.checkpoint_path is None:
real_actions = actions = np.array(env.action_space.sample())
if global_step <= learning_starts and args.checkpoint_path is None and "minedojo" not in args.env_id:
real_actions = actions = np.array(envs.action_space.sample())
if not is_continuous:
actions = np.concatenate(
[
F.one_hot(torch.tensor(act), act_dim).numpy()
for act, act_dim in zip(actions.reshape(len(actions_dim)), actions_dim)
for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim)
],
axis=-1,
)
Expand All @@ -543,14 +556,28 @@ def main():
if is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
else:
real_actions = np.array([real_act.cpu().argmax() for real_act in real_actions])
next_obs, rewards, dones, truncated, infos = env.step(real_actions.reshape(env.action_space.shape))
real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions])
next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape))
next_obs = next_obs["rgb"]
dones = np.logical_or(dones, truncated)

if (dones or truncated) and "episode" in infos:
fabric.print(f"Rank-0: global_step={global_step}, reward_env_{0}={infos['episode']['r'][0]}")
aggregator.update("Rewards/rew_avg", infos["episode"]["r"][0])
aggregator.update("Game/ep_len_avg", infos["episode"]["l"][0])
if "final_info" in infos:
for i, agent_final_info in enumerate(infos["final_info"]):
if agent_final_info is not None and "episode" in agent_final_info:
fabric.print(
f"Rank-0: global_step={global_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}"
)
aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0])
aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0])

# Save the real next observation
real_next_obs = next_obs.copy()
if "final_observation" in infos:
for idx, final_obs in enumerate(infos["final_observation"]):
if final_obs is not None:
for k, v in final_obs.items():
if k == "rgb":
real_next_obs[idx] = v

next_obs = torch.from_numpy(next_obs).view(args.num_envs, *observation_shape)
actions = torch.from_numpy(actions).view(args.num_envs, -1).float()
Expand All @@ -562,20 +589,25 @@ def main():

step_data["dones"] = dones
step_data["actions"] = actions
step_data["observations"] = obs
step_data["observations"] = real_next_obs
step_data["rewards"] = clip_rewards_fn(rewards)
rb.add(step_data[None, ...])

if dones or truncated:
obs = torch.from_numpy(env.reset(seed=args.seed)[0]).view(
args.num_envs, *observation_shape
) # [N_envs, N_obs]
step_data["dones"] = torch.zeros(args.num_envs, 1)
step_data["actions"] = torch.zeros(args.num_envs, np.sum(actions_dim))
step_data["rewards"] = torch.zeros(args.num_envs, 1)
step_data["observations"] = obs
rb.add(step_data[None, ...])
player.init_states()
# Reset and save the observation coming from the automatic reset
dones_idxes = dones.nonzero(as_tuple=True)[0].tolist()
reset_envs = len(dones_idxes)
if reset_envs > 0:
reset_data = TensorDict({}, batch_size=[reset_envs], device="cpu")
reset_data["observations"] = next_obs[dones_idxes]
reset_data["dones"] = torch.zeros(reset_envs, 1)
reset_data["actions"] = torch.zeros(reset_envs, np.sum(actions_dim))
reset_data["rewards"] = torch.zeros(reset_envs, 1)
rb.add(reset_data[None, ...], dones_idxes)
# Reset dones so that `is_first` is updated
for d in dones_idxes:
step_data["dones"][d] = torch.zeros_like(step_data["dones"][d])
# Reset internal agent states
player.init_states(dones_idxes)

step_before_training -= 1

Expand Down Expand Up @@ -642,7 +674,7 @@ def main():
replay_buffer=rb if args.checkpoint_buffer else None,
)

env.close()
envs.close()
if fabric.is_global_zero:
test(player, fabric, args)

Expand Down
Loading