From 702fabd2b6b22b686e836f386037c327c90b7e21 Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Mon, 16 Jan 2023 16:27:59 +0800 Subject: [PATCH] implement pygit2 backend for fetch_refspec fix #168 1. Add order select for `_backend_func`. 2. Raise exception for fetch_refspec for ssh:// repo on Windows. 3. Add order select for _backend_func --- setup.cfg | 1 + src/scmrepo/git/__init__.py | 10 ++- src/scmrepo/git/backend/pygit2.py | 106 +++++++++++++++++++++++++++++- tests/test_git.py | 28 +++++++- 4 files changed, 139 insertions(+), 6 deletions(-) diff --git a/setup.cfg b/setup.cfg index f47282b1..601365a7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,6 +32,7 @@ install_requires= pathspec>=0.9.0 asyncssh>=2.7.1,<3 funcy>=1.14 + shortuuid>=0.5.0 [options.extras_require] tests = diff --git a/src/scmrepo/git/__init__.py b/src/scmrepo/git/__init__.py index 98768b1b..5dd3667b 100644 --- a/src/scmrepo/git/__init__.py +++ b/src/scmrepo/git/__init__.py @@ -41,6 +41,8 @@ def __getitem__(self, key: str) -> BaseGitBackend: """Lazily initialize backends and cache it afterwards""" initialized = self.initialized.get(key) if not initialized: + if key not in self.backends and key in self.DEFAULT: + raise NotImplementedError backend = self.backends[key] initialized = backend(*self.args, **self.kwargs) self.initialized[key] = initialized @@ -266,11 +268,13 @@ def no_commits(self): # https://github.com/iterative/dvc/issues/5641 # https://github.com/iterative/dvc/issues/7458 def _backend_func(self, name, *args, **kwargs): - for key, backend in self.backends.items(): + backends: Iterable[str] = kwargs.pop("backends", self.backends) + for key in backends: if self._last_backend is not None and key != self._last_backend: self.backends[self._last_backend].close() self._last_backend = None try: + backend = self.backends[key] func = getattr(backend, name) result = func(*args, **kwargs) self._last_backend = key @@ -333,7 +337,9 @@ def add_commit( iter_remote_refs = partialmethod(_backend_func, "iter_remote_refs") get_refs_containing = partialmethod(_backend_func, "get_refs_containing") push_refspecs = partialmethod(_backend_func, "push_refspecs") - fetch_refspecs = partialmethod(_backend_func, "fetch_refspecs") + fetch_refspecs = partialmethod( + _backend_func, "fetch_refspecs", backends=["pygit2", "dulwich"] + ) _stash_iter = partialmethod(_backend_func, "_stash_iter") _stash_push = partialmethod(_backend_func, "_stash_push") _stash_apply = partialmethod(_backend_func, "_stash_apply") diff --git a/src/scmrepo/git/backend/pygit2.py b/src/scmrepo/git/backend/pygit2.py index 2fedecaf..b26da06a 100644 --- a/src/scmrepo/git/backend/pygit2.py +++ b/src/scmrepo/git/backend/pygit2.py @@ -7,6 +7,8 @@ from typing import ( TYPE_CHECKING, Callable, + Dict, + Generator, Iterable, List, Mapping, @@ -15,7 +17,8 @@ Union, ) -from funcy import cached_property +from funcy import cached_property, reraise +from shortuuid import uuid from scmrepo.exceptions import CloneError, MergeConflictError, RevError, SCMError from scmrepo.utils import relpath @@ -27,6 +30,8 @@ if TYPE_CHECKING: + from pygit2.remote import Remote # type: ignore + from scmrepo.progress import GitProgressEvent @@ -412,6 +417,52 @@ def push_refspecs( ) -> Mapping[str, SyncStatus]: raise NotImplementedError + def _merge_remote_branch( + self, + rh: str, + lh: str, + force: bool = False, + on_diverged: Optional[Callable[[str, str], bool]] = None, + ) -> SyncStatus: + import pygit2 + + rh_rev = self.resolve_rev(rh) + + if force: + self.set_ref(lh, rh_rev) + return SyncStatus.SUCCESS + + try: + merge_result, _ = self.repo.merge_analysis(rh_rev, lh) + except KeyError: + self.set_ref(lh, rh_rev) + return SyncStatus.SUCCESS + + if merge_result & pygit2.GIT_MERGE_ANALYSIS_UP_TO_DATE: + return SyncStatus.UP_TO_DATE + if merge_result & pygit2.GIT_MERGE_ANALYSIS_FASTFORWARD: + self.set_ref(lh, rh_rev) + return SyncStatus.SUCCESS + if merge_result & pygit2.GIT_MERGE_ANALYSIS_NORMAL: + if on_diverged and on_diverged(lh, rh_rev): + return SyncStatus.SUCCESS + return SyncStatus.DIVERGED + logger.debug("Unexpected merge result: %s", pygit2.GIT_MERGE_ANALYSIS_NORMAL) + raise SCMError("Unknown merge analysis result") + + @contextmanager + def get_remote(self, url: str) -> Generator["Remote", None, None]: + try: + yield self.repo.remotes[url] + except ValueError: + try: + remote_name = uuid() + yield self.repo.remotes.create(remote_name, url) + finally: + self.repo.remotes.delete(remote_name) + except KeyError: + raise SCMError(f"'{url}' is not a valid Git remote or URL") + def fetch_refspecs( self, url: str, @@ -421,7 +472,58 @@ def fetch_refspecs( progress: Callable[["GitProgressEvent"], None] = None, **kwargs, ) -> Mapping[str, SyncStatus]: - raise NotImplementedError + from pygit2 import GitError + + if isinstance(refspecs, str): + refspecs = [refspecs] + + with self.get_remote(url) as remote: + if os.name == "nt" and remote.url.startswith("ssh://"): + raise NotImplementedError + + if os.name == "nt" and remote.url.startswith("file://"): + url = remote.url[len("file://") :] + self.repo.remotes.set_url(remote.name, url) + remote = self.repo.remotes[remote.name] + + fetch_refspecs: List[str] = [] + for refspec in refspecs: + if ":" in refspec: + lh, rh = refspec.split(":") + else: + lh = rh = refspec + if not rh.startswith("refs/"): + rh = f"refs/heads/{rh}" + if not lh.startswith("refs/"): + lh = f"refs/heads/{lh}" + rh = rh[len("refs/") :] + refspec = f"+{lh}:refs/remotes/{remote.name}/{rh}" + fetch_refspecs.append(refspec) + + logger.debug("fetch_refspecs: %s", fetch_refspecs) + with reraise( + GitError, + SCMError(f"Git failed to fetch ref from '{url}'"), + ): + remote.fetch(refspecs=fetch_refspecs) + + result: Dict[str, "SyncStatus"] = {} + for refspec in fetch_refspecs: + _, rh = refspec.split(":") + if not rh.endswith("*"): + refname = rh.split("/", 3)[-1] + refname = f"refs/{refname}" + result[refname] = self._merge_remote_branch( + rh, refname, force, on_diverged + ) + continue + rh = rh.rstrip("*").rstrip("/") + "/" + for branch in self.iter_refs(base=rh): + refname = f"refs/{branch[len(rh):]}" + result[refname] = self._merge_remote_branch( + branch, refname, force, on_diverged + ) + return result def _stash_iter(self, ref: str): raise NotImplementedError diff --git a/tests/test_git.py b/tests/test_git.py index 323be54c..04e1fae1 100644 --- a/tests/test_git.py +++ b/tests/test_git.py @@ -1,12 +1,14 @@ import os import shutil -from typing import Any, Dict, Optional, Type +from typing import Any, Dict, List, Optional, Type import pytest from asyncssh import SFTPClient from asyncssh.connection import SSHClientConnection from dulwich.client import LocalGitClient from git import Repo as GitPythonRepo +from pygit2 import GitError +from pygit2.remote import Remote # type: ignore from pytest_mock import MockerFixture from pytest_test_utils import TempDirFactory, TmpDir from pytest_test_utils.matchers import Matcher @@ -306,7 +308,7 @@ def test_push_refspecs( assert remote_scm.get_ref("refs/foo/baz") is None -@pytest.mark.skip_git_backend("pygit2", "gitpython") +@pytest.mark.skip_git_backend("gitpython") @pytest.mark.parametrize("use_url", [True, False]) def test_fetch_refspecs( tmp_dir: TmpDir, @@ -362,8 +364,11 @@ def test_fetch_refspecs( with pytest.raises(SCMError): mocker.patch.object(LocalGitClient, "fetch", side_effect=KeyError) + mocker.patch.object(Remote, "fetch", side_effect=GitError) git.fetch_refspecs(remote, "refs/foo/bar:refs/foo/bar") + assert len(scm.pygit2.repo.remotes) == 1 + @pytest.mark.skip_git_backend("pygit2", "gitpython") @pytest.mark.parametrize("use_url", [True, False]) @@ -1046,3 +1051,22 @@ def test_is_dirty_untracked( tmp_dir.gen("untracked", "untracked") assert git.is_dirty(untracked_files=True) assert not git.is_dirty(untracked_files=False) + + +@pytest.mark.parametrize( + "backends", [["gitpython", "dulwich"], ["dulwich", "gitpython"]] +) +def test_backend_func( + tmp_dir: TmpDir, + scm: Git, + backends: List[str], + mocker: MockerFixture, +): + from functools import partial + + scm.add = partial(scm._backend_func, "add", backends=backends) + tmp_dir.gen({"foo": "foo"}) + backend = getattr(scm, backends[0]) + mock = mocker.spy(backend, "add") + scm.add(["foo"]) + mock.assert_called_once_with(["foo"])