From d99db91c0f54e6312eda294cfe2292a67d438607 Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Thu, 7 Dec 2023 08:13:04 +0100 Subject: [PATCH] Fix Python-based decorators templating Templating of Python-based decorators has been broken since implementation. The decorators used template_fields definition as defined originally in PythonOperator rather than the ones from virtualenv because template fields were redefined in _PythonDecoratedOperator class and they took precedence (MRU). This PR add explicit copying of template_fields from the operators that they are decorating. Fixes: #36102 --- airflow/decorators/branch_external_python.py | 1 + airflow/decorators/branch_python.py | 1 + airflow/decorators/branch_virtualenv.py | 1 + airflow/decorators/external_python.py | 1 + airflow/decorators/python_virtualenv.py | 1 + airflow/decorators/short_circuit.py | 1 + airflow/models/abstractoperator.py | 1 - tests/decorators/test_branch_virtualenv.py | 9 ++++--- tests/decorators/test_external_python.py | 14 +++++++++++ tests/decorators/test_python_virtualenv.py | 26 ++++++++++++++++++++ 10 files changed, 52 insertions(+), 4 deletions(-) diff --git a/airflow/decorators/branch_external_python.py b/airflow/decorators/branch_external_python.py index 8e945541c594e..2902a47c67741 100644 --- a/airflow/decorators/branch_external_python.py +++ b/airflow/decorators/branch_external_python.py @@ -29,6 +29,7 @@ class _BranchExternalPythonDecoratedOperator(_PythonDecoratedOperator, BranchExternalPythonOperator): """Wraps a Python callable and captures args/kwargs when called for execution.""" + template_fields = BranchExternalPythonOperator.template_fields custom_operator_name: str = "@task.branch_external_python" diff --git a/airflow/decorators/branch_python.py b/airflow/decorators/branch_python.py index 3ac11f0efa256..31750ef657a94 100644 --- a/airflow/decorators/branch_python.py +++ b/airflow/decorators/branch_python.py @@ -29,6 +29,7 @@ class _BranchPythonDecoratedOperator(_PythonDecoratedOperator, BranchPythonOperator): """Wraps a Python callable and captures args/kwargs when called for execution.""" + template_fields = BranchPythonOperator.template_fields custom_operator_name: str = "@task.branch" diff --git a/airflow/decorators/branch_virtualenv.py b/airflow/decorators/branch_virtualenv.py index 3e4c3fcaf1b8e..c96638ee20246 100644 --- a/airflow/decorators/branch_virtualenv.py +++ b/airflow/decorators/branch_virtualenv.py @@ -29,6 +29,7 @@ class _BranchPythonVirtualenvDecoratedOperator(_PythonDecoratedOperator, BranchPythonVirtualenvOperator): """Wraps a Python callable and captures args/kwargs when called for execution.""" + template_fields = BranchPythonVirtualenvOperator.template_fields custom_operator_name: str = "@task.branch_virtualenv" diff --git a/airflow/decorators/external_python.py b/airflow/decorators/external_python.py index 1e39ed561bb59..2d8e2603f94dd 100644 --- a/airflow/decorators/external_python.py +++ b/airflow/decorators/external_python.py @@ -29,6 +29,7 @@ class _PythonExternalDecoratedOperator(_PythonDecoratedOperator, ExternalPythonOperator): """Wraps a Python callable and captures args/kwargs when called for execution.""" + template_fields = ExternalPythonOperator.template_fields custom_operator_name: str = "@task.external_python" diff --git a/airflow/decorators/python_virtualenv.py b/airflow/decorators/python_virtualenv.py index 2eb86787795a8..d0eb93a0d7aa6 100644 --- a/airflow/decorators/python_virtualenv.py +++ b/airflow/decorators/python_virtualenv.py @@ -29,6 +29,7 @@ class _PythonVirtualenvDecoratedOperator(_PythonDecoratedOperator, PythonVirtualenvOperator): """Wraps a Python callable and captures args/kwargs when called for execution.""" + template_fields = PythonVirtualenvOperator.template_fields custom_operator_name: str = "@task.virtualenv" diff --git a/airflow/decorators/short_circuit.py b/airflow/decorators/short_circuit.py index 210a0e04537cb..c964ed6bb75fd 100644 --- a/airflow/decorators/short_circuit.py +++ b/airflow/decorators/short_circuit.py @@ -29,6 +29,7 @@ class _ShortCircuitDecoratedOperator(_PythonDecoratedOperator, ShortCircuitOperator): """Wraps a Python callable and captures args/kwargs when called for execution.""" + template_fields = ShortCircuitOperator.template_fields custom_operator_name: str = "@task.short_circuit" diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index df0e6cb34964a..f5a266f4b1875 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -679,7 +679,6 @@ def _do_render_template_fields( f"{attr_name!r} is configured as a template field " f"but {parent.task_type} does not have this attribute." ) - try: if not value: continue diff --git a/tests/decorators/test_branch_virtualenv.py b/tests/decorators/test_branch_virtualenv.py index 2b5f9bb95ed73..57db52f167746 100644 --- a/tests/decorators/test_branch_virtualenv.py +++ b/tests/decorators/test_branch_virtualenv.py @@ -31,7 +31,10 @@ class Test_BranchPythonVirtualenvDecoratedOperator: # possibilities. So we are increasing the timeout for this test to 3x of the default timeout @pytest.mark.execution_timeout(180) @pytest.mark.parametrize("branch_task_name", ["task_1", "task_2"]) - def test_branch_one(self, dag_maker, branch_task_name): + def test_branch_one(self, dag_maker, branch_task_name, tmp_path): + requirements_file = tmp_path / "requirements.txt" + requirements_file.write_text("funcsigs==0.4") + @task def dummy_f(): pass @@ -57,14 +60,14 @@ def branch_operator(): else: - @task.branch_virtualenv(task_id="branching", requirements=["funcsigs"]) + @task.branch_virtualenv(task_id="branching", requirements="requirements.txt") def branch_operator(): import funcsigs print(f"We successfully imported funcsigs version {funcsigs.__version__}") return "task_2" - with dag_maker(): + with dag_maker(template_searchpath=tmp_path.as_posix()): branchoperator = branch_operator() df = dummy_f() task_1 = task_1() diff --git a/tests/decorators/test_external_python.py b/tests/decorators/test_external_python.py index cdd8c6cd49de0..27d8b0ed100c8 100644 --- a/tests/decorators/test_external_python.py +++ b/tests/decorators/test_external_python.py @@ -74,6 +74,20 @@ def f(): ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + def test_with_templated_python(self, dag_maker, venv_python_with_dill): + # add template that produces empty string when rendered + templated_python_with_dill = venv_python_with_dill.as_posix() + "{{ '' }}" + + @task.external_python(python=templated_python_with_dill, use_dill=True) + def f(): + """Import dill to double-check it is installed .""" + import dill # noqa: F401 + + with dag_maker(): + ret = f() + + ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + def test_no_dill_installed_raises_exception_when_use_dill(self, dag_maker, venv_python): @task.external_python(python=venv_python, use_dill=True) def f(): diff --git a/tests/decorators/test_python_virtualenv.py b/tests/decorators/test_python_virtualenv.py index fc604ac4643a8..a069aee8b1ce5 100644 --- a/tests/decorators/test_python_virtualenv.py +++ b/tests/decorators/test_python_virtualenv.py @@ -103,6 +103,32 @@ def f(): ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + def test_with_requirements_file(self, dag_maker, tmp_path): + requirements_file = tmp_path / "requirements.txt" + requirements_file.write_text("funcsigs==0.4\nattrs==23.1.0") + + @task.virtualenv( + system_site_packages=False, + requirements="requirements.txt", + python_version=PYTHON_VERSION, + use_dill=True, + ) + def f(): + import funcsigs + + if funcsigs.__version__ != "0.4": + raise Exception + + import attrs + + if attrs.__version__ != "23.1.0": + raise Exception + + with dag_maker(template_searchpath=tmp_path.as_posix()): + ret = f() + + ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + def test_unpinned_requirements(self, dag_maker): @task.virtualenv( system_site_packages=False,