Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed all mentions of Hive #11

Merged
merged 1 commit into from
Mar 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ looks like this:
Using the components provided with Emote, we can write this as

```python
env = HiveGymWrapper(AsyncVectorEnv(10 * [HitTheMiddle]))
table = DictObsTable(spaces=env.hive_space, maxlen=1000)
env = DictGymWrapper(AsyncVectorEnv(10 * [HitTheMiddle]))
table = DictObsTable(spaces=env.dict_space, maxlen=1000)
memory_proxy = TableMemoryProxy(table)
dataloader = MemoryLoader(table, 100, 2, "batch_size")

Expand Down
6 changes: 3 additions & 3 deletions emote/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from emote.memory.core_types import Matrix
from emote.memory import Table
from emote.typing import HiveResponse, HiveObservation, AgentId, EpisodeState
from emote.typing import DictResponse, DictObservation, AgentId, EpisodeState
from emote.utils import TimedBlock


Expand Down Expand Up @@ -73,8 +73,8 @@ def is_initial(self, identity):

def add(
self,
observations: Dict[AgentId, HiveObservation],
responses: Dict[AgentId, HiveResponse],
observations: Dict[AgentId, DictObservation],
responses: Dict[AgentId, DictResponse],
):
completed_episodes = {}
for agent_id, observation in observations.items():
Expand Down
10 changes: 5 additions & 5 deletions emote/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@

from typing import Protocol, Dict

from emote.typing import AgentId, HiveObservation, HiveResponse
from emote.typing import AgentId, DictObservation, DictResponse


class AgentProxy(Protocol):
"""The interface between the agent in the game and the network used during training."""

def __call__(
self,
obserations: Dict[AgentId, HiveObservation],
) -> Dict[AgentId, HiveResponse]:
obserations: Dict[AgentId, DictObservation],
) -> Dict[AgentId, DictResponse]:
"""Take observations for the active agents and returns the relevant network output."""
...

Expand All @@ -23,8 +23,8 @@ class MemoryProxy(Protocol):

def add(
self,
observations: Dict[AgentId, HiveObservation],
responses: Dict[AgentId, HiveResponse],
observations: Dict[AgentId, DictObservation],
responses: Dict[AgentId, DictResponse],
):
"""Store episodes in the memory buffer used for training.

Expand Down
10 changes: 5 additions & 5 deletions emote/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import nn
from torch import optim

from emote.typing import AgentId, EpisodeState, HiveObservation, HiveResponse
from emote.typing import AgentId, EpisodeState, DictObservation, DictResponse

from .callbacks import LoggingCallback, LossCallback

Expand Down Expand Up @@ -237,7 +237,7 @@ def __init__(
# TODO(singhblom) Check number of actions
# self.t_entropy = -np.prod(self.env.action_space.shape).item() # Value from rlkit from Harnouja
self.t_entropy = (
n_actions * (1.0 + np.log(2.0 * np.pi * entropy_eps ** 2)) / 2.0
n_actions * (1.0 + np.log(2.0 * np.pi * entropy_eps**2)) / 2.0
)
self.ln_alpha = ln_alpha # This is log(alpha)

Expand Down Expand Up @@ -273,8 +273,8 @@ def __init__(self, policy: nn.Module):
self._end_states = [EpisodeState.TERMINAL, EpisodeState.INTERRUPTED]

def __call__(
self, observations: Dict[AgentId, HiveObservation]
) -> Dict[AgentId, HiveResponse]:
self, observations: Dict[AgentId, DictObservation]
) -> Dict[AgentId, DictResponse]:
"""Runs the policy and returns the actions."""
# The network takes observations of size batch x obs for each observation space.
assert len(observations) > 0, "Observations must not be empty."
Expand All @@ -290,6 +290,6 @@ def __call__(
)
actions = self.policy(tensor_obs)[0].detach().numpy()
return {
agent_id: HiveResponse(list_data={"actions": actions[i]}, scalar_data={})
agent_id: DictResponse(list_data={"actions": actions[i]}, scalar_data={})
for i, agent_id in enumerate(active_agents)
}
18 changes: 9 additions & 9 deletions emote/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,23 @@

# The AgentId is an application-defined integer
AgentId = int
# HiveData is a single ndarray containing correlated data for one agent.
HiveData = ArrayLike
# BatchArray is a concatenated set of arrays from multiple agents.
# The shape of BatchedData will be [Number of Agents, *(shape of HiveData)]
# SingleAgentData is a single ndarray containing correlated data for one agent.
SingleAgentData = ArrayLike
# BatchedData is a concatenated set of arrays from multiple agents.
# The shape of BatchedData will be [Number of Agents, *(shape of SingleAgentData)]
BatchedData = ArrayLike

# Input is a set of named inputs from one agent. We mainly use this for observations.
InputSpace = str
Input = Dict[InputSpace, HiveData]
Input = Dict[InputSpace, SingleAgentData]
# Input gathers inputs from multiple agents
InputGroup = Dict[AgentId, Input]
# InputBatch is the result of merging an InputGroup based on input name.
InputBatch = Dict[InputSpace, BatchedData]

# Output is a set of named outputs for one agent
OutputSpace = str
Output = Dict[OutputSpace, HiveData]
Output = Dict[OutputSpace, SingleAgentData]
# Input gathers inputs from multiple agents
OutputGroup = Dict[AgentId, Output]
# OutputBatch is the result of evaluating the neural network on an input batch, before unmerging.
Expand Down Expand Up @@ -55,14 +55,14 @@ class MetaData:


@dataclass
class HiveObservation:
class DictObservation:
rewards: Dict[str, float]
episode_state: EpisodeState
array_data: Dict[str, HiveData]
array_data: Dict[str, SingleAgentData]
metadata: MetaData = None


@dataclass
class HiveResponse:
class DictResponse:
list_data: Dict[str, FloatList]
scalar_data: Dict[str, float]
4 changes: 2 additions & 2 deletions tests/gym/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .hit_the_middle import HitTheMiddle
from .collector import SimpleGymCollector
from .hive_gym_wrapper import HiveGymWrapper
from .dict_gym_wrapper import DictGymWrapper

__all__ = [
"HitTheMiddle",
"SimpleGymCollector",
"HiveGymWrapper",
"DictGymWrapper",
]
14 changes: 7 additions & 7 deletions tests/gym/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@

from emote.callback import Callback
from emote.proxies import AgentProxy, MemoryProxy
from tests.gym.hive_gym_wrapper import HiveGymWrapper
from tests.gym.dict_gym_wrapper import DictGymWrapper


class GymCollector(Callback):
MAX_NUMBER_REWARDS = 1000

def __init__(
self,
env: HiveGymWrapper,
env: DictGymWrapper,
agent: AgentProxy,
memory: MemoryProxy,
render: bool = True,
Expand All @@ -31,7 +31,7 @@ def __init__(
def collect_data(self):
"""Collect a single rollout"""
actions = self._agent(self._obs)
next_obs = self._env.hive_step(actions)
next_obs = self._env.dict_step(actions)
self._memory.add(self._obs, actions)
self._obs = next_obs

Expand All @@ -45,16 +45,16 @@ def collect_multiple(self, count: int):

def begin_training(self):
"Runs through the init, step cycle once on main thread to make sure all envs work."
self._obs = self._env.hive_reset()
self._obs = self._env.dict_reset()
actions = self._agent(self._obs)
_ = self._env.step(actions)
self._obs = self._env.hive_reset()
self._obs = self._env.dict_reset()


class ThreadedGymCollector(GymCollector):
def __init__(
self,
env: HiveGymWrapper,
env: DictGymWrapper,
agent: AgentProxy,
memory: MemoryProxy,
render=True,
Expand Down Expand Up @@ -99,7 +99,7 @@ def end_training(self):
class SimpleGymCollector(GymCollector):
def __init__(
self,
env: HiveGymWrapper,
env: DictGymWrapper,
agent: AgentProxy,
memory: MemoryProxy,
render=True,
Expand Down
22 changes: 11 additions & 11 deletions tests/gym/hive_gym_wrapper.py → tests/gym/dict_gym_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import gym.spaces

from gym.vector import VectorEnvWrapper, VectorEnv
from emote.typing import EpisodeState, HiveObservation, AgentId, HiveResponse
from emote.typing import EpisodeState, DictObservation, AgentId, DictResponse
from emote.utils.spaces import BoxSpace, DictSpace, MDPSpace


class HiveGymWrapper(VectorEnvWrapper):
class DictGymWrapper(VectorEnvWrapper):
def __init__(self, env: VectorEnv):
super().__init__(env)
self._next_agent = count()
Expand All @@ -17,7 +17,7 @@ def __init__(self, env: VectorEnv):
]
assert isinstance(env.single_observation_space, gym.spaces.Box)
os: gym.spaces.Box = env.single_observation_space
self.hive_space = MDPSpace(
self.dict_space = MDPSpace(
BoxSpace(np.float32, (1,)),
BoxSpace(env.single_action_space.dtype, env.single_action_space.shape),
DictSpace({"obs": BoxSpace(os.dtype, os.shape)}),
Expand All @@ -26,9 +26,9 @@ def __init__(self, env: VectorEnv):
def render(self):
self.env.envs[0].render()

def hive_step(
self, actions: Dict[AgentId, HiveResponse]
) -> Dict[AgentId, HiveObservation]:
def dict_step(
self, actions: Dict[AgentId, DictResponse]
) -> Dict[AgentId, DictObservation]:
batched_actions = np.stack(
[actions[agent].list_data["actions"] for agent in self._agent_ids]
)
Expand All @@ -38,13 +38,13 @@ def hive_step(
results = {}
for env_id, done in enumerate(dones):
if done:
results[self._agent_ids[env_id]] = HiveObservation(
results[self._agent_ids[env_id]] = DictObservation(
episode_state=EpisodeState.TERMINAL,
array_data={"obs": next_obs[env_id]},
rewards={"reward": rewards[env_id]},
)
new_agent = next(self._next_agent)
results[new_agent] = HiveObservation(
results[new_agent] = DictObservation(
episode_state=EpisodeState.INITIAL,
array_data={"obs": next_obs[env_id]},
rewards={"reward": 0.0},
Expand All @@ -54,7 +54,7 @@ def hive_step(

results.update(
{
agent_id: HiveObservation(
agent_id: DictObservation(
episode_state=EpisodeState.RUNNING,
array_data={"obs": next_obs[env_id]},
rewards={"reward": rewards[env_id]},
Expand All @@ -65,12 +65,12 @@ def hive_step(
)
return results

def hive_reset(self) -> Dict[AgentId, HiveObservation]:
def dict_reset(self) -> Dict[AgentId, DictObservation]:
self._agent_ids = [next(self._next_agent) for i in range(self.num_envs)]
self.reset_async()
obs = self.reset_wait()
return {
agent_id: HiveObservation(
agent_id: DictObservation(
episode_state=EpisodeState.INITIAL,
array_data={"obs": obs[i]},
rewards={"reward": 0.0},
Expand Down
File renamed without changes.
6 changes: 3 additions & 3 deletions tests/test_htm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from emote.memory import TableMemoryProxy, MemoryLoader

from .gym import SimpleGymCollector, HitTheMiddle, HiveGymWrapper
from .gym import SimpleGymCollector, HitTheMiddle, DictGymWrapper


N_HIDDEN = 10
Expand Down Expand Up @@ -58,8 +58,8 @@ def forward(self, obs):

def test_htm():

env = HiveGymWrapper(AsyncVectorEnv(10 * [HitTheMiddle]))
table = DictObsTable(spaces=env.hive_space, maxlen=1000)
env = DictGymWrapper(AsyncVectorEnv(10 * [HitTheMiddle]))
table = DictObsTable(spaces=env.dict_space, maxlen=1000)
memory_proxy = TableMemoryProxy(table)
dataloader = MemoryLoader(table, 100, 2, "batch_size")

Expand Down