Skip to content

Commit

Permalink
Fix/minedojo (#286)
Browse files Browse the repository at this point in the history
* fix: multiple envs

* fix: multi-discrete actions

* fix: remove debug prints

* fix: remove debug prints

* fix: removed MINEDOJO_HEADLESS
  • Loading branch information
michele-milesi authored May 13, 2024
1 parent 4003506 commit 419c7ce
Show file tree
Hide file tree
Showing 24 changed files with 119 additions and 44 deletions.
4 changes: 2 additions & 2 deletions howto/register_external_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -605,9 +605,9 @@ def ext_sota_main(fabric: Fabric, cfg: Dict[str, Any]):
}
actions = agent.module(torch_obs)
if is_continuous:
real_actions = torch.cat(actions, -1).cpu().numpy()
real_actions = torch.stack(actions, -1).cpu().numpy()
else:
real_actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy()
real_actions = torch.stack([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy()
actions = torch.cat(actions, -1).cpu().numpy()

# Single environment step
Expand Down
4 changes: 2 additions & 2 deletions howto/register_new_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -603,9 +603,9 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]):
}
actions = player.get_actions(torch_obs)
if is_continuous:
real_actions = torch.cat(actions, -1).cpu().numpy()
real_actions = torch.stack(actions, -1).cpu().numpy()
else:
real_actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy()
real_actions = torch.stack([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy()
actions = torch.cat(actions, -1).cpu().numpy()

# Single environment step
Expand Down
4 changes: 2 additions & 2 deletions notebooks/dreamer_v3_imagination.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,9 @@
" real_actions = actions = player.get_actions(preprocessed_obs, mask=mask)\n",
" actions = torch.cat(actions, -1).cpu().numpy()\n",
" if is_continuous:\n",
" real_actions = torch.cat(real_actions, dim=-1).cpu().numpy()\n",
" real_actions = torch.stack(real_actions, dim=-1).cpu().numpy()\n",
" else:\n",
" real_actions = torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()\n",
" real_actions = torch.stack([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()\n",
"\n",
" step_data[\"actions\"] = actions.reshape((1, cfg.env.num_envs, -1))\n",
" rb_initial.add(step_data, validate_args=cfg.buffer.validate_args)\n",
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
torch_obs = prepare_obs(fabric, next_obs, num_envs=cfg.env.num_envs)
actions, _, values = player(torch_obs)
if is_continuous:
real_actions = torch.cat(actions, -1).cpu().numpy()
real_actions = torch.stack(actions, -1).cpu().numpy()
else:
real_actions = torch.cat([act.argmax(dim=-1) for act in actions], axis=-1).cpu().numpy()
real_actions = torch.stack([act.argmax(dim=-1) for act in actions], axis=-1).cpu().numpy()
actions = torch.cat(actions, -1).cpu().numpy()

# Single environment step
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
real_actions = actions = player.get_exploration_actions(torch_obs, mask=mask, step=policy_step)
actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
real_actions = torch.stack(real_actions, -1).cpu().numpy()
else:
real_actions = (
torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
Expand Down
15 changes: 11 additions & 4 deletions sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,11 @@ def train(
# One step of dynamic learning, which take the posterior state, the recurrent state, the action
# and the observation and compute the next recurrent, prior and posterior states
recurrent_state, posterior, _, posterior_logits, prior_logits = world_model.rssm.dynamic(
posterior, recurrent_state, data["actions"][i : i + 1], embedded_obs[i : i + 1], data["is_first"][i : i + 1]
posterior,
recurrent_state,
data["actions"][i : i + 1],
embedded_obs[i : i + 1],
data["is_first"][i : i + 1],
)
recurrent_states[i] = recurrent_state
priors_logits[i] = prior_logits
Expand Down Expand Up @@ -344,7 +348,10 @@ def train(
critic_grads = None
if cfg.algo.critic.clip_gradients is not None and cfg.algo.critic.clip_gradients > 0:
critic_grads = fabric.clip_gradients(
module=critic, optimizer=critic_optimizer, max_norm=cfg.algo.critic.clip_gradients, error_if_nonfinite=False
module=critic,
optimizer=critic_optimizer,
max_norm=cfg.algo.critic.clip_gradients,
error_if_nonfinite=False,
)
critic_optimizer.step()

Expand Down Expand Up @@ -606,10 +613,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
real_actions = actions = player.get_actions(torch_obs, mask=mask)
actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
real_actions = torch.stack(real_actions, -1).cpu().numpy()
else:
real_actions = (
torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
torch.stack([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
)

step_data["is_first"] = copy.deepcopy(np.logical_or(step_data["terminated"], step_data["truncated"]))
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/dreamer_v2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ def test(
torch_obs, greedy, {k: v for k, v in torch_obs.items() if k.startswith("mask")}
)
if player.actor.is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
real_actions = torch.stack(real_actions, -1).cpu().numpy()
else:
real_actions = torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
real_actions = torch.stack([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()

# Single environment step
obs, reward, done, truncated, _ = env.step(real_actions.reshape(env.action_space.shape))
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v3/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,7 @@ def forward(
if sampled_action == 15: # Craft action
logits[t, b][torch.logical_not(mask["mask_craft_smelt"][t, b])] = -torch.inf
elif i == 2:
mask["mask_destroy"][t, b] = mask["mask_destroy"].expand_as(logits)
mask["mask_destroy"] = mask["mask_destroy"].expand_as(logits)
mask["mask_equip_place"] = mask["mask_equip_place"].expand_as(logits)
for t in range(functional_action.shape[0]):
for b in range(functional_action.shape[1]):
Expand Down
9 changes: 6 additions & 3 deletions sheeprl/algos/dreamer_v3/dreamer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,10 @@ def train(
critic_grads = None
if cfg.algo.critic.clip_gradients is not None and cfg.algo.critic.clip_gradients > 0:
critic_grads = fabric.clip_gradients(
module=critic, optimizer=critic_optimizer, max_norm=cfg.algo.critic.clip_gradients, error_if_nonfinite=False
module=critic,
optimizer=critic_optimizer,
max_norm=cfg.algo.critic.clip_gradients,
error_if_nonfinite=False,
)
critic_optimizer.step()

Expand Down Expand Up @@ -573,10 +576,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
real_actions = actions = player.get_actions(torch_obs, mask=mask)
actions = torch.cat(actions, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, dim=-1).cpu().numpy()
real_actions = torch.stack(real_actions, dim=-1).cpu().numpy()
else:
real_actions = (
torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
torch.stack([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
)

step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1))
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/dreamer_v3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ def test(
torch_obs, greedy, {k: v for k, v in torch_obs.items() if k.startswith("mask")}
)
if player.actor.is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
real_actions = torch.stack(real_actions, -1).cpu().numpy()
else:
real_actions = torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
real_actions = torch.stack([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()

# Single environment step
obs, reward, done, truncated, _ = env.step(real_actions.reshape(env.action_space.shape))
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
real_actions = actions = player.get_exploration_actions(torch_obs, mask=mask, step=policy_step)
actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
real_actions = torch.stack(real_actions, -1).cpu().numpy()
else:
real_actions = (
torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
real_actions = actions = player.get_exploration_actions(torch_obs, mask=mask, step=policy_step)
actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
real_actions = torch.stack(real_actions, -1).cpu().numpy()
else:
real_actions = (
torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
torch.stack([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
)
next_obs, rewards, terminated, truncated, infos = envs.step(
real_actions.reshape(envs.action_space.shape)
Expand Down
10 changes: 7 additions & 3 deletions sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,11 @@ def train(

for i in range(0, sequence_length):
recurrent_state, posterior, prior, posterior_logits, prior_logits = world_model.rssm.dynamic(
posterior, recurrent_state, data["actions"][i : i + 1], embedded_obs[i : i + 1], data["is_first"][i : i + 1]
posterior,
recurrent_state,
data["actions"][i : i + 1],
embedded_obs[i : i + 1],
data["is_first"][i : i + 1],
)
recurrent_states[i] = recurrent_state
priors[i] = prior
Expand Down Expand Up @@ -742,10 +746,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
real_actions = actions = player.get_actions(torch_obs, mask=mask)
actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
real_actions = torch.stack(real_actions, -1).cpu().numpy()
else:
real_actions = (
torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
torch.stack([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
)

step_data["is_first"] = copy.deepcopy(np.logical_or(step_data["terminated"], step_data["truncated"]))
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
real_actions = actions = player.get_actions(torch_obs, mask=mask)
actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
real_actions = torch.stack(real_actions, -1).cpu().numpy()
else:
real_actions = (
torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
torch.stack([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
)

step_data["is_first"] = copy.deepcopy(np.logical_or(step_data["terminated"], step_data["truncated"]))
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,10 +814,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
real_actions = actions = player.get_actions(torch_obs, mask=mask)
actions = torch.cat(actions, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, dim=-1).cpu().numpy()
real_actions = torch.stack(real_actions, dim=-1).cpu().numpy()
else:
real_actions = (
torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
torch.stack([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
)

step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1))
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
real_actions = actions = player.get_actions(torch_obs, mask=mask)
actions = torch.cat(actions, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, dim=-1).cpu().numpy()
real_actions = torch.stack(real_actions, dim=-1).cpu().numpy()
else:
real_actions = (
torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
torch.stack([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
)

step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1))
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
)
actions, logprobs, values = player(torch_obs)
if is_continuous:
real_actions = torch.cat(actions, -1).cpu().numpy()
real_actions = torch.stack(actions, -1).cpu().numpy()
else:
real_actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy()
real_actions = torch.stack([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy()
actions = torch.cat(actions, -1).cpu().numpy()

# Single environment step
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/ppo/ppo_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,9 @@ def player(
torch_obs = prepare_obs(fabric, next_obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs)
actions, logprobs, values = agent(torch_obs)
if is_continuous:
real_actions = torch.cat(actions, -1).cpu().numpy()
real_actions = torch.stack(actions, -1).cpu().numpy()
else:
real_actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy()
real_actions = torch.stack([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy()
actions = torch.cat(actions, -1).cpu().numpy()

# Single environment step
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
torch_obs, prev_actions=torch_prev_actions, prev_states=prev_states
)
if is_continuous:
real_actions = torch.cat(actions, -1).cpu().numpy()
real_actions = torch.stack(actions, -1).cpu().numpy()
else:
real_actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy()
real_actions = torch.stack([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy()
torch_actions = torch.cat(actions, dim=-1)
actions = torch_actions.cpu().numpy()

Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/ppo_recurrent/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ def test(agent: "RecurrentPPOPlayer", fabric: Fabric, cfg: Dict[str, Any], log_d
# Act greedly through the environment
actions, state = agent.get_actions(torch_obs, actions, state, greedy=True)
if agent.actor.is_continuous:
real_actions = torch.cat(actions, -1)
real_actions = torch.stack(actions, -1)
actions = torch.cat(actions, dim=-1).view(1, 1, -1)
else:
real_actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1)
real_actions = torch.stack([act.argmax(dim=-1) for act in actions], dim=-1)
actions = torch.cat([act for act in actions], dim=-1).view(1, 1, -1)

# Single environment step
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/sac/sac_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def player(
torch_obs = prepare_obs(fabric, obs, num_envs=cfg.env.num_envs)
actions = actor(torch_obs)
actions = actions.cpu().numpy()
next_obs, rewards, terminated, truncated, infos = envs.step(actions)
next_obs, rewards, terminated, truncated, infos = envs.step(actions.reshape(envs.action_space.shape))
rewards = rewards.reshape(cfg.env.num_envs, -1)

if cfg.metric.log_level > 0 and "final_info" in infos:
Expand Down
57 changes: 57 additions & 0 deletions sheeprl/configs/exp/dreamer_v3_minedojo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# @package _global_

defaults:
- dreamer_v3
- override /algo: dreamer_v3_XS
- override /env: minedojo
- _self_

# Experiment
seed: 5
total_steps: 50000000

# Environment
env:
num_envs: 2
id: harvest_milk
reward_as_observation: True

# Checkpoint
checkpoint:
every: 100000

# Buffer
buffer:
checkpoint: True

# Algorithm
algo:
replay_ratio: 0.015625
learning_starts: 65536
actor:
cls: sheeprl.algos.dreamer_v3.agent.MinedojoActor
cnn_keys:
encoder:
- rgb
mlp_keys:
encoder:
- equipment
- inventory
- inventory_delta
- inventory_max
- life_stats
- mask_action_type
- mask_craft_smelt
- mask_destroy
- mask_equip_place
- reward
decoder:
- equipment
- inventory
- inventory_delta
- inventory_max
- life_stats
- mask_action_type
- mask_craft_smelt
- mask_destroy
- mask_equip_place
Loading

0 comments on commit 419c7ce

Please sign in to comment.