Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/minedojo #286

Merged
merged 7 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions howto/register_external_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -605,9 +605,9 @@ def ext_sota_main(fabric: Fabric, cfg: Dict[str, Any]):
}
actions = agent.module(torch_obs)
if is_continuous:
real_actions = torch.cat(actions, -1).cpu().numpy()
real_actions = torch.stack(actions, -1).cpu().numpy()
else:
real_actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy()
real_actions = torch.stack([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy()
actions = torch.cat(actions, -1).cpu().numpy()

# Single environment step
Expand Down
4 changes: 2 additions & 2 deletions howto/register_new_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -603,9 +603,9 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]):
}
actions = player.get_actions(torch_obs)
if is_continuous:
real_actions = torch.cat(actions, -1).cpu().numpy()
real_actions = torch.stack(actions, -1).cpu().numpy()
else:
real_actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy()
real_actions = torch.stack([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy()
actions = torch.cat(actions, -1).cpu().numpy()

# Single environment step
Expand Down
4 changes: 2 additions & 2 deletions notebooks/dreamer_v3_imagination.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,9 @@
" real_actions = actions = player.get_actions(preprocessed_obs, mask=mask)\n",
" actions = torch.cat(actions, -1).cpu().numpy()\n",
" if is_continuous:\n",
" real_actions = torch.cat(real_actions, dim=-1).cpu().numpy()\n",
" real_actions = torch.stack(real_actions, dim=-1).cpu().numpy()\n",
" else:\n",
" real_actions = torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()\n",
" real_actions = torch.stack([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()\n",
"\n",
" step_data[\"actions\"] = actions.reshape((1, cfg.env.num_envs, -1))\n",
" rb_initial.add(step_data, validate_args=cfg.buffer.validate_args)\n",
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
torch_obs = prepare_obs(fabric, next_obs, num_envs=cfg.env.num_envs)
actions, _, values = player(torch_obs)
if is_continuous:
real_actions = torch.cat(actions, -1).cpu().numpy()
real_actions = torch.stack(actions, -1).cpu().numpy()
else:
real_actions = torch.cat([act.argmax(dim=-1) for act in actions], axis=-1).cpu().numpy()
real_actions = torch.stack([act.argmax(dim=-1) for act in actions], axis=-1).cpu().numpy()
actions = torch.cat(actions, -1).cpu().numpy()

# Single environment step
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
real_actions = actions = player.get_exploration_actions(torch_obs, mask=mask, step=policy_step)
actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
real_actions = torch.stack(real_actions, -1).cpu().numpy()
else:
real_actions = (
torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
Expand Down
15 changes: 11 additions & 4 deletions sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,11 @@ def train(
# One step of dynamic learning, which take the posterior state, the recurrent state, the action
# and the observation and compute the next recurrent, prior and posterior states
recurrent_state, posterior, _, posterior_logits, prior_logits = world_model.rssm.dynamic(
posterior, recurrent_state, data["actions"][i : i + 1], embedded_obs[i : i + 1], data["is_first"][i : i + 1]
posterior,
recurrent_state,
data["actions"][i : i + 1],
embedded_obs[i : i + 1],
data["is_first"][i : i + 1],
)
recurrent_states[i] = recurrent_state
priors_logits[i] = prior_logits
Expand Down Expand Up @@ -344,7 +348,10 @@ def train(
critic_grads = None
if cfg.algo.critic.clip_gradients is not None and cfg.algo.critic.clip_gradients > 0:
critic_grads = fabric.clip_gradients(
module=critic, optimizer=critic_optimizer, max_norm=cfg.algo.critic.clip_gradients, error_if_nonfinite=False
module=critic,
optimizer=critic_optimizer,
max_norm=cfg.algo.critic.clip_gradients,
error_if_nonfinite=False,
)
critic_optimizer.step()

Expand Down Expand Up @@ -606,10 +613,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
real_actions = actions = player.get_actions(torch_obs, mask=mask)
actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
real_actions = torch.stack(real_actions, -1).cpu().numpy()
else:
real_actions = (
torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
torch.stack([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
)

step_data["is_first"] = copy.deepcopy(np.logical_or(step_data["terminated"], step_data["truncated"]))
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/dreamer_v2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ def test(
torch_obs, greedy, {k: v for k, v in torch_obs.items() if k.startswith("mask")}
)
if player.actor.is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
real_actions = torch.stack(real_actions, -1).cpu().numpy()
else:
real_actions = torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
real_actions = torch.stack([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()

# Single environment step
obs, reward, done, truncated, _ = env.step(real_actions.reshape(env.action_space.shape))
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v3/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,7 @@ def forward(
if sampled_action == 15: # Craft action
logits[t, b][torch.logical_not(mask["mask_craft_smelt"][t, b])] = -torch.inf
elif i == 2:
mask["mask_destroy"][t, b] = mask["mask_destroy"].expand_as(logits)
mask["mask_destroy"] = mask["mask_destroy"].expand_as(logits)
mask["mask_equip_place"] = mask["mask_equip_place"].expand_as(logits)
for t in range(functional_action.shape[0]):
for b in range(functional_action.shape[1]):
Expand Down
9 changes: 6 additions & 3 deletions sheeprl/algos/dreamer_v3/dreamer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,10 @@ def train(
critic_grads = None
if cfg.algo.critic.clip_gradients is not None and cfg.algo.critic.clip_gradients > 0:
critic_grads = fabric.clip_gradients(
module=critic, optimizer=critic_optimizer, max_norm=cfg.algo.critic.clip_gradients, error_if_nonfinite=False
module=critic,
optimizer=critic_optimizer,
max_norm=cfg.algo.critic.clip_gradients,
error_if_nonfinite=False,
)
critic_optimizer.step()

Expand Down Expand Up @@ -573,10 +576,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
real_actions = actions = player.get_actions(torch_obs, mask=mask)
actions = torch.cat(actions, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, dim=-1).cpu().numpy()
real_actions = torch.stack(real_actions, dim=-1).cpu().numpy()
else:
real_actions = (
torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
torch.stack([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
)

step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1))
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/dreamer_v3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ def test(
torch_obs, greedy, {k: v for k, v in torch_obs.items() if k.startswith("mask")}
)
if player.actor.is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
real_actions = torch.stack(real_actions, -1).cpu().numpy()
else:
real_actions = torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
real_actions = torch.stack([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()

# Single environment step
obs, reward, done, truncated, _ = env.step(real_actions.reshape(env.action_space.shape))
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
real_actions = actions = player.get_exploration_actions(torch_obs, mask=mask, step=policy_step)
actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
real_actions = torch.stack(real_actions, -1).cpu().numpy()
else:
real_actions = (
torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
real_actions = actions = player.get_exploration_actions(torch_obs, mask=mask, step=policy_step)
actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
real_actions = torch.stack(real_actions, -1).cpu().numpy()
else:
real_actions = (
torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
torch.stack([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
)
next_obs, rewards, terminated, truncated, infos = envs.step(
real_actions.reshape(envs.action_space.shape)
Expand Down
10 changes: 7 additions & 3 deletions sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,11 @@ def train(

for i in range(0, sequence_length):
recurrent_state, posterior, prior, posterior_logits, prior_logits = world_model.rssm.dynamic(
posterior, recurrent_state, data["actions"][i : i + 1], embedded_obs[i : i + 1], data["is_first"][i : i + 1]
posterior,
recurrent_state,
data["actions"][i : i + 1],
embedded_obs[i : i + 1],
data["is_first"][i : i + 1],
)
recurrent_states[i] = recurrent_state
priors[i] = prior
Expand Down Expand Up @@ -742,10 +746,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
real_actions = actions = player.get_actions(torch_obs, mask=mask)
actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
real_actions = torch.stack(real_actions, -1).cpu().numpy()
else:
real_actions = (
torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
torch.stack([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
)

step_data["is_first"] = copy.deepcopy(np.logical_or(step_data["terminated"], step_data["truncated"]))
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
real_actions = actions = player.get_actions(torch_obs, mask=mask)
actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
real_actions = torch.stack(real_actions, -1).cpu().numpy()
else:
real_actions = (
torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
torch.stack([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
)

step_data["is_first"] = copy.deepcopy(np.logical_or(step_data["terminated"], step_data["truncated"]))
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,10 +814,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
real_actions = actions = player.get_actions(torch_obs, mask=mask)
actions = torch.cat(actions, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, dim=-1).cpu().numpy()
real_actions = torch.stack(real_actions, dim=-1).cpu().numpy()
else:
real_actions = (
torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
torch.stack([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
)

step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1))
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
real_actions = actions = player.get_actions(torch_obs, mask=mask)
actions = torch.cat(actions, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, dim=-1).cpu().numpy()
real_actions = torch.stack(real_actions, dim=-1).cpu().numpy()
else:
real_actions = (
torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
torch.stack([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
)

step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1))
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
)
actions, logprobs, values = player(torch_obs)
if is_continuous:
real_actions = torch.cat(actions, -1).cpu().numpy()
real_actions = torch.stack(actions, -1).cpu().numpy()
else:
real_actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy()
real_actions = torch.stack([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy()
actions = torch.cat(actions, -1).cpu().numpy()

# Single environment step
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/ppo/ppo_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,9 @@ def player(
torch_obs = prepare_obs(fabric, next_obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs)
actions, logprobs, values = agent(torch_obs)
if is_continuous:
real_actions = torch.cat(actions, -1).cpu().numpy()
real_actions = torch.stack(actions, -1).cpu().numpy()
else:
real_actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy()
real_actions = torch.stack([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy()
actions = torch.cat(actions, -1).cpu().numpy()

# Single environment step
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
torch_obs, prev_actions=torch_prev_actions, prev_states=prev_states
)
if is_continuous:
real_actions = torch.cat(actions, -1).cpu().numpy()
real_actions = torch.stack(actions, -1).cpu().numpy()
else:
real_actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy()
real_actions = torch.stack([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy()
torch_actions = torch.cat(actions, dim=-1)
actions = torch_actions.cpu().numpy()

Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/ppo_recurrent/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ def test(agent: "RecurrentPPOPlayer", fabric: Fabric, cfg: Dict[str, Any], log_d
# Act greedly through the environment
actions, state = agent.get_actions(torch_obs, actions, state, greedy=True)
if agent.actor.is_continuous:
real_actions = torch.cat(actions, -1)
real_actions = torch.stack(actions, -1)
actions = torch.cat(actions, dim=-1).view(1, 1, -1)
else:
real_actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1)
real_actions = torch.stack([act.argmax(dim=-1) for act in actions], dim=-1)
actions = torch.cat([act for act in actions], dim=-1).view(1, 1, -1)

# Single environment step
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/sac/sac_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def player(
torch_obs = prepare_obs(fabric, obs, num_envs=cfg.env.num_envs)
actions = actor(torch_obs)
actions = actions.cpu().numpy()
next_obs, rewards, terminated, truncated, infos = envs.step(actions)
next_obs, rewards, terminated, truncated, infos = envs.step(actions.reshape(envs.action_space.shape))
rewards = rewards.reshape(cfg.env.num_envs, -1)

if cfg.metric.log_level > 0 and "final_info" in infos:
Expand Down
57 changes: 57 additions & 0 deletions sheeprl/configs/exp/dreamer_v3_minedojo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# @package _global_

defaults:
- dreamer_v3
- override /algo: dreamer_v3_XS
- override /env: minedojo
- _self_

# Experiment
seed: 5
total_steps: 50000000

# Environment
env:
num_envs: 2
id: harvest_milk
reward_as_observation: True

# Checkpoint
checkpoint:
every: 100000

# Buffer
buffer:
checkpoint: True

# Algorithm
algo:
replay_ratio: 0.015625
learning_starts: 65536
actor:
cls: sheeprl.algos.dreamer_v3.agent.MinedojoActor
cnn_keys:
encoder:
- rgb
mlp_keys:
encoder:
- equipment
- inventory
- inventory_delta
- inventory_max
- life_stats
- mask_action_type
- mask_craft_smelt
- mask_destroy
- mask_equip_place
- reward
decoder:
- equipment
- inventory
- inventory_delta
- inventory_max
- life_stats
- mask_action_type
- mask_craft_smelt
- mask_destroy
- mask_equip_place
Loading
Loading