-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch '97-custommultimodalmodel' into 'dev'
Resolve "CustomMultimodalModel" Closes #97 See merge request fireline/reinforcementlearning/simharness!49
- Loading branch information
Showing
9 changed files
with
487 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
defaults: | ||
- default | ||
- _self_ | ||
|
||
env: simharness2.environments.MultiAgentComplexObsReactiveHarness | ||
env_config: | ||
sim: | ||
_target_: simfire.sim.simulation.FireSimulation | ||
config: | ||
_target_: simfire.utils.config.Config | ||
config_dict: ${simulation.simfire} | ||
movements: [none, up, down, left, right] | ||
interactions: [none, fireline] | ||
# TODO: Need better way of aligning yaml spec with constant keys in code. | ||
attributes: [fire_map, agent_pos] | ||
normalized_attributes: [] | ||
agent_speed: 4 | ||
num_agents: 3 | ||
# agent_initialization_method: automatic | ||
agent_initialization_method: manual | ||
initial_agent_positions: [[0, 64], [127, 64], [64, 127]] | ||
# Defines the class that will be used to perform reward calculation at each timestep. | ||
# reward_cls_partial: | ||
# _target_: simharness2.rewards.base_reward.AreaSavedPropReward | ||
# _partial_: true | ||
benchmark_sim: null | ||
# Defines the class that will be used to monitor and track `ReactiveHarness`. | ||
harness_analytics_partial: | ||
_target_: simharness2.analytics.harness_analytics.ReactiveHarnessAnalytics | ||
_partial_: true | ||
# Defines the class that will be used to monitor and track `FireSimulation`. | ||
sim_analytics_partial: | ||
_target_: simharness2.analytics.simulation_analytics.FireSimulationAnalytics | ||
_partial_: true | ||
# Defines the class that will be used to monitor and track agent behavior. | ||
agent_analytics_partial: | ||
_target_: simharness2.analytics.agent_analytics.ReactiveAgentAnalytics | ||
_partial_: true | ||
movement_types: ${....movements} | ||
interaction_types: ${....interactions} | ||
|
||
# Defines the class that will be used to perform reward calculation at each timestep. | ||
reward_cls_partial: | ||
_target_: simharness2.rewards.base_reward.SimpleReward | ||
_partial_: true | ||
|
||
action_space_cls: | ||
_target_: hydra.utils.get_class | ||
# path: gymnasium.spaces.MultiDiscrete | ||
path: gymnasium.spaces.Discrete | ||
|
||
|
||
fire_initial_position: ${simulation.fire_initial_position} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
defaults: | ||
- config | ||
- override environment: marl_complex_obs | ||
- override training: ppo_with_custom_model | ||
# UNCOMMENT BELOW TO "DISABLE" TUNING, ie. use tune for standalone trial. | ||
# - override tunables: none | ||
- _self_ | ||
|
||
algo: | ||
name: PPO | ||
|
||
rollouts: | ||
# NOTE: MultiAgentComplexObsHarness DOES NOT normalize returned observations! | ||
# So, use the MeanStdFilter to ensure observations are normalized. | ||
# See the MeanStdObservationFilterCAgentConnector class for more: | ||
# https://github.com/ray-project/ray/blob/ad48682b4ec78a5699ba89a7c8a69327c264e47b/rllib/connectors/agent/mean_std_filter.py#L23 | ||
enable_connectors: true | ||
observation_filter: MeanStdFilter | ||
|
||
batch_mode: complete_episodes # truncate_episodes | ||
# Scale experiment as desired | ||
num_rollout_workers: 32 # FIXME! | ||
# num_envs_per_worker: ?? | ||
|
||
cli: | ||
mode: train | ||
data_dir: ?? | ||
|
||
hydra: | ||
run: | ||
dir: ${cli.data_dir}/debug_experiments/${now:%Y-%m-%d_%H-%M-%S} # FIXME! | ||
|
||
simulation: | ||
fire_initial_position: | ||
# Configure the generation of candidate initial fire positions. | ||
generator: | ||
# If true, then all possible fire initial positions will be included in the | ||
# generated "dataset". | ||
make_all_positions: false # FIXME! | ||
# Number of candidate initial fire positions to generate. Ignored when | ||
# `make_all_postions == true`. | ||
output_size: 1024 # FIXME! | ||
# The root location to save the generated dataset. | ||
save_path: ${cli.data_dir}/simfire/data/fire_initial_positions | ||
# To disable usage of generator, simply set to null. | ||
# Configure the sampling of new initial fire positions. | ||
sampler: | ||
# Get (new) random sample of size `sample_size` every `resample_interval` episodes. | ||
# NOTE: Evaluation scenarios are never resampled; this applies to training ONLY. | ||
resample_interval: 1 # FIXME! | ||
# Number of samples to draw from the distribution, ie. take `sample_size` initial | ||
# fire positions and distribute them across the respective RolloutWorker's. | ||
sample_size: | ||
train: ${rollouts.num_rollout_workers} | ||
eval: ${evaluation.evaluation_num_workers} | ||
# Filter generated "dataset" using the following condition. | ||
# NOTE: This is applied to the generated dataset, not the sampled positions. | ||
# For more information on the query string to evaluate, see: | ||
# - https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.query.html#pandas.DataFrame.query | ||
# Examples: | ||
# - "A < elapsed_steps", "elapsed_steps < B", "A < elapsed_steps < B" | ||
# - "(A < elapsed_steps) & (C < percent_area_burned)" | ||
# query: "50 < elapsed_steps < 150" | ||
query: "0.35 < percent_area_burned" # FIXME! | ||
# Number of samples available in resulting "dataset" after filtering with `query`. | ||
# NOTE: This allows the user to control the total number of distinct fire positions | ||
# that will be used to produce sample batches (or trajectories) of experiences. If | ||
# set to `None`, the population size will be equal to the size of the generated | ||
# "dataset" after filtering with `query`. | ||
# - Applies to training ONLY; eval "dataset" size is equal to `sample_size.eval`.ß | ||
population_size: 256 # FIXME! | ||
simfire: | ||
fire: | ||
diagonal_spread: true | ||
# Specify configuration pass to the `AimLoggerCallback` | ||
aim: | ||
repo: ${cli.data_dir}/aim | ||
experiment: test-multimodal-model-with-ppo | ||
system_tracking_interval: 30 | ||
log_hydra_config: false | ||
|
||
# Specify configuration used to create the ray.air.CheckpointConfig object | ||
checkpoint: | ||
# Frequency at which to save checkpoints (in terms of training iterations) | ||
checkpoint_frequency: ${evaluation.evaluation_interval} | ||
# Number of checkpoints to keep | ||
num_to_keep: 20 | ||
|
||
stop_conditions: | ||
training_iteration: 1000 | ||
timesteps_total: 2000000000 | ||
episodes_total: 1000000 | ||
episode_reward_mean: 10000000 | ||
|
||
exploration: | ||
exploration_config: | ||
type: EpsilonGreedy | ||
# FIXME: Below isn't necessarily true for this exp; Update accordingly. | ||
#For 20000 episodes | ||
#Average 512 timesteps per episodes | ||
initial_epsilon: 1.0 | ||
final_epsilon: 0.05 | ||
warmup_timesteps: 972800 | ||
epsilon_timesteps: 8755200 | ||
|
||
resources: | ||
num_gpus: 1 | ||
|
||
debugging: | ||
log_level: INFO | ||
logger_config: | ||
type: | ||
_target_: hydra.utils.get_class | ||
path: ray.tune.logger.UnifiedLogger | ||
logdir: ${hydra:run.dir} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
defaults: | ||
- model: screen_size_128 | ||
|
||
model: | ||
custom_model: multimodal_network | ||
|
||
# PPO parameter specs | ||
# ... | ||
train_batch_size: 1000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
128 changes: 128 additions & 0 deletions
128
simharness2/environments/multi_agent_complex_harness.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
import logging | ||
from collections import OrderedDict | ||
from typing import List, TypeVar | ||
|
||
import numpy as np | ||
from gymnasium import spaces | ||
from simfire.sim.simulation import FireSimulation | ||
|
||
from simharness2.environments.harness import get_unsupported_attributes | ||
from simharness2.environments.multi_agent_fire_harness import MultiAgentFireHarness | ||
from simharness2.models.custom_multimodal_torch_model import ( | ||
AGENT_POSITION_KEY, | ||
FIRE_MAP_KEY, | ||
) | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
AnyFireSimulation = TypeVar("AnyFireSimulation", bound=FireSimulation) | ||
|
||
|
||
class MultiAgentComplexObsReactiveHarness(MultiAgentFireHarness[AnyFireSimulation]): | ||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
# TODO: Make this check more general, ie. not included in every subclass? | ||
# Validate that the attributes provided are supported by this harness. | ||
curr_cls = self.__class__ | ||
bad_attributes = get_unsupported_attributes(self.attributes, curr_cls) | ||
if bad_attributes: | ||
msg = ( | ||
f"The {curr_cls.__name__} class does not support the " | ||
f"following attributes: {bad_attributes}." | ||
) | ||
raise AssertionError(msg) | ||
|
||
@staticmethod | ||
def supported_attributes() -> List[str]: | ||
"""Return the full list of attributes supported by the harness.""" | ||
# TODO: Expand to include SimFire data layers, when ready. | ||
return [FIRE_MAP_KEY, AGENT_POSITION_KEY] | ||
|
||
def get_initial_state(self) -> np.ndarray: | ||
"""TODO.""" | ||
fire_map = self.prepare_fire_map(place_agents=False) | ||
fire_map = np.expand_dims(fire_map, axis=-1).astype(np.float32) | ||
# Build MARL obs - position array will be different for each agent. | ||
marl_obs = {} | ||
for ag_id in self._agent_ids: | ||
curr_agent = self.agents[ag_id] | ||
pos_state = curr_agent.initial_position | ||
marl_obs[ag_id] = OrderedDict( | ||
{ | ||
FIRE_MAP_KEY: fire_map, | ||
AGENT_POSITION_KEY: np.asarray(pos_state), | ||
} | ||
) | ||
# marl_obs[ag_id] = { | ||
# FIRE_MAP_KEY: fire_map, | ||
# AGENT_POSITION_KEY: np.asarray(pos_state), | ||
# } | ||
|
||
# Note: Returning ordered dict bc spaces.Dict.sample() returns ordered dict. | ||
# return OrderedDict(marl_obs) | ||
return marl_obs | ||
|
||
def get_observation_space(self) -> spaces.Space: | ||
"""TODO.""" | ||
# NOTE: We are assuming each agent has the same observation space. | ||
agent_obs_space = spaces.Dict( | ||
OrderedDict( | ||
{ | ||
FIRE_MAP_KEY: self._get_fire_map_observation_space(), | ||
AGENT_POSITION_KEY: self._get_position_observation_space(), | ||
} | ||
) | ||
) | ||
|
||
self._obs_space_in_preferred_format = True | ||
return spaces.Dict({agent_id: agent_obs_space for agent_id in self._agent_ids}) | ||
|
||
def _get_fire_map_observation_space(self) -> spaces.Box: | ||
"""TODO.""" | ||
from simfire.enums import BurnStatus | ||
|
||
# TODO: Refactor to enable easier reuse of this method logic. | ||
# Prepare user-provided interactions | ||
interacts = [i.lower() for i in self.interactions] | ||
if "none" in interacts: | ||
interacts.pop(interacts.index("none")) | ||
# Prepare non-interaction disaster categories | ||
non_interacts = list(self._get_non_interaction_disaster_categories().keys()) | ||
non_interacts = [i.lower() for i in non_interacts] | ||
cats = interacts + non_interacts | ||
cat_vals = [] | ||
for status in BurnStatus: | ||
if status.name.lower() in cats: | ||
cat_vals.append(status.value) | ||
|
||
low = min(cat_vals) | ||
high = max(cat_vals) | ||
obs_shape = self.sim.fire_map.shape + (1,) | ||
return spaces.Box(low=low, high=high, shape=obs_shape, dtype=np.float32) | ||
|
||
def _get_position_observation_space(self) -> spaces.Box: | ||
"""TODO.""" | ||
row_max, col_max = self.sim.fire_map.shape | ||
return spaces.Box(low=np.array([0, 0]), high=np.array([row_max - 1, col_max - 1])) | ||
|
||
def _update_state(self): | ||
"""Modify environment's state to contain updates from the current timestep.""" | ||
# Copy the fire map from the simulation so we don't overwrite it. | ||
fire_map = np.expand_dims(np.copy(self.sim.fire_map), axis=-1).astype(np.float32) | ||
|
||
# Build MARL obs - position array will be different for each agent. | ||
marl_obs = {} | ||
for ag_id in self._agent_ids: | ||
curr_agent = self.agents[ag_id] | ||
pos_state = curr_agent.current_position | ||
marl_obs[ag_id] = OrderedDict( | ||
{ | ||
FIRE_MAP_KEY: fire_map, | ||
AGENT_POSITION_KEY: np.asarray(pos_state), | ||
} | ||
) | ||
# Note: Setting to ordered dict bc spaces.Dict.sample() returns ordered dict. | ||
# self.state = OrderedDict(marl_obs) | ||
self.state = marl_obs |
Oops, something went wrong.