diff --git a/samcli/lib/observability/observability_info_puller.py b/samcli/lib/observability/observability_info_puller.py index e0691a0433..f658195f6e 100644 --- a/samcli/lib/observability/observability_info_puller.py +++ b/samcli/lib/observability/observability_info_puller.py @@ -88,6 +88,9 @@ def load_events(self, event_ids: Union[List[Any], Dict]): List of event ids that will be pulled """ + def stop_tailing(self): + self.cancelled = True + # pylint: disable=fixme # fixme add ABC parent class back once we bump the pylint to a version 2.8.2 or higher @@ -187,8 +190,7 @@ def tail(self, start_time: Optional[datetime] = None, filter_pattern: Optional[s async_context.run_async() except KeyboardInterrupt: LOG.info(" CTRL+C received, cancelling...") - for puller in self._pullers: - puller.cancelled = True + self.stop_tailing() def load_time_period( self, @@ -218,3 +220,9 @@ def load_events(self, event_ids: Union[List[Any], Dict]): async_context.add_async_task(puller.load_events, event_ids) LOG.debug("Running all 'load_time_period' tasks in parallel") async_context.run_async() + + def stop_tailing(self): + # if ObservabilityCombinedPuller A is a child puller in other ObservabilityCombinedPuller B, make sure A's child + # pullers stop as well when B stops. + for puller in self._pullers: + puller.stop_tailing() diff --git a/tests/unit/lib/observability/test_observability_info_puller.py b/tests/unit/lib/observability/test_observability_info_puller.py index 2b3b6b2016..a050533fc9 100644 --- a/tests/unit/lib/observability/test_observability_info_puller.py +++ b/tests/unit/lib/observability/test_observability_info_puller.py @@ -85,8 +85,11 @@ def test_tail_cancel(self, patched_async_context): mock_puller_1 = Mock() mock_puller_2 = Mock() + mock_puller_3 = Mock() - combined_puller = ObservabilityCombinedPuller([mock_puller_1, mock_puller_2]) + child_combined_puller = ObservabilityCombinedPuller([mock_puller_3]) + + combined_puller = ObservabilityCombinedPuller([mock_puller_1, mock_puller_2, child_combined_puller]) given_start_time = Mock() given_filter_pattern = Mock() @@ -97,12 +100,14 @@ def test_tail_cancel(self, patched_async_context): [ call.add_async_task(mock_puller_1.tail, given_start_time, given_filter_pattern), call.add_async_task(mock_puller_2.tail, given_start_time, given_filter_pattern), + call.add_async_task(child_combined_puller.tail, given_start_time, given_filter_pattern), call.run_async(), ] ) - self.assertTrue(mock_puller_1.cancelled) - self.assertTrue(mock_puller_2.cancelled) + self.assertTrue(mock_puller_1.stop_tailing.called) + self.assertTrue(mock_puller_2.stop_tailing.called) + self.assertTrue(mock_puller_3.stop_tailing.called) @patch("samcli.lib.observability.observability_info_puller.AsyncContext") def test_load_time_period(self, patched_async_context):