Skip to content

Commit

Permalink
implement pygit2 backend for fetch_refspec
Browse files Browse the repository at this point in the history
fix #168

1. Add order select for `_backend_func`.
2. Raise exception for fetch_refspec for ssh:// repo on Windows.
3. Add order select for _backend_func
  • Loading branch information
karajan1001 committed Feb 1, 2023
1 parent 4984eb3 commit 702fabd
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 6 deletions.
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ install_requires=
pathspec>=0.9.0
asyncssh>=2.7.1,<3
funcy>=1.14
shortuuid>=0.5.0

[options.extras_require]
tests =
Expand Down
10 changes: 8 additions & 2 deletions src/scmrepo/git/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def __getitem__(self, key: str) -> BaseGitBackend:
"""Lazily initialize backends and cache it afterwards"""
initialized = self.initialized.get(key)
if not initialized:
if key not in self.backends and key in self.DEFAULT:
raise NotImplementedError
backend = self.backends[key]
initialized = backend(*self.args, **self.kwargs)
self.initialized[key] = initialized
Expand Down Expand Up @@ -266,11 +268,13 @@ def no_commits(self):
# https://github.com/iterative/dvc/issues/5641
# https://github.com/iterative/dvc/issues/7458
def _backend_func(self, name, *args, **kwargs):
for key, backend in self.backends.items():
backends: Iterable[str] = kwargs.pop("backends", self.backends)
for key in backends:
if self._last_backend is not None and key != self._last_backend:
self.backends[self._last_backend].close()
self._last_backend = None
try:
backend = self.backends[key]
func = getattr(backend, name)
result = func(*args, **kwargs)
self._last_backend = key
Expand Down Expand Up @@ -333,7 +337,9 @@ def add_commit(
iter_remote_refs = partialmethod(_backend_func, "iter_remote_refs")
get_refs_containing = partialmethod(_backend_func, "get_refs_containing")
push_refspecs = partialmethod(_backend_func, "push_refspecs")
fetch_refspecs = partialmethod(_backend_func, "fetch_refspecs")
fetch_refspecs = partialmethod(
_backend_func, "fetch_refspecs", backends=["pygit2", "dulwich"]
)
_stash_iter = partialmethod(_backend_func, "_stash_iter")
_stash_push = partialmethod(_backend_func, "_stash_push")
_stash_apply = partialmethod(_backend_func, "_stash_apply")
Expand Down
106 changes: 104 additions & 2 deletions src/scmrepo/git/backend/pygit2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from typing import (
TYPE_CHECKING,
Callable,
Dict,
Generator,
Iterable,
List,
Mapping,
Expand All @@ -15,7 +17,8 @@
Union,
)

from funcy import cached_property
from funcy import cached_property, reraise
from shortuuid import uuid

from scmrepo.exceptions import CloneError, MergeConflictError, RevError, SCMError
from scmrepo.utils import relpath
Expand All @@ -27,6 +30,8 @@


if TYPE_CHECKING:
from pygit2.remote import Remote # type: ignore

from scmrepo.progress import GitProgressEvent


Expand Down Expand Up @@ -412,6 +417,52 @@ def push_refspecs(
) -> Mapping[str, SyncStatus]:
raise NotImplementedError

def _merge_remote_branch(
self,
rh: str,
lh: str,
force: bool = False,
on_diverged: Optional[Callable[[str, str], bool]] = None,
) -> SyncStatus:
import pygit2

rh_rev = self.resolve_rev(rh)

if force:
self.set_ref(lh, rh_rev)
return SyncStatus.SUCCESS

try:
merge_result, _ = self.repo.merge_analysis(rh_rev, lh)
except KeyError:
self.set_ref(lh, rh_rev)
return SyncStatus.SUCCESS

if merge_result & pygit2.GIT_MERGE_ANALYSIS_UP_TO_DATE:
return SyncStatus.UP_TO_DATE
if merge_result & pygit2.GIT_MERGE_ANALYSIS_FASTFORWARD:
self.set_ref(lh, rh_rev)
return SyncStatus.SUCCESS
if merge_result & pygit2.GIT_MERGE_ANALYSIS_NORMAL:
if on_diverged and on_diverged(lh, rh_rev):
return SyncStatus.SUCCESS
return SyncStatus.DIVERGED
logger.debug("Unexpected merge result: %s", pygit2.GIT_MERGE_ANALYSIS_NORMAL)
raise SCMError("Unknown merge analysis result")

@contextmanager
def get_remote(self, url: str) -> Generator["Remote", None, None]:
try:
yield self.repo.remotes[url]
except ValueError:
try:
remote_name = uuid()
yield self.repo.remotes.create(remote_name, url)
finally:
self.repo.remotes.delete(remote_name)
except KeyError:
raise SCMError(f"'{url}' is not a valid Git remote or URL")

def fetch_refspecs(
self,
url: str,
Expand All @@ -421,7 +472,58 @@ def fetch_refspecs(
progress: Callable[["GitProgressEvent"], None] = None,
**kwargs,
) -> Mapping[str, SyncStatus]:
raise NotImplementedError
from pygit2 import GitError

if isinstance(refspecs, str):
refspecs = [refspecs]

with self.get_remote(url) as remote:
if os.name == "nt" and remote.url.startswith("ssh://"):
raise NotImplementedError

if os.name == "nt" and remote.url.startswith("file://"):
url = remote.url[len("file://") :]
self.repo.remotes.set_url(remote.name, url)
remote = self.repo.remotes[remote.name]

fetch_refspecs: List[str] = []
for refspec in refspecs:
if ":" in refspec:
lh, rh = refspec.split(":")
else:
lh = rh = refspec
if not rh.startswith("refs/"):
rh = f"refs/heads/{rh}"
if not lh.startswith("refs/"):
lh = f"refs/heads/{lh}"
rh = rh[len("refs/") :]
refspec = f"+{lh}:refs/remotes/{remote.name}/{rh}"
fetch_refspecs.append(refspec)

logger.debug("fetch_refspecs: %s", fetch_refspecs)
with reraise(
GitError,
SCMError(f"Git failed to fetch ref from '{url}'"),
):
remote.fetch(refspecs=fetch_refspecs)

result: Dict[str, "SyncStatus"] = {}
for refspec in fetch_refspecs:
_, rh = refspec.split(":")
if not rh.endswith("*"):
refname = rh.split("/", 3)[-1]
refname = f"refs/{refname}"
result[refname] = self._merge_remote_branch(
rh, refname, force, on_diverged
)
continue
rh = rh.rstrip("*").rstrip("/") + "/"
for branch in self.iter_refs(base=rh):
refname = f"refs/{branch[len(rh):]}"
result[refname] = self._merge_remote_branch(
branch, refname, force, on_diverged
)
return result

def _stash_iter(self, ref: str):
raise NotImplementedError
Expand Down
28 changes: 26 additions & 2 deletions tests/test_git.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import os
import shutil
from typing import Any, Dict, Optional, Type
from typing import Any, Dict, List, Optional, Type

import pytest
from asyncssh import SFTPClient
from asyncssh.connection import SSHClientConnection
from dulwich.client import LocalGitClient
from git import Repo as GitPythonRepo
from pygit2 import GitError
from pygit2.remote import Remote # type: ignore
from pytest_mock import MockerFixture
from pytest_test_utils import TempDirFactory, TmpDir
from pytest_test_utils.matchers import Matcher
Expand Down Expand Up @@ -306,7 +308,7 @@ def test_push_refspecs(
assert remote_scm.get_ref("refs/foo/baz") is None


@pytest.mark.skip_git_backend("pygit2", "gitpython")
@pytest.mark.skip_git_backend("gitpython")
@pytest.mark.parametrize("use_url", [True, False])
def test_fetch_refspecs(
tmp_dir: TmpDir,
Expand Down Expand Up @@ -362,8 +364,11 @@ def test_fetch_refspecs(

with pytest.raises(SCMError):
mocker.patch.object(LocalGitClient, "fetch", side_effect=KeyError)
mocker.patch.object(Remote, "fetch", side_effect=GitError)
git.fetch_refspecs(remote, "refs/foo/bar:refs/foo/bar")

assert len(scm.pygit2.repo.remotes) == 1


@pytest.mark.skip_git_backend("pygit2", "gitpython")
@pytest.mark.parametrize("use_url", [True, False])
Expand Down Expand Up @@ -1046,3 +1051,22 @@ def test_is_dirty_untracked(
tmp_dir.gen("untracked", "untracked")
assert git.is_dirty(untracked_files=True)
assert not git.is_dirty(untracked_files=False)


@pytest.mark.parametrize(
"backends", [["gitpython", "dulwich"], ["dulwich", "gitpython"]]
)
def test_backend_func(
tmp_dir: TmpDir,
scm: Git,
backends: List[str],
mocker: MockerFixture,
):
from functools import partial

scm.add = partial(scm._backend_func, "add", backends=backends)
tmp_dir.gen({"foo": "foo"})
backend = getattr(scm, backends[0])
mock = mocker.spy(backend, "add")
scm.add(["foo"])
mock.assert_called_once_with(["foo"])

0 comments on commit 702fabd

Please sign in to comment.