diff --git a/gto/api.py b/gto/api.py index 38691413..3208c3e3 100644 --- a/gto/api.py +++ b/gto/api.py @@ -33,8 +33,8 @@ def _get_state(repo: Union[str, Repo]): return GitRegistry.from_repo(repo).get_state() -def get_stages(repo: Union[str, Repo], allowed: bool = False): - return GitRegistry.from_repo(repo).get_stages(allowed=allowed) +def get_stages(repo: Union[str, Repo], allowed: bool = False, used: bool = False): + return GitRegistry.from_repo(repo).get_stages(allowed=allowed, used=used) # TODO: make this work the same as CLI version diff --git a/gto/cli.py b/gto/cli.py index 320e222d..0493ce8f 100644 --- a/gto/cli.py +++ b/gto/cli.py @@ -740,6 +740,14 @@ def stages( help="Show allowed stages from config", show_default=True, ), + used: bool = Option( + False, + "--used", + is_flag=True, + help="Show stages that were ever used (from all git tags)", + show_default=True, + ), + json: bool = option_json, ): """Print list of stages used in the registry @@ -747,7 +755,11 @@ def stages( $ gto stages $ gto stages --allowed """ - format_echo(gto.api.get_stages(repo, allowed=allowed), "lines") + result = gto.api.get_stages(repo, allowed=allowed, used=used) + if json: + format_echo(result, "json") + else: + format_echo(result, "lines") @gto_command(hidden=True) diff --git a/gto/config.py b/gto/config.py index 80ed5b8a..2a253fc8 100644 --- a/gto/config.py +++ b/gto/config.py @@ -42,8 +42,8 @@ def load(self) -> Enrichment: class NoFileConfig(BaseSettings): INDEX: str = "artifacts.yaml" - TYPE_ALLOWED: List[str] = [] - STAGE_ALLOWED: List[str] = [] + TYPES: Optional[List[str]] + STAGES: Optional[List[str]] LOG_LEVEL: str = "INFO" DEBUG: bool = False ENRICHMENTS: List[EnrichmentConfig] = [] @@ -56,13 +56,13 @@ class Config: def assert_type(self, name): assert_name_is_valid(name) - if self.TYPE_ALLOWED and name not in self.TYPE_ALLOWED: - raise UnknownType(name, self.TYPE_ALLOWED) + if self.TYPES is not None and name not in self.TYPES: + raise UnknownType(name, self.TYPES) def assert_stage(self, name): assert_name_is_valid(name) - if self.stages and name not in self.stages: - raise UnknownStage(name, self.stages) + if self.STAGES is not None and name not in self.STAGES: + raise UnknownStage(name, self.STAGES) @property def enrichments(self) -> Dict[str, Enrichment]: @@ -71,20 +71,18 @@ def enrichments(self) -> Dict[str, Enrichment]: return {**find_enrichments(), **res} return res - @property - def stages(self) -> List[str]: - return self.STAGE_ALLOWED - - @validator("TYPE_ALLOWED") + @validator("TYPES") def types_are_valid(cls, v): - for name in v: - assert_name_is_valid(name) + if v: + for name in v: + assert_name_is_valid(name) return v - @validator("STAGE_ALLOWED") + @validator("STAGES") def stages_are_valid(cls, v): - for name in v: - assert_name_is_valid(name) + if v: + for name in v: + assert_name_is_valid(name) return v diff --git a/gto/registry.py b/gto/registry.py index be7a86d1..8acc20df 100644 --- a/gto/registry.py +++ b/gto/registry.py @@ -265,12 +265,26 @@ def latest(self, name: str, all: bool = False, registered: bool = True): return artifact.sort_versions(registered=registered) return artifact.get_latest_version(registered_only=registered) - def get_stages(self, allowed: bool = False): + def _get_allowed_stages(self): + return self.config.STAGES + + def _get_used_stages(self): + return sorted( + {stage for o in self.get_artifacts().values() for stage in o.unique_stages} + ) + + def get_stages(self, allowed: bool = False, used: bool = False): """Return list of stages in the registry. If "allowed", return stages that are allowed in config. + If "used", return stages that were used in registry. """ + assert not (allowed and used), """Either "allowed" or "used" can be set""" if allowed: - return self.config.stages - return sorted( - {stage for o in self.get_artifacts().values() for stage in o.unique_stages} - ) + return self._get_allowed_stages() + if used: + return self._get_used_stages() + # if stages in config are set, return them + if self._get_allowed_stages() is not None: + return self._get_allowed_stages() + # if stages aren't set in config, return those in use + return self._get_used_stages() diff --git a/gto/utils.py b/gto/utils.py index e89211e3..da808f60 100644 --- a/gto/utils.py +++ b/gto/utils.py @@ -70,8 +70,9 @@ def format_echo(result, format, format_table=None, if_empty="", missing_value="- else if_empty ) elif format == "lines": - for line in result: - click.echo(line) + if result: + for line in result: + click.echo(line) else: raise NotImplementedError(f"Format {format} is not implemented") diff --git a/tests/test_config.py b/tests/test_config.py index a2c54365..cb98bcff 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -5,33 +5,40 @@ import pytest from typer.testing import CliRunner -from gto.api import annotate, promote, register +from gto.api import annotate, get_stages, promote, register from gto.cli import app from gto.config import CONFIG_FILE_NAME, check_name_is_valid from gto.exceptions import UnknownType, ValidationError from gto.index import init_index_manager from gto.registry import GitRegistry +CONFIG_CONTENT = """ +types: [model, dataset] +stages: [dev, prod] +""" + +PROHIBIT_CONFIG_CONTENT = """ +types: [] +stages: [] +""" + @pytest.fixture def init_repo(empty_git_repo: Tuple[git.Repo, Callable]): repo, write_file = empty_git_repo - write_file( - CONFIG_FILE_NAME, - "type_allowed: [model, dataset]", - ) + write_file(CONFIG_FILE_NAME, CONFIG_CONTENT) return repo def test_config_load_index(init_repo): index = init_index_manager(init_repo) - assert index.config.TYPE_ALLOWED == ["model", "dataset"] + assert index.config.TYPES == ["model", "dataset"] def test_config_load_registry(init_repo): registry = GitRegistry.from_repo(init_repo) - assert registry.config.TYPE_ALLOWED == ["model", "dataset"] + assert registry.config.TYPES == ["model", "dataset"] def test_adding_allowed_type(init_repo): @@ -43,6 +50,12 @@ def test_adding_not_allowed_type(init_repo): annotate(init_repo, "name", type="unknown") +def test_stages(init_repo): + assert get_stages(init_repo) == ["dev", "prod"] + assert get_stages(init_repo, allowed=True) == ["dev", "prod"] + assert get_stages(init_repo, used=True) == [] + + def test_correct_name(init_repo): annotate(init_repo, "model") @@ -62,9 +75,10 @@ def test_register_incorrect_name(init_repo): register(init_repo, "###", ref="HEAD") -# def test_register_incorrect_version(init_repo): -# with pytest.raises(ValidationError): -# register(init_repo, "model", ref="HEAD", version="###") +@pytest.mark.xfail +def test_register_incorrect_version(init_repo): + with pytest.raises(ValidationError): + register(init_repo, "model", ref="HEAD", version="###") def test_promote_incorrect_name(init_repo): @@ -91,6 +105,30 @@ def test_config_is_not_needed(empty_git_repo: Tuple[git.Repo, Callable], request assert result.exit_code == 0 +@pytest.fixture +def init_repo_prohibit(empty_git_repo: Tuple[git.Repo, Callable]): + repo, write_file = empty_git_repo + + write_file(CONFIG_FILE_NAME, PROHIBIT_CONFIG_CONTENT) + return repo + + +def test_prohibit_config_type(init_repo_prohibit): + with pytest.raises(UnknownType): + annotate(init_repo_prohibit, "name", type="model") + + +@pytest.mark.xfail +def test_prohibit_config_promote_incorrect_stage(init_repo): + with pytest.raises(ValidationError): + promote(init_repo, "model", promote_ref="HEAD", stage="dev") + + +def test_empty_config_type(empty_git_repo): + repo, _ = empty_git_repo + annotate(repo, "name", type="model") + + @pytest.mark.parametrize( "name", [