-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch '104-implement-multi-agent-version-of-damageawarereactiv…
…eharness' into 'dev' Resolve "Implement Multi Agent version of DamageAwareReactiveHarness" Closes #104 See merge request fireline/reinforcementlearning/simharness!53 This branch is confirmed to run without errors, assuming a correctly configured python environment, with: python main.py -cn test_multimodal_model cli.data_dir=$HOME/.simharness environment=marl_damage_aware environment.disable_env_checking=true
- Loading branch information
Showing
6 changed files
with
237 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
defaults: | ||
- marl_complex_obs | ||
- _self_ | ||
|
||
env: simharness2.environments.MultiAgentComplexObsDamageAwareReactiveHarness | ||
env_config: | ||
benchmark_sim: ${.sim} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
141 changes: 141 additions & 0 deletions
141
simharness2/environments/multi_agent_damage_aware_harness.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
import logging | ||
from typing import Optional, Tuple, TypeVar | ||
|
||
import numpy as np | ||
from simfire.sim.simulation import FireSimulation | ||
|
||
from simharness2.agents.agent import ReactiveAgent | ||
from simharness2.environments.multi_agent_complex_harness import ( | ||
MultiAgentComplexObsReactiveHarness, | ||
) | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
AnyFireSimulation = TypeVar("AnyFireSimulation", bound=FireSimulation) | ||
|
||
|
||
# FIXME: Consider effect of long time horizons, ie. see "max_bench_length" (this is a | ||
# note to discuss this later). | ||
# FIXME: Discuss harness naming convention with team - I hate how long this is... | ||
class MultiAgentComplexObsDamageAwareReactiveHarness( | ||
MultiAgentComplexObsReactiveHarness[AnyFireSimulation] | ||
): | ||
def __init__( | ||
self, | ||
*, | ||
terminate_if_greater_damage: bool = True, | ||
max_bench_length: int = 600, | ||
**kwargs, | ||
): | ||
super().__init__(**kwargs) | ||
# Bool to toggle the ability to terminate the agent simulation early if at the | ||
# current timestep of the agent simulation, the agents have caused more burn | ||
# damage (burned + burning) than the final state of the benchmark fire map. | ||
self._terminate_if_greater_damage = terminate_if_greater_damage | ||
|
||
# TODO: Define `benchmark_sim` here, and make it required (ie. never None). | ||
# Store the firemaps from the benchmark simulation if used in the state | ||
if self.benchmark_sim is None: | ||
# The benchmark_sim is required for rewards and termination! | ||
raise ValueError( | ||
"The benchmark simulation must be provided to use this harness." | ||
) | ||
else: | ||
# Create static list to store the episode benchsim firemaps | ||
self._max_bench_length = max_bench_length | ||
# FIXME: Suboptimal implementation and should be refactored later. | ||
self._bench_firemaps = [0] * self._max_bench_length | ||
|
||
def _should_terminate(self) -> bool: | ||
# Retrieve original value, based on `FireHarness` definition of terminated. | ||
terminated = super()._should_terminate() | ||
|
||
# Terminate episode early if burn damage in Agent Sim > final bench firemap | ||
if self.benchmark_sim: | ||
if self._terminate_if_greater_damage: | ||
# breakpoint() | ||
total_area = self.sim.fire_map.size | ||
|
||
sim_data = self.harness_analytics.sim_analytics.data | ||
sim_damaged_total = sim_data.burned + sim_data.burning | ||
benchsim_data = self.harness_analytics.benchmark_sim_analytics.data | ||
benchsim_damaged_total = total_area - benchsim_data.unburned | ||
|
||
logger.debug(f"sim_damaged_total: {sim_damaged_total}") | ||
logger.debug(f"benchsim_damaged_total: {benchsim_damaged_total}") | ||
if sim_damaged_total > benchsim_damaged_total: | ||
# TODO: add static negative penalty for making the fire worse? | ||
logger.info( | ||
"Terminating episode early because the agents have caused more " | ||
"burn damage than the final state of the benchmark fire map." | ||
) | ||
terminated = True | ||
|
||
return terminated | ||
|
||
def reset( | ||
self, | ||
*, | ||
seed: Optional[int] = None, | ||
options: Optional[dict] = None, | ||
) -> Tuple[np.ndarray, dict]: | ||
"""Reset environment to an initial state, returning an initial obs and info.""" | ||
# Use the following line to seed `self.np_random` | ||
initial_state, infos = super().reset(seed=seed, options=options) | ||
|
||
# Verify initial fire scenario is the same for both the sim and benchsim. | ||
sim_fire_init_pos = self.sim.config.fire.fire_initial_position | ||
benchmark_fire_init_pos = self.benchmark_sim.config.fire.fire_initial_position | ||
if sim_fire_init_pos != benchmark_fire_init_pos: | ||
raise ValueError( | ||
"The initial fire scenario for the simulation and benchmark simulation " | ||
"must be the same." | ||
) | ||
|
||
# TODO: Only call _run_benchmark if fire scenario differs from previous episode. | ||
# This is somewhat tricky - must ensure that analytics.data is NOT reset! | ||
# Run new benchsim to completion to obtain data for reward and policy. | ||
self.benchmark_sim.reset() | ||
# NOTE: The call below will do a few things: | ||
# - Run bench sim to completion (self.benchmark_sim.run(1)) | ||
# - Update bench sim analytics (update_bench_after_one_simulation_step) | ||
# - Store each bench fire map at the sim step in self._bench_firemaps | ||
self._run_benchmark() | ||
|
||
return initial_state, infos | ||
|
||
def _run_benchmark(self): | ||
"""Run benchmark sim until fire propagation is complete.""" | ||
# TODO: We can remove the benchmark_sim entirely, and get this behavior by simply | ||
# running self.sim until it terminates, then reset self.sim to the initial state. | ||
|
||
if self.benchmark_sim.elapsed_steps > 0: | ||
raise RuntimeError( | ||
"Benchmark simulation must be reset before running it again." | ||
) | ||
|
||
timesteps = 0 | ||
while self.benchmark_sim.active: | ||
run_sim = timesteps % self.agent_speed == 0 | ||
|
||
if run_sim: | ||
# TODO: Refactor logic into a method, and call it here. | ||
# Run for one timestep, then update respective metrics. | ||
self.benchmark_sim.run(1) | ||
# FIXME: This method call is VERY redundant (see method logic) | ||
self.harness_analytics.update_bench_after_one_simulation_step( | ||
timestep=timesteps | ||
) | ||
|
||
curr_step = self.harness_analytics.benchmark_sim_analytics.num_sim_steps | ||
# Store the bench fire map at the sim step | ||
if curr_step < self._max_bench_length - 1: | ||
self._bench_firemaps[curr_step - 1] = np.copy( | ||
self.benchmark_sim.fire_map | ||
) | ||
else: | ||
self._bench_firemaps.append(np.copy(self.benchmark_sim.fire_map)) | ||
self._max_bench_length += 1 | ||
|
||
timesteps += 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from typing import Tuple | ||
|
||
import numpy as np | ||
|
||
|
||
def get_adjacent_points( | ||
row: int, col: int, shape: Tuple[int, int], include_diagonals: bool = True | ||
) -> Tuple[np.ndarray, np.ndarray]: | ||
"""Return points adjacent to the provided point (excluding point itself). | ||
The current implementation considers the 4 cardinal directions (N, S, E, W) as | ||
adjacent points. If `include_diagonals` is set to True, the diagonal points | ||
(NE, NW, SE, SW) are also considered as adjacent points. | ||
Arguments: | ||
row: The row index of the current point. | ||
col: The column index of the current point. | ||
shape: A 2-tuple representing the shape of the map. | ||
include_diagonals: A boolean indicating whether to include diagonal points | ||
as adjacent points. Defaults to True. | ||
Returns: | ||
A tuple containing two numpy arrays, adjacent rows and adjacent columns. The | ||
returned arrays can be used as an advanced index to access the adjacent | ||
points, ex: `fire_map[adj_rows, adj_cols]`. | ||
""" | ||
# TODO: Logic below is copied from a method in simfire, namely | ||
# simfire.game.managers.fire.FireManager._get_new_locs(). It would be good to | ||
# refactor this logic into a utility function in simfire, and then call it here. | ||
x, y = col, row | ||
# Generate all possible adjacent points around the current point. | ||
if include_diagonals: | ||
new_locs = ( | ||
(x + 1, y), | ||
(x + 1, y + 1), | ||
(x, y + 1), | ||
(x - 1, y + 1), | ||
(x - 1, y), | ||
(x - 1, y - 1), | ||
(x, y - 1), | ||
(x + 1, y - 1), | ||
) | ||
else: | ||
new_locs = ( | ||
(x + 1, y), | ||
(x, y + 1), | ||
(x - 1, y), | ||
(x, y - 1), | ||
) | ||
|
||
col_coords, row_coords = zip(*new_locs) | ||
adj_array = np.array([row_coords, col_coords], dtype=np.int32) | ||
|
||
# Clip the adjacent points to ensure they are within the map boundaries | ||
row_max, col_max = [dim - 1 for dim in shape] | ||
adj_array = np.clip(adj_array, a_min=[[0], [0]], a_max=[[row_max], [col_max]]) | ||
# Remove the point itself from the list of adjacent points, if it exists. | ||
adj_array = adj_array[:, ~np.all(adj_array == [[row], [col]], axis=0)] | ||
|
||
return adj_array[0], adj_array[1] |