Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(registry): push to remote automatically only on cloned repos #417

Merged
merged 1 commit into from
Sep 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions gto/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
parse_shortcut,
)
from gto.exceptions import NoRepo, NotImplementedInGTO, RefNotFound, WrongArgs
from gto.git_utils import has_remote
from gto.index import Artifact, RepoIndexManager
from gto.registry import GitRegistry
from gto.tag import parse_name as parse_tag_name
Expand Down Expand Up @@ -97,7 +96,7 @@ def register(
bump_major=bump_major,
bump_minor=bump_minor,
bump_patch=bump_patch,
push=push or has_remote(reg.scm),
push=push,
stdout=stdout,
author=author,
author_email=author_email,
Expand Down Expand Up @@ -131,7 +130,7 @@ def assign(
message=message,
simple=simple,
force=force,
push=push or has_remote(reg.scm),
push=push,
skip_registration=skip_registration,
stdout=stdout,
author=author,
Expand Down Expand Up @@ -165,7 +164,7 @@ def unassign(
simple=simple if simple is not None else False,
force=force,
delete=delete,
push=push or has_remote(reg.scm),
push=push,
author=author,
author_email=author_email,
)
Expand Down Expand Up @@ -195,7 +194,7 @@ def deregister(
simple=simple if simple is not None else True,
force=force,
delete=delete,
push=push or has_remote(reg.scm),
push=push,
author=author,
author_email=author_email,
)
Expand Down Expand Up @@ -223,7 +222,7 @@ def deprecate(
simple=simple,
force=force,
delete=delete,
push=push or has_remote(reg.scm),
push=push,
author=author,
author_email=author_email,
)
Expand Down
2 changes: 1 addition & 1 deletion gto/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def callback_sort( # pylint: disable=inconsistent-return-statements
False,
"--push",
is_flag=True,
help="Push created git tag to `origin` (done automatically for remote repo)",
help="Push created git tag to `origin` (ignored if `repo` option is a remote URL)",
)
option_commit = Option(
False,
Expand Down
25 changes: 13 additions & 12 deletions gto/git_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tempfile import TemporaryDirectory
from typing import Optional, Union

from scmrepo.exceptions import InvalidRemote, SCMError
from scmrepo.exceptions import SCMError
from scmrepo.git import Git, SyncStatus

from gto.config import RegistryConfig
Expand All @@ -17,7 +17,17 @@
class RemoteRepoMixin:
@classmethod
@contextmanager
def from_scm(cls, scm: Git, config: Optional[RegistryConfig] = None):
def from_scm(
cls,
scm: Git,
cloned: bool = False,
config: Optional[RegistryConfig] = None,
):
"""
`cloned` - scm is a remote repo that was cloned locally into a tmp
directory to be used for the duration of the context manager.
Means that we push tags and changes back to the remote repo.
"""
raise NotImplementedError()

@classmethod
Expand Down Expand Up @@ -51,7 +61,7 @@ def from_url(
with cloned_git_repo(url_or_scm) as scm:
if branch:
scm.checkout(branch)
with cls.from_scm(scm=scm, config=config) as obj:
with cls.from_scm(scm=scm, config=config, cloned=True) as obj:
yield obj

def _call_commit_push(
Expand Down Expand Up @@ -152,12 +162,3 @@ def git_add_and_commit_all_changes(scm: Git, message: str) -> None:
def _reset_repo_to_head(scm: Git) -> None:
if scm.stash.push(include_untracked=True):
scm.stash.drop()


def has_remote(scm: Git, remote: str = "origin") -> bool:
try:
scm.validate_git_remote(remote)
return True
except InvalidRemote:
pass
return False
23 changes: 17 additions & 6 deletions gto/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,16 +319,22 @@ def artifact_centric_representation(self):

class RepoIndexManager(FileIndexManager, RemoteRepoMixin):
scm: Git
cloned: bool

def __init__(self, scm: Git, config):
super().__init__(scm=scm, config=config) # type: ignore[call-arg]
def __init__(self, scm: Git, cloned: bool, config):
super().__init__(scm=scm, cloned=cloned, config=config) # type: ignore[call-arg]

@classmethod
@contextmanager
def from_scm(cls, scm: Git, config: Optional[RegistryConfig] = None):
def from_scm(
cls,
scm: Git,
cloned: bool = False,
config: Optional[RegistryConfig] = None,
):
if config is None:
config = read_registry_config(os.path.join(scm.root_dir, CONFIG_FILE_NAME))
yield cls(scm=scm, config=config)
yield cls(scm=scm, cloned=cloned, config=config)

def add(
self,
Expand All @@ -351,7 +357,7 @@ def add(
commit=commit,
commit_message=commit_message
or generate_annotate_commit_message(name=name, type=type, path=path),
push=push,
push=push or self.cloned,
stdout=stdout,
name=name,
type=type,
Expand Down Expand Up @@ -458,7 +464,12 @@ class EnrichmentManager(BaseManager, RemoteRepoMixin):

@classmethod
@contextmanager
def from_scm(cls, scm: Git, config: Optional[RegistryConfig] = None):
def from_scm(
cls,
scm: Git,
cloned: Optional[bool] = False,
config: Optional[RegistryConfig] = None,
):
if config is None:
config = read_registry_config(os.path.join(scm.root_dir, CONFIG_FILE_NAME))
yield cls(scm=scm, config=config)
Expand Down
11 changes: 9 additions & 2 deletions gto/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

class GitRegistry(BaseModel, RemoteRepoMixin):
scm: Git
cloned: bool
artifact_manager: TagArtifactManager
version_manager: TagVersionManager
stage_manager: TagStageManager
Expand All @@ -54,12 +55,18 @@ class Config:

@classmethod
@contextmanager
def from_scm(cls, scm: Git, config: Optional[RegistryConfig] = None):
def from_scm(
cls,
scm: Git,
cloned: bool = False,
config: Optional[RegistryConfig] = None,
):
if config is None:
config = read_registry_config(os.path.join(scm.root_dir, CONFIG_FILE_NAME))

yield cls(
scm=scm,
cloned=cloned,
config=config,
artifact_manager=TagArtifactManager(scm=scm, config=config),
version_manager=TagVersionManager(scm=scm, config=config),
Expand Down Expand Up @@ -572,7 +579,7 @@ def get_stages(self, allowed: bool = False, used: bool = False):
def _push_tag_or_echo_reminder(
self, tag_name: str, push: bool, stdout: bool, delete: bool = False
) -> None:
if push:
if push or self.cloned:
if stdout:
echo(
f"Running `git push{' --delete ' if delete else ' '}origin {tag_name}`"
Expand Down
2 changes: 0 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
"pytest-mock",
"pytest-test-utils",
"pylint==2.17.5",
# we use this to suppress pytest-related false positives in our tests.
"pylint-pytest",
# we use this to suppress some messages in tests, eg: foo/bar naming,
# and, protected method calls in our tests
"pylint-plugin-utils",
Expand Down
15 changes: 15 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@

import pytest
from freezegun import freeze_time
from pytest_mock import MockFixture
from pytest_test_utils import TmpDir
from scmrepo.git import Git

import gto
import tests.resources
from gto.api import show
from gto.exceptions import RefNotFound, WrongArgs
from gto.git_utils import cloned_git_repo
from gto.index import RepoIndexManager
from gto.tag import find
from gto.versions import SemVer
Expand Down Expand Up @@ -590,3 +592,16 @@ def test_if_unassign_with_remote_repo_then_invoke_git_push_tag(tmp_dir: TmpDir):
tag_name="churn#staging!#3",
delete=False,
)


def test_action_doesnt_push_even_if_repo_has_remotes_set(mocker: MockFixture):
# test for https://github.com/iterative/gto/issues/405
with cloned_git_repo(tests.resources.SAMPLE_REMOTE_REPO_URL) as scm:
mocked_git_push_tag = mocker.patch("gto.registry.git_push_tag")
gto.api.unassign(
repo=scm,
name="churn",
stage="staging",
version="v3.1.0",
)
mocked_git_push_tag.assert_not_called()
15 changes: 15 additions & 0 deletions tests/test_index.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Sequence

import pytest
from pytest_mock import MockFixture
from pytest_test_utils import TmpDir
from scmrepo.git import Git

Expand Down Expand Up @@ -111,3 +112,17 @@ def test_check_existence_no_repo(tmp_dir: TmpDir):
tmp_dir.gen("m1.txt", "some content")
assert check_if_path_exists(tmp_dir / "m1.txt")
assert not check_if_path_exists(tmp_dir / "not" / "exists")


def test_from_url_sets_cloned_property(tmp_dir: TmpDir, scm: Git, mocker: MockFixture):
with RepoIndexManager.from_url(tmp_dir) as idx:
assert idx.cloned is False

with RepoIndexManager.from_url(scm) as idx:
assert idx.cloned is False

cloned_git_repo_mock = mocker.patch("gto.git_utils.cloned_git_repo")
cloned_git_repo_mock.return_value.__enter__.return_value = scm

with RepoIndexManager.from_url("https://github.com/iterative/gto") as idx:
assert idx.cloned is True
49 changes: 49 additions & 0 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from typing import Dict, List

import pytest
from pytest_mock import MockFixture
from pytest_test_utils import TmpDir
from scmrepo.git import Git

from gto.registry import GitRegistry

Expand Down Expand Up @@ -443,3 +445,50 @@ def test_registry_state_tag_tag(tmp_dir: TmpDir):
check_obj(
appeared["stages"][key], expected["stages"][key], exclude["stages"]
)


def test_from_url_sets_cloned_property(tmp_dir: TmpDir, scm: Git, mocker: MockFixture):
with GitRegistry.from_url(tmp_dir) as reg:
assert reg.cloned is False

with GitRegistry.from_url(scm) as reg:
assert reg.cloned is False

cloned_git_repo_mock = mocker.patch("gto.git_utils.cloned_git_repo")
cloned_git_repo_mock.return_value.__enter__.return_value = scm

with GitRegistry.from_url("https://github.com/iterative/gto") as reg:
assert reg.cloned is True


# Some method parameters (model names, versions, revs, etc) depend and set by
# the `showcase` fixture setup in the conftest.py.
@pytest.mark.parametrize(
"method,args,kwargs",
[
("register", ["new_model", "HEAD"], {}),
("deregister", ["nn"], {"version": "v0.0.1"}),
("assign", ["nn", "new_stage"], {"version": "v0.0.1"}),
("unassign", ["nn", "staging"], {"version": "v0.0.1"}),
("deprecate", ["nn"], {}),
],
)
@pytest.mark.usefixtures("showcase")
def test_tag_is_pushed_if_cloned_is_set(
tmp_dir: TmpDir,
mocker: MockFixture,
method,
args,
kwargs,
):
with GitRegistry.from_url(tmp_dir) as reg:
# imitate that we are doing actions on the remote repo
assert reg.cloned is False
reg.cloned = True

# check that it attempts to push tag to a remote repo, even if
# push=False is set in a call. `cloned` overrides it in this case
git_push_tag_mock = mocker.patch("gto.registry.git_push_tag")
kwargs["push"] = False
getattr(reg, method)(*args, **kwargs)
git_push_tag_mock.assert_called_once()