Skip to content

Commit

Permalink
Merge branch '104-implement-multi-agent-version-of-damageawarereactiv…
Browse files Browse the repository at this point in the history
…eharness' into 'dev'

Resolve "Implement Multi Agent version of DamageAwareReactiveHarness"

Closes #104

See merge request fireline/reinforcementlearning/simharness!53

This branch is confirmed to run without errors, assuming a correctly
configured python environment, with:

python main.py -cn test_multimodal_model cli.data_dir=$HOME/.simharness environment=marl_damage_aware environment.disable_env_checking=true
  • Loading branch information
afennelly-mitre committed Apr 30, 2024
2 parents fcfaee7 + 3c17ae0 commit 84aadc6
Show file tree
Hide file tree
Showing 6 changed files with 237 additions and 5 deletions.
7 changes: 7 additions & 0 deletions conf/environment/marl_damage_aware.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
- marl_complex_obs
- _self_

env: simharness2.environments.MultiAgentComplexObsDamageAwareReactiveHarness
env_config:
benchmark_sim: ${.sim}
9 changes: 8 additions & 1 deletion simharness2/agents/agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
from typing import Tuple, Any
from dataclasses import dataclass
from typing import Any, Tuple

import numpy as np

logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
Expand Down Expand Up @@ -52,6 +54,8 @@ class ReactiveAgent:
agent_id: Any # ex: "agent_0", "dozer_0", "handcrew_0", "ff_0", etc.
sim_id: int # should be contained within sim.agents.keys()
initial_position: Tuple[int, int]
# FIXME: Maybe use InitVar since we only need this to build array in post_init.
fire_map_shape: Tuple[int, int]

# Attributes with default values
latest_movement: int = None
Expand All @@ -64,6 +68,9 @@ def __post_init__(self):
self.x, self.y = self.initial_position
self.row, self.col = self.y, self.x

# Create array used to store coords adjacent to "true" mitigations placed by.
self.adj_to_mitigation = np.zeros(self.fire_map_shape, dtype=bool)

@property
def current_position(self) -> Tuple[int, int]:
return self._current_position
Expand Down
5 changes: 4 additions & 1 deletion simharness2/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
MultiAgentComplexObsReactiveHarness,
)
from simharness2.environments.multi_agent_fire_harness import MultiAgentFireHarness

from simharness2.environments.multi_agent_damage_aware_harness import (
MultiAgentComplexObsDamageAwareReactiveHarness,
)

__all__ = [
"FireHarness",
Expand All @@ -17,4 +19,5 @@
"ReactiveHarness",
"DamageAwareReactiveHarness",
"MultiAgentComplexObsReactiveHarness",
"MultiAgentComplexObsDamageAwareReactiveHarness",
]
20 changes: 17 additions & 3 deletions simharness2/environments/fire_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from simharness2.agents.agent import ReactiveAgent
from simharness2.environments.harness import Harness, get_unsupported_attributes

from simharness2.environments import utils as env_utils

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -126,6 +126,9 @@ def __init__(
# If provided, construct the class used to perform reward calculation.
self._setup_reward_cls(reward_cls_partial)

# Indicator flag to determine if fire in the sim can spread diagonally.
self._fire_diagonal_spread = self.sim.config.fire.diagonal_spread

def get_observation_space(self) -> spaces.Space:
"""TODO."""
# NOTE: calling `reshape()` to switch to channel-minor format.
Expand Down Expand Up @@ -261,6 +264,11 @@ def _update_mitigation(self, agent: ReactiveAgent) -> None:
self.sim.update_mitigation([mitigation_update])
agent.mitigation_placed = True

row, col, shape = agent.row, agent.col, self.sim.fire_map.shape
diag_spread = self._fire_diagonal_spread
adj_rows, adj_cols = env_utils.get_adjacent_points(row, col, shape, diag_spread)
agent.adj_to_mitigation[adj_rows, adj_cols] = 1

def _update_agent_position(self, agent: ReactiveAgent) -> None:
"""Update the agent's position on the map by performing the provided movement."""
# Store agent's current position in a temporary variable to avoid overwriting it.
Expand Down Expand Up @@ -421,6 +429,7 @@ def create_agents(
) -> Dict[str, ReactiveAgent]:
"""Create ReactiveAgent object (s) that will interact w/ the FireSimulation."""
agents_dict = {}
fire_map_shape = self.sim.fire_map.shape
# Use the user-provided agent positions to initialize the agents on the map.
if method == "manual":
# NOTE: The provided pos_list must be the same length as the number of agents
Expand All @@ -438,7 +447,12 @@ def create_agents(
agent_ids, pos_list, self._sim_agent_ids
):
x, y = agent_info
agent = ReactiveAgent(agent_str, sim_id, (x, y))
agent = ReactiveAgent(
agent_str,
sim_id,
(x, y),
fire_map_shape,
)
agents_dict[agent_str] = agent
return agents_dict

Expand All @@ -458,7 +472,7 @@ def create_agents(
# Populate the `self.agents` dict with `ReactiveAgent` object (s).
agent_ids = sorted(self._agent_ids, key=lambda x: int(x.split("_")[-1]))
for agent_str, sim_id, loc in zip(agent_ids, self._sim_agent_ids, agent_locs):
agent = ReactiveAgent(agent_str, sim_id, tuple(loc))
agent = ReactiveAgent(agent_str, sim_id, tuple(loc), fire_map_shape)
agents_dict[agent_str] = agent
return agents_dict

Expand Down
141 changes: 141 additions & 0 deletions simharness2/environments/multi_agent_damage_aware_harness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import logging
from typing import Optional, Tuple, TypeVar

import numpy as np
from simfire.sim.simulation import FireSimulation

from simharness2.agents.agent import ReactiveAgent
from simharness2.environments.multi_agent_complex_harness import (
MultiAgentComplexObsReactiveHarness,
)


logger = logging.getLogger(__name__)

AnyFireSimulation = TypeVar("AnyFireSimulation", bound=FireSimulation)


# FIXME: Consider effect of long time horizons, ie. see "max_bench_length" (this is a
# note to discuss this later).
# FIXME: Discuss harness naming convention with team - I hate how long this is...
class MultiAgentComplexObsDamageAwareReactiveHarness(
MultiAgentComplexObsReactiveHarness[AnyFireSimulation]
):
def __init__(
self,
*,
terminate_if_greater_damage: bool = True,
max_bench_length: int = 600,
**kwargs,
):
super().__init__(**kwargs)
# Bool to toggle the ability to terminate the agent simulation early if at the
# current timestep of the agent simulation, the agents have caused more burn
# damage (burned + burning) than the final state of the benchmark fire map.
self._terminate_if_greater_damage = terminate_if_greater_damage

# TODO: Define `benchmark_sim` here, and make it required (ie. never None).
# Store the firemaps from the benchmark simulation if used in the state
if self.benchmark_sim is None:
# The benchmark_sim is required for rewards and termination!
raise ValueError(
"The benchmark simulation must be provided to use this harness."
)
else:
# Create static list to store the episode benchsim firemaps
self._max_bench_length = max_bench_length
# FIXME: Suboptimal implementation and should be refactored later.
self._bench_firemaps = [0] * self._max_bench_length

def _should_terminate(self) -> bool:
# Retrieve original value, based on `FireHarness` definition of terminated.
terminated = super()._should_terminate()

# Terminate episode early if burn damage in Agent Sim > final bench firemap
if self.benchmark_sim:
if self._terminate_if_greater_damage:
# breakpoint()
total_area = self.sim.fire_map.size

sim_data = self.harness_analytics.sim_analytics.data
sim_damaged_total = sim_data.burned + sim_data.burning
benchsim_data = self.harness_analytics.benchmark_sim_analytics.data
benchsim_damaged_total = total_area - benchsim_data.unburned

logger.debug(f"sim_damaged_total: {sim_damaged_total}")
logger.debug(f"benchsim_damaged_total: {benchsim_damaged_total}")
if sim_damaged_total > benchsim_damaged_total:
# TODO: add static negative penalty for making the fire worse?
logger.info(
"Terminating episode early because the agents have caused more "
"burn damage than the final state of the benchmark fire map."
)
terminated = True

return terminated

def reset(
self,
*,
seed: Optional[int] = None,
options: Optional[dict] = None,
) -> Tuple[np.ndarray, dict]:
"""Reset environment to an initial state, returning an initial obs and info."""
# Use the following line to seed `self.np_random`
initial_state, infos = super().reset(seed=seed, options=options)

# Verify initial fire scenario is the same for both the sim and benchsim.
sim_fire_init_pos = self.sim.config.fire.fire_initial_position
benchmark_fire_init_pos = self.benchmark_sim.config.fire.fire_initial_position
if sim_fire_init_pos != benchmark_fire_init_pos:
raise ValueError(
"The initial fire scenario for the simulation and benchmark simulation "
"must be the same."
)

# TODO: Only call _run_benchmark if fire scenario differs from previous episode.
# This is somewhat tricky - must ensure that analytics.data is NOT reset!
# Run new benchsim to completion to obtain data for reward and policy.
self.benchmark_sim.reset()
# NOTE: The call below will do a few things:
# - Run bench sim to completion (self.benchmark_sim.run(1))
# - Update bench sim analytics (update_bench_after_one_simulation_step)
# - Store each bench fire map at the sim step in self._bench_firemaps
self._run_benchmark()

return initial_state, infos

def _run_benchmark(self):
"""Run benchmark sim until fire propagation is complete."""
# TODO: We can remove the benchmark_sim entirely, and get this behavior by simply
# running self.sim until it terminates, then reset self.sim to the initial state.

if self.benchmark_sim.elapsed_steps > 0:
raise RuntimeError(
"Benchmark simulation must be reset before running it again."
)

timesteps = 0
while self.benchmark_sim.active:
run_sim = timesteps % self.agent_speed == 0

if run_sim:
# TODO: Refactor logic into a method, and call it here.
# Run for one timestep, then update respective metrics.
self.benchmark_sim.run(1)
# FIXME: This method call is VERY redundant (see method logic)
self.harness_analytics.update_bench_after_one_simulation_step(
timestep=timesteps
)

curr_step = self.harness_analytics.benchmark_sim_analytics.num_sim_steps
# Store the bench fire map at the sim step
if curr_step < self._max_bench_length - 1:
self._bench_firemaps[curr_step - 1] = np.copy(
self.benchmark_sim.fire_map
)
else:
self._bench_firemaps.append(np.copy(self.benchmark_sim.fire_map))
self._max_bench_length += 1

timesteps += 1
60 changes: 60 additions & 0 deletions simharness2/environments/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Tuple

import numpy as np


def get_adjacent_points(
row: int, col: int, shape: Tuple[int, int], include_diagonals: bool = True
) -> Tuple[np.ndarray, np.ndarray]:
"""Return points adjacent to the provided point (excluding point itself).
The current implementation considers the 4 cardinal directions (N, S, E, W) as
adjacent points. If `include_diagonals` is set to True, the diagonal points
(NE, NW, SE, SW) are also considered as adjacent points.
Arguments:
row: The row index of the current point.
col: The column index of the current point.
shape: A 2-tuple representing the shape of the map.
include_diagonals: A boolean indicating whether to include diagonal points
as adjacent points. Defaults to True.
Returns:
A tuple containing two numpy arrays, adjacent rows and adjacent columns. The
returned arrays can be used as an advanced index to access the adjacent
points, ex: `fire_map[adj_rows, adj_cols]`.
"""
# TODO: Logic below is copied from a method in simfire, namely
# simfire.game.managers.fire.FireManager._get_new_locs(). It would be good to
# refactor this logic into a utility function in simfire, and then call it here.
x, y = col, row
# Generate all possible adjacent points around the current point.
if include_diagonals:
new_locs = (
(x + 1, y),
(x + 1, y + 1),
(x, y + 1),
(x - 1, y + 1),
(x - 1, y),
(x - 1, y - 1),
(x, y - 1),
(x + 1, y - 1),
)
else:
new_locs = (
(x + 1, y),
(x, y + 1),
(x - 1, y),
(x, y - 1),
)

col_coords, row_coords = zip(*new_locs)
adj_array = np.array([row_coords, col_coords], dtype=np.int32)

# Clip the adjacent points to ensure they are within the map boundaries
row_max, col_max = [dim - 1 for dim in shape]
adj_array = np.clip(adj_array, a_min=[[0], [0]], a_max=[[row_max], [col_max]])
# Remove the point itself from the list of adjacent points, if it exists.
adj_array = adj_array[:, ~np.all(adj_array == [[row], [col]], axis=0)]

return adj_array[0], adj_array[1]

0 comments on commit 84aadc6

Please sign in to comment.