Skip to content

Commit

Permalink
Merge pull request #20 from clement-moulin-frier/clement/refactor-pro…
Browse files Browse the repository at this point in the history
…x-computation

Add a mask in proximiter computation to only perceive existing entities
  • Loading branch information
corentinlger authored Mar 7, 2024
2 parents 3ba3ffd + 6598f57 commit ccd17c5
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 45 deletions.
4 changes: 2 additions & 2 deletions vivarium/controllers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class AgentConfig(Config):
proxs_cos_min = param.Number(0., bounds=(-1., 1.))
color = param.Color('blue')
friction = param.Number(1e-1)
visible = param.Boolean(True)
exists = param.Boolean(True)

def __init__(self, **params):
super().__init__(**params)
Expand All @@ -66,7 +66,7 @@ class ObjectConfig(Config):
diameter = param.Number(5.)
color = param.Color('red')
friction = param.Number(10.)
visible = param.Boolean(True)
exists = param.Boolean(True)

def __init__(self, **params):
super().__init__(**params)
Expand Down
8 changes: 4 additions & 4 deletions vivarium/controllers/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class StateFieldInfo:
color_c_to_s = lambda x: mcolors.to_rgb(x)
mass_center_s_to_c = lambda x, typ: typ(x)
mass_center_c_to_s = lambda x: [x]
visible_c_to_s = lambda x: int(x)
exists_c_to_s = lambda x: int(x)


agent_configs_to_state_dict = {'x_position': StateFieldInfo(('nve_state', 'position', 'center'), 0, identity_s_to_c, identity_c_to_s),
Expand All @@ -69,7 +69,7 @@ class StateFieldInfo:
'behavior': StateFieldInfo(('agent_state', 'behavior',), None, behavior_s_to_c, behavior_c_to_s),
'color': StateFieldInfo(('agent_state', 'color',), np.arange(3), color_s_to_c, color_c_to_s),
'idx': StateFieldInfo(('agent_state', 'nve_idx',), None, identity_s_to_c, identity_c_to_s),
'visible': StateFieldInfo(('nve_state', 'visible'), None, identity_s_to_c, visible_c_to_s)
'exists': StateFieldInfo(('nve_state', 'exists'), None, identity_s_to_c, exists_c_to_s)
}

agent_configs_to_state_dict.update({f: StateFieldInfo(('agent_state', f,), None, identity_s_to_c, identity_c_to_s) for f in agent_common_fields if f not in agent_configs_to_state_dict})
Expand All @@ -83,7 +83,7 @@ class StateFieldInfo:
'friction': StateFieldInfo(('nve_state', 'friction'), None, identity_s_to_c, identity_c_to_s),
'color': StateFieldInfo(('object_state', 'color',), np.arange(3), color_s_to_c, color_c_to_s),
'idx': StateFieldInfo(('object_state', 'nve_idx',), None, identity_s_to_c, identity_c_to_s),
'visible': StateFieldInfo(('nve_state', 'visible'), None, identity_s_to_c, visible_c_to_s)
'exists': StateFieldInfo(('nve_state', 'exists'), None, identity_s_to_c, exists_c_to_s)

}

Expand Down Expand Up @@ -115,7 +115,7 @@ def get_default_state(n_entities_dict):
entity_idx = jnp.array(list(range(n_agents)) + list(range(n_objects))),
diameter=jnp.zeros(n_entities),
friction=jnp.zeros(n_entities),
visible=jnp.ones(n_entities, dtype=int)
exists=jnp.ones(n_entities, dtype=int)
),
agent_state=AgentState(nve_idx=jnp.zeros(n_agents, dtype=int),
prox=jnp.zeros((n_agents, 2)),
Expand Down
4 changes: 2 additions & 2 deletions vivarium/simulator/grpc_server/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def proto_to_nve_state(nve_state):
entity_idx=proto_to_ndarray(nve_state.entity_idx).astype(int),
diameter=proto_to_ndarray(nve_state.diameter).astype(float),
friction=proto_to_ndarray(nve_state.friction).astype(float),
visible=proto_to_ndarray(nve_state.visible).astype(int)
exists=proto_to_ndarray(nve_state.exists).astype(int)
)


Expand Down Expand Up @@ -98,7 +98,7 @@ def nve_state_to_proto(nve_state):
entity_idx=ndarray_to_proto(nve_state.entity_idx),
diameter=ndarray_to_proto(nve_state.diameter),
friction=ndarray_to_proto(nve_state.friction),
visible=ndarray_to_proto(nve_state.visible)
exists=ndarray_to_proto(nve_state.exists)
)


Expand Down
2 changes: 1 addition & 1 deletion vivarium/simulator/grpc_server/protos/simulator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ message NVEState {
NDArray entity_type = 6;
NDArray entity_idx = 7;
NDArray friction = 8;
NDArray visible = 9;
NDArray exists = 9;
}

message AgentState {
Expand Down
32 changes: 16 additions & 16 deletions vivarium/simulator/grpc_server/simulator_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions vivarium/simulator/grpc_server/simulator_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class SimulatorState(_message.Message):
def __init__(self, idx: _Optional[_Union[NDArray, _Mapping]] = ..., box_size: _Optional[_Union[NDArray, _Mapping]] = ..., n_agents: _Optional[_Union[NDArray, _Mapping]] = ..., n_objects: _Optional[_Union[NDArray, _Mapping]] = ..., num_steps_lax: _Optional[_Union[NDArray, _Mapping]] = ..., dt: _Optional[_Union[NDArray, _Mapping]] = ..., freq: _Optional[_Union[NDArray, _Mapping]] = ..., neighbor_radius: _Optional[_Union[NDArray, _Mapping]] = ..., to_jit: _Optional[_Union[NDArray, _Mapping]] = ..., use_fori_loop: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ...

class NVEState(_message.Message):
__slots__ = ["position", "momentum", "force", "mass", "diameter", "entity_type", "entity_idx", "friction", "visible"]
__slots__ = ["position", "momentum", "force", "mass", "diameter", "entity_type", "entity_idx", "friction", "exists"]
POSITION_FIELD_NUMBER: _ClassVar[int]
MOMENTUM_FIELD_NUMBER: _ClassVar[int]
FORCE_FIELD_NUMBER: _ClassVar[int]
Expand All @@ -60,7 +60,7 @@ class NVEState(_message.Message):
ENTITY_TYPE_FIELD_NUMBER: _ClassVar[int]
ENTITY_IDX_FIELD_NUMBER: _ClassVar[int]
FRICTION_FIELD_NUMBER: _ClassVar[int]
VISIBLE_FIELD_NUMBER: _ClassVar[int]
EXISTS_FIELD_NUMBER: _ClassVar[int]
position: RigidBody
momentum: RigidBody
force: RigidBody
Expand All @@ -69,8 +69,8 @@ class NVEState(_message.Message):
entity_type: NDArray
entity_idx: NDArray
friction: NDArray
visible: NDArray
def __init__(self, position: _Optional[_Union[RigidBody, _Mapping]] = ..., momentum: _Optional[_Union[RigidBody, _Mapping]] = ..., force: _Optional[_Union[RigidBody, _Mapping]] = ..., mass: _Optional[_Union[RigidBody, _Mapping]] = ..., diameter: _Optional[_Union[NDArray, _Mapping]] = ..., entity_type: _Optional[_Union[NDArray, _Mapping]] = ..., entity_idx: _Optional[_Union[NDArray, _Mapping]] = ..., friction: _Optional[_Union[NDArray, _Mapping]] = ..., visible: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ...
exists: NDArray
def __init__(self, position: _Optional[_Union[RigidBody, _Mapping]] = ..., momentum: _Optional[_Union[RigidBody, _Mapping]] = ..., force: _Optional[_Union[RigidBody, _Mapping]] = ..., mass: _Optional[_Union[RigidBody, _Mapping]] = ..., diameter: _Optional[_Union[NDArray, _Mapping]] = ..., entity_type: _Optional[_Union[NDArray, _Mapping]] = ..., entity_idx: _Optional[_Union[NDArray, _Mapping]] = ..., friction: _Optional[_Union[NDArray, _Mapping]] = ..., exists: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ...

class AgentState(_message.Message):
__slots__ = ["nve_idx", "prox", "motor", "behavior", "wheel_diameter", "speed_mul", "theta_mul", "proxs_dist_max", "proxs_cos_min", "color"]
Expand Down
Loading

0 comments on commit ccd17c5

Please sign in to comment.