diff --git a/providers/git/docs/bundles/index.rst b/providers/git/docs/bundles/index.rst index 30e6ade89896b..f9e9efe9a1720 100644 --- a/providers/git/docs/bundles/index.rst +++ b/providers/git/docs/bundles/index.rst @@ -35,6 +35,7 @@ Example of using the GitDagBundle: "subdir": "dags", "tracking_ref": "main", "refresh_interval": 3600 + "submodules": False, "prune_dotgit_folder": True } } diff --git a/providers/git/src/airflow/providers/git/bundles/git.py b/providers/git/src/airflow/providers/git/bundles/git.py index 508efec76ca76..b10fc0e7f9489 100644 --- a/providers/git/src/airflow/providers/git/bundles/git.py +++ b/providers/git/src/airflow/providers/git/bundles/git.py @@ -45,6 +45,7 @@ class GitDagBundle(BaseDagBundle): :param subdir: Subdirectory within the repository where the DAGs are stored (Optional) :param git_conn_id: Connection ID for SSH/token based connection to the repository (Optional) :param repo_url: Explicit Git repository URL to override the connection's host. (Optional) + :param submodules: Whether to initialize git submodules. In case of submodules, the .git folder is preserved. :param prune_dotgit_folder: Remove .git folder from the versions after cloning. The per-version clone is not a full "git" copy (it makes use of git's `--local` ability @@ -62,6 +63,7 @@ def __init__( subdir: str | None = None, git_conn_id: str | None = None, repo_url: str | None = None, + submodules: bool = False, prune_dotgit_folder: bool = True, **kwargs, ) -> None: @@ -75,7 +77,13 @@ def __init__( self.repo_path = self.base_dir / "tracking_repo" self.git_conn_id = git_conn_id self.repo_url = repo_url - self.prune_dotgit_folder = prune_dotgit_folder + self.submodules = submodules + + # Force prune to False if submodules are used, otherwise git links break + if self.submodules: + self.prune_dotgit_folder = False + else: + self.prune_dotgit_folder = prune_dotgit_folder self._log = log.bind( bundle_name=self.name, @@ -84,6 +92,7 @@ def __init__( repo_path=self.repo_path, versions_path=self.versions_dir, git_conn_id=self.git_conn_id, + submodules=self.submodules, ) self._log.debug("bundle configured") @@ -124,10 +133,20 @@ def _initialize(self): self.repo.remotes.origin.fetch() self.repo.head.set_reference(str(self.repo.commit(self.version))) self.repo.head.reset(index=True, working_tree=True) + + if self.submodules: + cm_sub = self.hook.configure_hook_env() if self.hook else nullcontext() + with cm_sub: + try: + self._fetch_submodules() + except GitCommandError as e: + raise RuntimeError("Error pulling submodule from repository") from e + if self.prune_dotgit_folder: shutil.rmtree(self.repo_path / ".git") else: self.refresh() + self.repo.close() def initialize(self) -> None: @@ -212,6 +231,7 @@ def __repr__(self): f"" @@ -244,6 +264,16 @@ def _fetch_bare_repo(self): self.bare_repo.remotes.origin.fetch(refspecs) self.bare_repo.close() + @retry( + retry=retry_if_exception_type((GitCommandError,)), + stop=stop_after_attempt(2), + reraise=True, + ) + def _fetch_submodules(self) -> None: + self._log.info("Initializing and updating submodules", repo_path=self.repo_path) + self.repo.git.submodule("sync", "--recursive") + self.repo.git.submodule("update", "--init", "--recursive", "--jobs", "1") + def refresh(self) -> None: if self.version: raise AirflowException("Refreshing a specific version is not supported") @@ -261,6 +291,13 @@ def refresh(self) -> None: else: target = self.tracking_ref self.repo.head.reset(target, index=True, working_tree=True) + + if self.submodules: + try: + self._fetch_submodules() + except GitCommandError as e: + raise RuntimeError("Error pulling submodule from repository") from e + self.repo.close() @staticmethod diff --git a/providers/git/tests/unit/git/bundles/test_git.py b/providers/git/tests/unit/git/bundles/test_git.py index 413aa7737c447..ffeb3250bd7c0 100644 --- a/providers/git/tests/unit/git/bundles/test_git.py +++ b/providers/git/tests/unit/git/bundles/test_git.py @@ -860,3 +860,119 @@ def test_clone_bare_repo_invalid_repository_error_retry_fails( # Verify Repo was called twice (failed attempt + failed retry) assert mock_repo_class.call_count == 2 + + @mock.patch("airflow.providers.git.bundles.git.shutil.rmtree") + @mock.patch("airflow.providers.git.bundles.git.os.path.exists") + @mock.patch("airflow.providers.git.bundles.git.GitHook") + @mock.patch("airflow.providers.git.bundles.git.Repo") + def test_initialize_fetches_submodules_when_enabled( + self, mock_repo_class, mock_githook, mock_exists, mock_rmtree + ): + """Test that submodules are synced and updated when submodules=True during initialization.""" + mock_githook.return_value.repo_url = "git@github.com:apache/airflow.git" + + # Mock exists to return True so we skip the clone logic and go straight to initialization + mock_exists.return_value = True + + mock_repo_instance = mock_repo_class.return_value + # Ensure _has_version returns True so we don't try to fetch origin + mock_repo_instance.commit.return_value = mock.MagicMock() + + bundle = GitDagBundle( + name="test", + git_conn_id="git_default", + tracking_ref="main", + version="123456", + submodules=True, + ) + + bundle.initialize() + + # Verify submodule commands were called + mock_repo_instance.git.submodule.assert_has_calls( + [mock.call("sync", "--recursive"), mock.call("update", "--init", "--recursive", "--jobs", "1")] + ) + mock_rmtree.assert_not_called() + + @mock.patch("airflow.providers.git.bundles.git.shutil.rmtree") + @mock.patch("airflow.providers.git.bundles.git.os.path.exists") + @mock.patch("airflow.providers.git.bundles.git.GitHook") + @mock.patch("airflow.providers.git.bundles.git.Repo") + def test_refresh_fetches_submodules_when_enabled( + self, mock_repo_class, mock_githook, mock_exists, mock_rmtree + ): + """Test that submodules are synced and updated when submodules=True during refresh.""" + mock_githook.return_value.repo_url = "git@github.com:apache/airflow.git" + mock_exists.return_value = True + + mock_repo_instance = mock_repo_class.return_value + + bundle = GitDagBundle( + name="test", + git_conn_id="git_default", + tracking_ref="main", + submodules=True, + ) + + # Calling initialize without a specific version triggers refresh() + bundle.initialize() + + # Verify submodule commands were called + mock_repo_instance.git.submodule.assert_has_calls( + [mock.call("sync", "--recursive"), mock.call("update", "--init", "--recursive", "--jobs", "1")] + ) + mock_rmtree.assert_not_called() + + @mock.patch("airflow.providers.git.bundles.git.shutil.rmtree") + @mock.patch("airflow.providers.git.bundles.git.os.path.exists") + @mock.patch("airflow.providers.git.bundles.git.GitHook") + @mock.patch("airflow.providers.git.bundles.git.Repo") + def test_submodules_disabled_by_default(self, mock_repo_class, mock_githook, mock_exists, mock_rmtree): + """Test that submodules are NOT fetched by default.""" + mock_githook.return_value.repo_url = "git@github.com:apache/airflow.git" + mock_exists.return_value = True + + mock_repo_instance = mock_repo_class.return_value + + bundle = GitDagBundle( + name="test", + git_conn_id="git_default", + tracking_ref="main", + version="123456", + # submodules defaults to False + ) + + bundle.initialize() + + # Ensure submodule commands were NOT called + mock_repo_instance.git.submodule.assert_not_called() + + @mock.patch("airflow.providers.git.bundles.git.shutil.rmtree") + @mock.patch("airflow.providers.git.bundles.git.os.path.exists") + @mock.patch("airflow.providers.git.bundles.git.GitHook") + @mock.patch("airflow.providers.git.bundles.git.Repo") + def test_submodule_fetch_error_raises_runtime_error( + self, mock_repo_class, mock_githook, mock_exists, mock_rmtree + ): + """Test that a GitCommandError during submodule update is raised as a RuntimeError.""" + mock_githook.return_value.repo_url = "git@github.com:apache/airflow.git" + mock_exists.return_value = True + + mock_repo_instance = mock_repo_class.return_value + mock_repo_instance.commit.return_value = mock.MagicMock() + + # Simulate a git error when running submodule update + mock_repo_instance.git.submodule.side_effect = GitCommandError("submodule update", "failed") + + bundle = GitDagBundle( + name="test", + git_conn_id="git_default", + tracking_ref="main", + version="123456", + submodules=True, + ) + + with pytest.raises(RuntimeError, match="Error pulling submodule from repository"): + bundle.initialize() + + mock_rmtree.assert_not_called()