Skip to content

Commit

Permalink
Experiments with the base class to make it a bit more efficient. Ever…
Browse files Browse the repository at this point in the history
…ything failed
  • Loading branch information
i-m-iron-man committed Nov 21, 2024
1 parent ad3c2d6 commit 4149092
Show file tree
Hide file tree
Showing 15 changed files with 572 additions and 34 deletions.
Binary file added .DS_Store
Binary file not shown.
Binary file added docs/.DS_Store
Binary file not shown.
Binary file added docs/assets/.DS_Store
Binary file not shown.
Binary file added examples/.DS_Store
Binary file not shown.
Binary file added examples/basic/.DS_Store
Binary file not shown.
79 changes: 74 additions & 5 deletions examples/basic/hello_world/hello_world.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -15,7 +15,7 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -31,11 +31,11 @@
" \n",
" def create_active_agent(key):\n",
" draw = jax.random.randint(key, (1,), 1, 7)\n",
" state_content = {'draw': draw, 'key': key}\n",
" state_content = {'draw': draw, 'key': key, 'input_value': 0}\n",
" return fgx_classes.State(content=state_content)\n",
" \n",
" def create_inactive_agent(key):\n",
" state_content = {'draw': jnp.array([0]), 'key': key}\n",
" state_content = {'draw': jnp.array([0]), 'key': key, 'input_value': 0}\n",
" return fgx_classes.State(content=state_content)\n",
" agent_state = jax.lax.cond(active_state, lambda _: create_active_agent(subkey), lambda _: create_inactive_agent(subkey), None)\n",
" \n",
Expand All @@ -48,9 +48,11 @@
" def step_active_agent(dice_agent):\n",
" old_state = dice_agent.state.content\n",
" key, subkey = jax.random.split(old_state['key'])\n",
" input_value = input.content['value']\n",
" input_value = input_value+1\n",
" \n",
" draw = jax.random.randint(subkey, (1,), 1, 7)\n",
" state_content = {'draw': draw, 'key': subkey}\n",
" state_content = {'draw': draw, 'key': subkey, 'input_value': input_value[0]}\n",
" new_state = fgx_classes.State(content = state_content)\n",
" return dice_agent.replace(state = new_state, age = dice_agent.age + 1.0)\n",
" \n",
Expand Down Expand Up @@ -80,6 +82,73 @@
" return inactive_dice_agent\n"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"AgentSet initialized\n",
"[6 4 6 2 4 0 0 0 0 0]\n"
]
}
],
"source": [
"Dice_set = fgx_classes.Agent_Set(agent = Dice, num_total_agents = 10, num_active_agents = 5, agent_type = 0)\n",
"\n",
"Dice_set.agents = fgx_methods.create_agents(params = None, agent_set = Dice_set, key = jax.random.PRNGKey(0))\n",
"print(jnp.reshape(Dice_set.agents.state.content['draw'], (-1)))"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"input_signal_content = {'value': jnp.array([[1],[2],[3],[4],[5],[6],[7],[8],[9],[10]])}\n",
"input_signal = fgx_classes.Signal(content = input_signal_content)\n",
"step_params = None"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"21.8 µs ± 295 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"Dice_set.agents = fgx_methods.step_agents(params = None, agent_set = Dice_set, input=input_signal)"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2 3 4 5 6 0 0 0 0 0]\n"
]
}
],
"source": [
"print(Dice_set.agents.state.content['input_value'])"
]
},
{
"cell_type": "code",
"execution_count": 36,
Expand Down
Binary file added foragax/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
20 changes: 7 additions & 13 deletions foragax/base/agent_classes.py → foragax/base/agent_classe.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ class Policy:
def create_policy(params:Params, active_state:bool, key:jax.random.PRNGKey):
pass
@staticmethod
def step_policy(input:Signal, policy:struct.dataclass):
def step_policy(input:Signal, policy:struct.dataclass, key:jax.random.PRNGKey):
pass
@staticmethod
def reset_policy(input:Signal, policy:struct.dataclass):
def reset_policy(input:Signal, policy:struct.dataclass, key:jax.random.PRNGKey):
pass

@struct.dataclass
Expand All @@ -45,7 +45,7 @@ def create_agent(create_params:Params, unique_id:int, active_state:int, agent_ty
pass

@staticmethod
def step_agent(step_params:Params, input:Signal, agents:struct.dataclass):
def step_agent(step_params:Params, input:Signal, agents:struct.dataclass, key:jax.random.PRNGKey):
pass

@staticmethod
Expand All @@ -57,27 +57,21 @@ def set_agent(set_params:Params, agents:struct.dataclass, idx:int, key:jax.rando
pass

@staticmethod
def reset_agent():
def remove_agent(remove_params:Params, Agents:struct.dataclass, idx:int, key:jax.random.PRNGKey):
pass

@staticmethod
def remove_agent(remove_params:Params, Agents:struct.dataclass, idx:int):
pass


class Agent_Set:
agents:Agent
def __init__(self, agent:Agent, num_total_agents:jnp.int32, num_active_agents:jnp.int32, agent_type:jnp.int32):

self.create_agents = jax.vmap(agent.create_agent, in_axes=(None,0,0,0,0))
#create( params, unique_id, active_state, agent_types, key)
#params is not vmaped over

self.step_agents = jax.jit(jax.vmap(agent.step_agent, in_axes=(None,0,0)))

self.reset_agents = jax.jit(jax.vmap(agent.reset_agent))
self.step_agents = jax.jit(jax.vmap(agent.step_agent, in_axes=(None,0,0,0)))

self.num_total_agents = num_total_agents
self.num_active_agents = num_active_agents
self.num_inactive_agents = num_total_agents - num_active_agents
self.agent_type = agent_type
print("AgentSet initialized")
print("AgentSet initialized")
29 changes: 13 additions & 16 deletions foragax/base/agent_methods.py → foragax/base/agent_method.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import jax
import jax.numpy as jnp
from flax import struct
from foragax.base.agent_classes import *
from agent_classe import *
from foragax.base.space_classes import *


Expand All @@ -17,11 +17,9 @@ def create_agents(params:Params, agent_set:Agent_Set, key:jax.random.PRNGKey):
return agent_set.create_agents(params, uniq_ids, active_states, agent_types, create_keys)


def step_agents(params:Params, input:Signal, agent_set:Agent_Set):
try:
return agent_set.step_agents(params, input, agent_set.agents)
except ValueError:
print("Error in step_agents")
def step_agents(params:Params, input:Signal, agent_set:Agent_Set, key:jax.random.PRNGKey):
return agent_set.step_agents(params, input, agent_set.agents, key)


def add_agents(add_func:callable, num_agents_add:int, add_params:Params, agents:Agent, key:jax.random.PRNGKey):
id_last_active = jnp.sum(agents.active_state, dtype=jnp.int32)
Expand All @@ -37,17 +35,16 @@ def set_data(idx, agents__add_params__key):

jit_add_agents = jax.jit(add_agents, static_argnums=(0,))


def remove_agents(remove_func:callable, num_agents_remove:int, remove_params:Params, agents:Agent):
def remove_agents(remove_func:callable, num_agents_remove:int, remove_params:Params, agents:Agent, key:jax.random.PRNGKey):

def remove_data(idx, agents__remove_params):
agents, remove_params = agents__remove_params
remove_ids = remove_params.content['remove_ids']
new_agent = jax.jit(remove_func)(remove_params, agents, remove_ids[idx])
new_agents = jax.tree_util.tree_map(lambda x,y:x.at[remove_ids[idx]].set(y), agents, new_agent)
return new_agents, remove_params
new_agents, remove_params = jax.lax.fori_loop(0, num_agents_remove, remove_data, (agents, remove_params))
return new_agents
def remove_data(idx, agents__remove_params__key):
agents, remove_params, key = agents__remove_params__key
new_agent, key = jax.jit(remove_func)(remove_params, agents, idx, key)
new_agents = jax.tree_util.tree_map(lambda x,y:x.at[idx].set(y), agents, new_agent)
return new_agents, remove_params, key

new_agents, remove_params, key = jax.lax.fori_loop(0, num_agents_remove, remove_data, (agents, remove_params, key))
return new_agents, key

jit_remove_agents = jax.jit(remove_agents, static_argnums=(0,))

Expand Down
Loading

0 comments on commit 4149092

Please sign in to comment.