Skip to content

Commit

Permalink
Merge branch '97-custommultimodalmodel' into 'dev'
Browse files Browse the repository at this point in the history
Resolve "CustomMultimodalModel"

Closes #97

See merge request fireline/reinforcementlearning/simharness!49
  • Loading branch information
afennelly-mitre committed Apr 30, 2024
2 parents 33df887 + b0a6ad2 commit fcfaee7
Show file tree
Hide file tree
Showing 9 changed files with 487 additions and 36 deletions.
53 changes: 53 additions & 0 deletions conf/environment/marl_complex_obs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
defaults:
- default
- _self_

env: simharness2.environments.MultiAgentComplexObsReactiveHarness
env_config:
sim:
_target_: simfire.sim.simulation.FireSimulation
config:
_target_: simfire.utils.config.Config
config_dict: ${simulation.simfire}
movements: [none, up, down, left, right]
interactions: [none, fireline]
# TODO: Need better way of aligning yaml spec with constant keys in code.
attributes: [fire_map, agent_pos]
normalized_attributes: []
agent_speed: 4
num_agents: 3
# agent_initialization_method: automatic
agent_initialization_method: manual
initial_agent_positions: [[0, 64], [127, 64], [64, 127]]
# Defines the class that will be used to perform reward calculation at each timestep.
# reward_cls_partial:
# _target_: simharness2.rewards.base_reward.AreaSavedPropReward
# _partial_: true
benchmark_sim: null
# Defines the class that will be used to monitor and track `ReactiveHarness`.
harness_analytics_partial:
_target_: simharness2.analytics.harness_analytics.ReactiveHarnessAnalytics
_partial_: true
# Defines the class that will be used to monitor and track `FireSimulation`.
sim_analytics_partial:
_target_: simharness2.analytics.simulation_analytics.FireSimulationAnalytics
_partial_: true
# Defines the class that will be used to monitor and track agent behavior.
agent_analytics_partial:
_target_: simharness2.analytics.agent_analytics.ReactiveAgentAnalytics
_partial_: true
movement_types: ${....movements}
interaction_types: ${....interactions}

# Defines the class that will be used to perform reward calculation at each timestep.
reward_cls_partial:
_target_: simharness2.rewards.base_reward.SimpleReward
_partial_: true

action_space_cls:
_target_: hydra.utils.get_class
# path: gymnasium.spaces.MultiDiscrete
path: gymnasium.spaces.Discrete


fire_initial_position: ${simulation.fire_initial_position}
115 changes: 115 additions & 0 deletions conf/test_multimodal_model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
defaults:
- config
- override environment: marl_complex_obs
- override training: ppo_with_custom_model
# UNCOMMENT BELOW TO "DISABLE" TUNING, ie. use tune for standalone trial.
# - override tunables: none
- _self_

algo:
name: PPO

rollouts:
# NOTE: MultiAgentComplexObsHarness DOES NOT normalize returned observations!
# So, use the MeanStdFilter to ensure observations are normalized.
# See the MeanStdObservationFilterCAgentConnector class for more:
# https://github.com/ray-project/ray/blob/ad48682b4ec78a5699ba89a7c8a69327c264e47b/rllib/connectors/agent/mean_std_filter.py#L23
enable_connectors: true
observation_filter: MeanStdFilter

batch_mode: complete_episodes # truncate_episodes
# Scale experiment as desired
num_rollout_workers: 32 # FIXME!
# num_envs_per_worker: ??

cli:
mode: train
data_dir: ??

hydra:
run:
dir: ${cli.data_dir}/debug_experiments/${now:%Y-%m-%d_%H-%M-%S} # FIXME!

simulation:
fire_initial_position:
# Configure the generation of candidate initial fire positions.
generator:
# If true, then all possible fire initial positions will be included in the
# generated "dataset".
make_all_positions: false # FIXME!
# Number of candidate initial fire positions to generate. Ignored when
# `make_all_postions == true`.
output_size: 1024 # FIXME!
# The root location to save the generated dataset.
save_path: ${cli.data_dir}/simfire/data/fire_initial_positions
# To disable usage of generator, simply set to null.
# Configure the sampling of new initial fire positions.
sampler:
# Get (new) random sample of size `sample_size` every `resample_interval` episodes.
# NOTE: Evaluation scenarios are never resampled; this applies to training ONLY.
resample_interval: 1 # FIXME!
# Number of samples to draw from the distribution, ie. take `sample_size` initial
# fire positions and distribute them across the respective RolloutWorker's.
sample_size:
train: ${rollouts.num_rollout_workers}
eval: ${evaluation.evaluation_num_workers}
# Filter generated "dataset" using the following condition.
# NOTE: This is applied to the generated dataset, not the sampled positions.
# For more information on the query string to evaluate, see:
# - https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.query.html#pandas.DataFrame.query
# Examples:
# - "A < elapsed_steps", "elapsed_steps < B", "A < elapsed_steps < B"
# - "(A < elapsed_steps) & (C < percent_area_burned)"
# query: "50 < elapsed_steps < 150"
query: "0.35 < percent_area_burned" # FIXME!
# Number of samples available in resulting "dataset" after filtering with `query`.
# NOTE: This allows the user to control the total number of distinct fire positions
# that will be used to produce sample batches (or trajectories) of experiences. If
# set to `None`, the population size will be equal to the size of the generated
# "dataset" after filtering with `query`.
# - Applies to training ONLY; eval "dataset" size is equal to `sample_size.eval`.ß
population_size: 256 # FIXME!
simfire:
fire:
diagonal_spread: true
# Specify configuration pass to the `AimLoggerCallback`
aim:
repo: ${cli.data_dir}/aim
experiment: test-multimodal-model-with-ppo
system_tracking_interval: 30
log_hydra_config: false

# Specify configuration used to create the ray.air.CheckpointConfig object
checkpoint:
# Frequency at which to save checkpoints (in terms of training iterations)
checkpoint_frequency: ${evaluation.evaluation_interval}
# Number of checkpoints to keep
num_to_keep: 20

stop_conditions:
training_iteration: 1000
timesteps_total: 2000000000
episodes_total: 1000000
episode_reward_mean: 10000000

exploration:
exploration_config:
type: EpsilonGreedy
# FIXME: Below isn't necessarily true for this exp; Update accordingly.
#For 20000 episodes
#Average 512 timesteps per episodes
initial_epsilon: 1.0
final_epsilon: 0.05
warmup_timesteps: 972800
epsilon_timesteps: 8755200

resources:
num_gpus: 1

debugging:
log_level: INFO
logger_config:
type:
_target_: hydra.utils.get_class
path: ray.tune.logger.UnifiedLogger
logdir: ${hydra:run.dir}
9 changes: 9 additions & 0 deletions conf/training/ppo_with_custom_model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
defaults:
- model: screen_size_128

model:
custom_model: multimodal_network

# PPO parameter specs
# ...
train_batch_size: 1000
22 changes: 15 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,17 @@
from ray.tune.logger import pretty_print
from ray.tune.registry import get_trainable_cls, register_env
from ray.tune.result_grid import ResultGrid
from ray.rllib.env import MultiAgentEnv

from simfire.enums import BurnStatus

# from simharness2.utils.evaluation_fires import get_default_operational_fires
import simharness2.models # noqa
from simharness2.callbacks.render_env import RenderEnv
from simharness2.logger.aim import AimLoggerCallback


# from simharness2.utils.evaluation_fires import get_default_operational_fires
import simharness2.models # noqa

# from simharness2.callbacks.set_env_seeds_callback import SetEnvSeedsCallback

os.environ["HYDRA_FULL_ERROR"] = "1"
Expand Down Expand Up @@ -257,13 +261,17 @@ def _build_algo_cfg(cfg: DictConfig) -> Tuple[Algorithm, AlgorithmConfig]:
.resources(**cfg.resources)
.debugging(**debug_settings)
.callbacks(RenderEnv)
# FIXME: Enable passing multi_agent settings to the algorithm config.
# .multi_agent(
# policies=agent_ids,
# policy_mapping_fn=(lambda agent_id, *args, **kwargs: agent_id),
# )
)

# Add multi agent settings if needed for the specified environment.
env_module, env_cls = cfg.environment.env.rsplit(".", 1)
env_cls = getattr(import_module(env_module), env_cls)
if issubclass(env_cls, MultiAgentEnv):
algo_cfg = algo_cfg.multi_agent(
policies=agent_ids,
policy_mapping_fn=(lambda agent_id, *args, **kwargs: agent_id),
)

return algo_cfg


Expand Down
4 changes: 4 additions & 0 deletions simharness2/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
ReactiveHarness,
)
from simharness2.environments.harness import Harness
from simharness2.environments.multi_agent_complex_harness import (
MultiAgentComplexObsReactiveHarness,
)
from simharness2.environments.multi_agent_fire_harness import MultiAgentFireHarness


Expand All @@ -13,4 +16,5 @@
"MultiAgentFireHarness",
"ReactiveHarness",
"DamageAwareReactiveHarness",
"MultiAgentComplexObsReactiveHarness",
]
128 changes: 128 additions & 0 deletions simharness2/environments/multi_agent_complex_harness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import logging
from collections import OrderedDict
from typing import List, TypeVar

import numpy as np
from gymnasium import spaces
from simfire.sim.simulation import FireSimulation

from simharness2.environments.harness import get_unsupported_attributes
from simharness2.environments.multi_agent_fire_harness import MultiAgentFireHarness
from simharness2.models.custom_multimodal_torch_model import (
AGENT_POSITION_KEY,
FIRE_MAP_KEY,
)


logger = logging.getLogger(__name__)

AnyFireSimulation = TypeVar("AnyFireSimulation", bound=FireSimulation)


class MultiAgentComplexObsReactiveHarness(MultiAgentFireHarness[AnyFireSimulation]):
def __init__(self, **kwargs):
super().__init__(**kwargs)

# TODO: Make this check more general, ie. not included in every subclass?
# Validate that the attributes provided are supported by this harness.
curr_cls = self.__class__
bad_attributes = get_unsupported_attributes(self.attributes, curr_cls)
if bad_attributes:
msg = (
f"The {curr_cls.__name__} class does not support the "
f"following attributes: {bad_attributes}."
)
raise AssertionError(msg)

@staticmethod
def supported_attributes() -> List[str]:
"""Return the full list of attributes supported by the harness."""
# TODO: Expand to include SimFire data layers, when ready.
return [FIRE_MAP_KEY, AGENT_POSITION_KEY]

def get_initial_state(self) -> np.ndarray:
"""TODO."""
fire_map = self.prepare_fire_map(place_agents=False)
fire_map = np.expand_dims(fire_map, axis=-1).astype(np.float32)
# Build MARL obs - position array will be different for each agent.
marl_obs = {}
for ag_id in self._agent_ids:
curr_agent = self.agents[ag_id]
pos_state = curr_agent.initial_position
marl_obs[ag_id] = OrderedDict(
{
FIRE_MAP_KEY: fire_map,
AGENT_POSITION_KEY: np.asarray(pos_state),
}
)
# marl_obs[ag_id] = {
# FIRE_MAP_KEY: fire_map,
# AGENT_POSITION_KEY: np.asarray(pos_state),
# }

# Note: Returning ordered dict bc spaces.Dict.sample() returns ordered dict.
# return OrderedDict(marl_obs)
return marl_obs

def get_observation_space(self) -> spaces.Space:
"""TODO."""
# NOTE: We are assuming each agent has the same observation space.
agent_obs_space = spaces.Dict(
OrderedDict(
{
FIRE_MAP_KEY: self._get_fire_map_observation_space(),
AGENT_POSITION_KEY: self._get_position_observation_space(),
}
)
)

self._obs_space_in_preferred_format = True
return spaces.Dict({agent_id: agent_obs_space for agent_id in self._agent_ids})

def _get_fire_map_observation_space(self) -> spaces.Box:
"""TODO."""
from simfire.enums import BurnStatus

# TODO: Refactor to enable easier reuse of this method logic.
# Prepare user-provided interactions
interacts = [i.lower() for i in self.interactions]
if "none" in interacts:
interacts.pop(interacts.index("none"))
# Prepare non-interaction disaster categories
non_interacts = list(self._get_non_interaction_disaster_categories().keys())
non_interacts = [i.lower() for i in non_interacts]
cats = interacts + non_interacts
cat_vals = []
for status in BurnStatus:
if status.name.lower() in cats:
cat_vals.append(status.value)

low = min(cat_vals)
high = max(cat_vals)
obs_shape = self.sim.fire_map.shape + (1,)
return spaces.Box(low=low, high=high, shape=obs_shape, dtype=np.float32)

def _get_position_observation_space(self) -> spaces.Box:
"""TODO."""
row_max, col_max = self.sim.fire_map.shape
return spaces.Box(low=np.array([0, 0]), high=np.array([row_max - 1, col_max - 1]))

def _update_state(self):
"""Modify environment's state to contain updates from the current timestep."""
# Copy the fire map from the simulation so we don't overwrite it.
fire_map = np.expand_dims(np.copy(self.sim.fire_map), axis=-1).astype(np.float32)

# Build MARL obs - position array will be different for each agent.
marl_obs = {}
for ag_id in self._agent_ids:
curr_agent = self.agents[ag_id]
pos_state = curr_agent.current_position
marl_obs[ag_id] = OrderedDict(
{
FIRE_MAP_KEY: fire_map,
AGENT_POSITION_KEY: np.asarray(pos_state),
}
)
# Note: Setting to ordered dict bc spaces.Dict.sample() returns ordered dict.
# self.state = OrderedDict(marl_obs)
self.state = marl_obs
Loading

0 comments on commit fcfaee7

Please sign in to comment.