diff --git a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py index eff64faff87f9..4486ad8e7620a 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py @@ -2530,6 +2530,8 @@ def execute(self, context: Context): self.log.info("Automatic injection of OpenLineage information into Spark properties is enabled.") self._inject_openlineage_properties_into_dataproc_batch(context) + self.__update_batch_labels() + try: self.operation = self.hook.create_batch( region=self.region, @@ -2710,6 +2712,31 @@ def _inject_openlineage_properties_into_dataproc_batch(self, context: Context) - exc_info=e, ) + def __update_batch_labels(self): + dag_id = re.sub(r"[.\s]", "_", self.dag_id.lower()) + task_id = re.sub(r"[.\s]", "_", self.task_id.lower()) + + labels_regex = re.compile(r"^[a-z][\w-]{0,63}$") + if not labels_regex.match(dag_id) or not labels_regex.match(task_id): + return + + labels_limit = 32 + new_labels = {"airflow-dag-id": dag_id, "airflow-task-id": task_id} + + if self._dag: + dag_display_name = re.sub(r"[.\s]", "_", self._dag.dag_display_name.lower()) + if labels_regex.match(dag_id): + new_labels["airflow-dag-display-name"] = dag_display_name + + if isinstance(self.batch, Batch): + if len(self.batch.labels) + len(new_labels) <= labels_limit: + self.batch.labels.update(new_labels) + elif "labels" not in self.batch: + self.batch["labels"] = new_labels + elif isinstance(self.batch.get("labels"), dict): + if len(self.batch["labels"]) + len(new_labels) <= labels_limit: + self.batch["labels"].update(new_labels) + class DataprocDeleteBatchOperator(GoogleCloudBaseOperator): """ diff --git a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py index f15de603053fa..7eacfcf4ca415 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py +++ b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py @@ -18,8 +18,9 @@ import datetime as dt import inspect +from copy import deepcopy from unittest import mock -from unittest.mock import MagicMock, Mock, call +from unittest.mock import ANY, MagicMock, Mock, call import pytest from google.api_core.exceptions import AlreadyExists, NotFound @@ -3775,6 +3776,111 @@ def test_execute_openlineage_transport_info_injection_skipped_when_ol_not_access metadata=METADATA, ) + @staticmethod + def __assert_batch_create(mock_hook, expected_batch): + mock_hook.return_value.create_batch.assert_called_once_with( + region=ANY, + project_id=ANY, + batch=expected_batch, + batch_id=ANY, + request_id=ANY, + retry=ANY, + timeout=ANY, + metadata=ANY, + ) + + @mock.patch(DATAPROC_PATH.format("Batch.to_dict")) + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_create_batch_asdict_labels_updated(self, mock_hook, to_dict_mock): + expected_labels = { + "airflow-dag-id": "test_dag", + "airflow-dag-display-name": "test_dag", + "airflow-task-id": "test-task", + } + + expected_batch = { + **BATCH, + "labels": expected_labels, + } + + DataprocCreateBatchOperator( + task_id="test-task", + dag=DAG(dag_id="test_dag"), + batch=BATCH, + region=GCP_REGION, + ).execute(context=EXAMPLE_CONTEXT) + + TestDataprocCreateBatchOperator.__assert_batch_create(mock_hook, expected_batch) + + @mock.patch(DATAPROC_PATH.format("Batch.to_dict")) + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_create_batch_asdict_labels_uppercase_transformed(self, mock_hook, to_dict_mock): + expected_labels = { + "airflow-dag-id": "test_dag", + "airflow-dag-display-name": "test_dag", + "airflow-task-id": "test-task", + } + + expected_batch = { + **BATCH, + "labels": expected_labels, + } + + DataprocCreateBatchOperator( + task_id="test-TASK", + dag=DAG(dag_id="Test_dag"), + batch=BATCH, + region=GCP_REGION, + ).execute(context=EXAMPLE_CONTEXT) + + TestDataprocCreateBatchOperator.__assert_batch_create(mock_hook, expected_batch) + + @mock.patch(DATAPROC_PATH.format("Batch.to_dict")) + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_create_batch_invalid_taskid_labels_ignored(self, mock_hook, to_dict_mock): + DataprocCreateBatchOperator( + task_id=".task-id", + dag=DAG(dag_id="test-dag"), + batch=BATCH, + region=GCP_REGION, + ).execute(context=EXAMPLE_CONTEXT) + + TestDataprocCreateBatchOperator.__assert_batch_create(mock_hook, BATCH) + + @mock.patch(DATAPROC_PATH.format("Batch.to_dict")) + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_create_batch_long_taskid_labels_ignored(self, mock_hook, to_dict_mock): + DataprocCreateBatchOperator( + task_id="a" * 65, + dag=DAG(dag_id="test-dag"), + batch=BATCH, + region=GCP_REGION, + ).execute(context=EXAMPLE_CONTEXT) + + TestDataprocCreateBatchOperator.__assert_batch_create(mock_hook, BATCH) + + @mock.patch(DATAPROC_PATH.format("Batch.to_dict")) + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_create_batch_asobj_labels_updated(self, mock_hook, to_dict_mock): + batch = Batch(name="test") + batch.labels["foo"] = "bar" + dag = DAG(dag_id="test_dag") + + expected_labels = { + "airflow-dag-id": "test_dag", + "airflow-dag-display-name": "test_dag", + "airflow-task-id": "test-task", + } + + expected_batch = deepcopy(batch) + expected_batch.labels.update(expected_labels) + + DataprocCreateBatchOperator(task_id="test-task", batch=batch, region=GCP_REGION, dag=dag).execute( + context=EXAMPLE_CONTEXT + ) + + TestDataprocCreateBatchOperator.__assert_batch_create(mock_hook, expected_batch) + class TestDataprocDeleteBatchOperator: @mock.patch(DATAPROC_PATH.format("DataprocHook"))