diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cd7565f0b4d7..ded656d9dcb7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -318,7 +318,6 @@ repos: ^airflow\/providers\/google\/cloud\/operators\/bigquery\.py$| ^airflow\/providers\/amazon\/aws\/transfers\/gcs_to_s3\.py$| ^airflow\/providers\/databricks\/operators\/databricks\.py$| - ^airflow\/providers\/google\/cloud\/operators\/cloud_storage_transfer_service\.py$| ^airflow\/providers\/google\/cloud\/transfers\/bigquery_to_mysql\.py$| ^airflow\/providers\/amazon\/aws\/transfers\/redshift_to_s3\.py$| ^airflow\/providers\/google\/cloud\/operators\/compute\.py$| diff --git a/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py b/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py index ea6a7eacdb25..9f25ac58f728 100644 --- a/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +++ b/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py @@ -236,7 +236,9 @@ def __init__( **kwargs, ) -> None: super().__init__(**kwargs) - self.body = deepcopy(body) + self.body = body + if isinstance(self.body, dict): + self.body = deepcopy(body) self.aws_conn_id = aws_conn_id self.gcp_conn_id = gcp_conn_id self.api_version = api_version diff --git a/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py b/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py index 40c2b87b0a26..1ef5d6b729da 100644 --- a/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py +++ b/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py @@ -114,6 +114,10 @@ SCHEDULE: SCHEDULE_DICT, TRANSFER_SPEC: {GCS_DATA_SINK: {BUCKET_NAME: GCS_BUCKET_NAME, PATH: DESTINATION_PATH}}, } +VALID_TRANSFER_JOB_JINJA = deepcopy(VALID_TRANSFER_JOB_BASE) +VALID_TRANSFER_JOB_JINJA[NAME] = "{{ dag.dag_id }}" +VALID_TRANSFER_JOB_JINJA_RENDERED = deepcopy(VALID_TRANSFER_JOB_JINJA) +VALID_TRANSFER_JOB_JINJA_RENDERED[NAME] = "TestGcpStorageTransferJobCreateOperator" VALID_TRANSFER_JOB_GCS = deepcopy(VALID_TRANSFER_JOB_BASE) VALID_TRANSFER_JOB_GCS[TRANSFER_SPEC].update(deepcopy(SOURCE_GCS)) VALID_TRANSFER_JOB_AWS = deepcopy(VALID_TRANSFER_JOB_BASE) @@ -324,21 +328,25 @@ def test_job_create_multiple(self, aws_hook, gcp_hook): # (could be anything else) just to test if the templating works for all # fields @pytest.mark.db_test + @pytest.mark.parametrize( + "body, excepted", + [(VALID_TRANSFER_JOB_JINJA, VALID_TRANSFER_JOB_JINJA_RENDERED)], + ) @mock.patch( "airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook" ) - def test_templates(self, _, create_task_instance_of_operator): - dag_id = "TestGcpStorageTransferJobCreateOperator_test_templates" + def test_templates(self, _, create_task_instance_of_operator, body, excepted): + dag_id = "TestGcpStorageTransferJobCreateOperator" ti = create_task_instance_of_operator( CloudDataTransferServiceCreateJobOperator, dag_id=dag_id, - body={"description": "{{ dag.dag_id }}"}, + body=body, gcp_conn_id="{{ dag.dag_id }}", aws_conn_id="{{ dag.dag_id }}", task_id="task-id", ) ti.render_templates() - assert dag_id == getattr(ti.task, "body")[DESCRIPTION] + assert excepted == getattr(ti.task, "body") assert dag_id == getattr(ti.task, "gcp_conn_id") assert dag_id == getattr(ti.task, "aws_conn_id")