diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py index 7d911b0ebc881..fa3e1cd56f6f7 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py @@ -87,6 +87,7 @@ class BigQueryToGCSOperator(BaseOperator): :param reattach_states: Set of BigQuery job's states in case of which we should reattach to the job. Should be other than final states. :param deferrable: Run operator in the deferrable mode + :return: URIs for the objects created in Google Cloud Storage """ template_fields: Sequence[str] = ( @@ -275,6 +276,8 @@ def execute(self, context: Context): else: job.result(timeout=self.result_timeout, retry=self.result_retry) + return self.destination_cloud_storage_uris + def execute_complete(self, context: Context, event: dict[str, Any]): """ Return immediately and relies on trigger to throw a success event. Callback for the trigger. @@ -291,6 +294,8 @@ def execute_complete(self, context: Context, event: dict[str, Any]): # Save job_id as an attribute to be later used by listeners self.job_id = event.get("job_id") + return self.destination_cloud_storage_uris + def get_openlineage_facets_on_complete(self, task_instance): """Implement on_complete as we will include final BQ job id.""" from airflow.providers.common.compat.openlineage.facet import ( diff --git a/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_gcs.py b/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_gcs.py index ce459b81706ec..ba1beed291d58 100644 --- a/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_gcs.py +++ b/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_gcs.py @@ -112,7 +112,11 @@ def test_execute(self, mock_hook): labels=labels, project_id=JOB_PROJECT_ID, ) - operator.execute(context=mock.MagicMock()) + result = operator.execute(context=mock.MagicMock()) + + assert result is not None + assert isinstance(result, list) + assert result == ["gs://some-bucket/some-file.txt"] mock_hook.return_value.insert_job.assert_called_once_with( job_id="123456_hash", @@ -207,6 +211,27 @@ def test_execute_complete_reassigns_job_id(self): ) assert operator.job_id == job_id + def test_execute_complete_returns_destination_cloud_storage_uris(self): + """Assert that self.destination_cloud_storage_uris is returned by execute_complete.""" + + operator = BigQueryToGCSOperator( + project_id=JOB_PROJECT_ID, + task_id=TASK_ID, + source_project_dataset_table=f"{PROJECT_ID}.{TEST_DATASET}.{TEST_TABLE_ID}", + destination_cloud_storage_uris=[f"gs://{TEST_BUCKET}/{TEST_FOLDER}/"], + deferrable=True, + job_id=None, + ) + + result = operator.execute_complete( + context=MagicMock(), + event={"status": "success", "message": "Job completed", "job_id": None}, + ) + + assert result is not None + assert isinstance(result, list) + assert result == [f"gs://{TEST_BUCKET}/{TEST_FOLDER}/"] + @pytest.mark.parametrize( ("gcs_uri", "expected_dataset_name"), (