diff --git a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py index 3246c5bb6c356..808beb0613209 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py @@ -2484,13 +2484,13 @@ def handle_batch_status( link = DATAPROC_BATCH_LINK.format(region=self.region, project_id=self.project_id, batch_id=batch_id) if state == Batch.State.FAILED: raise AirflowException( - f"Batch job {batch_id} failed with error: {state_message}\nDriver Logs: {link}" + f"Batch job {batch_id} failed with error: {state_message}.\nDriver logs: {link}" ) if state in (Batch.State.CANCELLED, Batch.State.CANCELLING): - raise AirflowException(f"Batch job {batch_id} was cancelled. Driver logs: {link}") + raise AirflowException(f"Batch job {batch_id} was cancelled.\nDriver logs: {link}") if state == Batch.State.STATE_UNSPECIFIED: - raise AirflowException(f"Batch job {batch_id} unspecified. Driver logs: {link}") - self.log.info("Batch job %s completed. Driver logs: %s", batch_id, link) + raise AirflowException(f"Batch job {batch_id} unspecified.\nDriver logs: {link}") + self.log.info("Batch job %s completed.\nDriver logs: %s", batch_id, link) def retry_batch_creation( self, diff --git a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py index 00923cb590a1f..c76c7046db1cc 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py +++ b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py @@ -39,6 +39,7 @@ ) from airflow.models import DAG, DagBag from airflow.providers.google.cloud.links.dataproc import ( + DATAPROC_BATCH_LINK, DATAPROC_CLUSTER_LINK_DEPRECATED, DATAPROC_JOB_LINK_DEPRECATED, ) @@ -353,6 +354,12 @@ TEST_JOB_ID = "test-job" TEST_WORKFLOW_ID = "test-workflow" +EXPECTED_LABELS = { + "airflow-dag-id": TEST_DAG_ID, + "airflow-dag-display-name": TEST_DAG_ID, + "airflow-task-id": TASK_ID, +} + DATAPROC_JOB_LINK_EXPECTED = ( f"https://console.cloud.google.com/dataproc/jobs/{TEST_JOB_ID}?region={GCP_REGION}&project={GCP_PROJECT}" ) @@ -3187,9 +3194,10 @@ def test_missing_region_parameter(self): class TestDataprocCreateBatchOperator: + @mock.patch.object(DataprocCreateBatchOperator, "log", new_callable=mock.MagicMock) @mock.patch(DATAPROC_PATH.format("Batch.to_dict")) @mock.patch(DATAPROC_PATH.format("DataprocHook")) - def test_execute(self, mock_hook, to_dict_mock): + def test_execute(self, mock_hook, to_dict_mock, mock_log): op = DataprocCreateBatchOperator( task_id=TASK_ID, gcp_conn_id=GCP_CONN_ID, @@ -3203,7 +3211,10 @@ def test_execute(self, mock_hook, to_dict_mock): timeout=TIMEOUT, metadata=METADATA, ) - mock_hook.return_value.wait_for_operation.return_value = Batch(state=Batch.State.SUCCEEDED) + mock_hook.return_value.create_batch.return_value.metadata.batch = f"prefix/{BATCH_ID}" + batch_state_succeeded = Batch(state=Batch.State.SUCCEEDED) + mock_hook.return_value.wait_for_batch.return_value = batch_state_succeeded + op.execute(context=MagicMock()) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.create_batch.assert_called_once_with( @@ -3216,6 +3227,16 @@ def test_execute(self, mock_hook, to_dict_mock): timeout=TIMEOUT, metadata=METADATA, ) + to_dict_mock.assert_called_once_with(batch_state_succeeded) + logs_link = DATAPROC_BATCH_LINK.format(region=GCP_REGION, project_id=GCP_PROJECT, batch_id=BATCH_ID) + mock_log.info.assert_has_calls( + [ + mock.call("Starting batch %s", BATCH_ID), + mock.call("The batch %s was created.", BATCH_ID), + mock.call("Waiting for the completion of batch job %s", BATCH_ID), + mock.call("Batch job %s completed.\nDriver logs: %s", BATCH_ID, logs_link), + ] + ) @mock.patch(DATAPROC_PATH.format("Batch.to_dict")) @mock.patch(DATAPROC_PATH.format("DataprocHook")) @@ -3234,7 +3255,7 @@ def test_execute_with_result_retry(self, mock_hook, to_dict_mock): timeout=TIMEOUT, metadata=METADATA, ) - mock_hook.return_value.wait_for_operation.return_value = Batch(state=Batch.State.SUCCEEDED) + mock_hook.return_value.wait_for_batch.return_value = Batch(state=Batch.State.SUCCEEDED) op.execute(context=MagicMock()) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.create_batch.assert_called_once_with( @@ -3268,8 +3289,9 @@ def test_execute_batch_failed(self, mock_hook, to_dict_mock): with pytest.raises(AirflowException): op.execute(context=MagicMock()) + @mock.patch.object(DataprocCreateBatchOperator, "log", new_callable=mock.MagicMock) @mock.patch(DATAPROC_PATH.format("DataprocHook")) - def test_execute_batch_already_exists_succeeds(self, mock_hook): + def test_execute_batch_already_exists_succeeds(self, mock_hook, mock_log): op = DataprocCreateBatchOperator( task_id=TASK_ID, gcp_conn_id=GCP_CONN_ID, @@ -3283,9 +3305,10 @@ def test_execute_batch_already_exists_succeeds(self, mock_hook): timeout=TIMEOUT, metadata=METADATA, ) - mock_hook.return_value.wait_for_operation.side_effect = AlreadyExists("") - mock_hook.return_value.wait_for_batch.return_value = Batch(state=Batch.State.SUCCEEDED) + mock_hook.return_value.create_batch.side_effect = AlreadyExists("") mock_hook.return_value.create_batch.return_value.metadata.batch = f"prefix/{BATCH_ID}" + mock_hook.return_value.wait_for_batch.return_value = Batch(state=Batch.State.SUCCEEDED) + op.execute(context=MagicMock()) mock_hook.return_value.wait_for_batch.assert_called_once_with( batch_id=BATCH_ID, @@ -3295,9 +3318,23 @@ def test_execute_batch_already_exists_succeeds(self, mock_hook): timeout=TIMEOUT, metadata=METADATA, ) + # Check for succeeded run + logs_link = DATAPROC_BATCH_LINK.format(region=GCP_REGION, project_id=GCP_PROJECT, batch_id=BATCH_ID) + + mock_log.info.assert_has_calls( + [ + mock.call( + "Batch with given id already exists.", + ), + mock.call("Attaching to the job %s if it is still running.", BATCH_ID), + mock.call("Waiting for the completion of batch job %s", BATCH_ID), + mock.call("Batch job %s completed.\nDriver logs: %s", BATCH_ID, logs_link), + ] + ) + @mock.patch.object(DataprocCreateBatchOperator, "log", new_callable=mock.MagicMock) @mock.patch(DATAPROC_PATH.format("DataprocHook")) - def test_execute_batch_already_exists_fails(self, mock_hook): + def test_execute_batch_already_exists_fails(self, mock_hook, mock_log): op = DataprocCreateBatchOperator( task_id=TASK_ID, gcp_conn_id=GCP_CONN_ID, @@ -3311,11 +3348,15 @@ def test_execute_batch_already_exists_fails(self, mock_hook): timeout=TIMEOUT, metadata=METADATA, ) - mock_hook.return_value.wait_for_operation.side_effect = AlreadyExists("") - mock_hook.return_value.wait_for_batch.return_value = Batch(state=Batch.State.FAILED) + mock_hook.return_value.create_batch.side_effect = AlreadyExists("") mock_hook.return_value.create_batch.return_value.metadata.batch = f"prefix/{BATCH_ID}" - with pytest.raises(AirflowException): + mock_hook.return_value.wait_for_batch.return_value = Batch(state=Batch.State.FAILED) + + with pytest.raises(AirflowException) as exc: op.execute(context=MagicMock()) + # Check msg for FAILED batch state + logs_link = DATAPROC_BATCH_LINK.format(region=GCP_REGION, project_id=GCP_PROJECT, batch_id=BATCH_ID) + assert str(exc.value) == (f"Batch job {BATCH_ID} failed with error: .\nDriver logs: {logs_link}") mock_hook.return_value.wait_for_batch.assert_called_once_with( batch_id=BATCH_ID, region=GCP_REGION, @@ -3324,9 +3365,12 @@ def test_execute_batch_already_exists_fails(self, mock_hook): timeout=TIMEOUT, metadata=METADATA, ) + # Check logs for AlreadyExists being called + mock_log.info.assert_any_call("Batch with given id already exists.") + @mock.patch.object(DataprocCreateBatchOperator, "log", new_callable=mock.MagicMock) @mock.patch(DATAPROC_PATH.format("DataprocHook")) - def test_execute_batch_already_exists_cancelled(self, mock_hook): + def test_execute_batch_already_exists_cancelled(self, mock_hook, mock_log): op = DataprocCreateBatchOperator( task_id=TASK_ID, gcp_conn_id=GCP_CONN_ID, @@ -3340,11 +3384,16 @@ def test_execute_batch_already_exists_cancelled(self, mock_hook): timeout=TIMEOUT, metadata=METADATA, ) - mock_hook.return_value.wait_for_operation.side_effect = AlreadyExists("") - mock_hook.return_value.wait_for_batch.return_value = Batch(state=Batch.State.CANCELLED) + mock_hook.return_value.create_batch.side_effect = AlreadyExists("") mock_hook.return_value.create_batch.return_value.metadata.batch = f"prefix/{BATCH_ID}" - with pytest.raises(AirflowException): + mock_hook.return_value.wait_for_batch.return_value = Batch(state=Batch.State.CANCELLED) + + with pytest.raises(AirflowException) as exc: op.execute(context=MagicMock()) + # Check msg for CANCELLED batch state + logs_link = DATAPROC_BATCH_LINK.format(region=GCP_REGION, project_id=GCP_PROJECT, batch_id=BATCH_ID) + assert str(exc.value) == f"Batch job {BATCH_ID} was cancelled.\nDriver logs: {logs_link}" + mock_hook.return_value.wait_for_batch.assert_called_once_with( batch_id=BATCH_ID, region=GCP_REGION, @@ -3353,21 +3402,29 @@ def test_execute_batch_already_exists_cancelled(self, mock_hook): timeout=TIMEOUT, metadata=METADATA, ) + # Check logs for AlreadyExists being called + mock_log.info.assert_any_call("Batch with given id already exists.") + @mock.patch.object(DataprocCreateBatchOperator, "log", new_callable=mock.MagicMock) @mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid") @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") @mock.patch(DATAPROC_PATH.format("Batch.to_dict")) @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute_openlineage_parent_job_info_injection( - self, mock_hook, to_dict_mock, mock_ol_accessible, mock_static_uuid + self, + mock_hook, + to_dict_mock, + mock_ol_accessible, + mock_static_uuid, + mock_log, ): mock_ol_accessible.return_value = True mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267" expected_batch = { **BATCH, + "labels": EXPECTED_LABELS, "runtime_config": {"properties": OPENLINEAGE_PARENT_JOB_EXAMPLE_SPARK_PROPERTIES}, } - op = DataprocCreateBatchOperator( task_id=TASK_ID, gcp_conn_id=GCP_CONN_ID, @@ -3381,9 +3438,13 @@ def test_execute_openlineage_parent_job_info_injection( timeout=TIMEOUT, metadata=METADATA, openlineage_inject_parent_job_info=True, + dag=DAG(dag_id=TEST_DAG_ID), ) - mock_hook.return_value.wait_for_operation.return_value = Batch(state=Batch.State.SUCCEEDED) + batch_state_succeeded = Batch(state=Batch.State.SUCCEEDED) + mock_hook.return_value.wait_for_batch.return_value = batch_state_succeeded + mock_hook.return_value.create_batch.return_value.metadata.batch = f"prefix/{BATCH_ID}" op.execute(context=EXAMPLE_CONTEXT) + mock_hook.return_value.create_batch.assert_called_once_with( region=GCP_REGION, project_id=GCP_PROJECT, @@ -3394,14 +3455,19 @@ def test_execute_openlineage_parent_job_info_injection( timeout=TIMEOUT, metadata=METADATA, ) + to_dict_mock.assert_called_once_with(batch_state_succeeded) + logs_link = DATAPROC_BATCH_LINK.format(region=GCP_REGION, project_id=GCP_PROJECT, batch_id=BATCH_ID) + # Check SUCCEED run from the logs + mock_log.info.assert_any_call("Batch job %s completed.\nDriver logs: %s", BATCH_ID, logs_link) + @mock.patch.object(DataprocCreateBatchOperator, "log", new_callable=mock.MagicMock) @mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid") @mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") @mock.patch(DATAPROC_PATH.format("Batch.to_dict")) @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute_openlineage_transport_info_injection( - self, mock_hook, to_dict_mock, mock_ol_accessible, mock_ol_listener, mock_static_uuid + self, mock_hook, to_dict_mock, mock_ol_accessible, mock_ol_listener, mock_static_uuid, mock_log ): mock_ol_accessible.return_value = True mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267" @@ -3410,9 +3476,9 @@ def test_execute_openlineage_transport_info_injection( ) expected_batch = { **BATCH, + "labels": EXPECTED_LABELS, "runtime_config": {"properties": OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_SPARK_PROPERTIES}, } - op = DataprocCreateBatchOperator( task_id=TASK_ID, gcp_conn_id=GCP_CONN_ID, @@ -3426,9 +3492,13 @@ def test_execute_openlineage_transport_info_injection( timeout=TIMEOUT, metadata=METADATA, openlineage_inject_transport_info=True, + dag=DAG(dag_id=TEST_DAG_ID), ) - mock_hook.return_value.wait_for_operation.return_value = Batch(state=Batch.State.SUCCEEDED) + batch_state_succeeded = Batch(state=Batch.State.SUCCEEDED) + mock_hook.return_value.wait_for_batch.return_value = batch_state_succeeded + mock_hook.return_value.create_batch.return_value.metadata.batch = f"prefix/{BATCH_ID}" op.execute(context=EXAMPLE_CONTEXT) + mock_hook.return_value.create_batch.assert_called_once_with( region=GCP_REGION, project_id=GCP_PROJECT, @@ -3439,6 +3509,14 @@ def test_execute_openlineage_transport_info_injection( timeout=TIMEOUT, metadata=METADATA, ) + to_dict_mock.assert_called_once_with(batch_state_succeeded) + logs_link = DATAPROC_BATCH_LINK.format(region=GCP_REGION, project_id=GCP_PROJECT, batch_id=BATCH_ID) + # Verify logs for successful run + mock_log.info.assert_any_call( + "Batch job %s completed.\nDriver logs: %s", + BATCH_ID, + logs_link, + ) @mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid") @mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") @@ -3455,6 +3533,7 @@ def test_execute_openlineage_all_info_injection( ) expected_batch = { **BATCH, + "labels": EXPECTED_LABELS, "runtime_config": { "properties": { **OPENLINEAGE_PARENT_JOB_EXAMPLE_SPARK_PROPERTIES, @@ -3462,7 +3541,6 @@ def test_execute_openlineage_all_info_injection( } }, } - op = DataprocCreateBatchOperator( task_id=TASK_ID, gcp_conn_id=GCP_CONN_ID, @@ -3477,8 +3555,9 @@ def test_execute_openlineage_all_info_injection( metadata=METADATA, openlineage_inject_parent_job_info=True, openlineage_inject_transport_info=True, + dag=DAG(dag_id=TEST_DAG_ID), ) - mock_hook.return_value.wait_for_operation.return_value = Batch(state=Batch.State.SUCCEEDED) + mock_hook.return_value.wait_for_batch.return_value = Batch(state=Batch.State.SUCCEEDED) op.execute(context=EXAMPLE_CONTEXT) mock_hook.return_value.create_batch.assert_called_once_with( region=GCP_REGION, @@ -3498,22 +3577,15 @@ def test_execute_openlineage_parent_job_info_injection_skipped_when_already_pres self, mock_hook, to_dict_mock, mock_ol_accessible ): mock_ol_accessible.return_value = True - expected_labels = { - "airflow-dag-id": "test_dag", - "airflow-dag-display-name": "test_dag", - "airflow-task-id": "task-id", - } - batch = { **BATCH, - "labels": expected_labels, + "labels": EXPECTED_LABELS, "runtime_config": { "properties": { "spark.openlineage.parentJobName": "dag_id.task_id", } }, } - op = DataprocCreateBatchOperator( task_id=TASK_ID, gcp_conn_id=GCP_CONN_ID, @@ -3527,9 +3599,9 @@ def test_execute_openlineage_parent_job_info_injection_skipped_when_already_pres timeout=TIMEOUT, metadata=METADATA, openlineage_inject_parent_job_info=True, - dag=DAG(dag_id="test_dag"), + dag=DAG(dag_id=TEST_DAG_ID), ) - mock_hook.return_value.wait_for_operation.return_value = Batch(state=Batch.State.SUCCEEDED) + mock_hook.return_value.wait_for_batch.return_value = Batch(state=Batch.State.SUCCEEDED) op.execute(context=EXAMPLE_CONTEXT) mock_hook.return_value.create_batch.assert_called_once_with( region=GCP_REGION, @@ -3553,23 +3625,15 @@ def test_execute_openlineage_transport_info_injection_skipped_when_already_prese mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG) ) - - expected_labels = { - "airflow-dag-id": "test_dag", - "airflow-dag-display-name": "test_dag", - "airflow-task-id": "task-id", - } - batch = { **BATCH, - "labels": expected_labels, + "labels": EXPECTED_LABELS, "runtime_config": { "properties": { "spark.openlineage.transport.type": "console", } }, } - op = DataprocCreateBatchOperator( task_id=TASK_ID, gcp_conn_id=GCP_CONN_ID, @@ -3583,7 +3647,7 @@ def test_execute_openlineage_transport_info_injection_skipped_when_already_prese timeout=TIMEOUT, metadata=METADATA, openlineage_inject_transport_info=True, - dag=DAG(dag_id="test_dag"), + dag=DAG(dag_id=TEST_DAG_ID), ) mock_hook.return_value.wait_for_operation.return_value = Batch(state=Batch.State.SUCCEEDED) op.execute(context=EXAMPLE_CONTEXT) @@ -3609,7 +3673,6 @@ def test_execute_openlineage_parent_job_info_injection_skipped_by_default_unless **BATCH, "runtime_config": {"properties": {}}, } - op = DataprocCreateBatchOperator( task_id=TASK_ID, gcp_conn_id=GCP_CONN_ID, @@ -3624,7 +3687,7 @@ def test_execute_openlineage_parent_job_info_injection_skipped_by_default_unless metadata=METADATA, # not passing openlineage_inject_parent_job_info, should be False by default ) - mock_hook.return_value.wait_for_operation.return_value = Batch(state=Batch.State.SUCCEEDED) + mock_hook.return_value.wait_for_batch.return_value = Batch(state=Batch.State.SUCCEEDED) op.execute(context=EXAMPLE_CONTEXT) mock_hook.return_value.create_batch.assert_called_once_with( region=GCP_REGION, @@ -3652,7 +3715,6 @@ def test_execute_openlineage_transport_info_injection_skipped_by_default_unless_ **BATCH, "runtime_config": {"properties": {}}, } - op = DataprocCreateBatchOperator( task_id=TASK_ID, gcp_conn_id=GCP_CONN_ID, @@ -3667,7 +3729,7 @@ def test_execute_openlineage_transport_info_injection_skipped_by_default_unless_ metadata=METADATA, # not passing openlineage_inject_transport_info, should be False by default ) - mock_hook.return_value.wait_for_operation.return_value = Batch(state=Batch.State.SUCCEEDED) + mock_hook.return_value.wait_for_batch.return_value = Batch(state=Batch.State.SUCCEEDED) op.execute(context=EXAMPLE_CONTEXT) mock_hook.return_value.create_batch.assert_called_once_with( region=GCP_REGION, @@ -3691,7 +3753,6 @@ def test_execute_openlineage_parent_job_info_injection_skipped_when_ol_not_acces **BATCH, "runtime_config": {"properties": {}}, } - op = DataprocCreateBatchOperator( task_id=TASK_ID, gcp_conn_id=GCP_CONN_ID, @@ -3706,7 +3767,7 @@ def test_execute_openlineage_parent_job_info_injection_skipped_when_ol_not_acces metadata=METADATA, openlineage_inject_parent_job_info=True, ) - mock_hook.return_value.wait_for_operation.return_value = Batch(state=Batch.State.SUCCEEDED) + mock_hook.return_value.wait_for_batch.return_value = Batch(state=Batch.State.SUCCEEDED) op.execute(context=EXAMPLE_CONTEXT) mock_hook.return_value.create_batch.assert_called_once_with( region=GCP_REGION, @@ -3734,7 +3795,6 @@ def test_execute_openlineage_transport_info_injection_skipped_when_ol_not_access **BATCH, "runtime_config": {"properties": {}}, } - op = DataprocCreateBatchOperator( task_id=TASK_ID, gcp_conn_id=GCP_CONN_ID, @@ -3749,7 +3809,7 @@ def test_execute_openlineage_transport_info_injection_skipped_when_ol_not_access metadata=METADATA, openlineage_inject_transport_info=True, ) - mock_hook.return_value.wait_for_operation.return_value = Batch(state=Batch.State.SUCCEEDED) + mock_hook.return_value.wait_for_batch.return_value = Batch(state=Batch.State.SUCCEEDED) op.execute(context=EXAMPLE_CONTEXT) mock_hook.return_value.create_batch.assert_called_once_with( region=GCP_REGION, @@ -3778,20 +3838,13 @@ def __assert_batch_create(mock_hook, expected_batch): @mock.patch(DATAPROC_PATH.format("Batch.to_dict")) @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_create_batch_asdict_labels_updated(self, mock_hook, to_dict_mock): - expected_labels = { - "airflow-dag-id": "test_dag", - "airflow-dag-display-name": "test_dag", - "airflow-task-id": "test-task", - } - expected_batch = { **BATCH, - "labels": expected_labels, + "labels": EXPECTED_LABELS, } - DataprocCreateBatchOperator( - task_id="test-task", - dag=DAG(dag_id="test_dag"), + task_id=TASK_ID, + dag=DAG(dag_id=TEST_DAG_ID), batch=BATCH, region=GCP_REGION, ).execute(context=EXAMPLE_CONTEXT) @@ -3801,20 +3854,13 @@ def test_create_batch_asdict_labels_updated(self, mock_hook, to_dict_mock): @mock.patch(DATAPROC_PATH.format("Batch.to_dict")) @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_create_batch_asdict_labels_uppercase_transformed(self, mock_hook, to_dict_mock): - expected_labels = { - "airflow-dag-id": "test_dag", - "airflow-dag-display-name": "test_dag", - "airflow-task-id": "test-task", - } - expected_batch = { **BATCH, - "labels": expected_labels, + "labels": EXPECTED_LABELS, } - DataprocCreateBatchOperator( - task_id="test-TASK", - dag=DAG(dag_id="Test_dag"), + task_id=TASK_ID, + dag=DAG(dag_id=TEST_DAG_ID), batch=BATCH, region=GCP_REGION, ).execute(context=EXAMPLE_CONTEXT) @@ -3826,7 +3872,7 @@ def test_create_batch_asdict_labels_uppercase_transformed(self, mock_hook, to_di def test_create_batch_invalid_taskid_labels_ignored(self, mock_hook, to_dict_mock): DataprocCreateBatchOperator( task_id=".task-id", - dag=DAG(dag_id="test-dag"), + dag=DAG(dag_id=TEST_DAG_ID), batch=BATCH, region=GCP_REGION, ).execute(context=EXAMPLE_CONTEXT) @@ -3838,7 +3884,7 @@ def test_create_batch_invalid_taskid_labels_ignored(self, mock_hook, to_dict_moc def test_create_batch_long_taskid_labels_ignored(self, mock_hook, to_dict_mock): DataprocCreateBatchOperator( task_id="a" * 65, - dag=DAG(dag_id="test-dag"), + dag=DAG(dag_id=TEST_DAG_ID), batch=BATCH, region=GCP_REGION, ).execute(context=EXAMPLE_CONTEXT) @@ -3850,21 +3896,13 @@ def test_create_batch_long_taskid_labels_ignored(self, mock_hook, to_dict_mock): def test_create_batch_asobj_labels_updated(self, mock_hook, to_dict_mock): batch = Batch(name="test") batch.labels["foo"] = "bar" - dag = DAG(dag_id="test_dag") - - expected_labels = { - "airflow-dag-id": "test_dag", - "airflow-dag-display-name": "test_dag", - "airflow-task-id": "test-task", - } - expected_batch = deepcopy(batch) - expected_batch.labels.update(expected_labels) + expected_batch.labels.update(EXPECTED_LABELS) + dag = DAG(dag_id=TEST_DAG_ID) - DataprocCreateBatchOperator(task_id="test-task", batch=batch, region=GCP_REGION, dag=dag).execute( + DataprocCreateBatchOperator(task_id=TASK_ID, batch=batch, region=GCP_REGION, dag=dag).execute( context=EXAMPLE_CONTEXT ) - TestDataprocCreateBatchOperator.__assert_batch_create(mock_hook, expected_batch)