diff --git a/conf/environment/marl_complex_obs.yaml b/conf/environment/marl_complex_obs.yaml new file mode 100644 index 0000000..80eddb9 --- /dev/null +++ b/conf/environment/marl_complex_obs.yaml @@ -0,0 +1,53 @@ +defaults: + - default + - _self_ + +env: simharness2.environments.MultiAgentComplexObsReactiveHarness +env_config: + sim: + _target_: simfire.sim.simulation.FireSimulation + config: + _target_: simfire.utils.config.Config + config_dict: ${simulation.simfire} + movements: [none, up, down, left, right] + interactions: [none, fireline] + # TODO: Need better way of aligning yaml spec with constant keys in code. + attributes: [fire_map, agent_pos] + normalized_attributes: [] + agent_speed: 4 + num_agents: 3 + # agent_initialization_method: automatic + agent_initialization_method: manual + initial_agent_positions: [[0, 64], [127, 64], [64, 127]] + # Defines the class that will be used to perform reward calculation at each timestep. + # reward_cls_partial: + # _target_: simharness2.rewards.base_reward.AreaSavedPropReward + # _partial_: true + benchmark_sim: null + # Defines the class that will be used to monitor and track `ReactiveHarness`. + harness_analytics_partial: + _target_: simharness2.analytics.harness_analytics.ReactiveHarnessAnalytics + _partial_: true + # Defines the class that will be used to monitor and track `FireSimulation`. + sim_analytics_partial: + _target_: simharness2.analytics.simulation_analytics.FireSimulationAnalytics + _partial_: true + # Defines the class that will be used to monitor and track agent behavior. + agent_analytics_partial: + _target_: simharness2.analytics.agent_analytics.ReactiveAgentAnalytics + _partial_: true + movement_types: ${....movements} + interaction_types: ${....interactions} + + # Defines the class that will be used to perform reward calculation at each timestep. + reward_cls_partial: + _target_: simharness2.rewards.base_reward.SimpleReward + _partial_: true + + action_space_cls: + _target_: hydra.utils.get_class + # path: gymnasium.spaces.MultiDiscrete + path: gymnasium.spaces.Discrete + + + fire_initial_position: ${simulation.fire_initial_position} diff --git a/conf/test_multimodal_model.yaml b/conf/test_multimodal_model.yaml new file mode 100644 index 0000000..3fe75f7 --- /dev/null +++ b/conf/test_multimodal_model.yaml @@ -0,0 +1,115 @@ +defaults: + - config + - override environment: marl_complex_obs + - override training: ppo_with_custom_model + # UNCOMMENT BELOW TO "DISABLE" TUNING, ie. use tune for standalone trial. + # - override tunables: none + - _self_ + +algo: + name: PPO + +rollouts: + # NOTE: MultiAgentComplexObsHarness DOES NOT normalize returned observations! + # So, use the MeanStdFilter to ensure observations are normalized. + # See the MeanStdObservationFilterCAgentConnector class for more: + # https://github.com/ray-project/ray/blob/ad48682b4ec78a5699ba89a7c8a69327c264e47b/rllib/connectors/agent/mean_std_filter.py#L23 + enable_connectors: true + observation_filter: MeanStdFilter + + batch_mode: complete_episodes # truncate_episodes + # Scale experiment as desired + num_rollout_workers: 32 # FIXME! + # num_envs_per_worker: ?? + +cli: + mode: train + data_dir: ?? + +hydra: + run: + dir: ${cli.data_dir}/debug_experiments/${now:%Y-%m-%d_%H-%M-%S} # FIXME! + +simulation: + fire_initial_position: + # Configure the generation of candidate initial fire positions. + generator: + # If true, then all possible fire initial positions will be included in the + # generated "dataset". + make_all_positions: false # FIXME! + # Number of candidate initial fire positions to generate. Ignored when + # `make_all_postions == true`. + output_size: 1024 # FIXME! + # The root location to save the generated dataset. + save_path: ${cli.data_dir}/simfire/data/fire_initial_positions + # To disable usage of generator, simply set to null. + # Configure the sampling of new initial fire positions. + sampler: + # Get (new) random sample of size `sample_size` every `resample_interval` episodes. + # NOTE: Evaluation scenarios are never resampled; this applies to training ONLY. + resample_interval: 1 # FIXME! + # Number of samples to draw from the distribution, ie. take `sample_size` initial + # fire positions and distribute them across the respective RolloutWorker's. + sample_size: + train: ${rollouts.num_rollout_workers} + eval: ${evaluation.evaluation_num_workers} + # Filter generated "dataset" using the following condition. + # NOTE: This is applied to the generated dataset, not the sampled positions. + # For more information on the query string to evaluate, see: + # - https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.query.html#pandas.DataFrame.query + # Examples: + # - "A < elapsed_steps", "elapsed_steps < B", "A < elapsed_steps < B" + # - "(A < elapsed_steps) & (C < percent_area_burned)" + # query: "50 < elapsed_steps < 150" + query: "0.35 < percent_area_burned" # FIXME! + # Number of samples available in resulting "dataset" after filtering with `query`. + # NOTE: This allows the user to control the total number of distinct fire positions + # that will be used to produce sample batches (or trajectories) of experiences. If + # set to `None`, the population size will be equal to the size of the generated + # "dataset" after filtering with `query`. + # - Applies to training ONLY; eval "dataset" size is equal to `sample_size.eval`.ß + population_size: 256 # FIXME! + simfire: + fire: + diagonal_spread: true +# Specify configuration pass to the `AimLoggerCallback` +aim: + repo: ${cli.data_dir}/aim + experiment: test-multimodal-model-with-ppo + system_tracking_interval: 30 + log_hydra_config: false + +# Specify configuration used to create the ray.air.CheckpointConfig object +checkpoint: + # Frequency at which to save checkpoints (in terms of training iterations) + checkpoint_frequency: ${evaluation.evaluation_interval} + # Number of checkpoints to keep + num_to_keep: 20 + +stop_conditions: + training_iteration: 1000 + timesteps_total: 2000000000 + episodes_total: 1000000 + episode_reward_mean: 10000000 + +exploration: + exploration_config: + type: EpsilonGreedy + # FIXME: Below isn't necessarily true for this exp; Update accordingly. + #For 20000 episodes + #Average 512 timesteps per episodes + initial_epsilon: 1.0 + final_epsilon: 0.05 + warmup_timesteps: 972800 + epsilon_timesteps: 8755200 + +resources: + num_gpus: 1 + +debugging: + log_level: INFO + logger_config: + type: + _target_: hydra.utils.get_class + path: ray.tune.logger.UnifiedLogger + logdir: ${hydra:run.dir} diff --git a/conf/training/ppo_with_custom_model.yaml b/conf/training/ppo_with_custom_model.yaml new file mode 100644 index 0000000..9b41c3c --- /dev/null +++ b/conf/training/ppo_with_custom_model.yaml @@ -0,0 +1,9 @@ +defaults: + - model: screen_size_128 + +model: + custom_model: multimodal_network + +# PPO parameter specs +# ... +train_batch_size: 1000 diff --git a/main.py b/main.py index e64c049..e836fa0 100644 --- a/main.py +++ b/main.py @@ -28,13 +28,17 @@ from ray.tune.logger import pretty_print from ray.tune.registry import get_trainable_cls, register_env from ray.tune.result_grid import ResultGrid +from ray.rllib.env import MultiAgentEnv + from simfire.enums import BurnStatus -# from simharness2.utils.evaluation_fires import get_default_operational_fires -import simharness2.models # noqa from simharness2.callbacks.render_env import RenderEnv from simharness2.logger.aim import AimLoggerCallback + +# from simharness2.utils.evaluation_fires import get_default_operational_fires +import simharness2.models # noqa + # from simharness2.callbacks.set_env_seeds_callback import SetEnvSeedsCallback os.environ["HYDRA_FULL_ERROR"] = "1" @@ -257,13 +261,17 @@ def _build_algo_cfg(cfg: DictConfig) -> Tuple[Algorithm, AlgorithmConfig]: .resources(**cfg.resources) .debugging(**debug_settings) .callbacks(RenderEnv) - # FIXME: Enable passing multi_agent settings to the algorithm config. - # .multi_agent( - # policies=agent_ids, - # policy_mapping_fn=(lambda agent_id, *args, **kwargs: agent_id), - # ) ) + # Add multi agent settings if needed for the specified environment. + env_module, env_cls = cfg.environment.env.rsplit(".", 1) + env_cls = getattr(import_module(env_module), env_cls) + if issubclass(env_cls, MultiAgentEnv): + algo_cfg = algo_cfg.multi_agent( + policies=agent_ids, + policy_mapping_fn=(lambda agent_id, *args, **kwargs: agent_id), + ) + return algo_cfg diff --git a/simharness2/environments/__init__.py b/simharness2/environments/__init__.py index 803d489..37f2017 100644 --- a/simharness2/environments/__init__.py +++ b/simharness2/environments/__init__.py @@ -4,6 +4,9 @@ ReactiveHarness, ) from simharness2.environments.harness import Harness +from simharness2.environments.multi_agent_complex_harness import ( + MultiAgentComplexObsReactiveHarness, +) from simharness2.environments.multi_agent_fire_harness import MultiAgentFireHarness @@ -13,4 +16,5 @@ "MultiAgentFireHarness", "ReactiveHarness", "DamageAwareReactiveHarness", + "MultiAgentComplexObsReactiveHarness", ] diff --git a/simharness2/environments/multi_agent_complex_harness.py b/simharness2/environments/multi_agent_complex_harness.py new file mode 100644 index 0000000..f2f2376 --- /dev/null +++ b/simharness2/environments/multi_agent_complex_harness.py @@ -0,0 +1,128 @@ +import logging +from collections import OrderedDict +from typing import List, TypeVar + +import numpy as np +from gymnasium import spaces +from simfire.sim.simulation import FireSimulation + +from simharness2.environments.harness import get_unsupported_attributes +from simharness2.environments.multi_agent_fire_harness import MultiAgentFireHarness +from simharness2.models.custom_multimodal_torch_model import ( + AGENT_POSITION_KEY, + FIRE_MAP_KEY, +) + + +logger = logging.getLogger(__name__) + +AnyFireSimulation = TypeVar("AnyFireSimulation", bound=FireSimulation) + + +class MultiAgentComplexObsReactiveHarness(MultiAgentFireHarness[AnyFireSimulation]): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # TODO: Make this check more general, ie. not included in every subclass? + # Validate that the attributes provided are supported by this harness. + curr_cls = self.__class__ + bad_attributes = get_unsupported_attributes(self.attributes, curr_cls) + if bad_attributes: + msg = ( + f"The {curr_cls.__name__} class does not support the " + f"following attributes: {bad_attributes}." + ) + raise AssertionError(msg) + + @staticmethod + def supported_attributes() -> List[str]: + """Return the full list of attributes supported by the harness.""" + # TODO: Expand to include SimFire data layers, when ready. + return [FIRE_MAP_KEY, AGENT_POSITION_KEY] + + def get_initial_state(self) -> np.ndarray: + """TODO.""" + fire_map = self.prepare_fire_map(place_agents=False) + fire_map = np.expand_dims(fire_map, axis=-1).astype(np.float32) + # Build MARL obs - position array will be different for each agent. + marl_obs = {} + for ag_id in self._agent_ids: + curr_agent = self.agents[ag_id] + pos_state = curr_agent.initial_position + marl_obs[ag_id] = OrderedDict( + { + FIRE_MAP_KEY: fire_map, + AGENT_POSITION_KEY: np.asarray(pos_state), + } + ) + # marl_obs[ag_id] = { + # FIRE_MAP_KEY: fire_map, + # AGENT_POSITION_KEY: np.asarray(pos_state), + # } + + # Note: Returning ordered dict bc spaces.Dict.sample() returns ordered dict. + # return OrderedDict(marl_obs) + return marl_obs + + def get_observation_space(self) -> spaces.Space: + """TODO.""" + # NOTE: We are assuming each agent has the same observation space. + agent_obs_space = spaces.Dict( + OrderedDict( + { + FIRE_MAP_KEY: self._get_fire_map_observation_space(), + AGENT_POSITION_KEY: self._get_position_observation_space(), + } + ) + ) + + self._obs_space_in_preferred_format = True + return spaces.Dict({agent_id: agent_obs_space for agent_id in self._agent_ids}) + + def _get_fire_map_observation_space(self) -> spaces.Box: + """TODO.""" + from simfire.enums import BurnStatus + + # TODO: Refactor to enable easier reuse of this method logic. + # Prepare user-provided interactions + interacts = [i.lower() for i in self.interactions] + if "none" in interacts: + interacts.pop(interacts.index("none")) + # Prepare non-interaction disaster categories + non_interacts = list(self._get_non_interaction_disaster_categories().keys()) + non_interacts = [i.lower() for i in non_interacts] + cats = interacts + non_interacts + cat_vals = [] + for status in BurnStatus: + if status.name.lower() in cats: + cat_vals.append(status.value) + + low = min(cat_vals) + high = max(cat_vals) + obs_shape = self.sim.fire_map.shape + (1,) + return spaces.Box(low=low, high=high, shape=obs_shape, dtype=np.float32) + + def _get_position_observation_space(self) -> spaces.Box: + """TODO.""" + row_max, col_max = self.sim.fire_map.shape + return spaces.Box(low=np.array([0, 0]), high=np.array([row_max - 1, col_max - 1])) + + def _update_state(self): + """Modify environment's state to contain updates from the current timestep.""" + # Copy the fire map from the simulation so we don't overwrite it. + fire_map = np.expand_dims(np.copy(self.sim.fire_map), axis=-1).astype(np.float32) + + # Build MARL obs - position array will be different for each agent. + marl_obs = {} + for ag_id in self._agent_ids: + curr_agent = self.agents[ag_id] + pos_state = curr_agent.current_position + marl_obs[ag_id] = OrderedDict( + { + FIRE_MAP_KEY: fire_map, + AGENT_POSITION_KEY: np.asarray(pos_state), + } + ) + # Note: Setting to ordered dict bc spaces.Dict.sample() returns ordered dict. + # self.state = OrderedDict(marl_obs) + self.state = marl_obs diff --git a/simharness2/environments/multi_agent_fire_harness.py b/simharness2/environments/multi_agent_fire_harness.py index cc67f9b..1255c6f 100644 --- a/simharness2/environments/multi_agent_fire_harness.py +++ b/simharness2/environments/multi_agent_fire_harness.py @@ -97,42 +97,34 @@ def step( truncated = self._should_truncate() terminated = self._should_terminate() - # Calculate the reward for the current timestep - # TODO pass `terminated` into `get_reward` method - # FIXME: Update reward for MARL case!! - # TODO: Give each agent the "same" simple reward for now. - reward = self.reward_cls.get_reward(self.timesteps, sim_run) - - # Terminate episode early if burn damage in Agent Sim is larger than final bench fire map - self._terminate_if_greater_damage = False #FIXME get rid of this line when the fixme in the if statement below is implemented - if self.benchmark_sim: - if self._terminate_if_greater_damage: - total_area = self.sim.fire_map.size - - sim_damaged_total = self.harness_analytics.sim_analytics.data.burned + self.harness_analytics.sim_analytics.data.burning - # FIXME Fix this damage calculation if needed to account for damage across all the agent sims - benchsim_damaged_total = total_area - self.harness_analytics.benchmark_sim_analytics.data.unburned - - if sim_damaged_total > benchsim_damaged_total: - terminated = True - # TODO potentially add a static negative penalty for making the fire worse - - # TODO account for below updates in the reward_cls.calculate_reward() method - # "End of episode" reward - if terminated: - reward += 10 - + # Calculate the timestep reward for each agent. + rewards = self.reward_cls.get_reward( + timestep=self.timesteps, + sim_run=sim_run, + done_episode=terminated or truncated, + agents=self.agents, + agent_speed=self.agent_speed, + ) + + # FIXME: Refactor logic to ensure all rewards return expected types + if np.isscalar(rewards): + if self.timesteps < 1: + logger.warning("Calculated reward value is scalar. Converting to Dict.") + rewards = {agent_id: rewards for agent_id in self.agents} + + # FIXME: We are passing the TIMESTEP reward, not CUMULATIVE reward!! if self.harness_analytics: + # FIXME: Decide if we should pass all agent rewards. For now, use the sum. + cumulative_reward = sum(rewards.values()) self.harness_analytics.update_after_one_harness_step( - sim_run, terminated, reward, timestep=self.timesteps + sim_run, terminated, cumulative_reward, timestep=self.timesteps ) - rewards, truncateds, terminateds, infos = {}, {}, {}, {} + # TODO: Override _should_truncate() etc. to return Dict instead of single value. + truncateds, terminateds, infos = {}, {}, {} truncs = set() terms = set() for agent_id, agent in self.agents.items(): - # FIXME: All agents receive the SAME reward !!! - rewards[agent_id] = reward # FIXME: Trunc/Term logic is the SAME for all agents. # We may not always want this, but it's a good starting point. truncateds[agent_id] = truncated diff --git a/simharness2/models/__init__.py b/simharness2/models/__init__.py index a2d2fe3..4104b19 100644 --- a/simharness2/models/__init__.py +++ b/simharness2/models/__init__.py @@ -1,8 +1,12 @@ from ray.rllib.models.catalog import ModelCatalog from .custom_dqn_torch_model import CustomDQNTorchVisionNet +from .custom_multimodal_torch_model import CustomMultimodalTorchModel + # Register custom model. ModelCatalog.register_custom_model( "metric_reporting_vision_network", CustomDQNTorchVisionNet ) + +ModelCatalog.register_custom_model("multimodal_network", CustomMultimodalTorchModel) diff --git a/simharness2/models/custom_multimodal_torch_model.py b/simharness2/models/custom_multimodal_torch_model.py new file mode 100644 index 0000000..8020fe8 --- /dev/null +++ b/simharness2/models/custom_multimodal_torch_model.py @@ -0,0 +1,138 @@ +"""Custom model for running on multimodel data""" + +from typing import Dict, List, Tuple + +import torch +from gymnasium.spaces import Space +from ray.rllib.models.modelv2 import restore_original_dimensions +from ray.rllib.models.torch.misc import SlimConv2d, SlimFC, same_padding +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.models.utils import get_activation_fn, get_filter_config +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import override +from ray.rllib.utils.typing import ModelConfigDict, TensorType + + +# FIXME: Move these constants to a more appropriate location. +FIRE_MAP_KEY = "fire_map" +AGENT_POSITION_KEY = "agent_pos" + + +class CustomMultimodalTorchModel(TorchModelV2, torch.nn.Module): + """Custom model for running on multimodel data""" + + def __init__( + self, + obs_space: Space, + action_space: Space, + num_outputs: int, + model_config: ModelConfigDict, + name: str, + ): + """TODO: Add docstring.""" + TorchModelV2.__init__( + self, obs_space, action_space, num_outputs, model_config, name + ) + torch.nn.Module.__init__(self) + + # NOTE: obs_space layout can change if no preprocessor is used. For more details, + # see the `_disable_preprocessor_api` parameter here: + # https://docs.ray.io/en/latest/rllib/rllib-training.html#specifying-experimental-features + original_obs_space: Dict[str, Space] + if self.model_config.get("_disable_preprocessor_api"): + original_obs_space = obs_space + else: + original_obs_space = obs_space.original_space + + # Validate the provided config settings for conv + # TODO: Provide a better default config for conv_filters? + if not self.model_config.get("conv_filters"): + self.model_config["conv_filters"] = get_filter_config( + original_obs_space[FIRE_MAP_KEY].shape + ) + filters = self.model_config["conv_filters"] + assert len(filters) > 0, "Must provide at least 1 entry in `conv_filters`!" + conv_activation = get_activation_fn( + self.model_config.get("conv_activation"), framework="torch" + ) + + # Construct conv model, where fire map is expected input. + # FIXME: Decide if we want "fire_map" to be 2D or 3D. + in_size = original_obs_space[FIRE_MAP_KEY].shape[:2] + in_channels = original_obs_space[FIRE_MAP_KEY].shape[-1] + layers = [] + for out_channels, kernel, stride in filters: + padding, out_size = same_padding(in_size, kernel, stride) + layers.append( + SlimConv2d( + in_channels, + out_channels, + kernel, + stride, + padding, + activation_fn=conv_activation, + ) + ) + in_channels = out_channels + in_size = out_size + + layers.append(torch.nn.Flatten()) + self._conv_model = torch.nn.Sequential(*layers) + + # Construct fc model, where conv model output and position are expected input. + fcnet_activation = get_activation_fn( + self.model_config.get("fcnet_activation"), framework="torch" + ) + in_size = ( + out_size[0] * out_size[1] * out_channels + + original_obs_space[AGENT_POSITION_KEY].shape[0] + ) + fc_layers, fc_layers_value = [], [] + for out_size in self.model_config.get("fcnet_hiddens", []): + fc_layers.append( + SlimFC( + in_size, + out_size, + activation_fn=fcnet_activation, + ) + ) + fc_layers_value.append( + SlimFC( + in_size, + out_size, + activation_fn=fcnet_activation, + ) + ) + in_size = out_size + + fc_layers.append(SlimFC(in_size, num_outputs)) + fc_layers_value.append(SlimFC(in_size, 1)) + self._fc_model = torch.nn.Sequential(*fc_layers) + self._fc_value = torch.nn.Sequential(*fc_layers_value) + + @override(TorchModelV2) + def forward( + self, + input_dict: Dict[str, TensorType], + state: List[TensorType], + seq_lens: TensorType, + ) -> Tuple[TensorType, List[TensorType]]: + # Extract the original observation from the input_dict + if SampleBatch.OBS in input_dict and "obs_flat" in input_dict: + orig_obs = input_dict[SampleBatch.OBS] + else: + orig_obs = restore_original_dimensions( + input_dict[SampleBatch.OBS], self.obs_space, "torch" + ) + + # FIXME: If we want the fire_map in channel-major format, we can simply return + # the correct layout from the harness; should not NEED to transpose here. + fire_map = torch.transpose(orig_obs[FIRE_MAP_KEY], 1, -1) + conv_out = self._conv_model(fire_map) + self._features = torch.cat([conv_out, orig_obs[AGENT_POSITION_KEY]], dim=-1) + out = self._fc_model(self._features) + return out, state + + @override(TorchModelV2) + def value_function(self) -> TensorType: + return self._fc_value(self._features).squeeze(1)