diff --git a/bsk_rl/envs/general_satellite_tasking/scenario/satellites.py b/bsk_rl/envs/general_satellite_tasking/scenario/satellites.py index 6308277d..6fd80714 100644 --- a/bsk_rl/envs/general_satellite_tasking/scenario/satellites.py +++ b/bsk_rl/envs/general_satellite_tasking/scenario/satellites.py @@ -270,63 +270,59 @@ def set_action(self, action: int) -> None: pass -class ImagingSatellite(Satellite): - dyn_type = dynamics.ImagingDynModel - fsw_type = fsw.ImagingFSWModel - +class AccessSatellite(Satellite): def __init__( self, - name: str, - sat_args: dict[str, Any], *args, generation_duration: float = 60 * 95 / 10, initial_generation_duration: Optional[float] = None, - target_dist_threshold: float = 1e6, + access_dist_threshold: float = 1e6, **kwargs, ) -> None: - """Satellite with agile imaging capabilities. Ends the simulation when a target - is imaged or missed + """Satellite that can detect access opportunities for ground locations with + elevation constraints. Args: - name: Satellite.name - sat_args: Satellite.sat_args - n_ahead_observe: Number of upcoming targets to include in observations. - n_ahead_act: Number of upcoming targets to include in actions. generation_duration: Duration to calculate additional imaging windows for when windows are exhausted. If `None`, generate for the simulation `time_limit` unless the simulation is infinite. [s] initial_generation_duration: Duration to initially calculate imaging windows [s] - target_dist_threshold: Distance bound [m] for evaluating imaging windows + access_dist_threshold: Distance bound [m] for evaluating imaging windows more exactly. """ - super().__init__(name, sat_args, *args, **kwargs) + super().__init__(*args, **kwargs) self.generation_duration = generation_duration self.initial_generation_duration = initial_generation_duration - self.min_elev = sat_args["imageTargetMinimumElevation"] # Used for window calcs - self.target_dist_threshold = target_dist_threshold - self.fsw: ImagingSatellite.fsw_type - self.dynamics: ImagingSatellite.dyn_type - self.data_store: UniqueImageStore + self.access_dist_threshold = access_dist_threshold def reset_pre_sim(self) -> None: - """Set the buffer parameters based on computed windows""" super().reset_pre_sim() - self.sat_args["transmitterNumBuffers"] = len( - self.data_store.env_knowledge.targets - ) - self.sat_args["bufferNames"] = [ - target.id for target in self.data_store.env_knowledge.targets - ] self.opportunities: list[dict] = [] self.window_calculation_time = 0 - self.current_action = None - self._image_event_name = None - self.imaged = 0 - self.missed = 0 + self.locations_for_access_checking: list[dict[str, Any]] = [] + + def add_location_for_access_checking( + self, + object: Any, + location: np.ndarray, + min_elev: float, + type: str, + ) -> None: + """Adds a location to be included in window calculations. Note that this + location will only be included in future calls to calculate_additional_windows. + + Args: + object: Object to add window for + location: Objects PCPF location [m] + min_elev: Minimum elevation angle for access [rad] + type: Category of windows to add location to + """ + location_dict = dict(location=location, min_elev=min_elev, type=type) + location_dict[type] = object + self.locations_for_access_checking.append(location_dict) def reset_post_sim(self) -> None: - """Handle initial_generation_duration setting and calculate windows""" super().reset_post_sim() if self.initial_generation_duration is None: if self.simulator.time_limit == float("inf"): @@ -336,8 +332,7 @@ def reset_post_sim(self) -> None: self.calculate_additional_windows(self.initial_generation_duration) def calculate_additional_windows(self, duration: float) -> None: - """Use a multiroot finding method to evaluate imaging windows for each target; - data is saved to self.windows. + """Use a multiroot finding method to evaluate imaging windows for each location. Args: duration: Time to calculate windows from end of previous window. @@ -362,20 +357,28 @@ def calculate_additional_windows(self, duration: float) -> None: times = r_BP_P_interp.x[window_calc_span] positions = r_BP_P_interp.y[window_calc_span] - for target in self.data_store.env_knowledge.targets: + for location in self.locations_for_access_checking: candidate_windows = self._find_candidate_windows( - target.location, times, positions, self.target_dist_threshold + location["location"], times, positions, self.access_dist_threshold ) for candidate_window in candidate_windows: roots = self._find_elevation_roots( - r_BP_P_interp, target.location, self.min_elev, candidate_window + r_BP_P_interp, + location["location"], + location["min_elev"], + candidate_window, ) new_windows = self._refine_window( roots, candidate_window, (times[0], times[-1]) ) for new_window in new_windows: - self._add_window(target, new_window, merge_time=times[0]) + self._add_window( + location[location["type"]], + new_window, + type=location["type"], + merge_time=times[0], + ) self.window_calculation_time = calculation_end @@ -386,6 +389,9 @@ def _find_elevation_roots( min_elev: float, window: tuple[float, float], ): + """Find exact times where the satellite's elevation relative to a target is + equal to the minimum elevation.""" + def root_fn(t): return elevation(position_interp(t), location) - min_elev @@ -421,6 +427,8 @@ def _refine_window( candidate_window: tuple[float, float], computation_window: tuple[float, float], ) -> list[tuple[float, float]]: + """Detect if an exact window has been truncated by the edge of the coarse + window.""" endpoints = list(endpoints) if len(endpoints) % 2 == 1: if candidate_window[0] == computation_window[0]: @@ -438,65 +446,228 @@ def _refine_window( def _add_window( self, - target: Target, + object: Any, new_window: tuple[float, float], + type: str, merge_time: Optional[float] = None, ): """ Args: - target: Target to add window for + object: Object to add window for new_window: New window for target + type: Type of window being added merge_time: Time at which merges with existing windows will occur. If None, check all windows for merges """ if new_window[0] == merge_time or merge_time is None: - for window in self.opportunities: - if window["target"] == target and window["window"][1] == new_window[0]: - window["window"] = (window["window"][0], new_window[1]) + for opportunity in self.opportunities: + if ( + opportunity["type"] == type + and opportunity[type] == object + and opportunity["window"][1] == new_window[0] + ): + opportunity["window"] = (opportunity["window"][0], new_window[1]) return bisect.insort( self.opportunities, - {"target": target, "window": new_window}, + {type: object, "window": new_window, "type": type}, key=lambda x: x["window"][1], ) - @property - def windows(self) -> dict[Target, list[tuple[float, float]]]: - """Access windows via dict of targets -> list of windows""" - windows = {} - for opportunity in self.opportunities: - if opportunity["target"] not in windows: - windows[opportunity["target"]] = [] - windows[opportunity["target"]].append(opportunity["window"]) - return windows - @property def upcoming_opportunities(self) -> list[dict]: - """Subset of opportunities that have not yet closed. Attempts to filter out - known imaged windows if data on imaged windows is accessible.""" + """Ordered list of opportunities that have not yet closed. + + Returns: + list: list of upcoming opportunities + """ start = bisect.bisect_left( self.opportunities, self.simulator.sim_time, key=lambda x: x["window"][1] ) upcoming = self.opportunities[start:] - try: # Attempt to filter already known imaged targets - upcoming = [ - opportunity - for opportunity in upcoming - if opportunity["target"] not in self.data_store.data.imaged - ] - except AttributeError: - pass return upcoming + def opportunities_dict( + self, + types: Optional[Union[str, list[str]]] = None, + filter: list = [], + ) -> dict[Any, list[tuple[float, float]]]: + """Dictionary of opportunities that maps objects to lists of windows. + + Args: + types: Types of opportunities to include. If None, include all types. + filter: Objects to exclude from the dictionary. + + Returns: + windows: objects -> windows list + """ + if isinstance(types, str): + types = [types] + + windows = {} + for opportunity in self.opportunities: + type = opportunity["type"] + if (types is None or type in types) and opportunity[type] not in filter: + if opportunity[type] not in windows: + windows[opportunity[type]] = [] + windows[opportunity[type]].append(opportunity["window"]) + return windows + + def upcoming_opportunities_dict( + self, + types: Optional[Union[str, list[str]]] = None, + filter: list = [], + ) -> dict[Any, list[tuple[float, float]]]: + """Dictionary of opportunities that maps objects to lists of windows that have + not yet closed. + + Args: + types: Types of opportunities to include. If None, include all types. + filter: Objects to exclude from the dictionary. + + Returns: + windows: objects -> windows list (upcoming only) + """ + if isinstance(types, str): + types = [types] + + windows = {} + for opportunity in self.upcoming_opportunities: + type = opportunity["type"] + if (types is None or type in types) and opportunity[type] not in filter: + if opportunity[type] not in windows: + windows[opportunity[type]] = [] + windows[opportunity[type]].append(opportunity["window"]) + return windows + + def next_opportunities_dict( + self, + types: Optional[Union[str, list[str]]] = None, + filter: list = [], + ) -> dict[Any, tuple[float, float]]: + """Dictionary of opportunities that maps objects to the next open window. + + Args: + types: Types of opportunities to include. If None, include all types. + filter: Objects to exclude from the dictionary. + + Returns: + windows: objects -> next window + """ + if isinstance(types, str): + types = [types] + + next_windows = {} + for opportunity in self.upcoming_opportunities: + type = opportunity["type"] + if (types is None or type in types) and opportunity[type] not in filter: + if opportunity[type] not in next_windows: + next_windows[opportunity[type]] = opportunity["window"] + return next_windows + + def find_next_opportunities( + self, + n: int, + pad: bool = True, + max_lookahead: int = 100, + types: Optional[Union[str, list[str]]] = None, + filter: list = [], + ) -> list[dict]: + """Find the n nearest opportunities, sorted by window close time. + + Args: + n: Number of opportunities to attempt to include. + pad: If true, duplicates the last target if the number of opportunities + found is less than n. + max_lookahead: Maximum times to call calculate_additional_windows. + types: Types of opportunities to include. If None, include all types. + filter: Objects to exclude from the dictionary. + + Returns: + list: n nearest opportunities, ordered + """ + if isinstance(types, str): + types = [types] + + if n == 0: + return [] + + for _ in range(max_lookahead): + upcoming_opportunities = self.upcoming_opportunities + next_opportunities = [] + for opportunity in upcoming_opportunities: + type = opportunity["type"] + if (types is None or type in types) and opportunity[type] not in filter: + next_opportunities.append(opportunity) + + if len(next_opportunities) >= n: + return next_opportunities + self.calculate_additional_windows(self.generation_duration) + if pad: + next_opportunities += [next_opportunities[-1]] * ( + n - len(next_opportunities) + ) + return next_opportunities + + +class ImagingSatellite(AccessSatellite): + dyn_type = dynamics.ImagingDynModel + fsw_type = fsw.ImagingFSWModel + + def __init__( + self, + *args, + **kwargs, + ) -> None: + """Satellite with agile imaging capabilities. Can stop the simulation when a + target is imaged or missed. + """ + super().__init__(*args, **kwargs) + self.fsw: ImagingSatellite.fsw_type + self.dynamics: ImagingSatellite.dyn_type + self.data_store: UniqueImageStore + + def reset_pre_sim(self) -> None: + """Set the buffer parameters based on computed windows""" + super().reset_pre_sim() + self.sat_args["transmitterNumBuffers"] = len( + self.data_store.env_knowledge.targets + ) + self.sat_args["bufferNames"] = [ + target.id for target in self.data_store.env_knowledge.targets + ] + self._image_event_name = None + self.imaged = 0 + self.missed = 0 + + def reset_post_sim(self) -> None: + """Handle initial_generation_duration setting and calculate windows""" + for target in self.data_store.env_knowledge.targets: + self.add_location_for_access_checking( + object=target, + location=target.location, + min_elev=self.sat_args["imageTargetMinimumElevation"], + type="target", + ) + super().reset_post_sim() + + def _get_imaged_filter(self): + try: + return self.data_store.data.imaged + except AttributeError: + return [] + + @property + def windows(self) -> dict[Target, list[tuple[float, float]]]: + """Access windows via dict of targets -> list of windows""" + return self.opportunities_dict(types="target", filter=self._get_imaged_filter()) + @property def upcoming_windows(self) -> dict[Target, list[tuple[float, float]]]: """Access upcoming windows in a dict of targets -> list of windows.""" - windows = {} - for window in self.upcoming_opportunities: - if window["target"] not in windows: - windows[window["target"]] = [] - windows[window["target"]].append(window["window"]) - return windows + return self.upcoming_opportunities_dict( + types="target", filter=self._get_imaged_filter() + ) @property def next_windows(self) -> dict[Target, tuple[float, float]]: @@ -505,11 +676,9 @@ def next_windows(self) -> dict[Target, tuple[float, float]]: Returns: dict: first non-closed window for each target """ - next_windows = {} - for opportunity in self.upcoming_opportunities: - if opportunity["target"] not in next_windows: - next_windows[opportunity["target"]] = opportunity["window"] - return next_windows + return self.next_opportunities_dict( + types="target", filter=self._get_imaged_filter() + ) def upcoming_targets( self, n: int, pad: bool = True, max_lookahead: int = 100 @@ -526,18 +695,16 @@ def upcoming_targets( Returns: list: n nearest targets, ordered """ - if n == 0: - return [] - for _ in range(max_lookahead): - soonest = self.upcoming_opportunities - if len(soonest) < n: - self.calculate_additional_windows(self.generation_duration) - else: - break - targets = [opportunity["target"] for opportunity in soonest[0:n]] - if pad: - targets += [targets[-1]] * (n - len(targets)) - return targets + return [ + opportunity["target"] + for opportunity in self.find_next_opportunities( + n=n, + pad=pad, + max_lookahead=max_lookahead, + filter=self._get_imaged_filter(), + types="target", + ) + ] def _update_image_event(self, target: Target) -> None: """Create a simulator event that causes the simulation to stop when a target is diff --git a/tests/integration/envs/general_satellite_tasking/scenario/test_int_environment_features.py b/tests/integration/envs/general_satellite_tasking/scenario/test_int_environment_features.py index 50116c71..fe3408b5 100644 --- a/tests/integration/envs/general_satellite_tasking/scenario/test_int_environment_features.py +++ b/tests/integration/envs/general_satellite_tasking/scenario/test_int_environment_features.py @@ -38,6 +38,7 @@ class ImageSat( time_limit=5700.0, max_step_duration=1e9, disable_env_checker=True, + failure_penalty=0, ) return env diff --git a/tests/unittest/envs/general_satellite_tasking/scenario/test_sat_observations.py b/tests/unittest/envs/general_satellite_tasking/scenario/test_sat_observations.py index 0944064a..ff34ffd4 100644 --- a/tests/unittest/envs/general_satellite_tasking/scenario/test_sat_observations.py +++ b/tests/unittest/envs/general_satellite_tasking/scenario/test_sat_observations.py @@ -151,7 +151,7 @@ def test_target_state_normed(self, sat_init): ] ) sat.opportunities = [ - dict(target=target, window=(10.0, 20.0)) + dict(target=target, window=(10.0, 20.0), type="target") for target in sat.upcoming_targets() ] sat.simulator = MagicMock(sim_time=5.0) diff --git a/tests/unittest/envs/general_satellite_tasking/scenario/test_satellites.py b/tests/unittest/envs/general_satellite_tasking/scenario/test_satellites.py index 724fe6ae..aa9ae38f 100644 --- a/tests/unittest/envs/general_satellite_tasking/scenario/test_satellites.py +++ b/tests/unittest/envs/general_satellite_tasking/scenario/test_satellites.py @@ -149,6 +149,269 @@ def test_proxy_setters(self): assert sat.dynamics == mock_dyn +@patch( + "bsk_rl.envs.general_satellite_tasking.scenario.satellites.Satellite.__init__", + MagicMock(), +) +@patch( + "bsk_rl.envs.general_satellite_tasking.utils.orbital.elevation", lambda x, y: y - x +) +@patch.multiple(sats.AccessSatellite, __abstractmethods__=set()) +class TestAccessSatellite: + def make_sat(self): + return sats.AccessSatellite( + "TestSat", + sat_args={"imageTargetMinimumElevation": 1}, + ) + + def test_add_location_for_access_checking(self): + sat = self.make_sat() + sat.locations_for_access_checking = [] + target = MagicMock() + sat.add_location_for_access_checking( + object=target, location=[0, 0, 0], min_elev=1.0, type="target" + ) + assert ( + dict(target=target, location=[0, 0, 0], min_elev=1.0, type="target") + in sat.locations_for_access_checking + ) + + @pytest.mark.parametrize("start", [0.0, 100.0]) + @pytest.mark.parametrize("duration", [0.0, 20.0, 500.0]) + @pytest.mark.parametrize("traj_dt", [30.0, 200.0]) + @pytest.mark.parametrize("generation_duration", [60.0, 100.0]) + def test_calculate_windows_duration( + self, start, duration, traj_dt, generation_duration + ): + sat = self.make_sat() + sat.window_calculation_time = start + sat.generation_duration = generation_duration + sat.trajectory = MagicMock( + dt=traj_dt, + r_BP_P=MagicMock( + x=np.linspace(0, start + duration), y=np.linspace(0, start + duration) + ), + ) + sat.locations_for_access_checking = [] + sat.calculate_additional_windows(duration) + if duration == 0.0: + return + assert sat.trajectory.extend_to.call_args[0][0] >= start + duration + assert sat.trajectory.extend_to.call_args[0][0] - start >= traj_dt * 2 + + def test_calculate_windows(self): + tgt = Target("tgt_0", location=[0.0, 0.0, 1.0], priority=1.0) + sat = self.make_sat() + sat.window_calculation_time = 0.0 + sat.opportunities = [] + sat.access_dist_threshold = 5.0 + sat.trajectory = MagicMock( + dt=2.0, + r_BP_P=MagicMock( + x=np.arange(0, 100, 2), + y=np.array([[t - 50.0, 0.0, 2.0] for t in np.arange(0, 100, 2)]), + side_effect=( # noqa: E731 + lambda t: np.array([[ti - 50.0, 0.0, 2.0] for ti in t]) + if isinstance(t, Iterable) + else np.array([t - 50.0, 0.0, 2.0]) + ), + ), + ) + sat.locations_for_access_checking = [ + dict(target=tgt, type="target", min_elev=1.3, location=tgt.location) + ] + sat.calculate_additional_windows(100.0) + assert tgt in sat.opportunities_dict() + assert sat.opportunities[0]["window"][0] == approx( + 50 - 0.27762037530835193, abs=1e-2 + ) + assert sat.opportunities[0]["window"][1] == approx( + 50 + 0.27762037530835193, abs=1e-2 + ) + + def test_find_elevation_roots(self): + interp = ( # noqa: E731 + lambda t: np.array([[ti, 0.0, 2.0] for ti in t]) + if isinstance(t, Iterable) + else np.array([t, 0.0, 2.0]) + ) + loc = np.array([0.0, 0.0, 1.0]) + elev = 1.3 + times = sats.ImagingSatellite._find_elevation_roots(interp, loc, elev, (-1, 1)) + assert len(times) == 2 + assert times[0] == approx(-times[1], abs=1e-5) + assert times[1] == approx(0.27762037530835193, abs=1e-5) + times = sats.ImagingSatellite._find_elevation_roots(interp, loc, elev, (0, 1)) + assert len(times) == 1 + assert times[0] == approx(0.27762037530835193, abs=1e-5) + + @pytest.mark.parametrize( + "location,times,positions,threshold,expected", + [ + ( + np.array([2.5, 0.0]), + np.array([0.0, 10.0, 20.0, 30.0]), + np.array([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0]]), + 1.0, + [(10.0, 30.0)], + ), + ( + np.array([2.5, 0.0]), + np.array([0.0, 10.0, 20.0, 30.0, 40.0]), + np.array([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0], [4.0, 0.0]]), + 1.0, + [(10.0, 40.0)], + ), + ( + np.array([0.5, 0.0]), + np.array([0.0, 10.0, 20.0, 30.0]), + np.array([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0]]), + 1.0, + [(0.0, 20.0)], + ), + ( + np.array([1.2, 0.0]), + np.array([0.0, 10.0, 20.0, 30.0]), + np.array([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0]]), + 5.0, + [(0.0, 30.0)], + ), + ( + np.array([2.5, 100.0]), + np.array([0.0, 10.0, 20.0, 30.0]), + np.array([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0]]), + 1.0, + [], + ), + ( + np.array([-0.1, 0.0]), + np.array([0.0, 10.0, 20.0, 30.0, 40.0]), + np.array([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0], [-1.0, 0.0]]), + 1.0, + [(0.0, 10.0), (30.0, 40.0)], + ), + ], + ) + def test_find_candidate_windows( + self, location, times, positions, threshold, expected + ): + assert ( + sats.ImagingSatellite._find_candidate_windows( + location, times, positions, threshold + ) + == expected + ) + + @pytest.mark.parametrize( + "endpoints,candidate_window,computation_window,expected", + [ + ([2.4, 14.6], (0.0, 20.0), (0.0, 30.0), [(2.4, 14.6)]), + ([12.4], (0.0, 20.0), (0.0, 30.0), [(0.0, 12.4)]), + ([12.4], (10.0, 30.0), (0.0, 30.0), [(12.4, 30.0)]), + ([2.4, 14.6, 18.8], (0.0, 20.0), (0.0, 30.0), [(0.0, 2.4), (14.6, 18.8)]), + ( + [2.4, 14.6, 18.8, 19.3], + (0.0, 20.0), + (0.0, 30.0), + [(2.4, 14.6), (18.8, 19.3)], + ), + ], + ) + def test_refine_windows( + self, endpoints, candidate_window, computation_window, expected + ): + assert ( + sats.ImagingSatellite._refine_window( + endpoints, candidate_window, computation_window + ) + == expected + ) + + def test_refine_windows_impossible(self): + with pytest.raises(ValueError): + sats.ImagingSatellite._refine_window( + [1.0, 2.0, 3.0], (0.0, 4.0), (0.5, 3.5) + ) + + tgt0 = Target("tgt_0", location=[0.0, 0.0, 0.0], priority=1.0) + tgt1 = Target("tgt_1", location=[0.0, 0.0, 0.0], priority=1.0) + tgt2 = Target("tgt_2", location=[0.0, 0.0, 0.0], priority=1.0) + + @pytest.mark.parametrize( + "merge_time", + [None, 10.0], + ) + @pytest.mark.parametrize( + "tgt,window,expected_window", + [ + (tgt0, (13.0, 18.0), (13.0, 18.0)), + (tgt2, (13.0, 18.0), (13.0, 18.0)), + (tgt0, (10.0, 18.0), (2.0, 18.0)), # Check that merging works + ], + ) + def test_add_window(self, merge_time, tgt, window, expected_window): + sat = self.make_sat() + sat.opportunities = [ + dict(target=self.tgt1, window=(3.0, 8.0), type="target"), + dict(target=self.tgt0, window=(2.0, 10.0), type="target"), + ] + sat._add_window(tgt, window, merge_time=merge_time, type="target") + assert expected_window in sat.opportunities_dict()[tgt] + + opportunities = [ + dict(downlink="downObj1", window=(10, 20), type="downlink"), + dict(target="tgtObj1", window=(20, 30), type="target"), + dict(downlink="downObj1", window=(30, 40), type="downlink"), + dict(downlink="downObj2", window=(35, 45), type="downlink"), + ] + + def test_upcoming_opportunities(self): + sat = self.make_sat() + sat.opportunities = self.opportunities + sat.simulator = MagicMock(sim_time=25.0) + assert sat.upcoming_opportunities == self.opportunities[1:4] + + def test_opportunities_dict(self): + sat = self.make_sat() + sat.opportunities = self.opportunities + assert sat.opportunities_dict(types="target") == dict(tgtObj1=[(20, 30)]) + assert sat.opportunities_dict(types=None) == sat.opportunities_dict( + types=["target", "downlink"] + ) + assert sat.opportunities_dict(types="downlink", filter=["downObj1"]) == dict( + downObj2=[(35, 45)] + ) + + def test_upcoming_opportunities_dict(self): + sat = self.make_sat() + sat.opportunities = self.opportunities + sat.simulator = MagicMock(sim_time=35.0) + assert sat.upcoming_opportunities_dict(types="target") == {} + assert sat.upcoming_opportunities_dict( + types=None + ) == sat.upcoming_opportunities_dict(types=["target", "downlink"]) + assert sat.upcoming_opportunities_dict( + types="downlink", filter=["downObj2"] + ) == dict(downObj1=[(30, 40)]) + + def test_next_opportunities_dict(self): + sat = self.make_sat() + sat.opportunities = self.opportunities + sat.simulator = MagicMock(sim_time=15.0) + assert sat.next_opportunities_dict() == dict( + downObj1=(10, 20), tgtObj1=(20, 30), downObj2=(35, 45) + ) + assert sat.next_opportunities_dict(types="downlink") == dict( + downObj1=(10, 20), downObj2=(35, 45) + ) + assert sat.next_opportunities_dict(filter=["downObj1"]) == dict( + tgtObj1=(20, 30), downObj2=(35, 45) + ) + + def test_find_next_opportunities(self): + pass # Tested in TestImagingSatellite + + @patch("bsk_rl.envs.general_satellite_tasking.scenario.satellites.Satellite.__init__") @patch.multiple(sats.ImagingSatellite, __abstractmethods__=set()) def test_init(mock_init): @@ -197,6 +460,7 @@ def test_reset_post_sim(self, mock_reset, gen_duration, time_limit, expected): sat.calculate_additional_windows = MagicMock() sat.initial_generation_duration = gen_duration sat.simulator = MagicMock(time_limit=time_limit) + sat.data_store = MagicMock() sat.reset_post_sim() mock_reset.assert_called_once() assert sat.initial_generation_duration == expected @@ -204,24 +468,24 @@ def test_reset_post_sim(self, mock_reset, gen_duration, time_limit, expected): tgt0 = Target("tgt_0", location=[0.0, 0.0, 0.0], priority=1.0) tgt1 = Target("tgt_1", location=[0.0, 0.0, 0.0], priority=1.0) - tgt2 = Target("tgt_1", location=[0.0, 0.0, 0.0], priority=1.0) + tgt2 = Target("tgt_2", location=[0.0, 0.0, 0.0], priority=1.0) windows = { tgt0: [(0.0, 10.0), (20.0, 30.0), (40.0, 50.0)], tgt1: [(10.0, 20.0)], tgt2: [(30.0, 40.0)], } - _opportunities = [ - dict(target=tgt0, window=(0.0, 10.0)), - dict(target=tgt1, window=(10.0, 20.0)), - dict(target=tgt0, window=(20.0, 30.0)), - dict(target=tgt2, window=(30.0, 40.0)), - dict(target=tgt0, window=(40.0, 50.0)), + opportunities = [ + dict(target=tgt0, window=(0.0, 10.0), type="target"), + dict(target=tgt1, window=(10.0, 20.0), type="target"), + dict(target=tgt0, window=(20.0, 30.0), type="target"), + dict(target=tgt2, window=(30.0, 40.0), type="target"), + dict(target=tgt0, window=(40.0, 50.0), type="target"), ] def test_upcoming_windows_unfiltered(self): sat = self.make_sat() sat.simulator = MagicMock(sim_time=25.0) - sat.opportunities = self._opportunities + sat.opportunities = self.opportunities assert sat.upcoming_windows == { self.tgt0: [(20.0, 30.0), (40.0, 50.0)], self.tgt2: [(30.0, 40.0)], @@ -230,16 +494,21 @@ def test_upcoming_windows_unfiltered(self): def test_upcoming_windows_filtered(self): sat = self.make_sat() sat.simulator = MagicMock(sim_time=25.0) - sat.opportunities = self._opportunities + sat.opportunities = self.opportunities sat.data_store = MagicMock(data=MagicMock(imaged=[self.tgt0])) assert sat.upcoming_windows == { self.tgt2: [(30.0, 40.0)], } + def test_windows(self): + sat = self.make_sat() + sat.opportunities = self.opportunities + assert sat.windows == self.windows + def test_next_windows(self): sat = self.make_sat() sat.simulator = MagicMock(sim_time=35.0) - sat.opportunities = self._opportunities + sat.opportunities = self.opportunities assert sat.next_windows == { self.tgt0: (40.0, 50.0), self.tgt2: (30.0, 40.0), @@ -248,19 +517,19 @@ def test_next_windows(self): def test_upcoming_targets(self): sat = self.make_sat() sat.simulator = MagicMock(sim_time=35.0) - sat.opportunities = self._opportunities + sat.opportunities = self.opportunities assert sat.upcoming_targets(2) == [self.tgt2, self.tgt0] def test_no_upcoming_targets(self): sat = self.make_sat() sat.simulator = MagicMock(sim_time=35.0) - sat.opportunities = self._opportunities + sat.opportunities = self.opportunities assert sat.upcoming_targets(0) == [] def test_upcoming_targets_pad(self): sat = self.make_sat() sat.simulator = MagicMock(sim_time=35.0) - sat.opportunities = self._opportunities + sat.opportunities = self.opportunities sat.calculate_additional_windows = MagicMock() assert sat.upcoming_targets(4, pad=True) == [ self.tgt2, @@ -272,12 +541,21 @@ def test_upcoming_targets_pad(self): def test_upcoming_targets_generate_more(self): sat = self.make_sat() sat.simulator = MagicMock(sim_time=35.0) - sat.opportunities = self._opportunities + sat.opportunities = self.opportunities sat.calculate_additional_windows = MagicMock( - side_effect=lambda t: self._opportunities.append( - dict(target=self.tgt1, window=(60.0, 70.0)) + side_effect=lambda t: self.opportunities.append( + dict(target=self.tgt1, window=(60.0, 70.0), type="target") ) ) + print(sat.opportunities) + print( + sat.upcoming_targets(3, pad=True), + [ + self.tgt2, + self.tgt0, + self.tgt1, + ], + ) assert sat.upcoming_targets(3, pad=True) == [ self.tgt2, self.tgt0, @@ -343,7 +621,7 @@ def test_parse_target_selection_invalid(self): def test_task_target_for_imaging(self): sat = self.make_sat() - sat.opportunities = self._opportunities + sat.opportunities = self.opportunities sat.name = "Sat" sat.fsw = MagicMock() sat.simulator = MagicMock(sim_time=35.0) @@ -358,202 +636,3 @@ def test_task_target_for_imaging(self): assert sat._update_image_event.call_args[0][0] == self.tgt0 sat._update_timed_terminal_event.assert_called_once() assert sat._update_timed_terminal_event.call_args[0][0] == 50.0 - - -@patch( - "bsk_rl.envs.general_satellite_tasking.scenario.satellites.Satellite.__init__", - MagicMock(), -) -@patch( - "bsk_rl.envs.general_satellite_tasking.utils.orbital.elevation", lambda x, y: y - x -) -@patch.multiple(sats.ImagingSatellite, __abstractmethods__=set()) -class TestCalculateWindows: - def make_sat(self): - return sats.ImagingSatellite( - "TestSat", - sat_args={"imageTargetMinimumElevation": 1}, - ) - - @pytest.mark.parametrize("start", [0.0, 100.0]) - @pytest.mark.parametrize("duration", [0.0, 20.0, 500.0]) - @pytest.mark.parametrize("traj_dt", [30.0, 200.0]) - @pytest.mark.parametrize("generation_duration", [60.0, 100.0]) - def test_calculate_windows_duration( - self, start, duration, traj_dt, generation_duration - ): - sat = self.make_sat() - sat.window_calculation_time = start - sat.generation_duration = generation_duration - sat.trajectory = MagicMock( - dt=traj_dt, - r_BP_P=MagicMock( - x=np.linspace(0, start + duration), y=np.linspace(0, start + duration) - ), - ) - sat.data_store = MagicMock() - sat.data_store.env_knowledge.targets = [] - sat.calculate_additional_windows(duration) - if duration == 0.0: - return - assert sat.trajectory.extend_to.call_args[0][0] >= start + duration - assert sat.trajectory.extend_to.call_args[0][0] - start >= traj_dt * 2 - - def test_calculate_windows(self): - tgt = Target("tgt_0", location=[0.0, 0.0, 1.0], priority=1.0) - sat = self.make_sat() - sat.window_calculation_time = 0.0 - sat.opportunities = [] - sat.min_elev = 1.3 - sat.target_dist_threshold = 5.0 - sat.trajectory = MagicMock( - dt=2.0, - r_BP_P=MagicMock( - x=np.arange(0, 100, 2), - y=np.array([[t - 50.0, 0.0, 2.0] for t in np.arange(0, 100, 2)]), - side_effect=( # noqa: E731 - lambda t: np.array([[ti - 50.0, 0.0, 2.0] for ti in t]) - if isinstance(t, Iterable) - else np.array([t - 50.0, 0.0, 2.0]) - ), - ), - ) - sat.data_store = MagicMock() - sat.data_store.env_knowledge.targets = [tgt] - sat.calculate_additional_windows(100.0) - assert tgt in sat.windows - assert sat.opportunities[0]["window"][0] == approx( - 50 - 0.27762037530835193, abs=1e-2 - ) - assert sat.opportunities[0]["window"][1] == approx( - 50 + 0.27762037530835193, abs=1e-2 - ) - - def test_find_elevation_roots(self): - interp = ( # noqa: E731 - lambda t: np.array([[ti, 0.0, 2.0] for ti in t]) - if isinstance(t, Iterable) - else np.array([t, 0.0, 2.0]) - ) - loc = np.array([0.0, 0.0, 1.0]) - elev = 1.3 - times = sats.ImagingSatellite._find_elevation_roots(interp, loc, elev, (-1, 1)) - assert len(times) == 2 - assert times[0] == approx(-times[1], abs=1e-5) - assert times[1] == approx(0.27762037530835193, abs=1e-5) - times = sats.ImagingSatellite._find_elevation_roots(interp, loc, elev, (0, 1)) - assert len(times) == 1 - assert times[0] == approx(0.27762037530835193, abs=1e-5) - - @pytest.mark.parametrize( - "location,times,positions,threshold,expected", - [ - ( - np.array([2.5, 0.0]), - np.array([0.0, 10.0, 20.0, 30.0]), - np.array([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0]]), - 1.0, - [(10.0, 30.0)], - ), - ( - np.array([2.5, 0.0]), - np.array([0.0, 10.0, 20.0, 30.0, 40.0]), - np.array([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0], [4.0, 0.0]]), - 1.0, - [(10.0, 40.0)], - ), - ( - np.array([0.5, 0.0]), - np.array([0.0, 10.0, 20.0, 30.0]), - np.array([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0]]), - 1.0, - [(0.0, 20.0)], - ), - ( - np.array([1.2, 0.0]), - np.array([0.0, 10.0, 20.0, 30.0]), - np.array([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0]]), - 5.0, - [(0.0, 30.0)], - ), - ( - np.array([2.5, 100.0]), - np.array([0.0, 10.0, 20.0, 30.0]), - np.array([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0]]), - 1.0, - [], - ), - ( - np.array([-0.1, 0.0]), - np.array([0.0, 10.0, 20.0, 30.0, 40.0]), - np.array([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0], [-1.0, 0.0]]), - 1.0, - [(0.0, 10.0), (30.0, 40.0)], - ), - ], - ) - def test_find_candidate_windows( - self, location, times, positions, threshold, expected - ): - assert ( - sats.ImagingSatellite._find_candidate_windows( - location, times, positions, threshold - ) - == expected - ) - - @pytest.mark.parametrize( - "endpoints,candidate_window,computation_window,expected", - [ - ([2.4, 14.6], (0.0, 20.0), (0.0, 30.0), [(2.4, 14.6)]), - ([12.4], (0.0, 20.0), (0.0, 30.0), [(0.0, 12.4)]), - ([12.4], (10.0, 30.0), (0.0, 30.0), [(12.4, 30.0)]), - ([2.4, 14.6, 18.8], (0.0, 20.0), (0.0, 30.0), [(0.0, 2.4), (14.6, 18.8)]), - ( - [2.4, 14.6, 18.8, 19.3], - (0.0, 20.0), - (0.0, 30.0), - [(2.4, 14.6), (18.8, 19.3)], - ), - ], - ) - def test_refine_windows( - self, endpoints, candidate_window, computation_window, expected - ): - assert ( - sats.ImagingSatellite._refine_window( - endpoints, candidate_window, computation_window - ) - == expected - ) - - def test_refine_windows_impossible(self): - with pytest.raises(ValueError): - sats.ImagingSatellite._refine_window( - [1.0, 2.0, 3.0], (0.0, 4.0), (0.5, 3.5) - ) - - tgt0 = Target("tgt_0", location=[0.0, 0.0, 0.0], priority=1.0) - tgt1 = Target("tgt_1", location=[0.0, 0.0, 0.0], priority=1.0) - tgt2 = Target("tgt_2", location=[0.0, 0.0, 0.0], priority=1.0) - - @pytest.mark.parametrize( - "merge_time", - [None, 10.0], - ) - @pytest.mark.parametrize( - "tgt,window,expected_window", - [ - (tgt0, (13.0, 18.0), (13.0, 18.0)), - (tgt2, (13.0, 18.0), (13.0, 18.0)), - (tgt0, (10.0, 18.0), (2.0, 18.0)), # Check that merging works - ], - ) - def test_add_window(self, merge_time, tgt, window, expected_window): - sat = self.make_sat() - sat.opportunities = [ - dict(target=self.tgt1, window=(3.0, 8.0)), - dict(target=self.tgt0, window=(2.0, 10.0)), - ] - sat._add_window(tgt, window, merge_time=merge_time) - assert expected_window in sat.windows[tgt]