diff --git a/providers/standard/docs/operators/python.rst b/providers/standard/docs/operators/python.rst index 60e2c83b5f1e7..bfd069cdbadb2 100644 --- a/providers/standard/docs/operators/python.rst +++ b/providers/standard/docs/operators/python.rst @@ -196,6 +196,9 @@ If you want to use additional task specific private python repositories to setup pip install configurations. Passed index urls replace the standard system configured index url settings. To prevent adding secrets to the private repository in your DAG code you can use the Airflow :doc:`apache-airflow:authoring-and-scheduling/connections`. For this purpose the connection type ``Package Index (Python)`` can be used. +In the ``Package Index (Python)`` connection type you can specify the index URL and credentials for the private repository. +After creating a ``Package Index (Python)`` connection, you can provide the connection ID to the ``PythonVirtualenvOperator`` using the ``index_urls_from_connection_ids`` parameter. +The ``PythonVirtualenvOperator`` will automatically append the index URLs from the connection to the ``index_urls`` parameter of the pip installer including the provided credentials. In the special case you want to prevent remote calls for setup of a virtual environment, pass the ``index_urls`` as empty list as ``index_urls=[]`` which forced pip installer to use the ``--no-index`` option. diff --git a/providers/standard/src/airflow/providers/standard/operators/python.py b/providers/standard/src/airflow/providers/standard/operators/python.py index 92387d28cbffd..5b78fab6b121e 100644 --- a/providers/standard/src/airflow/providers/standard/operators/python.py +++ b/providers/standard/src/airflow/providers/standard/operators/python.py @@ -49,6 +49,7 @@ DeserializingResultError, ) from airflow.models.variable import Variable +from airflow.providers.standard.hooks.package_index import PackageIndexHook from airflow.providers.standard.utils.python_virtualenv import prepare_virtualenv, write_python_script from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS from airflow.utils import hashlib_wrapper @@ -660,6 +661,8 @@ class PythonVirtualenvOperator(_BasePythonVirtualenvOperator): exit code will be treated as a failure. :param index_urls: an optional list of index urls to load Python packages from. If not provided the system pip conf will be used to source packages from. + :param index_urls_from_connection_ids: An optional list of ``PackageIndex`` connection IDs. + Will be appended to ``index_urls``. :param venv_cache_path: Optional path to the virtual environment parent folder in which the virtual environment will be cached, creates a sub-folder venv-{hash} whereas hash will be replaced with a checksum of requirements. If not provided the virtual environment will be created and deleted @@ -673,7 +676,9 @@ class PythonVirtualenvOperator(_BasePythonVirtualenvOperator): """ template_fields: Sequence[str] = tuple( - {"requirements", "index_urls", "venv_cache_path"}.union(PythonOperator.template_fields) + {"requirements", "index_urls", "index_urls_from_connection_ids", "venv_cache_path"}.union( + PythonOperator.template_fields + ) ) template_ext: Sequence[str] = (".txt",) @@ -694,6 +699,7 @@ def __init__( expect_airflow: bool = True, skip_on_exit_code: int | Container[int] | None = None, index_urls: None | Collection[str] | str = None, + index_urls_from_connection_ids: None | Collection[str] | str = None, venv_cache_path: None | os.PathLike[str] = None, env_vars: dict[str, str] | None = None, inherit_env: bool = True, @@ -728,6 +734,12 @@ def __init__( self.index_urls = list(index_urls) else: self.index_urls = None + if isinstance(index_urls_from_connection_ids, str): + self.index_urls_from_connection_ids: list[str] | None = [index_urls_from_connection_ids] + elif isinstance(index_urls_from_connection_ids, Collection): + self.index_urls_from_connection_ids = list(index_urls_from_connection_ids) + else: + self.index_urls_from_connection_ids = None self.venv_cache_path = venv_cache_path super().__init__( python_callable=python_callable, @@ -854,7 +866,18 @@ def _ensure_venv_cache_exists(self, venv_cache_path: Path) -> Path: self.log.info("New Python virtual environment created in %s", venv_path) return venv_path + def _retrieve_index_urls_from_connection_ids(self): + """Retrieve index URLs from Package Index connections.""" + if self.index_urls is None: + self.index_urls = [] + for conn_id in self.index_urls_from_connection_ids: + conn_url = PackageIndexHook(conn_id).get_connection_url() + self.index_urls.append(conn_url) + def execute_callable(self): + if self.index_urls_from_connection_ids: + self._retrieve_index_urls_from_connection_ids() + if self.venv_cache_path: venv_path = self._ensure_venv_cache_exists(Path(self.venv_cache_path)) python_path = venv_path / "bin" / "python" diff --git a/providers/standard/tests/unit/standard/operators/test_python.py b/providers/standard/tests/unit/standard/operators/test_python.py index e3ae3672718a2..1be34d0740203 100644 --- a/providers/standard/tests/unit/standard/operators/test_python.py +++ b/providers/standard/tests/unit/standard/operators/test_python.py @@ -47,6 +47,7 @@ AirflowProviderDeprecationWarning, DeserializingResultError, ) +from airflow.models.connection import Connection from airflow.models.taskinstance import TaskInstance, clear_task_instances, set_current_context from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import ( @@ -1331,6 +1332,37 @@ def f(a): self.run_as_task(f, index_urls=["https://abc.def.de", "http://xyz.abc.de"], op_args=[4]) + def test_with_index_url_from_connection(self, monkeypatch): + class MockConnection(Connection): + """Mock for the Connection class.""" + + def __init__(self, host: str | None, login: str | None, password: str | None): + super().__init__() + self.host = host + self.login = login + self.password = password + + monkeypatch.setattr( + "airflow.providers.standard.hooks.package_index.PackageIndexHook.get_connection", + lambda *_: MockConnection("https://my.package.index", "my_username", "my_password"), + ) + + def f(a): + import sys + from pathlib import Path + + pip_conf = (Path(sys.executable).parents[1] / "pip.conf").read_text() + assert "abc.def.de" in pip_conf + assert "https://my_username:my_password@my.package.index" in pip_conf + return a + + self.run_as_task( + f, + index_urls=["https://abc.def.de"], + index_urls_from_connection_ids=["my_connection"], + op_args=[4], + ) + def test_caching(self): def f(a): import sys