Skip to content

Commit

Permalink
Enhance sorting (#162)
Browse files Browse the repository at this point in the history
* sort desc by default; refactor sorting in a separate method

* try to adapt to Studio BE

* making sort flag infer flag from version name
  • Loading branch information
aguschin authored May 27, 2022
1 parent 22cf39e commit 65e34c9
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 41 deletions.
2 changes: 1 addition & 1 deletion gto/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def format_hexsha(hexsha):
commits + registration + promotion,
key=lambda x: (x["timestamp"], events_order[x["event"]]),
)
if ascending:
if not ascending:
events.reverse()
if artifact:
events = [event for event in events if event["artifact"] == artifact]
Expand Down
78 changes: 42 additions & 36 deletions gto/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,44 @@ def dict_status(self):
return version


def sort_versions(
versions,
sort=VersionSort.SemVer,
ascending=False,
version="name",
timestamp="created_at",
):
def get(obj, key):
if isinstance(obj, dict):
return obj[key]
if isinstance(obj, BaseModel):
return getattr(obj, key)
raise NotImplementedError("Can sort either dict or BaseModel")

sort = sort if isinstance(sort, VersionSort) else VersionSort[sort]
if sort == VersionSort.SemVer:
# sorting SemVer versions in a right way
sorted_versions = sorted(
(v for v in versions if SemVer.is_valid(get(v, version))),
key=lambda x: SemVer(get(x, version)),
)[:: 1 if ascending else -1]
# sorting hexsha versions alphabetically
sorted_versions.extend(
sorted(
(v for v in versions if not SemVer.is_valid(get(v, version))),
key=lambda x: get(x, version),
)[:: 1 if ascending else -1]
)
else:
sorted_versions = sorted(
versions,
key=lambda x: get(x, timestamp),
)[:: 1 if ascending else -1]
# if ascending:
# sorted_versions.reverse()
return sorted_versions


class BaseArtifact(BaseModel):
name: str
versions: List[BaseVersion]
Expand All @@ -82,44 +120,14 @@ def get_versions(
sort=VersionSort.SemVer,
ascending=False,
) -> List[BaseVersion]:
sort = sort if isinstance(sort, VersionSort) else VersionSort[sort]
all_versions = [
versions = [
v
for v in self.versions
if (v.is_registered and not v.discovered)
or (include_discovered and v.discovered)
or (include_non_explicit and not v.is_registered)
]
if sort == VersionSort.SemVer:
# sorting SemVer versions in a right way
versions = sorted(
(v for v in all_versions if not v.discovered and v.is_registered),
key=lambda x: x.version,
)
# sorting hexsha versions alphabetically
if include_non_explicit:
versions.extend(
sorted(
(
v
for v in all_versions
if not v.discovered and not v.is_registered
),
key=lambda x: x.name,
)
)
else:
versions = sorted(
(
v
for v in all_versions
if not v.discovered and (include_non_explicit or v.is_registered)
),
key=lambda x: x.created_at,
)
if ascending:
versions.reverse()
return versions
return sort_versions(versions, sort=sort, ascending=ascending)

def get_latest_version(
self, registered_only=False, sort=VersionSort.SemVer
Expand All @@ -128,7 +136,7 @@ def get_latest_version(
include_non_explicit=not registered_only, sort=sort
)
if versions:
return versions[-1]
return versions[0]
return None

def get_promotions(
Expand All @@ -146,11 +154,9 @@ def get_promotions(
promotion = version.stage
if promotion:
stages[promotion.stage] = stages.get(promotion.stage, []) + [promotion]
# else:
# stages[promotion.stage] = stages.get(promotion.stage) or promotion
if all:
return stages
return {stage: promotions[-1] for stage, promotions in stages.items()}
return {stage: promotions[0] for stage, promotions in stages.items()}

def add_version(self, version: BaseVersion):
self.versions.append(version)
Expand Down
6 changes: 3 additions & 3 deletions gto/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ def which(
ref: bool = option_ref_bool,
all: bool = option_all,
registered_only: bool = option_registered_only,
ascending: bool = option_ascending,
# ascending: bool = option_ascending,
):
"""Find the latest artifact version in a given stage
Expand All @@ -558,8 +558,8 @@ def which(
)
if version:
if all:
if ascending:
version.reverse()
# if ascending:
# version.reverse()
format_echo([v.version for v in version], "lines")
elif ref:
echo(version.tag or version.commit_hexsha)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def test_commands(showcase):
)
_check_successful_cmd(
"which",
["-r", path, "rf", "production", "--all", "--ascending"],
["-r", path, "rf", "production", "--all"],
"v1.2.4\nv1.2.3\n",
)
_check_successful_cmd(
Expand Down

0 comments on commit 65e34c9

Please sign in to comment.