Skip to content

Commit

Permalink
Fixing stages lookup (#165)
Browse files Browse the repository at this point in the history
* improving stages API/CLI

* rename stage_allowed>stages, type_allowed>types

* add tests

* mark tests xfail

* add --json for 'gto stages'
  • Loading branch information
aguschin authored Jun 2, 2022
1 parent 7ac36b7 commit d859e43
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 36 deletions.
4 changes: 2 additions & 2 deletions gto/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def _get_state(repo: Union[str, Repo]):
return GitRegistry.from_repo(repo).get_state()


def get_stages(repo: Union[str, Repo], allowed: bool = False):
return GitRegistry.from_repo(repo).get_stages(allowed=allowed)
def get_stages(repo: Union[str, Repo], allowed: bool = False, used: bool = False):
return GitRegistry.from_repo(repo).get_stages(allowed=allowed, used=used)


# TODO: make this work the same as CLI version
Expand Down
14 changes: 13 additions & 1 deletion gto/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,14 +740,26 @@ def stages(
help="Show allowed stages from config",
show_default=True,
),
used: bool = Option(
False,
"--used",
is_flag=True,
help="Show stages that were ever used (from all git tags)",
show_default=True,
),
json: bool = option_json,
):
"""Print list of stages used in the registry
Examples:
$ gto stages
$ gto stages --allowed
"""
format_echo(gto.api.get_stages(repo, allowed=allowed), "lines")
result = gto.api.get_stages(repo, allowed=allowed, used=used)
if json:
format_echo(result, "json")
else:
format_echo(result, "lines")


@gto_command(hidden=True)
Expand Down
30 changes: 14 additions & 16 deletions gto/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def load(self) -> Enrichment:

class NoFileConfig(BaseSettings):
INDEX: str = "artifacts.yaml"
TYPE_ALLOWED: List[str] = []
STAGE_ALLOWED: List[str] = []
TYPES: Optional[List[str]]
STAGES: Optional[List[str]]
LOG_LEVEL: str = "INFO"
DEBUG: bool = False
ENRICHMENTS: List[EnrichmentConfig] = []
Expand All @@ -56,13 +56,13 @@ class Config:

def assert_type(self, name):
assert_name_is_valid(name)
if self.TYPE_ALLOWED and name not in self.TYPE_ALLOWED:
raise UnknownType(name, self.TYPE_ALLOWED)
if self.TYPES is not None and name not in self.TYPES:
raise UnknownType(name, self.TYPES)

def assert_stage(self, name):
assert_name_is_valid(name)
if self.stages and name not in self.stages:
raise UnknownStage(name, self.stages)
if self.STAGES is not None and name not in self.STAGES:
raise UnknownStage(name, self.STAGES)

@property
def enrichments(self) -> Dict[str, Enrichment]:
Expand All @@ -71,20 +71,18 @@ def enrichments(self) -> Dict[str, Enrichment]:
return {**find_enrichments(), **res}
return res

@property
def stages(self) -> List[str]:
return self.STAGE_ALLOWED

@validator("TYPE_ALLOWED")
@validator("TYPES")
def types_are_valid(cls, v):
for name in v:
assert_name_is_valid(name)
if v:
for name in v:
assert_name_is_valid(name)
return v

@validator("STAGE_ALLOWED")
@validator("STAGES")
def stages_are_valid(cls, v):
for name in v:
assert_name_is_valid(name)
if v:
for name in v:
assert_name_is_valid(name)
return v


Expand Down
24 changes: 19 additions & 5 deletions gto/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,12 +265,26 @@ def latest(self, name: str, all: bool = False, registered: bool = True):
return artifact.sort_versions(registered=registered)
return artifact.get_latest_version(registered_only=registered)

def get_stages(self, allowed: bool = False):
def _get_allowed_stages(self):
return self.config.STAGES

def _get_used_stages(self):
return sorted(
{stage for o in self.get_artifacts().values() for stage in o.unique_stages}
)

def get_stages(self, allowed: bool = False, used: bool = False):
"""Return list of stages in the registry.
If "allowed", return stages that are allowed in config.
If "used", return stages that were used in registry.
"""
assert not (allowed and used), """Either "allowed" or "used" can be set"""
if allowed:
return self.config.stages
return sorted(
{stage for o in self.get_artifacts().values() for stage in o.unique_stages}
)
return self._get_allowed_stages()
if used:
return self._get_used_stages()
# if stages in config are set, return them
if self._get_allowed_stages() is not None:
return self._get_allowed_stages()
# if stages aren't set in config, return those in use
return self._get_used_stages()
5 changes: 3 additions & 2 deletions gto/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ def format_echo(result, format, format_table=None, if_empty="", missing_value="-
else if_empty
)
elif format == "lines":
for line in result:
click.echo(line)
if result:
for line in result:
click.echo(line)
else:
raise NotImplementedError(f"Format {format} is not implemented")

Expand Down
58 changes: 48 additions & 10 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,40 @@
import pytest
from typer.testing import CliRunner

from gto.api import annotate, promote, register
from gto.api import annotate, get_stages, promote, register
from gto.cli import app
from gto.config import CONFIG_FILE_NAME, check_name_is_valid
from gto.exceptions import UnknownType, ValidationError
from gto.index import init_index_manager
from gto.registry import GitRegistry

CONFIG_CONTENT = """
types: [model, dataset]
stages: [dev, prod]
"""

PROHIBIT_CONFIG_CONTENT = """
types: []
stages: []
"""


@pytest.fixture
def init_repo(empty_git_repo: Tuple[git.Repo, Callable]):
repo, write_file = empty_git_repo

write_file(
CONFIG_FILE_NAME,
"type_allowed: [model, dataset]",
)
write_file(CONFIG_FILE_NAME, CONFIG_CONTENT)
return repo


def test_config_load_index(init_repo):
index = init_index_manager(init_repo)
assert index.config.TYPE_ALLOWED == ["model", "dataset"]
assert index.config.TYPES == ["model", "dataset"]


def test_config_load_registry(init_repo):
registry = GitRegistry.from_repo(init_repo)
assert registry.config.TYPE_ALLOWED == ["model", "dataset"]
assert registry.config.TYPES == ["model", "dataset"]


def test_adding_allowed_type(init_repo):
Expand All @@ -43,6 +50,12 @@ def test_adding_not_allowed_type(init_repo):
annotate(init_repo, "name", type="unknown")


def test_stages(init_repo):
assert get_stages(init_repo) == ["dev", "prod"]
assert get_stages(init_repo, allowed=True) == ["dev", "prod"]
assert get_stages(init_repo, used=True) == []


def test_correct_name(init_repo):
annotate(init_repo, "model")

Expand All @@ -62,9 +75,10 @@ def test_register_incorrect_name(init_repo):
register(init_repo, "###", ref="HEAD")


# def test_register_incorrect_version(init_repo):
# with pytest.raises(ValidationError):
# register(init_repo, "model", ref="HEAD", version="###")
@pytest.mark.xfail
def test_register_incorrect_version(init_repo):
with pytest.raises(ValidationError):
register(init_repo, "model", ref="HEAD", version="###")


def test_promote_incorrect_name(init_repo):
Expand All @@ -91,6 +105,30 @@ def test_config_is_not_needed(empty_git_repo: Tuple[git.Repo, Callable], request
assert result.exit_code == 0


@pytest.fixture
def init_repo_prohibit(empty_git_repo: Tuple[git.Repo, Callable]):
repo, write_file = empty_git_repo

write_file(CONFIG_FILE_NAME, PROHIBIT_CONFIG_CONTENT)
return repo


def test_prohibit_config_type(init_repo_prohibit):
with pytest.raises(UnknownType):
annotate(init_repo_prohibit, "name", type="model")


@pytest.mark.xfail
def test_prohibit_config_promote_incorrect_stage(init_repo):
with pytest.raises(ValidationError):
promote(init_repo, "model", promote_ref="HEAD", stage="dev")


def test_empty_config_type(empty_git_repo):
repo, _ = empty_git_repo
annotate(repo, "name", type="model")


@pytest.mark.parametrize(
"name",
[
Expand Down

0 comments on commit d859e43

Please sign in to comment.