Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add OptionalType click parameter type #1406

Merged
merged 23 commits into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/1393.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `OptionalType` class as a new parameter type wrapper, allowing the client CLI to manage arguments of the `undefined` type.
1 change: 1 addition & 0 deletions changes/1393.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Minor fixes to execute `backend.ai sesstpl create` and `backend.ai session create-from-template` commands
18 changes: 18 additions & 0 deletions src/ai/backend/client/cli/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import click

from ..types import undefined


class ByteSizeParamType(click.ParamType):
name = "byte"
Expand Down Expand Up @@ -166,3 +168,19 @@ def convert(self, arg, param, ctx):
return arg.split(",")
except ValueError as e:
self.fail(repr(e), param, ctx)


class OptionalType(click.ParamType):
name = "Optional Type Wrapper"

def __init__(self, type_: type) -> None:
super().__init__()
self.type_ = type_

def convert(self, value: Any, param, ctx):
try:
if value is None or value is undefined:
return value
return self.type_(value)
except ValueError:
self.fail(f"{value!r} is not valid `{self.type_}` or `undefined`", param, ctx)
27 changes: 20 additions & 7 deletions src/ai/backend/client/cli/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ..session import AsyncSession, Session
from ..types import Undefined, undefined
from . import events
from .params import CommaSeparatedListType
from .params import CommaSeparatedListType, OptionalType
from .pretty import print_done, print_error, print_fail, print_info, print_wait, print_warn
from .run import format_stats, prepare_env_arg, prepare_mount_arg, prepare_resource_arg
from .ssh import container_ssh_ctx
Expand Down Expand Up @@ -391,6 +391,7 @@ def _create_from_template_cmd(docs: str = None):
"--name",
"--client-token",
metavar="NAME",
type=OptionalType(str),
default=undefined,
help="Specify a human-readable session name. If not set, a random hex string is used.",
)
Expand All @@ -399,6 +400,7 @@ def _create_from_template_cmd(docs: str = None):
"--owner",
"--owner-access-key",
metavar="ACCESS_KEY",
type=OptionalType(str),
default=undefined,
help="Set the owner of the target session explicitly.",
)
Expand All @@ -418,11 +420,18 @@ def _create_from_template_cmd(docs: str = None):
default=None,
help="Let session to be started at a specific or relative time.",
)
@click.option("-i", "--image", default=undefined, help="Set compute_session image to run.")
@click.option(
"-i",
"--image",
type=OptionalType(str),
default=undefined,
help="Set compute_session image to run.",
)
@click.option(
"-c",
"--startup-command",
metavar="COMMAND",
type=OptionalType(str),
default=undefined,
help="Set the command to execute for batch-type sessions.",
)
Expand All @@ -434,7 +443,7 @@ def _create_from_template_cmd(docs: str = None):
@click.option(
"--max-wait",
metavar="SECONDS",
type=int,
type=OptionalType(int),
default=undefined,
help="The maximum duration to wait until the session starts.",
)
Expand Down Expand Up @@ -469,7 +478,10 @@ def _create_from_template_cmd(docs: str = None):
)
# extra options
@click.option(
"--tag", type=str, default=undefined, help="User-defined tag string to annotate sessions."
"--tag",
type=OptionalType(str),
default=undefined,
help="User-defined tag string to annotate sessions.",
)
# resource spec
@click.option(
Expand All @@ -489,7 +501,7 @@ def _create_from_template_cmd(docs: str = None):
@click.option(
"--scaling-group",
"--sgroup",
type=str,
type=OptionalType(str),
default=undefined,
help=(
"The scaling group to execute session. If not specified, "
Expand All @@ -512,7 +524,7 @@ def _create_from_template_cmd(docs: str = None):
@click.option(
"--cluster-size",
metavar="NUMBER",
type=int,
type=OptionalType(int),
default=undefined,
help="The size of cluster in number of containers.",
)
Expand All @@ -538,7 +550,8 @@ def _create_from_template_cmd(docs: str = None):
"-g",
"--group",
metavar="GROUP_NAME",
default=None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a None value is passed, the group value gets overwritten with None, leading to an error.

collected an error log [ERROR] "ai.backend.manager.api.exceptions.InvalidAPIParameters: Missing or invalid API parameters. (Invalid group)" from manager

type=OptionalType(str),
default=undefined,
help=(
"Group name where the session is spawned. "
"User should be a member of the group to execute the code."
Expand Down
3 changes: 3 additions & 0 deletions src/ai/backend/client/func/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ async def create_from_template(
enqueue_only: bool | Undefined = undefined,
max_wait: int | Undefined = undefined,
dependencies: Sequence[str] = None, # cannot be stored in templates
callback_url: str | Undefined = undefined,
no_reuse: bool | Undefined = undefined,
image: str | Undefined = undefined,
mounts: Union[List[str], Undefined] = undefined,
Expand Down Expand Up @@ -468,6 +469,8 @@ async def create_from_template(
"bootstrap_script": bootstrap_script,
"enqueueOnly": enqueue_only,
"maxWaitSeconds": max_wait,
"dependencies": dependencies,
"callbackURL": callback_url,
"reuseIfExists": not no_reuse,
"startupCommand": startup_command,
"owner_access_key": owner_access_key,
Expand Down
3 changes: 2 additions & 1 deletion src/ai/backend/client/func/session_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ async def create(
rqst.set_json(body)
async with rqst.fetch() as resp:
response = await resp.json()
return cls(response["id"], owner_access_key=owner_access_key)
template_id = ", ".join(data["id"] for data in response)
return cls(template_id, owner_access_key=owner_access_key)

@api_function
@classmethod
Expand Down
12 changes: 6 additions & 6 deletions src/ai/backend/manager/api/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,8 @@ async def query_userinfo(
return await _query_userinfo(
conn,
request["user"]["uuid"],
request["user"]["role"],
request["keypair"]["access_key"],
request["user"]["role"],
request["user"]["domain_name"],
request["keypair"]["resource_policy"],
params["domain"] or request["user"]["domain_name"],
Expand Down Expand Up @@ -384,10 +384,10 @@ async def _create(request: web.Request, params: dict[str, Any]) -> web.Response:
tx.AliasedKey(["name", "session_name", "clientSessionToken"], default=undefined)
>> "session_name": UndefChecker | t.Regexp(r"^(?=.{4,64}$)\w[\w.-]*\w$", re.ASCII),
tx.AliasedKey(["image", "lang"], default=undefined): UndefChecker | t.Null | t.String,
tx.AliasedKey(["arch", "architecture"], default=DEFAULT_IMAGE_ARCH)
>> "architecture": t.String,
tx.AliasedKey(["type", "sessionType"], default="interactive")
>> "session_type": tx.Enum(SessionTypes),
Comment on lines 386 to -390
Copy link
Contributor Author

@kimjinmyeong kimjinmyeong Aug 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When a default value is set in here, the configuration value of the template gets overwritten by this default value.

tx.AliasedKey(["arch", "architecture"], default=undefined)
>> "architecture": t.String | UndefChecker,
tx.AliasedKey(["type", "sessionType"], default=undefined)
>> "session_type": tx.Enum(SessionTypes) | UndefChecker,
tx.AliasedKey(["group", "groupName", "group_name"], default=undefined): (
UndefChecker | t.Null | t.String
),
Expand Down Expand Up @@ -470,7 +470,7 @@ async def create_from_template(request: web.Request, params: dict[str, Any]) ->

param_from_template = {
"image": template["spec"]["kernel"]["image"],
"architecture": template["spec"]["kernel"].get("architecture", DEFAULT_IMAGE_ARCH),
"architecture": template["spec"]["kernel"]["architecture"],
}
if "domain_name" in template_info:
param_from_template["domain"] = template_info["domain_name"]
Expand Down
15 changes: 9 additions & 6 deletions src/ai/backend/manager/api/session_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,18 @@ async def create(request: web.Request, params: Any) -> web.Response:
owner_access_key if owner_access_key != requester_access_key else "*",
)
root_ctx: RootContext = request.app["_root.context"]
resp = []
async with root_ctx.db.begin() as conn:
user_uuid, group_id, _ = await query_userinfo(request, params, conn)
log.debug("Params: {0}", params)
try:
body = json.loads(params["payload"])
except json.JSONDecodeError:
try:
body = yaml.safe_load(params["payload"])
body = yaml.safe_load_all(params["payload"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

except (yaml.YAMLError, yaml.MarkedYAMLError):
raise InvalidAPIParameters("Malformed payload")
for st in body["session_templates"]:
for st in body:
template_data = check_task_template(st["template"])
template_id = uuid.uuid4().hex
name = st["name"] if "name" in st else template_data["metadata"]["name"]
Expand All @@ -81,10 +82,12 @@ async def create(request: web.Request, params: Any) -> web.Response:
}
)
result = await conn.execute(query)
resp = {
"id": template_id,
"user": user_uuid if isinstance(user_uuid, str) else user_uuid.hex,
}
resp.append(
{
"id": template_id,
"user": user_uuid if isinstance(user_uuid, str) else user_uuid.hex,
}
)
assert result.rowcount == 1
return web.json_response(resp)

Expand Down