Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions providers/git/docs/bundles/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Example of using the GitDagBundle:
"subdir": "dags",
"tracking_ref": "main",
"refresh_interval": 3600
"submodules": False,
"prune_dotgit_folder": True
}
}
Expand Down
39 changes: 38 additions & 1 deletion providers/git/src/airflow/providers/git/bundles/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class GitDagBundle(BaseDagBundle):
:param subdir: Subdirectory within the repository where the DAGs are stored (Optional)
:param git_conn_id: Connection ID for SSH/token based connection to the repository (Optional)
:param repo_url: Explicit Git repository URL to override the connection's host. (Optional)
:param submodules: Whether to initialize git submodules. In case of submodules, the .git folder is preserved.
:param prune_dotgit_folder: Remove .git folder from the versions after cloning.

The per-version clone is not a full "git" copy (it makes use of git's `--local` ability
Expand All @@ -62,6 +63,7 @@ def __init__(
subdir: str | None = None,
git_conn_id: str | None = None,
repo_url: str | None = None,
submodules: bool = False,
prune_dotgit_folder: bool = True,
**kwargs,
) -> None:
Expand All @@ -75,7 +77,13 @@ def __init__(
self.repo_path = self.base_dir / "tracking_repo"
self.git_conn_id = git_conn_id
self.repo_url = repo_url
self.prune_dotgit_folder = prune_dotgit_folder
self.submodules = submodules

# Force prune to False if submodules are used, otherwise git links break
if self.submodules:
self.prune_dotgit_folder = False
else:
self.prune_dotgit_folder = prune_dotgit_folder

self._log = log.bind(
bundle_name=self.name,
Expand All @@ -84,6 +92,7 @@ def __init__(
repo_path=self.repo_path,
versions_path=self.versions_dir,
git_conn_id=self.git_conn_id,
submodules=self.submodules,
)

self._log.debug("bundle configured")
Expand Down Expand Up @@ -124,10 +133,20 @@ def _initialize(self):
self.repo.remotes.origin.fetch()
self.repo.head.set_reference(str(self.repo.commit(self.version)))
self.repo.head.reset(index=True, working_tree=True)

if self.submodules:
cm_sub = self.hook.configure_hook_env() if self.hook else nullcontext()
with cm_sub:
try:
self._fetch_submodules()
except GitCommandError as e:
raise RuntimeError("Error pulling submodule from repository") from e

if self.prune_dotgit_folder:
shutil.rmtree(self.repo_path / ".git")
else:
self.refresh()

self.repo.close()

def initialize(self) -> None:
Expand Down Expand Up @@ -212,6 +231,7 @@ def __repr__(self):
f"<GitDagBundle("
f"name={self.name!r}, "
f"tracking_ref={self.tracking_ref!r}, "
f"submodules={self.submodules!r}, "
f"subdir={self.subdir!r}, "
f"version={self.version!r}"
f")>"
Expand Down Expand Up @@ -244,6 +264,16 @@ def _fetch_bare_repo(self):
self.bare_repo.remotes.origin.fetch(refspecs)
self.bare_repo.close()

@retry(
retry=retry_if_exception_type((GitCommandError,)),
stop=stop_after_attempt(2),
reraise=True,
)
def _fetch_submodules(self) -> None:
self._log.info("Initializing and updating submodules", repo_path=self.repo_path)
self.repo.git.submodule("sync", "--recursive")
self.repo.git.submodule("update", "--init", "--recursive", "--jobs", "1")

def refresh(self) -> None:
if self.version:
raise AirflowException("Refreshing a specific version is not supported")
Expand All @@ -261,6 +291,13 @@ def refresh(self) -> None:
else:
target = self.tracking_ref
self.repo.head.reset(target, index=True, working_tree=True)

if self.submodules:
try:
self._fetch_submodules()
except GitCommandError as e:
raise RuntimeError("Error pulling submodule from repository") from e

self.repo.close()

@staticmethod
Expand Down
116 changes: 116 additions & 0 deletions providers/git/tests/unit/git/bundles/test_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,3 +860,119 @@ def test_clone_bare_repo_invalid_repository_error_retry_fails(

# Verify Repo was called twice (failed attempt + failed retry)
assert mock_repo_class.call_count == 2

@mock.patch("airflow.providers.git.bundles.git.shutil.rmtree")
@mock.patch("airflow.providers.git.bundles.git.os.path.exists")
@mock.patch("airflow.providers.git.bundles.git.GitHook")
@mock.patch("airflow.providers.git.bundles.git.Repo")
def test_initialize_fetches_submodules_when_enabled(
self, mock_repo_class, mock_githook, mock_exists, mock_rmtree
):
"""Test that submodules are synced and updated when submodules=True during initialization."""
mock_githook.return_value.repo_url = "git@github.com:apache/airflow.git"

# Mock exists to return True so we skip the clone logic and go straight to initialization
mock_exists.return_value = True

mock_repo_instance = mock_repo_class.return_value
# Ensure _has_version returns True so we don't try to fetch origin
mock_repo_instance.commit.return_value = mock.MagicMock()

bundle = GitDagBundle(
name="test",
git_conn_id="git_default",
tracking_ref="main",
version="123456",
submodules=True,
)

bundle.initialize()

# Verify submodule commands were called
mock_repo_instance.git.submodule.assert_has_calls(
[mock.call("sync", "--recursive"), mock.call("update", "--init", "--recursive", "--jobs", "1")]
)
mock_rmtree.assert_not_called()

@mock.patch("airflow.providers.git.bundles.git.shutil.rmtree")
@mock.patch("airflow.providers.git.bundles.git.os.path.exists")
@mock.patch("airflow.providers.git.bundles.git.GitHook")
@mock.patch("airflow.providers.git.bundles.git.Repo")
def test_refresh_fetches_submodules_when_enabled(
self, mock_repo_class, mock_githook, mock_exists, mock_rmtree
):
"""Test that submodules are synced and updated when submodules=True during refresh."""
mock_githook.return_value.repo_url = "git@github.com:apache/airflow.git"
mock_exists.return_value = True

mock_repo_instance = mock_repo_class.return_value

bundle = GitDagBundle(
name="test",
git_conn_id="git_default",
tracking_ref="main",
submodules=True,
)

# Calling initialize without a specific version triggers refresh()
bundle.initialize()

# Verify submodule commands were called
mock_repo_instance.git.submodule.assert_has_calls(
[mock.call("sync", "--recursive"), mock.call("update", "--init", "--recursive", "--jobs", "1")]
)
mock_rmtree.assert_not_called()

@mock.patch("airflow.providers.git.bundles.git.shutil.rmtree")
@mock.patch("airflow.providers.git.bundles.git.os.path.exists")
@mock.patch("airflow.providers.git.bundles.git.GitHook")
@mock.patch("airflow.providers.git.bundles.git.Repo")
def test_submodules_disabled_by_default(self, mock_repo_class, mock_githook, mock_exists, mock_rmtree):
"""Test that submodules are NOT fetched by default."""
mock_githook.return_value.repo_url = "git@github.com:apache/airflow.git"
mock_exists.return_value = True

mock_repo_instance = mock_repo_class.return_value

bundle = GitDagBundle(
name="test",
git_conn_id="git_default",
tracking_ref="main",
version="123456",
# submodules defaults to False
)

bundle.initialize()

# Ensure submodule commands were NOT called
mock_repo_instance.git.submodule.assert_not_called()

@mock.patch("airflow.providers.git.bundles.git.shutil.rmtree")
@mock.patch("airflow.providers.git.bundles.git.os.path.exists")
@mock.patch("airflow.providers.git.bundles.git.GitHook")
@mock.patch("airflow.providers.git.bundles.git.Repo")
def test_submodule_fetch_error_raises_runtime_error(
self, mock_repo_class, mock_githook, mock_exists, mock_rmtree
):
"""Test that a GitCommandError during submodule update is raised as a RuntimeError."""
mock_githook.return_value.repo_url = "git@github.com:apache/airflow.git"
mock_exists.return_value = True

mock_repo_instance = mock_repo_class.return_value
mock_repo_instance.commit.return_value = mock.MagicMock()

# Simulate a git error when running submodule update
mock_repo_instance.git.submodule.side_effect = GitCommandError("submodule update", "failed")

bundle = GitDagBundle(
name="test",
git_conn_id="git_default",
tracking_ref="main",
version="123456",
submodules=True,
)

with pytest.raises(RuntimeError, match="Error pulling submodule from repository"):
bundle.initialize()

mock_rmtree.assert_not_called()
Loading