From 51b76d60ec2ddcc8b3f0576c6b67bd2172028930 Mon Sep 17 00:00:00 2001 From: Ivan Shcheklein Date: Mon, 18 Sep 2023 19:14:08 -0700 Subject: [PATCH] fix(registry): push to remote automatically only on cloned repos --- gto/api.py | 11 +++++----- gto/cli.py | 2 +- gto/git_utils.py | 25 ++++++++++----------- gto/index.py | 23 ++++++++++++++------ gto/registry.py | 11 ++++++++-- setup.py | 2 -- tests/test_api.py | 15 +++++++++++++ tests/test_index.py | 15 +++++++++++++ tests/test_registry.py | 49 ++++++++++++++++++++++++++++++++++++++++++ 9 files changed, 124 insertions(+), 29 deletions(-) diff --git a/gto/api.py b/gto/api.py index c4a45759..79c419d9 100644 --- a/gto/api.py +++ b/gto/api.py @@ -19,7 +19,6 @@ parse_shortcut, ) from gto.exceptions import NoRepo, NotImplementedInGTO, RefNotFound, WrongArgs -from gto.git_utils import has_remote from gto.index import Artifact, RepoIndexManager from gto.registry import GitRegistry from gto.tag import parse_name as parse_tag_name @@ -97,7 +96,7 @@ def register( bump_major=bump_major, bump_minor=bump_minor, bump_patch=bump_patch, - push=push or has_remote(reg.scm), + push=push, stdout=stdout, author=author, author_email=author_email, @@ -131,7 +130,7 @@ def assign( message=message, simple=simple, force=force, - push=push or has_remote(reg.scm), + push=push, skip_registration=skip_registration, stdout=stdout, author=author, @@ -165,7 +164,7 @@ def unassign( simple=simple if simple is not None else False, force=force, delete=delete, - push=push or has_remote(reg.scm), + push=push, author=author, author_email=author_email, ) @@ -195,7 +194,7 @@ def deregister( simple=simple if simple is not None else True, force=force, delete=delete, - push=push or has_remote(reg.scm), + push=push, author=author, author_email=author_email, ) @@ -223,7 +222,7 @@ def deprecate( simple=simple, force=force, delete=delete, - push=push or has_remote(reg.scm), + push=push, author=author, author_email=author_email, ) diff --git a/gto/cli.py b/gto/cli.py index f282619a..ca13371e 100644 --- a/gto/cli.py +++ b/gto/cli.py @@ -375,7 +375,7 @@ def callback_sort( # pylint: disable=inconsistent-return-statements False, "--push", is_flag=True, - help="Push created git tag to `origin` (done automatically for remote repo)", + help="Push created git tag to `origin` (ignored if `repo` option is a remote URL)", ) option_commit = Option( False, diff --git a/gto/git_utils.py b/gto/git_utils.py index f24f704e..85beae64 100644 --- a/gto/git_utils.py +++ b/gto/git_utils.py @@ -6,7 +6,7 @@ from tempfile import TemporaryDirectory from typing import Optional, Union -from scmrepo.exceptions import InvalidRemote, SCMError +from scmrepo.exceptions import SCMError from scmrepo.git import Git, SyncStatus from gto.config import RegistryConfig @@ -17,7 +17,17 @@ class RemoteRepoMixin: @classmethod @contextmanager - def from_scm(cls, scm: Git, config: Optional[RegistryConfig] = None): + def from_scm( + cls, + scm: Git, + cloned: bool = False, + config: Optional[RegistryConfig] = None, + ): + """ + `cloned` - scm is a remote repo that was cloned locally into a tmp + directory to be used for the duration of the context manager. + Means that we push tags and changes back to the remote repo. + """ raise NotImplementedError() @classmethod @@ -51,7 +61,7 @@ def from_url( with cloned_git_repo(url_or_scm) as scm: if branch: scm.checkout(branch) - with cls.from_scm(scm=scm, config=config) as obj: + with cls.from_scm(scm=scm, config=config, cloned=True) as obj: yield obj def _call_commit_push( @@ -152,12 +162,3 @@ def git_add_and_commit_all_changes(scm: Git, message: str) -> None: def _reset_repo_to_head(scm: Git) -> None: if scm.stash.push(include_untracked=True): scm.stash.drop() - - -def has_remote(scm: Git, remote: str = "origin") -> bool: - try: - scm.validate_git_remote(remote) - return True - except InvalidRemote: - pass - return False diff --git a/gto/index.py b/gto/index.py index 35b55437..4b1ac8c1 100644 --- a/gto/index.py +++ b/gto/index.py @@ -319,16 +319,22 @@ def artifact_centric_representation(self): class RepoIndexManager(FileIndexManager, RemoteRepoMixin): scm: Git + cloned: bool - def __init__(self, scm: Git, config): - super().__init__(scm=scm, config=config) # type: ignore[call-arg] + def __init__(self, scm: Git, cloned: bool, config): + super().__init__(scm=scm, cloned=cloned, config=config) # type: ignore[call-arg] @classmethod @contextmanager - def from_scm(cls, scm: Git, config: Optional[RegistryConfig] = None): + def from_scm( + cls, + scm: Git, + cloned: bool = False, + config: Optional[RegistryConfig] = None, + ): if config is None: config = read_registry_config(os.path.join(scm.root_dir, CONFIG_FILE_NAME)) - yield cls(scm=scm, config=config) + yield cls(scm=scm, cloned=cloned, config=config) def add( self, @@ -351,7 +357,7 @@ def add( commit=commit, commit_message=commit_message or generate_annotate_commit_message(name=name, type=type, path=path), - push=push, + push=push or self.cloned, stdout=stdout, name=name, type=type, @@ -458,7 +464,12 @@ class EnrichmentManager(BaseManager, RemoteRepoMixin): @classmethod @contextmanager - def from_scm(cls, scm: Git, config: Optional[RegistryConfig] = None): + def from_scm( + cls, + scm: Git, + cloned: Optional[bool] = False, + config: Optional[RegistryConfig] = None, + ): if config is None: config = read_registry_config(os.path.join(scm.root_dir, CONFIG_FILE_NAME)) yield cls(scm=scm, config=config) diff --git a/gto/registry.py b/gto/registry.py index 569ee584..ba1ff944 100644 --- a/gto/registry.py +++ b/gto/registry.py @@ -43,6 +43,7 @@ class GitRegistry(BaseModel, RemoteRepoMixin): scm: Git + cloned: bool artifact_manager: TagArtifactManager version_manager: TagVersionManager stage_manager: TagStageManager @@ -54,12 +55,18 @@ class Config: @classmethod @contextmanager - def from_scm(cls, scm: Git, config: Optional[RegistryConfig] = None): + def from_scm( + cls, + scm: Git, + cloned: bool = False, + config: Optional[RegistryConfig] = None, + ): if config is None: config = read_registry_config(os.path.join(scm.root_dir, CONFIG_FILE_NAME)) yield cls( scm=scm, + cloned=cloned, config=config, artifact_manager=TagArtifactManager(scm=scm, config=config), version_manager=TagVersionManager(scm=scm, config=config), @@ -572,7 +579,7 @@ def get_stages(self, allowed: bool = False, used: bool = False): def _push_tag_or_echo_reminder( self, tag_name: str, push: bool, stdout: bool, delete: bool = False ) -> None: - if push: + if push or self.cloned: if stdout: echo( f"Running `git push{' --delete ' if delete else ' '}origin {tag_name}`" diff --git a/setup.py b/setup.py index cc8190b8..5c8f3827 100644 --- a/setup.py +++ b/setup.py @@ -24,8 +24,6 @@ "pytest-mock", "pytest-test-utils", "pylint==2.17.5", - # we use this to suppress pytest-related false positives in our tests. - "pylint-pytest", # we use this to suppress some messages in tests, eg: foo/bar naming, # and, protected method calls in our tests "pylint-plugin-utils", diff --git a/tests/test_api.py b/tests/test_api.py index ad64a493..a7eb9cc7 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -8,6 +8,7 @@ import pytest from freezegun import freeze_time +from pytest_mock import MockFixture from pytest_test_utils import TmpDir from scmrepo.git import Git @@ -15,6 +16,7 @@ import tests.resources from gto.api import show from gto.exceptions import RefNotFound, WrongArgs +from gto.git_utils import cloned_git_repo from gto.index import RepoIndexManager from gto.tag import find from gto.versions import SemVer @@ -590,3 +592,16 @@ def test_if_unassign_with_remote_repo_then_invoke_git_push_tag(tmp_dir: TmpDir): tag_name="churn#staging!#3", delete=False, ) + + +def test_action_doesnt_push_even_if_repo_has_remotes_set(mocker: MockFixture): + # test for https://github.com/iterative/gto/issues/405 + with cloned_git_repo(tests.resources.SAMPLE_REMOTE_REPO_URL) as scm: + mocked_git_push_tag = mocker.patch("gto.registry.git_push_tag") + gto.api.unassign( + repo=scm, + name="churn", + stage="staging", + version="v3.1.0", + ) + mocked_git_push_tag.assert_not_called() diff --git a/tests/test_index.py b/tests/test_index.py index 1e665bd8..015ecae7 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -1,6 +1,7 @@ from typing import Sequence import pytest +from pytest_mock import MockFixture from pytest_test_utils import TmpDir from scmrepo.git import Git @@ -111,3 +112,17 @@ def test_check_existence_no_repo(tmp_dir: TmpDir): tmp_dir.gen("m1.txt", "some content") assert check_if_path_exists(tmp_dir / "m1.txt") assert not check_if_path_exists(tmp_dir / "not" / "exists") + + +def test_from_url_sets_cloned_property(tmp_dir: TmpDir, scm: Git, mocker: MockFixture): + with RepoIndexManager.from_url(tmp_dir) as idx: + assert idx.cloned is False + + with RepoIndexManager.from_url(scm) as idx: + assert idx.cloned is False + + cloned_git_repo_mock = mocker.patch("gto.git_utils.cloned_git_repo") + cloned_git_repo_mock.return_value.__enter__.return_value = scm + + with RepoIndexManager.from_url("https://github.com/iterative/gto") as idx: + assert idx.cloned is True diff --git a/tests/test_registry.py b/tests/test_registry.py index abe43914..5f78b0e7 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -2,7 +2,9 @@ from typing import Dict, List import pytest +from pytest_mock import MockFixture from pytest_test_utils import TmpDir +from scmrepo.git import Git from gto.registry import GitRegistry @@ -443,3 +445,50 @@ def test_registry_state_tag_tag(tmp_dir: TmpDir): check_obj( appeared["stages"][key], expected["stages"][key], exclude["stages"] ) + + +def test_from_url_sets_cloned_property(tmp_dir: TmpDir, scm: Git, mocker: MockFixture): + with GitRegistry.from_url(tmp_dir) as reg: + assert reg.cloned is False + + with GitRegistry.from_url(scm) as reg: + assert reg.cloned is False + + cloned_git_repo_mock = mocker.patch("gto.git_utils.cloned_git_repo") + cloned_git_repo_mock.return_value.__enter__.return_value = scm + + with GitRegistry.from_url("https://github.com/iterative/gto") as reg: + assert reg.cloned is True + + +# Some method parameters (model names, versions, revs, etc) depend and set by +# the `showcase` fixture setup in the conftest.py. +@pytest.mark.parametrize( + "method,args,kwargs", + [ + ("register", ["new_model", "HEAD"], {}), + ("deregister", ["nn"], {"version": "v0.0.1"}), + ("assign", ["nn", "new_stage"], {"version": "v0.0.1"}), + ("unassign", ["nn", "staging"], {"version": "v0.0.1"}), + ("deprecate", ["nn"], {}), + ], +) +@pytest.mark.usefixtures("showcase") +def test_tag_is_pushed_if_cloned_is_set( + tmp_dir: TmpDir, + mocker: MockFixture, + method, + args, + kwargs, +): + with GitRegistry.from_url(tmp_dir) as reg: + # imitate that we are doing actions on the remote repo + assert reg.cloned is False + reg.cloned = True + + # check that it attempts to push tag to a remote repo, even if + # push=False is set in a call. `cloned` overrides it in this case + git_push_tag_mock = mocker.patch("gto.registry.git_push_tag") + kwargs["push"] = False + getattr(reg, method)(*args, **kwargs) + git_push_tag_mock.assert_called_once()