Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for memoroids (linear recurrent models) #91

Open
wants to merge 41 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
7fe570c
Convert ffm memoroid to flax, it seems to work
smorad Jun 16, 2024
229e7f1
temporary work
EdanToledo Jun 21, 2024
c2f3e9a
temporary work - hacky solution to get running
EdanToledo Jun 21, 2024
a8b9621
Merge branch 'main' into pr/smorad/91
EdanToledo Jun 22, 2024
5344323
chore: cleanup and edit stacked_rnn arch
EdanToledo Jun 22, 2024
73b5a83
chore: move around files
EdanToledo Jun 22, 2024
77e4a21
large ffm refactor, runs but possibly introduced bugs
smorad Jun 22, 2024
46f2d26
comments, further cleanup, and only return final state
smorad Jun 22, 2024
369b1cf
factor out recurrent associative scan
smorad Jun 22, 2024
72e4b7a
Add simplified FFM and stacked simplified FFM
smorad Jun 22, 2024
2faac71
chore: separate files - add baseclass etc
EdanToledo Jun 22, 2024
dc25d6c
chore: edit comments
EdanToledo Jun 22, 2024
da03f00
modify ffm edan to work for batched/vmapped memoroids
smorad Jun 22, 2024
9eacd4f
feat: make edits to ffm to make batch dim after sequence dim and add …
EdanToledo Jun 23, 2024
a2e0c75
Merge branch 'main' into pr/smorad/91
EdanToledo Jun 23, 2024
15473a4
chore: change scanned memorid to expect non sequence dimension carry …
EdanToledo Jun 24, 2024
74e35b3
feat: add explicit batch dimension and network config - rec_ppo now w…
EdanToledo Jun 24, 2024
2446b9c
chore: remove reliance on start variable being inside recurrent state
EdanToledo Jun 24, 2024
aa76325
fix reset zero error
smorad Jun 24, 2024
c5816de
better tests
smorad Jun 24, 2024
8fc2640
feat: add more popjym configs
EdanToledo Jun 24, 2024
ca03d7f
Merge branch 'memoroid' of https://github.com/smorad/stoix into pr/sm…
EdanToledo Jun 24, 2024
d8c845b
fix dummy state
smorad Jun 25, 2024
7feaa3f
better reset tests
smorad Jun 25, 2024
564dc81
add simple training test
smorad Jun 25, 2024
4752cb8
add simple training test
smorad Jun 25, 2024
78e5fad
add simple training test
smorad Jun 25, 2024
0b210ce
oops wrong y, pls pull
smorad Jun 25, 2024
acf9a29
feat: add demos
EdanToledo Jun 25, 2024
6d93baa
chore: edit working demo
EdanToledo Jun 25, 2024
f79febc
fix: add required popjym wrapper
EdanToledo Jun 26, 2024
e90e270
chore: move all current work into memoroids file to be contained
EdanToledo Jun 30, 2024
5f41a03
small hparam tweaks for memoroid
smorad Jul 1, 2024
9916c79
chore: reorganize code and add start to lru
EdanToledo Jul 3, 2024
b21d5e2
chore: slight config change
EdanToledo Jul 3, 2024
0cfb832
chore: more editing and add s5
EdanToledo Jul 4, 2024
232b7ca
chore: config and network edits
EdanToledo Jul 4, 2024
6dd59c5
feat: add stacked model
EdanToledo Jul 7, 2024
ab3de3f
chore: refactor slightly
EdanToledo Jul 7, 2024
e278bb5
chore: clean up code
EdanToledo Jul 7, 2024
26d1fc4
chore: change configs
EdanToledo Jul 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions stoix/configs/arch/anakin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ total_timesteps: 1e7 # Set the total environment steps.
num_updates: ~ # Number of updates

# --- Evaluation ---
evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select
evaluation_greedy: True # Evaluate the policy greedily. If True the policy will select
# an action which corresponds to the greatest logit. If false, the policy will sample
# from the logits.
num_eval_episodes: 128 # Number of episodes to evaluate per evaluation.
num_evaluation: 50 # Number of evenly spaced evaluations to perform during training.
num_evaluation: 19 # Number of evenly spaced evaluations to perform during training.
absolute_metric: True # Whether the absolute metric should be computed. For more details
# on the absolute metric please see: https://arxiv.org/abs/2209.10485
4 changes: 2 additions & 2 deletions stoix/configs/default_rec_ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ defaults:
- logger: base_logger
- arch: anakin
- system: rec_ppo
- network: rnn
- env: gymnax/cartpole
- network: memoroid
- env: popjym/repeat_first_easy
- _self_
12 changes: 12 additions & 0 deletions stoix/configs/env/popjym/auto_encode_easy.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# ---Environment Configs---
env_name: popjym

scenario:
name: AutoencodeEasy
task_name: auto_encode_easy

kwargs: {}

# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return
12 changes: 12 additions & 0 deletions stoix/configs/env/popjym/auto_encode_medium.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# ---Environment Configs---
env_name: popjym

scenario:
name: AutoencodeMedium
task_name: auto_encode_medium

kwargs: {}

# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return
12 changes: 12 additions & 0 deletions stoix/configs/env/popjym/count_recall_easy.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# ---Environment Configs---
env_name: popjym

scenario:
name: CountRecallEasy
task_name: count_recall_easy

kwargs: {}

# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return
12 changes: 12 additions & 0 deletions stoix/configs/env/popjym/count_recall_medium.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# ---Environment Configs---
env_name: popjym

scenario:
name: CountRecallMedium
task_name: count_recall_medium

kwargs: {}

# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return
12 changes: 12 additions & 0 deletions stoix/configs/env/popjym/repeat_first_easy.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# ---Environment Configs---
env_name: popjym

scenario:
name: RepeatFirstEasy
task_name: repeat_first_easy

kwargs: {}

# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return
12 changes: 12 additions & 0 deletions stoix/configs/env/popjym/repeat_first_hard.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# ---Environment Configs---
env_name: popjym

scenario:
name: RepeatFirstHard
task_name: repeat_first_hard

kwargs: {}

# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return
12 changes: 12 additions & 0 deletions stoix/configs/env/popjym/repeat_first_medium.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# ---Environment Configs---
env_name: popjym

scenario:
name: RepeatFirstMedium
task_name: repeat_first_medium

kwargs: {}

# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return
4 changes: 4 additions & 0 deletions stoix/configs/env/popjym/stateless_cartpole_easy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,7 @@ scenario:
task_name: stateless_cartpole_easy

kwargs: {}

# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return
4 changes: 2 additions & 2 deletions stoix/configs/logger/base_logger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ base_exp_path: results # Base path for logging.
use_console: True # Whether to log to stdout.
use_tb: False # Whether to use tensorboard logging.
use_json: False # Whether to log marl-eval style to json files.
use_neptune: False # Whether to log to neptune.ai.
use_neptune: True # Whether to log to neptune.ai.
use_wandb: False # Whether to log to wandb.ai.

# --- Other logger kwargs ---
kwargs:
project: ~ # Project name in neptune.ai or wandb.ai.
project: e.toledo/Stoix # Project name in neptune.ai or wandb.ai.
tags: [stoix] # Tags to add to the experiment.
detailed_logging: False # having mean/std/min/max can clutter neptune/wandb so we make it optional
json_path: ~ # If set, json files will be logged to a set path so that multiple experiments can
Expand Down
3 changes: 1 addition & 2 deletions stoix/configs/network/muzero.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ wm_network:
activation: silu

# This can be seen as the dyanmics network.
rnn_size: 256
num_stacked_rnn_layers: 2
rnn_sizes: [256, 256]
rnn_cell_type: "gru"
recurrent_activation: "sigmoid"

Expand Down
28 changes: 14 additions & 14 deletions stoix/configs/network/rnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,35 @@
actor_network:
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [128]
use_layer_norm: False
activation: silu
layer_sizes: [256]
use_layer_norm: True
activation: leaky_relu
rnn_layer:
_target_: stoix.networks.base.ScannedRNN
_target_: stoix.networks.recurrent.ScannedRNN
cell_type: gru
hidden_state_dim: 128
hidden_state_dim: 256
post_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [128]
layer_sizes: [256]
use_layer_norm: False
activation: silu
activation: leaky_relu
action_head:
_target_: stoix.networks.heads.CategoricalHead

critic_network:
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [128]
use_layer_norm: False
activation: silu
layer_sizes: [256]
use_layer_norm: True
activation: leaky_relu
rnn_layer:
_target_: stoix.networks.base.ScannedRNN
_target_: stoix.networks.recurrent.ScannedRNN
cell_type: gru
hidden_state_dim: 128
hidden_state_dim: 256
post_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [128]
layer_sizes: [256]
use_layer_norm: False
activation: silu
activation: leaky_relu
critic_head:
_target_: stoix.networks.heads.ScalarCriticHead
41 changes: 41 additions & 0 deletions stoix/configs/network/stacked_lrm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# ---Recurrent Structure Networks for PPO ---

actor_network:
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256]
use_layer_norm: False
activation: leaky_relu
rnn_layer:
_target_: stoix.networks.lrm.shared.StackedLRM
num_cells: 2
lrm_cell_type: lru
cell_kwargs:
hidden_state_dim: 256
post_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256]
use_layer_norm: False
activation: leaky_relu
action_head:
_target_: stoix.networks.heads.CategoricalHead

critic_network:
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256]
use_layer_norm: False
activation: leaky_relu
rnn_layer:
_target_: stoix.networks.lrm.shared.StackedLRM
num_cells: 2
lrm_cell_type: lru
cell_kwargs:
hidden_state_dim: 256
post_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256]
use_layer_norm: False
activation: leaky_relu
critic_head:
_target_: stoix.networks.heads.ScalarCriticHead
12 changes: 6 additions & 6 deletions stoix/configs/system/rec_ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
system_name: rec_ppo # Name of the system.

# --- RL hyperparameters ---
actor_lr: 2.5e-4 # Learning rate for actor network
critic_lr: 2.5e-4 # Learning rate for critic network
rollout_length: 128 # Number of environment steps per vectorised environment.
epochs: 4 # Number of ppo epochs per training data batch.
num_minibatches: 2 # Number of minibatches per ppo epoch.
actor_lr: 3e-5 # Learning rate for actor network
critic_lr: 3e-5 # Learning rate for critic network
rollout_length: 64 # Number of environment steps per vectorised environment.
epochs: 10 # Number of ppo epochs per training data batch.
num_minibatches: 64 # Number of minibatches per ppo epoch.
gamma: 0.99 # Discounting factor.
gae_lambda: 0.95 # Lambda value for GAE computation.
clip_eps: 0.2 # Clipping value for PPO updates and value function.
ent_coef: 0.01 # Entropy regularisation term for loss function.
vf_coef: 0.5 # Critic weight in
max_grad_norm: 0.5 # Maximum norm of the gradients for a weight update.
decay_learning_rates: False # Whether learning rates should be linearly decayed during training.
decay_learning_rates: True # Whether learning rates should be linearly decayed during training.
standardize_advantages: True # Whether to standardize the advantages.

# --- Recurrent hyperparameters ---
Expand Down
56 changes: 4 additions & 52 deletions stoix/networks/base.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
import functools
from typing import Sequence, Tuple, Union

import chex
import distrax
import jax
import jax.numpy as jnp
import numpy as np
from flax import linen as nn

from stoix.base_types import Observation, RNNObservation
from stoix.networks.inputs import ObservationInput
from stoix.networks.utils import parse_rnn_cell


class FeedForwardActor(nn.Module):
Expand Down Expand Up @@ -83,51 +79,12 @@ def __call__(
return concatenated


class ScannedRNN(nn.Module):
hidden_state_dim: int
cell_type: str

@functools.partial(
nn.scan,
variable_broadcast="params",
in_axes=0,
out_axes=0,
split_rngs={"params": False},
)
@nn.compact
def __call__(self, rnn_state: chex.Array, x: chex.Array) -> Tuple[chex.Array, chex.Array]:
"""Applies the module."""
ins, resets = x
hidden_state_reset_fn = lambda reset_state, current_state: jnp.where(
resets[:, np.newaxis],
reset_state,
current_state,
)
rnn_state = jax.tree_util.tree_map(
hidden_state_reset_fn,
self.initialize_carry(ins.shape[0]),
rnn_state,
)
new_rnn_state, y = parse_rnn_cell(self.cell_type)(features=self.hidden_state_dim)(
rnn_state, ins
)
return new_rnn_state, y

@nn.nowrap
def initialize_carry(self, batch_size: int) -> chex.Array:
"""Initializes the carry state."""
# Use a dummy key since the default state init fn is just zeros.
cell = parse_rnn_cell(self.cell_type)(features=self.hidden_state_dim)
return cell.initialize_carry(jax.random.PRNGKey(0), (batch_size, self.hidden_state_dim))


class RecurrentActor(nn.Module):
"""Recurrent Actor Network."""

action_head: nn.Module
post_torso: nn.Module
hidden_state_dim: int
cell_type: str
rnn: nn.Module
pre_torso: nn.Module
input_layer: nn.Module = ObservationInput()

Expand All @@ -143,9 +100,7 @@ def __call__(
observation = self.input_layer(observation)
policy_embedding = self.pre_torso(observation)
policy_rnn_input = (policy_embedding, done)
policy_hidden_state, policy_embedding = ScannedRNN(self.hidden_state_dim, self.cell_type)(
policy_hidden_state, policy_rnn_input
)
policy_hidden_state, policy_embedding = self.rnn(policy_hidden_state, policy_rnn_input)
actor_logits = self.post_torso(policy_embedding)
pi = self.action_head(actor_logits)

Expand All @@ -157,8 +112,7 @@ class RecurrentCritic(nn.Module):

critic_head: nn.Module
post_torso: nn.Module
hidden_state_dim: int
cell_type: str
rnn: nn.Module
pre_torso: nn.Module
input_layer: nn.Module = ObservationInput()

Expand All @@ -175,9 +129,7 @@ def __call__(

critic_embedding = self.pre_torso(observation)
critic_rnn_input = (critic_embedding, done)
critic_hidden_state, critic_embedding = ScannedRNN(self.hidden_state_dim, self.cell_type)(
critic_hidden_state, critic_rnn_input
)
critic_hidden_state, critic_embedding = self.rnn(critic_hidden_state, critic_rnn_input)
critic_output = self.post_torso(critic_embedding)
critic_output = self.critic_head(critic_output)

Expand Down
12 changes: 5 additions & 7 deletions stoix/networks/layers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Tuple
from typing import List, Optional, Sequence, Tuple

import chex
import jax
Expand All @@ -24,19 +24,17 @@ class StackedRNN(nn.Module):
activation_fn (str): The activation function to use in each RNN cell (default is "tanh").
"""

rnn_size: int
rnn_sizes: Sequence[int]
rnn_cls: nn.Module
num_layers: int
activation_fn: str = "sigmoid"

def setup(self) -> None:
"""Set up the RNN cells for the stacked RNN."""
self.cells = [
self.rnn_cls(
features=self.rnn_size, activation_fn=parse_activation_fn(self.activation_fn)
)
for _ in range(self.num_layers)
self.rnn_cls(features=size, activation_fn=parse_activation_fn(self.activation_fn))
for size in self.rnn_sizes
]
self.num_layers = len(self.cells)

def __call__(
self, all_rnn_states: List[chex.ArrayTree], x: chex.Array
Expand Down
Loading
Loading