Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement pygit2 backend for fetch_refspec #169

Merged
merged 1 commit into from
Feb 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ install_requires=
pathspec>=0.9.0
asyncssh>=2.7.1,<3
funcy>=1.14
shortuuid>=0.5.0

[options.extras_require]
tests =
Expand Down
10 changes: 8 additions & 2 deletions src/scmrepo/git/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def __getitem__(self, key: str) -> BaseGitBackend:
"""Lazily initialize backends and cache it afterwards"""
initialized = self.initialized.get(key)
if not initialized:
if key not in self.backends and key in self.DEFAULT:
raise NotImplementedError
backend = self.backends[key]
initialized = backend(*self.args, **self.kwargs)
self.initialized[key] = initialized
Expand Down Expand Up @@ -266,11 +268,13 @@ def no_commits(self):
# https://github.com/iterative/dvc/issues/5641
# https://github.com/iterative/dvc/issues/7458
def _backend_func(self, name, *args, **kwargs):
for key, backend in self.backends.items():
backends: Iterable[str] = kwargs.pop("backends", self.backends)
for key in backends:
if self._last_backend is not None and key != self._last_backend:
self.backends[self._last_backend].close()
self._last_backend = None
try:
backend = self.backends[key]
func = getattr(backend, name)
result = func(*args, **kwargs)
self._last_backend = key
Expand Down Expand Up @@ -333,7 +337,9 @@ def add_commit(
iter_remote_refs = partialmethod(_backend_func, "iter_remote_refs")
get_refs_containing = partialmethod(_backend_func, "get_refs_containing")
push_refspecs = partialmethod(_backend_func, "push_refspecs")
fetch_refspecs = partialmethod(_backend_func, "fetch_refspecs")
fetch_refspecs = partialmethod(
_backend_func, "fetch_refspecs", backends=["pygit2", "dulwich"]
)
_stash_iter = partialmethod(_backend_func, "_stash_iter")
_stash_push = partialmethod(_backend_func, "_stash_push")
_stash_apply = partialmethod(_backend_func, "_stash_apply")
Expand Down
106 changes: 104 additions & 2 deletions src/scmrepo/git/backend/pygit2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from typing import (
TYPE_CHECKING,
Callable,
Dict,
Generator,
Iterable,
List,
Mapping,
Expand All @@ -15,7 +17,8 @@
Union,
)

from funcy import cached_property
from funcy import cached_property, reraise
from shortuuid import uuid

from scmrepo.exceptions import CloneError, MergeConflictError, RevError, SCMError
from scmrepo.utils import relpath
Expand All @@ -27,6 +30,8 @@


if TYPE_CHECKING:
from pygit2.remote import Remote # type: ignore

from scmrepo.progress import GitProgressEvent


Expand Down Expand Up @@ -412,6 +417,52 @@ def push_refspecs(
) -> Mapping[str, SyncStatus]:
raise NotImplementedError

def _merge_remote_branch(
self,
rh: str,
lh: str,
force: bool = False,
on_diverged: Optional[Callable[[str, str], bool]] = None,
) -> SyncStatus:
import pygit2

rh_rev = self.resolve_rev(rh)

if force:
self.set_ref(lh, rh_rev)
return SyncStatus.SUCCESS

try:
merge_result, _ = self.repo.merge_analysis(rh_rev, lh)
except KeyError:
self.set_ref(lh, rh_rev)
return SyncStatus.SUCCESS

if merge_result & pygit2.GIT_MERGE_ANALYSIS_UP_TO_DATE:
return SyncStatus.UP_TO_DATE
if merge_result & pygit2.GIT_MERGE_ANALYSIS_FASTFORWARD:
self.set_ref(lh, rh_rev)
return SyncStatus.SUCCESS
if merge_result & pygit2.GIT_MERGE_ANALYSIS_NORMAL:
if on_diverged and on_diverged(lh, rh_rev):
return SyncStatus.SUCCESS
return SyncStatus.DIVERGED
logger.debug("Unexpected merge result: %s", pygit2.GIT_MERGE_ANALYSIS_NORMAL)
raise SCMError("Unknown merge analysis result")

@contextmanager
def get_remote(self, url: str) -> Generator["Remote", None, None]:
try:
yield self.repo.remotes[url]
except ValueError:
try:
remote_name = uuid()
yield self.repo.remotes.create(remote_name, url)
finally:
self.repo.remotes.delete(remote_name)
except KeyError:
raise SCMError(f"'{url}' is not a valid Git remote or URL")

def fetch_refspecs(
self,
url: str,
Expand All @@ -421,7 +472,58 @@ def fetch_refspecs(
progress: Callable[["GitProgressEvent"], None] = None,
**kwargs,
) -> Mapping[str, SyncStatus]:
raise NotImplementedError
from pygit2 import GitError

if isinstance(refspecs, str):
refspecs = [refspecs]

with self.get_remote(url) as remote:
if os.name == "nt" and remote.url.startswith("ssh://"):
raise NotImplementedError

if os.name == "nt" and remote.url.startswith("file://"):
url = remote.url[len("file://") :]
self.repo.remotes.set_url(remote.name, url)
remote = self.repo.remotes[remote.name]
Comment on lines +484 to +487
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still needed even after the remotes.create() change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it still fails.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can make this change in the PR then, but we probably need to open another issue for it in scmrepo (this likely means that a regular git remote with file:// URL in it is also broken for us on windows). This may still be related to us not setting the path separators or formatting the URL in a way that libgit2 expects and probably needs some more investigation.


fetch_refspecs: List[str] = []
for refspec in refspecs:
if ":" in refspec:
lh, rh = refspec.split(":")
else:
lh = rh = refspec
if not rh.startswith("refs/"):
rh = f"refs/heads/{rh}"
if not lh.startswith("refs/"):
lh = f"refs/heads/{lh}"
rh = rh[len("refs/") :]
refspec = f"+{lh}:refs/remotes/{remote.name}/{rh}"
fetch_refspecs.append(refspec)

logger.debug("fetch_refspecs: %s", fetch_refspecs)
with reraise(
GitError,
SCMError(f"Git failed to fetch ref from '{url}'"),
):
remote.fetch(refspecs=fetch_refspecs)

result: Dict[str, "SyncStatus"] = {}
for refspec in fetch_refspecs:
_, rh = refspec.split(":")
if not rh.endswith("*"):
refname = rh.split("/", 3)[-1]
refname = f"refs/{refname}"
result[refname] = self._merge_remote_branch(
rh, refname, force, on_diverged
)
continue
rh = rh.rstrip("*").rstrip("/") + "/"
for branch in self.iter_refs(base=rh):
refname = f"refs/{branch[len(rh):]}"
result[refname] = self._merge_remote_branch(
branch, refname, force, on_diverged
)
return result

def _stash_iter(self, ref: str):
raise NotImplementedError
Expand Down
28 changes: 26 additions & 2 deletions tests/test_git.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import os
import shutil
from typing import Any, Dict, Optional, Type
from typing import Any, Dict, List, Optional, Type

import pytest
from asyncssh import SFTPClient
from asyncssh.connection import SSHClientConnection
from dulwich.client import LocalGitClient
from git import Repo as GitPythonRepo
from pygit2 import GitError
from pygit2.remote import Remote # type: ignore
from pytest_mock import MockerFixture
from pytest_test_utils import TempDirFactory, TmpDir
from pytest_test_utils.matchers import Matcher
Expand Down Expand Up @@ -306,7 +308,7 @@ def test_push_refspecs(
assert remote_scm.get_ref("refs/foo/baz") is None


@pytest.mark.skip_git_backend("pygit2", "gitpython")
@pytest.mark.skip_git_backend("gitpython")
@pytest.mark.parametrize("use_url", [True, False])
def test_fetch_refspecs(
tmp_dir: TmpDir,
Expand Down Expand Up @@ -362,8 +364,11 @@ def test_fetch_refspecs(

with pytest.raises(SCMError):
mocker.patch.object(LocalGitClient, "fetch", side_effect=KeyError)
mocker.patch.object(Remote, "fetch", side_effect=GitError)
git.fetch_refspecs(remote, "refs/foo/bar:refs/foo/bar")

assert len(scm.pygit2.repo.remotes) == 1


@pytest.mark.skip_git_backend("pygit2", "gitpython")
@pytest.mark.parametrize("use_url", [True, False])
Expand Down Expand Up @@ -1046,3 +1051,22 @@ def test_is_dirty_untracked(
tmp_dir.gen("untracked", "untracked")
assert git.is_dirty(untracked_files=True)
assert not git.is_dirty(untracked_files=False)


@pytest.mark.parametrize(
"backends", [["gitpython", "dulwich"], ["dulwich", "gitpython"]]
)
def test_backend_func(
tmp_dir: TmpDir,
scm: Git,
backends: List[str],
mocker: MockerFixture,
):
from functools import partial

scm.add = partial(scm._backend_func, "add", backends=backends)
tmp_dir.gen({"foo": "foo"})
backend = getattr(scm, backends[0])
mock = mocker.spy(backend, "add")
scm.add(["foo"])
mock.assert_called_once_with(["foo"])