diff --git a/sdks/python/apache_beam/runners/interactive/interactive_beam.py b/sdks/python/apache_beam/runners/interactive/interactive_beam.py index 7b773fda5db8..457401aa29a7 100644 --- a/sdks/python/apache_beam/runners/interactive/interactive_beam.py +++ b/sdks/python/apache_beam/runners/interactive/interactive_beam.py @@ -879,7 +879,8 @@ def collect( options=None, force_compute=False, force_tuple=False, - raw_records=False): + raw_records=False, + wait_for_inputs=True): """Materializes the elements from a PCollection into a Dataframe. This reads each element from file and reads only the amount that it needs @@ -903,6 +904,10 @@ def collect( the bare results if only one PCollection is computed raw_records: (optional) if True, return a list of collected records without converting to a DataFrame. Default False. + wait_for_inputs: Whether to wait until the asynchronous dependencies are + computed. Setting this to False allows to immediately schedule the + computation, but also potentially results in running the same pipeline + stages multiple times. For example:: @@ -980,7 +985,8 @@ def as_pcollection(pcoll_or_df): max_duration=duration, runner=runner, options=options, - force_compute=force_compute) + force_compute=force_compute, + wait_for_inputs=wait_for_inputs) try: for pcoll in uncomputed: diff --git a/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py b/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py index 21163fc121c5..f0bb69ef249d 100644 --- a/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py +++ b/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py @@ -387,6 +387,115 @@ def test_collect_raw_records_true_force_tuple(self): self.assertIsInstance(result[0], list) self.assertEqual(result[0], data) + def test_collect_wait_for_inputs_true(self): + with patch('apache_beam.runners.interactive.interactive_beam.ie.current_env' + ) as mock_current_env: + mock_env = MagicMock() + mock_current_env.return_value = mock_env + mock_rm = MagicMock() + mock_env.get_recording_manager.return_value = mock_rm + mock_env.computed_pcollections = set() + mock_env.user_pipeline.side_effect = lambda x: x + + p = beam.Pipeline(ir.InteractiveRunner()) + pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3]) + pcoll2 = pcoll1 | 'Map1' >> beam.Map(lambda x: x * 2) + + # Simulate pcoll1 being computed asynchronously + mock_env.is_pcollection_computing.return_value = True + async_res = MagicMock(spec=AsyncComputationResult) + mock_rm._async_computations = {'id1': async_res} + mock_rm._get_all_dependencies.return_value = {pcoll1} + mock_rm._wait_for_dependencies.return_value = True + + # Set up return value for record + mock_recording = MagicMock() + mock_rm.record.return_value = mock_recording + + ib.collect(pcoll2, wait_for_inputs=True) + + # Check that record was called with wait_for_inputs=True + mock_rm.record.assert_called_once_with({pcoll2}, + max_n=float('inf'), + max_duration=float('inf'), + runner=None, + options=None, + force_compute=False, + wait_for_inputs=True) + + def test_collect_wait_for_inputs_false(self): + with patch('apache_beam.runners.interactive.interactive_beam.ie.current_env' + ) as mock_current_env: + mock_env = MagicMock() + mock_current_env.return_value = mock_env + mock_rm = MagicMock() + mock_env.get_recording_manager.return_value = mock_rm + mock_env.computed_pcollections = set() + mock_env.user_pipeline.side_effect = lambda x: x + + p = beam.Pipeline(ir.InteractiveRunner()) + pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3]) + pcoll2 = pcoll1 | 'Map1' >> beam.Map(lambda x: x * 2) + + # Simulate pcoll1 being computed asynchronously + mock_env.is_pcollection_computing.return_value = True + async_res = MagicMock(spec=AsyncComputationResult) + mock_rm._async_computations = {'id1': async_res} + mock_rm._get_all_dependencies.return_value = {pcoll1} + + # Set up return value for record + mock_recording = MagicMock() + mock_rm.record.return_value = mock_recording + + ib.collect(pcoll2, wait_for_inputs=False) + + # Check that wait_for_dependencies was NOT called + mock_rm._wait_for_dependencies.assert_not_called() + # Check that record was called with wait_for_inputs=False + mock_rm.record.assert_called_once_with({pcoll2}, + max_n=float('inf'), + max_duration=float('inf'), + runner=None, + options=None, + force_compute=False, + wait_for_inputs=False) + + def test_collect_wait_for_inputs_default(self): + with patch('apache_beam.runners.interactive.interactive_beam.ie.current_env' + ) as mock_current_env: + mock_env = MagicMock() + mock_current_env.return_value = mock_env + mock_rm = MagicMock() + mock_env.get_recording_manager.return_value = mock_rm + mock_env.computed_pcollections = set() + mock_env.user_pipeline.side_effect = lambda x: x + + p = beam.Pipeline(ir.InteractiveRunner()) + pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3]) + pcoll2 = pcoll1 | 'Map1' >> beam.Map(lambda x: x * 2) + + # Simulate pcoll1 being computed asynchronously + mock_env.is_pcollection_computing.return_value = True + async_res = MagicMock(spec=AsyncComputationResult) + mock_rm._async_computations = {'id1': async_res} + mock_rm._get_all_dependencies.return_value = {pcoll1} + mock_rm._wait_for_dependencies.return_value = True + + # Set up return value for record + mock_recording = MagicMock() + mock_rm.record.return_value = mock_recording + + ib.collect(pcoll2) # wait_for_inputs defaults to True + + # Check that record was called with wait_for_inputs=True + mock_rm.record.assert_called_once_with({pcoll2}, + max_n=float('inf'), + max_duration=float('inf'), + runner=None, + options=None, + force_compute=False, + wait_for_inputs=True) + @unittest.skipIf( not ie.current_env().is_interactive_ready, diff --git a/sdks/python/apache_beam/runners/interactive/recording_manager.py b/sdks/python/apache_beam/runners/interactive/recording_manager.py index c19b60b64fd2..c768e4e6d943 100644 --- a/sdks/python/apache_beam/runners/interactive/recording_manager.py +++ b/sdks/python/apache_beam/runners/interactive/recording_manager.py @@ -849,7 +849,8 @@ def record( max_duration: Union[int, str], runner: runner.PipelineRunner = None, options: pipeline_options.PipelineOptions = None, - force_compute: bool = False) -> Recording: + force_compute: bool = False, + wait_for_inputs: bool = True) -> Recording: # noqa: F821 """Records the given PCollections.""" @@ -886,10 +887,11 @@ def record( # Start a pipeline fragment to start computing the PCollections. uncomputed_pcolls = set(pcolls).difference(computed_pcolls) if uncomputed_pcolls: - if not self._wait_for_dependencies(uncomputed_pcolls): - raise RuntimeError( - 'Cannot record because a dependency failed to compute' - ' asynchronously.') + if wait_for_inputs: + if not self._wait_for_dependencies(uncomputed_pcolls): + raise RuntimeError( + 'Cannot record because a dependency failed to compute' + ' asynchronously.') self._clear()