Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ACT runs, check reprod next #135

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
repos:
- repo: https://github.com/ambv/black
rev: 22.6.0
rev: 22.10.0
hooks:
- id: black
language_version: python3.9
- repo: https://gitlab.com/pycqa/flake8
rev: '3.9.2'
- repo: https://github.com/pycqa/flake8
rev: '5.0.4'
hooks:
- id: flake8
additional_dependencies: [flake8-bugbear]
Expand Down
128 changes: 128 additions & 0 deletions pax/agents/act/act_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from typing import Any, Dict, NamedTuple, Tuple

import haiku as hk
import jax
import jax.numpy as jnp
import optax

from pax.agents.agent import AgentInterface
from pax import utils
from pax.utils import Logger, MemoryState, TrainingState, get_advantages
from pax.agents.act.networks import make_act_network

class ActAgent(AgentInterface):
def __init__(
self,
network: NamedTuple,
optimizer: optax.GradientTransformation,
random_key: jnp.ndarray,
obs_spec: Tuple,
num_envs: int = 4,
entropy_coeff_start: float = 0.1,
player_id: int = 0,
):
@jax.jit
def policy(
state: TrainingState, observation: jnp.ndarray, mem: MemoryState
):
"""Agent policy to select actions and calculate agent specific information"""
values = network.apply(state.params, observation)
mem.extras["values"] = values
mem = mem._replace(extras=mem.extras)
return values, state, mem

def make_initial_state(key: Any, hidden: jnp.ndarray) -> TrainingState:
"""Initialises the training state (parameters and optimiser state)."""
key, subkey = jax.random.split(key)
dummy_obs = jnp.zeros(shape=obs_spec)

dummy_obs = utils.add_batch_dim(dummy_obs)
initial_params = network.init(subkey, dummy_obs)
initial_opt_state = optimizer.init(initial_params)
return TrainingState(
random_key=key,
params=initial_params,
opt_state=initial_opt_state,
timesteps=0,
), MemoryState(
hidden=jnp.zeros((num_envs, 1)),
extras={
"values": jnp.zeros((num_envs, 2)),
"log_probs": jnp.zeros(num_envs),
},
)
self.make_initial_state = make_initial_state
self._state, self._mem = make_initial_state(random_key, jnp.zeros(1))

# Set up counters and logger
self._logger = Logger()
self._total_steps = 0
self._until_sgd = 0
self._logger.metrics = {
"total_steps": 0,
"sgd_steps": 0,
"loss_total": 0,
"loss_policy": 0,
"loss_value": 0,
"loss_entropy": 0,
"entropy_cost": entropy_coeff_start,
}

# Initialize functions
self._policy = policy
self.player_id = player_id

# Other useful hyperparameters
self._num_envs = num_envs # number of environments

def reset_memory(self, memory, eval=False) -> MemoryState:
num_envs = 1 if eval else self._num_envs
memory = memory._replace(
extras={
"values": jnp.zeros((num_envs, 2)),
"log_probs": jnp.zeros(num_envs),
},
)
return memory

def make_act_agent(
args,
obs_spec,
action_spec,
seed: int,
player_id: int,
tabular=False,
):
"""Make PPO agent"""
if args.runner == "act_evo":
network = make_act_network(action_spec)
else:
raise NotImplementedError

# Optimizer
batch_size = int(args.num_envs * args.num_steps)
transition_steps = (
args.total_timesteps
/ batch_size
* args.ppo.num_epochs
* args.ppo.num_minibatches
)

optimizer = optax.chain(
optax.clip_by_global_norm(args.ppo.max_gradient_norm),
optax.scale_by_adam(eps=args.ppo.adam_epsilon),
optax.scale(-args.ppo.learning_rate),
)

random_key = jax.random.PRNGKey(seed=seed)

agent = ActAgent(
network=network,
optimizer=optimizer,
random_key=random_key,
obs_spec=obs_spec,
num_envs=args.num_envs,
entropy_coeff_start=args.ppo.entropy_coeff_start,
player_id=player_id,
)
return agent
54 changes: 54 additions & 0 deletions pax/agents/act/networks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Optional, Tuple

import haiku as hk
import jax
import jax.numpy as jnp

from pax import utils

class DeterministicFunction(hk.Module):
"""Network head that produces a categorical distribution and value."""

def __init__(
self,
num_values: int,
name: Optional[str] = None):
super().__init__(name=name)
self._act_body = hk.nets.MLP(
[64, 64],
w_init=hk.initializers.Orthogonal(jnp.sqrt(2)),
b_init=hk.initializers.Constant(0),
activate_final=True,
activation=jax.nn.relu,
)
self._act_output = hk.nets.MLP(
[num_values],
w_init=hk.initializers.Orthogonal(jnp.sqrt(2)),
b_init=hk.initializers.Constant(0),
activate_final=True,
activation=jnp.tanh,
)


def __call__(self, inputs: jnp.ndarray):
output = self._act_body(inputs)
output = self._act_output(output)

return output


def make_act_network(num_actions: int):
"""Creates a hk network using the baseline hyperparameters from OpenAI"""

def forward_fn(inputs):
layers = []
layers.extend(
[
DeterministicFunction(num_values=num_actions)
]
)
act_network = hk.Sequential(layers)
return act_network(inputs)

network = hk.without_apply_rng(hk.transform(forward_fn))
return network
2 changes: 1 addition & 1 deletion pax/agents/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def make_agent(
tabular=False,
):
"""Make PPO agent"""
if args.runner == "sarl":
if args.runner in ["sarl", "sarl_eval"]:
network = make_sarl_network(action_spec)
elif args.env_id == "coin_game":
print(f"Making network for {args.env_id}")
Expand Down
35 changes: 35 additions & 0 deletions pax/agents/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,41 @@ def reset_memory(self, mem, *args) -> MemoryState:
def make_initial_state(self, _unused, *args) -> TrainingState:
return self._state, self._mem

class RandomACT(AgentInterface):
def __init__(self, num_actions: int, num_envs: int):
self.make_initial_state = initial_state_fun(num_envs)
self._state, self._mem = self.make_initial_state(None, None)
self.reset_memory = reset_mem_fun(num_envs)
self._logger = Logger()
self._logger.metrics = {}
self._num_actions = num_actions
print('self num actions', self._num_actions)

def _policy(
state: NamedTuple,
obs: jnp.array,
mem: NamedTuple,
) -> jnp.ndarray:
# state is [batch x time_step x num_players]
# return [batch]
batch_size = obs.shape[0]
new_key, _ = jax.random.split(state.random_key)
action = jnp.zeros((batch_size, num_actions))
# action = jax.random.uniform(new_key, (batch_size, num_actions), dtype=jnp.float32, minval=0.0, maxval=1.0)
state = state._replace(random_key=new_key)
return action, state, mem

self._policy = jax.jit(_policy)

def update(self, unused0, unused1, state, mem) -> None:
return state, mem, {}

def reset_memory(self, mem, *args) -> MemoryState:
return self._mem

def make_initial_state(self, _unused, *args) -> TrainingState:
return self._state, self._mem


class Stay(AgentInterface):
def __init__(self, num_actions: int, num_envs: int):
Expand Down
113 changes: 113 additions & 0 deletions pax/agents/synq/networks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from typing import Optional, Tuple

import haiku as hk
import jax
import jax.numpy as jnp

from pax import utils

class CategoricalValueHeadSeparate(hk.Module):
"""Network head that produces a categorical distribution and value."""

def __init__(
self,
num_values: int,
name: Optional[str] = None,
):
super().__init__(name=name)
self._action_body = hk.nets.MLP(
[64, 64],
w_init=hk.initializers.Orthogonal(jnp.sqrt(2)),
b_init=hk.initializers.Constant(0),
activate_final=True,
activation=jnp.tanh,
)
self._logit_layer = hk.Linear(
num_values,
w_init=hk.initializers.Orthogonal(1.0),
b_init=hk.initializers.Constant(0),
)
self._value_body = hk.nets.MLP(
[64, 64],
w_init=hk.initializers.Orthogonal(jnp.sqrt(2)),
b_init=hk.initializers.Constant(0),
activate_final=True,
activation=jnp.tanh,
)
self._value_layer = hk.Linear(
1,
w_init=hk.initializers.Orthogonal(0.01),
b_init=hk.initializers.Constant(0),
)

def __call__(self, inputs: jnp.ndarray):
# action_output, value_output = inputs
logits = self._action_body(inputs)
logits = self._logit_layer(logits)

value = self._value_body(inputs)
value = jnp.squeeze(self._value_layer(value), axis=-1)
return (distrax.Categorical(logits=logits), value)

class DeterministicFunction(hk.Module):
"""Network head that produces a categorical distribution and value."""

def __init__(
self,
num_values: int,
name: Optional[str] = None):
super().__init__(name=name)
self._act_body = hk.nets.MLP(
[64, 64],
w_init=hk.initializers.Orthogonal(jnp.sqrt(2)),
b_init=hk.initializers.Constant(0),
activate_final=True,
activation=jax.nn.relu,
)
self._act_output = hk.nets.MLP(
[num_values],
w_init=hk.initializers.Orthogonal(jnp.sqrt(2)),
b_init=hk.initializers.Constant(0),
activate_final=True,
activation=jnp.tanh,
)


def __call__(self, inputs: jnp.ndarray):
output = self._act_body(inputs)
output = self._act_output(output)

return output


def make_synq_network(num_actions: int):
"""Creates a hk network using the baseline hyperparameters from OpenAI"""

def forward_fn(inputs):
layers = []
layers.extend(
[
DeterministicFunction(num_values=num_actions)
]
)
act_network = hk.Sequential(layers)
return act_network(inputs)

network = hk.without_apply_rng(hk.transform(forward_fn))
return network

def make_policy_network(num_actions: int):
"""Creates a hk network using the baseline hyperparameters from OpenAI"""

def forward_fn(inputs):
layers = []
layers.extend(
[
CategoricalValueHeadSeparate(num_values=num_actions)
]
)
policy_value_network = hk.Sequential(layers)
return policy_value_network(inputs)

network = hk.without_apply_rng(hk.transform(forward_fn))
return network
Loading