Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OS-7421. Ability to set model version tags and aliases #26

Merged
merged 4 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,19 @@ To create a model, use the model method with the following parameters:
arcee.model("my_model", "/home/user/my_model")
```

To set custom model version, use the set_model_version method with the following parameter:
To set a custom model version, use the model_version method with the following parameter:
- version (str): version name
```
arcee.set_model_version("1.2.3-release")
```
arcee.model_version("1.2.3-release")
```

To set a model version alias, use the model_version_alias method with the following parameter:
- alias (str): alias name
```
arcee.model_version_alias("winner")
```

To add tags to model version (key, value):
```
arcee.model_version_tag("env", "staging demo")
```
4 changes: 3 additions & 1 deletion examples/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
arcee.tag("test1", "test2")
arcee.milestone("just a milestone")
arcee.model("model_key", "/src/simple.py")
arcee.set_model_version("1.23.45-rc")
arcee.model_version("1.23.45-rc")
arcee.model_version_alias("winner")
arcee.model_version_tag("key", "value")
arcee.send({"t": 2})
print(arcee.info())
3 changes: 2 additions & 1 deletion optscale_arcee/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# flake8: noqa: F401
from .arcee import (init, send, tag, milestone, info, finish, error, stage,
dataset, hyperparam, model, set_model_version)
dataset, hyperparam, model, model_version,
model_version_alias, model_version_tag)
52 changes: 51 additions & 1 deletion optscale_arcee/arcee.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def __init__(
self._hyperparams = dict()
self._dataset = None
self._model = None
self._model_version = None
self._model_version_tags = dict()
self._model_version_aliases = list()

@property
def run(self):
Expand Down Expand Up @@ -111,6 +114,33 @@ def model(self):
def model(self, value):
self._model = value

@property
def model_version(self):
return self._model_version

@model_version.setter
def model_version(self, value):
self._model_version = value

@property
def model_version_tags(self):
return self._model_version_tags

@model_version_tags.setter
def model_version_tags(self, value):
k, v = value
self._model_version_tags.update({k: v})

@property
def model_version_aliases(self):
return self._model_version_aliases

@model_version_aliases.setter
def model_version_aliases(self, value):
aliases = set(self._model_version_aliases)
aliases.add(value)
self._model_version_aliases = list(aliases)


def init(
token, task_key=None, run_name=None, endpoint_url=None, ssl=True, period=1,
Expand Down Expand Up @@ -261,10 +291,30 @@ def model(key, path=None):
)


def set_model_version(version):
def model_version(version):
arcee = Arcee()
asyncio.run(
arcee.sender.add_version(
arcee.model, arcee.run, arcee.token, version
)
)


def model_version_alias(alias):
arcee = Arcee()
arcee.model_version_aliases = alias
asyncio.run(
arcee.sender.add_version_aliases(
arcee.model, arcee.run, arcee.token, arcee.model_version_aliases
)
)


def model_version_tag(key, value):
arcee = Arcee()
arcee.model_version_tags = (key, value)
asyncio.run(
arcee.sender.add_version_tags(
arcee.model, arcee.run, arcee.token, arcee.model_version_tags
)
)
15 changes: 13 additions & 2 deletions optscale_arcee/sender/sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,19 @@ async def assign_model_run(self, model_id, run_id, token, path=None):
await self.send_post_request(uri, headers, body)

@check_shutdown_flag_set
async def add_version(self, model_id, run_id, token, version):
async def patch_model_version(self, model_id, run_id, token, params):
headers = {"x-api-key": token, "Content-Type": "application/json"}
uri = f'{self.endpoint_url}/models/{model_id}/runs/{run_id}'
await self.send_patch_request(uri, headers, params)

async def add_version(self, model_id, run_id, token, version):
body = {'version': str(version)}
await self.send_patch_request(uri, headers, body)
await self.patch_model_version(model_id, run_id, token, body)

async def add_version_aliases(self, model_id, run_id, token, aliases):
body = {'aliases': aliases}
await self.patch_model_version(model_id, run_id, token, body)

async def add_version_tags(self, model_id, run_id, token, tags):
body = {'tags': tags}
await self.patch_model_version(model_id, run_id, token, body)
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# setup.cfg
[metadata]
name = optscale_arcee
version = 0.1.38
version = 0.1.39
author = Hystax
description = ML profiling tool for OptScale
long_description = file: README.md
Expand Down
Loading