Skip to content

Commit

Permalink
registry: fix check_ref with full ref name (#382)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmrowla authored Aug 21, 2023
1 parent 7663159 commit 6296c29
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 12 deletions.
9 changes: 4 additions & 5 deletions gto/registry.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
from contextlib import contextmanager
from typing import Optional, TypeVar
from typing import List, Optional, TypeVar, cast

from funcy import distinct
from pydantic import BaseModel
Expand Down Expand Up @@ -472,8 +472,7 @@ def _return_event(self, tag) -> TBaseEvent:
event = self.check_ref(tag)
if len(event) > 1:
raise NotImplementedInGTO("Can't process a tag that caused multiple events")
event = event[0]
return event
return cast(TBaseEvent, event[0])

@staticmethod
def _echo_git_suggestion(tag, delete=False):
Expand All @@ -493,7 +492,7 @@ def _delete_tags(self, tags, stdout, push: bool):
delete=True,
)

def check_ref(self, ref: str):
def check_ref(self, ref: str) -> List[BaseEvent]:
"Find out what was registered/assigned in this ref"
try:
name = ""
Expand All @@ -513,7 +512,7 @@ def check_ref(self, ref: str):
if aname == name
for event in artifact.get_events()
# TODO: support matching the shortened commit hashes
if event.ref == ref
if event.ref == tag_name
]

def find_commit(self, name, version):
Expand Down
18 changes: 11 additions & 7 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# pylint: disable=unused-variable, protected-access
# pylint: disable=protected-access
"""TODO: add more tests for API"""
import os
from contextlib import contextmanager
from time import sleep
from typing import Optional, Tuple
from typing import Optional
from unittest.mock import ANY, call, patch

import pytest
Expand Down Expand Up @@ -267,7 +267,9 @@ def environ(**overrides):
os.environ.pop(name, None)


def test_check_ref_detailed(scm: Git, artifact: str):
@pytest.mark.usefixtures("artifact")
@pytest.mark.parametrize("with_prefix", [True, False])
def test_check_ref_detailed(scm: Git, with_prefix: bool):
NAME = "model"
SEMVER = "v1.2.3"
GIT_AUTHOR_NAME = "Alexander Guschin"
Expand All @@ -283,7 +285,10 @@ def test_check_ref_detailed(scm: Git, artifact: str):
):
gto.api.register(scm, name=NAME, ref="HEAD", version=SEMVER)

events = gto.api.check_ref(scm, f"{NAME}@{SEMVER}")
ref = f"{NAME}@{SEMVER}"
if with_prefix:
ref = f"refs/tags/{ref}"
events = gto.api.check_ref(scm, ref)
assert len(events) == 1, "Should return one event"
check_obj(
events[0].dict_state(),
Expand All @@ -299,9 +304,8 @@ def test_check_ref_detailed(scm: Git, artifact: str):
)


def test_check_ref_multiple_showcase(scm: Git, showcase: Tuple[str, str]):
first_commit, second_commit = showcase

@pytest.mark.usefixtures("showcase")
def test_check_ref_multiple_showcase(scm: Git):
for tag in find(scm=scm):
events = gto.api.check_ref(scm, tag.name)
assert len(events) == 1, "Should return one event"
Expand Down
5 changes: 5 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ def test_commands(tmp_dir: TmpDir, showcase: Tuple[str, str]):
["-r", tmp_dir, "rf#production#3", "--name"],
"rf\n",
)
_check_successful_cmd(
"check-ref",
["-r", tmp_dir, "refs/tags/rf#production#3", "--name"],
"rf\n",
)
_check_successful_cmd(
"check-ref",
["-r", tmp_dir, "rf#production#3", "--stage"],
Expand Down

0 comments on commit 6296c29

Please sign in to comment.