Skip to content

Commit

Permalink
Merge pull request #39 from clement-moulin-frier/corentin/refactor_co…
Browse files Browse the repository at this point in the history
…de_argparse

Move executable code in separate scripts and add argparse option for …
  • Loading branch information
corentinlger authored Mar 13, 2024
2 parents 5c893bf + 167b32b commit 37499d5
Show file tree
Hide file tree
Showing 11 changed files with 158 additions and 69 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ pip install -e .
Use the following command :

```bash
python3 vivarium/simulator/grpc_server/simulator_server.py
python3 scripts/run_server.py
```

### Interact with it from a web interface

And launch the web interface from another terminal :

```bash
panel serve vivarium/interface/panel_app.py --autoreload
panel serve scripts/run_interface.py --autoreload
```

Once this command will have completed, it will output a URL looking like `http://localhost:5006/panel_app`, that you can open in your browser.
Expand Down
5 changes: 5 additions & 0 deletions scripts/run_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from vivarium.interface.panel_app import WindowManager

# Serve the app
wm = WindowManager()
wm.app.servable(title="Vivarium")
69 changes: 69 additions & 0 deletions scripts/run_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import logging
import argparse

import numpy as np

import vivarium.simulator.behaviors as behaviors
from vivarium.controllers.config import SimulatorConfig, AgentConfig, ObjectConfig
from vivarium.simulator.sim_computation import StateType
from vivarium.simulator.simulator import Simulator
from vivarium.simulator.sim_computation import dynamics_rigid
from vivarium.controllers.converters import set_state_from_config_dict
from vivarium.simulator.grpc_server.simulator_server import serve

lg = logging.getLogger(__name__)

def parse_args():
parser = argparse.ArgumentParser(description='Simulator Configuration')
parser.add_argument('--box_size', type=float, default=100.0, help='Size of the simulation box')
parser.add_argument('--n_agents', type=int, default=10, help='Number of agents')
parser.add_argument('--n_objects', type=int, default=2, help='Number of objects')
parser.add_argument('--num_steps-lax', type=int, default=4, help='Number of lax steps per loop')
parser.add_argument('--dt', type=float, default=0.1, help='Time step size')
parser.add_argument('--freq', type=float, default=40.0, help='Frequency parameter')
parser.add_argument('--neighbor_radius', type=float, default=100.0, help='Radius for neighbor calculations')
# By default jit compile the code and use normal python loops
parser.add_argument('--to_jit', action='store_false', help='Whether to use JIT compilation')
parser.add_argument('--use_fori_loop', action='store_true', help='Whether to use fori loop')
parser.add_argument('--log_level', type=str, default='INFO', help='Logging level')

return parser.parse_args()


if __name__ == '__main__':
args = parse_args()

logging.basicConfig(level=args.log_level.upper())

simulator_config = SimulatorConfig(
box_size=args.box_size,
n_agents=args.n_agents,
n_objects=args.n_objects,
num_steps_lax=args.num_steps_lax,
dt=args.dt,
freq=args.freq,
neighbor_radius=args.neighbor_radius,
to_jit=args.to_jit,
use_fori_loop=args.use_fori_loop
)

agent_configs = [AgentConfig(idx=i,
x_position=np.random.rand() * simulator_config.box_size,
y_position=np.random.rand() * simulator_config.box_size,
orientation=np.random.rand() * 2. * np.pi)
for i in range(simulator_config.n_agents)]

object_configs = [ObjectConfig(idx=simulator_config.n_agents + i,
x_position=np.random.rand() * simulator_config.box_size,
y_position=np.random.rand() * simulator_config.box_size,
orientation=np.random.rand() * 2. * np.pi)
for i in range(simulator_config.n_objects)]

state = set_state_from_config_dict({StateType.AGENT: agent_configs,
StateType.OBJECT: object_configs,
StateType.SIMULATOR: [simulator_config]
})

simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid)
lg.info('Simulator server started')
serve(simulator)
82 changes: 82 additions & 0 deletions scripts/run_simulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import argparse
import logging

import numpy as np

from vivarium.simulator import behaviors
from vivarium.simulator.sim_computation import dynamics_rigid, StateType
from vivarium.controllers.config import AgentConfig, ObjectConfig, SimulatorConfig
from vivarium.controllers import converters
from vivarium.simulator.simulator import Simulator

lg = logging.getLogger(__name__)

def parse_args():
parser = argparse.ArgumentParser(description='Simulator Configuration')
# Experiment run arguments
parser.add_argument('--num_loops', type=int, default=10, help='Number of simulation loops')
# Simulator config arguments
parser.add_argument('--box_size', type=float, default=100.0, help='Size of the simulation box')
parser.add_argument('--n_agents', type=int, default=10, help='Number of agents')
parser.add_argument('--n_objects', type=int, default=2, help='Number of objects')
parser.add_argument('--num_steps_lax', type=int, default=4, help='Number of lax steps per loop')
parser.add_argument('--dt', type=float, default=0.1, help='Time step size')
parser.add_argument('--freq', type=float, default=40.0, help='Frequency parameter')
parser.add_argument('--neighbor_radius', type=float, default=100.0, help='Radius for neighbor calculations')
# By default jit compile the code and use normal python loops
parser.add_argument('--to_jit', action='store_false', help='Whether to use JIT compilation')
parser.add_argument('--use_fori_loop', action='store_true', help='Whether to use fori loop')
parser.add_argument('--log_level', type=str, default='INFO', help='Logging level')

return parser.parse_args()


if __name__ == "__main__":
args = parse_args()

logging.basicConfig(level=args.log_level.upper())

simulator_config = SimulatorConfig(
box_size=args.box_size,
n_agents=args.n_agents,
n_objects=args.n_objects,
num_steps_lax=args.num_steps_lax,
dt=args.dt,
freq=args.freq,
neighbor_radius=args.neighbor_radius,
to_jit=args.to_jit,
use_fori_loop=args.use_fori_loop
)

agent_configs = [
AgentConfig(idx=i,
x_position=np.random.rand() * simulator_config.box_size,
y_position=np.random.rand() * simulator_config.box_size,
orientation=np.random.rand() * 2. * np.pi)
for i in range(simulator_config.n_agents)
]

object_configs = [
ObjectConfig(idx=simulator_config.n_agents + i,
x_position=np.random.rand() * simulator_config.box_size,
y_position=np.random.rand() * simulator_config.box_size,
orientation=np.random.rand() * 2. * np.pi)
for i in range(simulator_config.n_objects)
]

state = converters.set_state_from_config_dict(
{
StateType.AGENT: agent_configs,
StateType.OBJECT: object_configs,
StateType.SIMULATOR: [simulator_config]
}
)


simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid)

lg.info("Running simulation")

simulator.run(threaded=False, num_loops=10)

lg.info("Simulation complete")
1 change: 0 additions & 1 deletion vivarium/controllers/notebook_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from vivarium.controllers.simulator_controller import SimulatorController
from vivarium.simulator.sim_computation import StateType, EntityType

logging.basicConfig(level=logging.INFO)
lg = logging.getLogger(__name__)

class Entity:
Expand Down
1 change: 0 additions & 1 deletion vivarium/controllers/panel_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import logging

logging.basicConfig(level=logging.INFO)
lg = logging.getLogger(__name__)

class PanelConfig(Config):
Expand Down
1 change: 0 additions & 1 deletion vivarium/controllers/simulator_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from contextlib import contextmanager
import logging

logging.basicConfig(level=logging.INFO)
lg = logging.getLogger(__name__)

param.Dynamic.time_dependent = True
Expand Down
4 changes: 0 additions & 4 deletions vivarium/interface/panel_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,3 @@ def set_callbacks(self):
self.start_toggle.param.watch(self.start_toggle_cb, "value")
self.update_switch.param.watch(self.update_switch_cb, "value")
self.update_timestep.param.watch(self.update_timestep_cb, "value")

# Serve the app
wm = WindowManager()
wm.app.servable(title="Vivarium")
1 change: 0 additions & 1 deletion vivarium/simulator/behaviors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from jax import vmap
from functools import partial

logging.basicConfig(level=logging.INFO)
lg = logging.getLogger(__name__)

linear_behavior_enum = Enum('matrices', ['FEAR', 'AGGRESSION', 'LOVE', 'SHY'])
Expand Down
30 changes: 0 additions & 30 deletions vivarium/simulator/grpc_server/simulator_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from vivarium.simulator.grpc_server.converters import state_to_proto, nve_state_to_proto, agent_state_to_proto, object_state_to_proto
from vivarium.controllers.converters import set_state_from_config_dict

logging.basicConfig(level=logging.INFO)
lg = logging.getLogger(__name__)

Empty = simulator_pb2.google_dot_protobuf_dot_empty__pb2.Empty
Expand Down Expand Up @@ -94,32 +93,3 @@ def serve(simulator):
server.add_insecure_port('[::]:50051')
server.start()
server.wait_for_termination()


if __name__ == '__main__':

simulator_config = SimulatorConfig(to_jit=True)

agent_configs = [AgentConfig(idx=i,
x_position=np.random.rand() * simulator_config.box_size,
y_position=np.random.rand() * simulator_config.box_size,
orientation=np.random.rand() * 2. * np.pi)
for i in range(simulator_config.n_agents)]

object_configs = [ObjectConfig(idx=simulator_config.n_agents + i,
x_position=np.random.rand() * simulator_config.box_size,
y_position=np.random.rand() * simulator_config.box_size,
orientation=np.random.rand() * 2. * np.pi)
for i in range(simulator_config.n_objects)]

state = set_state_from_config_dict({StateType.AGENT: agent_configs,
StateType.OBJECT: object_configs,
StateType.SIMULATOR: [simulator_config]
})

simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid)
lg.info('Simulator server started')
serve(simulator)



29 changes: 0 additions & 29 deletions vivarium/simulator/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import math
import logging

logging.basicConfig(level=logging.INFO)
lg = logging.getLogger(__name__)

class Simulator:
Expand Down Expand Up @@ -181,31 +180,3 @@ def get_change_time(self):

def get_state(self):
return self.state


if __name__ == "__main__":

simulator_config = SimulatorConfig(to_jit=True)

agent_configs = [AgentConfig(idx=i,
x_position=np.random.rand() * simulator_config.box_size,
y_position=np.random.rand() * simulator_config.box_size,
orientation=np.random.rand() * 2. * np.pi)
for i in range(simulator_config.n_agents)]

object_configs = [ObjectConfig(idx=simulator_config.n_agents + i,
x_position=np.random.rand() * simulator_config.box_size,
y_position=np.random.rand() * simulator_config.box_size,
orientation=np.random.rand() * 2. * np.pi)
for i in range(simulator_config.n_objects)]

state = converters.set_state_from_config_dict({StateType.AGENT: agent_configs,
StateType.OBJECT: object_configs,
StateType.SIMULATOR: [simulator_config]
})

simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid)

simulator.run(threaded=False, num_loops=10)


0 comments on commit 37499d5

Please sign in to comment.