diff --git a/bsk_rl/envs/general_satellite_tasking/scenario/environment_features.py b/bsk_rl/envs/general_satellite_tasking/scenario/environment_features.py index 474110cd..23187786 100644 --- a/bsk_rl/envs/general_satellite_tasking/scenario/environment_features.py +++ b/bsk_rl/envs/general_satellite_tasking/scenario/environment_features.py @@ -31,7 +31,11 @@ def __init__(self, name: str, location: Iterable[float], priority: float) -> Non @property def id(self) -> str: """str: Unique human-readable identifier""" - return f"{self.name}_{id(self)}" + try: + return self._id + except AttributeError: + self._id = f"{self.name}_{id(self)}" + return self._id def __hash__(self) -> int: return hash((self.id)) diff --git a/bsk_rl/envs/general_satellite_tasking/scenario/satellites.py b/bsk_rl/envs/general_satellite_tasking/scenario/satellites.py index a6cf276a..25475553 100644 --- a/bsk_rl/envs/general_satellite_tasking/scenario/satellites.py +++ b/bsk_rl/envs/general_satellite_tasking/scenario/satellites.py @@ -1,3 +1,4 @@ +import bisect import inspect from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Iterable, Optional, Union @@ -316,7 +317,7 @@ def reset_pre_sim(self) -> None: self.sat_args["bufferNames"] = [ target.id for target in self.data_store.env_knowledge.targets ] - self.windows = {} + self.opportunities: list[dict] = [] self.window_calculation_time = 0 self.current_action = None self._image_event_name = None @@ -447,60 +448,73 @@ def _add_window( merge_time: Time at which merges with existing windows will occur. If None, check all windows for merges """ - if target in self.windows: - # Merge touching windows - if new_window[0] == merge_time or merge_time is None: - for window in self.windows[target]: - if window[1] == new_window[0]: - self.windows[target].remove(window) - window = (window[0], new_window[1]) - self.windows[target].append(window) - return - self.windows[target].append(new_window) - else: - self.windows[target] = [new_window] + if new_window[0] == merge_time or merge_time is None: + for window in self.opportunities: + if window["target"] == target and window["window"][1] == new_window[0]: + window["window"] = (window["window"][0], new_window[1]) + return + bisect.insort( + self.opportunities, + {"target": target, "window": new_window}, + key=lambda x: x["window"][1], + ) @property - def upcoming_windows(self) -> dict[Target, list[tuple[float, float]]]: - """Subset of windows that have not yet closed. Attempts to filter out known - imaged windows if data is accessible + def windows(self) -> dict[Target, list[tuple[float, float]]]: + """Access windows via dict of targets -> list of windows""" + windows = {} + for opportunity in self.opportunities: + if opportunity["target"] not in windows: + windows[opportunity["target"]] = [] + windows[opportunity["target"]].append(opportunity["window"]) + return windows - Returns: - filtered windows - """ + @property + def upcoming_opportunities(self) -> list[dict]: + """Subset of opportunities that have not yet closed. Attempts to filter out + known imaged windows if data on imaged windows is accessible.""" + start = bisect.bisect_left( + self.opportunities, self.simulator.sim_time, key=lambda x: x["window"][1] + ) + upcoming = self.opportunities[start:] try: # Attempt to filter already known imaged targets - return { - tgt: [ - window for window in windows if window[1] > self.simulator.sim_time - ] - for tgt, windows in self.windows.items() - if any(window[1] > self.simulator.sim_time for window in windows) - and tgt not in self.data_store.data.imaged - } + upcoming = [ + opportunity + for opportunity in upcoming + if opportunity["target"] not in self.data_store.data.imaged + ] except AttributeError: - return { - tgt: [ - window for window in windows if window[1] > self.simulator.sim_time - ] - for tgt, windows in self.windows.items() - if any(window[1] > self.simulator.sim_time for window in windows) - } + pass + return upcoming + + @property + def upcoming_windows(self) -> dict[Target, list[tuple[float, float]]]: + """Access upcoming windows in a dict of targets -> list of windows.""" + windows = {} + for window in self.upcoming_opportunities: + if window["target"] not in windows: + windows[window["target"]] = [] + windows[window["target"]].append(window["window"]) + return windows @property def next_windows(self) -> dict[Target, tuple[float, float]]: - """Soonest window for each target + """Soonest window for each target. Returns: dict: first non-closed window for each target """ - return {tgt: windows[0] for tgt, windows in self.upcoming_windows.items()} + next_windows = {} + for opportunity in self.upcoming_opportunities: + if opportunity["target"] not in next_windows: + next_windows[opportunity["target"]] = opportunity["window"] + return next_windows def upcoming_targets( self, n: int, pad: bool = True, max_lookahead: int = 100 ) -> list[Target]: """Find the n nearest targets. Targets are sorted by window close time; - currently open windows are included. Only the first window for a target is - accounted for. + currently open windows are included. Args: n: number of windows to look ahead @@ -514,12 +528,12 @@ def upcoming_targets( if n == 0: return [] for _ in range(max_lookahead): - soonest = sorted(self.next_windows.items(), key=lambda x: x[1][1]) + soonest = self.upcoming_opportunities if len(soonest) < n: self.calculate_additional_windows(self.generation_duration) else: break - targets = [target for target, _ in soonest[0:n]] + targets = [opportunity["target"] for opportunity in soonest[0:n]] if pad: targets += [targets[-1]] * (n - len(targets)) return targets diff --git a/tests/unittest/envs/general_satellite_tasking/scenario/test_sat_observations.py b/tests/unittest/envs/general_satellite_tasking/scenario/test_sat_observations.py index 56ad7cca..0944064a 100644 --- a/tests/unittest/envs/general_satellite_tasking/scenario/test_sat_observations.py +++ b/tests/unittest/envs/general_satellite_tasking/scenario/test_sat_observations.py @@ -150,7 +150,10 @@ def test_target_state_normed(self, sat_init): for i in range(n_ahead) ] ) - sat.windows = {target: [(10.0, 20.0)] for target in sat.upcoming_targets()} + sat.opportunities = [ + dict(target=target, window=(10.0, 20.0)) + for target in sat.upcoming_targets() + ] sat.simulator = MagicMock(sim_time=5.0) expected = dict( diff --git a/tests/unittest/envs/general_satellite_tasking/scenario/test_satellites.py b/tests/unittest/envs/general_satellite_tasking/scenario/test_satellites.py index 27f5a71f..d4a038ef 100644 --- a/tests/unittest/envs/general_satellite_tasking/scenario/test_satellites.py +++ b/tests/unittest/envs/general_satellite_tasking/scenario/test_satellites.py @@ -210,11 +210,18 @@ def test_reset_post_sim(self, mock_reset, gen_duration, time_limit, expected): tgt1: [(10.0, 20.0)], tgt2: [(30.0, 40.0)], } + _opportunities = [ + dict(target=tgt0, window=(0.0, 10.0)), + dict(target=tgt1, window=(10.0, 20.0)), + dict(target=tgt0, window=(20.0, 30.0)), + dict(target=tgt2, window=(30.0, 40.0)), + dict(target=tgt0, window=(40.0, 50.0)), + ] def test_upcoming_windows_unfiltered(self): sat = self.make_sat() sat.simulator = MagicMock(sim_time=25.0) - sat.windows = self.windows + sat.opportunities = self._opportunities assert sat.upcoming_windows == { self.tgt0: [(20.0, 30.0), (40.0, 50.0)], self.tgt2: [(30.0, 40.0)], @@ -223,7 +230,7 @@ def test_upcoming_windows_unfiltered(self): def test_upcoming_windows_filtered(self): sat = self.make_sat() sat.simulator = MagicMock(sim_time=25.0) - sat.windows = self.windows + sat.opportunities = self._opportunities sat.data_store = MagicMock(data=MagicMock(imaged=[self.tgt0])) assert sat.upcoming_windows == { self.tgt2: [(30.0, 40.0)], @@ -232,7 +239,7 @@ def test_upcoming_windows_filtered(self): def test_next_windows(self): sat = self.make_sat() sat.simulator = MagicMock(sim_time=35.0) - sat.windows = self.windows + sat.opportunities = self._opportunities assert sat.next_windows == { self.tgt0: (40.0, 50.0), self.tgt2: (30.0, 40.0), @@ -241,19 +248,19 @@ def test_next_windows(self): def test_upcoming_targets(self): sat = self.make_sat() sat.simulator = MagicMock(sim_time=35.0) - sat.windows = self.windows + sat.opportunities = self._opportunities assert sat.upcoming_targets(2) == [self.tgt2, self.tgt0] def test_no_upcoming_targets(self): sat = self.make_sat() sat.simulator = MagicMock(sim_time=35.0) - sat.windows = self.windows + sat.opportunities = self._opportunities assert sat.upcoming_targets(0) == [] def test_upcoming_targets_pad(self): sat = self.make_sat() sat.simulator = MagicMock(sim_time=35.0) - sat.windows = self.windows + sat.opportunities = self._opportunities sat.calculate_additional_windows = MagicMock() assert sat.upcoming_targets(4, pad=True) == [ self.tgt2, @@ -265,9 +272,11 @@ def test_upcoming_targets_pad(self): def test_upcoming_targets_generate_more(self): sat = self.make_sat() sat.simulator = MagicMock(sim_time=35.0) - sat.windows = self.windows + sat.opportunities = self._opportunities sat.calculate_additional_windows = MagicMock( - side_effect=lambda t: self.windows[self.tgt1].append((60.0, 70.0)) + side_effect=lambda t: self._opportunities.append( + dict(target=self.tgt1, window=(60.0, 70.0)) + ) ) assert sat.upcoming_targets(3, pad=True) == [ self.tgt2, @@ -334,7 +343,7 @@ def test_parse_target_selection_invalid(self): def test_task_target_for_imaging(self): sat = self.make_sat() - sat.windows = self.windows + sat.opportunities = self._opportunities sat.name = "Sat" sat.fsw = MagicMock() sat.simulator = MagicMock(sim_time=35.0) @@ -394,7 +403,7 @@ def test_calculate_windows(self): tgt = Target("tgt_0", location=[0.0, 0.0, 1.0], priority=1.0) sat = self.make_sat() sat.window_calculation_time = 0.0 - sat.windows = {} + sat.opportunities = [] sat.min_elev = 1.3 sat.target_dist_threshold = 5.0 sat.trajectory = MagicMock( @@ -413,8 +422,12 @@ def test_calculate_windows(self): sat.data_store.env_knowledge.targets = [tgt] sat.calculate_additional_windows(100.0) assert tgt in sat.windows - assert sat.windows[tgt][0][0] == approx(50 - 0.27762037530835193, abs=1e-2) - assert sat.windows[tgt][0][1] == approx(50 + 0.27762037530835193, abs=1e-2) + assert sat.opportunities[0]["window"][0] == approx( + 50 - 0.27762037530835193, abs=1e-2 + ) + assert sat.opportunities[0]["window"][1] == approx( + 50 + 0.27762037530835193, abs=1e-2 + ) def test_find_elevation_roots(self): interp = ( # noqa: E731 @@ -522,7 +535,7 @@ def test_refine_windows_impossible(self): tgt0 = Target("tgt_0", location=[0.0, 0.0, 0.0], priority=1.0) tgt1 = Target("tgt_1", location=[0.0, 0.0, 0.0], priority=1.0) - tgt2 = Target("tgt_1", location=[0.0, 0.0, 0.0], priority=1.0) + tgt2 = Target("tgt_2", location=[0.0, 0.0, 0.0], priority=1.0) @pytest.mark.parametrize( "merge_time", @@ -538,6 +551,9 @@ def test_refine_windows_impossible(self): ) def test_add_window(self, merge_time, tgt, window, expected_window): sat = self.make_sat() - sat.windows = {self.tgt0: [(2.0, 10.0)], self.tgt1: [(3.0, 8.0)]} + sat.opportunities = [ + dict(target=self.tgt1, window=(3.0, 8.0)), + dict(target=self.tgt0, window=(2.0, 10.0)), + ] sat._add_window(tgt, window, merge_time=merge_time) assert expected_window in sat.windows[tgt]