diff --git a/changes/1393.feature.md b/changes/1393.feature.md new file mode 100644 index 0000000000..4255243d5b --- /dev/null +++ b/changes/1393.feature.md @@ -0,0 +1 @@ +Add `OptionalType` class as a new parameter type wrapper, allowing the client CLI to manage arguments of the `undefined` type. diff --git a/changes/1393.fix.md b/changes/1393.fix.md new file mode 100644 index 0000000000..c02810dba4 --- /dev/null +++ b/changes/1393.fix.md @@ -0,0 +1 @@ +Minor fixes to execute `backend.ai sesstpl create` and `backend.ai session create-from-template` commands diff --git a/src/ai/backend/client/cli/params.py b/src/ai/backend/client/cli/params.py index 7d80c2d43b..c9586596d3 100644 --- a/src/ai/backend/client/cli/params.py +++ b/src/ai/backend/client/cli/params.py @@ -5,6 +5,8 @@ import click +from ..types import undefined + class ByteSizeParamType(click.ParamType): name = "byte" @@ -166,3 +168,19 @@ def convert(self, arg, param, ctx): return arg.split(",") except ValueError as e: self.fail(repr(e), param, ctx) + + +class OptionalType(click.ParamType): + name = "Optional Type Wrapper" + + def __init__(self, type_: type) -> None: + super().__init__() + self.type_ = type_ + + def convert(self, value: Any, param, ctx): + try: + if value is None or value is undefined: + return value + return self.type_(value) + except ValueError: + self.fail(f"{value!r} is not valid `{self.type_}` or `undefined`", param, ctx) diff --git a/src/ai/backend/client/cli/session.py b/src/ai/backend/client/cli/session.py index bbb5c24e68..1a190d314e 100644 --- a/src/ai/backend/client/cli/session.py +++ b/src/ai/backend/client/cli/session.py @@ -31,7 +31,7 @@ from ..session import AsyncSession, Session from ..types import Undefined, undefined from . import events -from .params import CommaSeparatedListType +from .params import CommaSeparatedListType, OptionalType from .pretty import print_done, print_error, print_fail, print_info, print_wait, print_warn from .run import format_stats, prepare_env_arg, prepare_mount_arg, prepare_resource_arg from .ssh import container_ssh_ctx @@ -391,6 +391,7 @@ def _create_from_template_cmd(docs: str = None): "--name", "--client-token", metavar="NAME", + type=OptionalType(str), default=undefined, help="Specify a human-readable session name. If not set, a random hex string is used.", ) @@ -399,6 +400,7 @@ def _create_from_template_cmd(docs: str = None): "--owner", "--owner-access-key", metavar="ACCESS_KEY", + type=OptionalType(str), default=undefined, help="Set the owner of the target session explicitly.", ) @@ -418,11 +420,18 @@ def _create_from_template_cmd(docs: str = None): default=None, help="Let session to be started at a specific or relative time.", ) - @click.option("-i", "--image", default=undefined, help="Set compute_session image to run.") + @click.option( + "-i", + "--image", + type=OptionalType(str), + default=undefined, + help="Set compute_session image to run.", + ) @click.option( "-c", "--startup-command", metavar="COMMAND", + type=OptionalType(str), default=undefined, help="Set the command to execute for batch-type sessions.", ) @@ -434,7 +443,7 @@ def _create_from_template_cmd(docs: str = None): @click.option( "--max-wait", metavar="SECONDS", - type=int, + type=OptionalType(int), default=undefined, help="The maximum duration to wait until the session starts.", ) @@ -469,7 +478,10 @@ def _create_from_template_cmd(docs: str = None): ) # extra options @click.option( - "--tag", type=str, default=undefined, help="User-defined tag string to annotate sessions." + "--tag", + type=OptionalType(str), + default=undefined, + help="User-defined tag string to annotate sessions.", ) # resource spec @click.option( @@ -489,7 +501,7 @@ def _create_from_template_cmd(docs: str = None): @click.option( "--scaling-group", "--sgroup", - type=str, + type=OptionalType(str), default=undefined, help=( "The scaling group to execute session. If not specified, " @@ -512,7 +524,7 @@ def _create_from_template_cmd(docs: str = None): @click.option( "--cluster-size", metavar="NUMBER", - type=int, + type=OptionalType(int), default=undefined, help="The size of cluster in number of containers.", ) @@ -538,7 +550,8 @@ def _create_from_template_cmd(docs: str = None): "-g", "--group", metavar="GROUP_NAME", - default=None, + type=OptionalType(str), + default=undefined, help=( "Group name where the session is spawned. " "User should be a member of the group to execute the code." diff --git a/src/ai/backend/client/func/session.py b/src/ai/backend/client/func/session.py index 5d688838b8..a242907978 100644 --- a/src/ai/backend/client/func/session.py +++ b/src/ai/backend/client/func/session.py @@ -358,6 +358,7 @@ async def create_from_template( enqueue_only: bool | Undefined = undefined, max_wait: int | Undefined = undefined, dependencies: Sequence[str] = None, # cannot be stored in templates + callback_url: str | Undefined = undefined, no_reuse: bool | Undefined = undefined, image: str | Undefined = undefined, mounts: Union[List[str], Undefined] = undefined, @@ -468,6 +469,8 @@ async def create_from_template( "bootstrap_script": bootstrap_script, "enqueueOnly": enqueue_only, "maxWaitSeconds": max_wait, + "dependencies": dependencies, + "callbackURL": callback_url, "reuseIfExists": not no_reuse, "startupCommand": startup_command, "owner_access_key": owner_access_key, diff --git a/src/ai/backend/client/func/session_template.py b/src/ai/backend/client/func/session_template.py index ff77e64764..397ceeca17 100644 --- a/src/ai/backend/client/func/session_template.py +++ b/src/ai/backend/client/func/session_template.py @@ -32,7 +32,8 @@ async def create( rqst.set_json(body) async with rqst.fetch() as resp: response = await resp.json() - return cls(response["id"], owner_access_key=owner_access_key) + template_id = ", ".join(data["id"] for data in response) + return cls(template_id, owner_access_key=owner_access_key) @api_function @classmethod diff --git a/src/ai/backend/manager/api/session.py b/src/ai/backend/manager/api/session.py index 636128dc27..adcc2bb97b 100644 --- a/src/ai/backend/manager/api/session.py +++ b/src/ai/backend/manager/api/session.py @@ -298,8 +298,8 @@ async def query_userinfo( return await _query_userinfo( conn, request["user"]["uuid"], - request["user"]["role"], request["keypair"]["access_key"], + request["user"]["role"], request["user"]["domain_name"], request["keypair"]["resource_policy"], params["domain"] or request["user"]["domain_name"], @@ -384,10 +384,10 @@ async def _create(request: web.Request, params: dict[str, Any]) -> web.Response: tx.AliasedKey(["name", "session_name", "clientSessionToken"], default=undefined) >> "session_name": UndefChecker | t.Regexp(r"^(?=.{4,64}$)\w[\w.-]*\w$", re.ASCII), tx.AliasedKey(["image", "lang"], default=undefined): UndefChecker | t.Null | t.String, - tx.AliasedKey(["arch", "architecture"], default=DEFAULT_IMAGE_ARCH) - >> "architecture": t.String, - tx.AliasedKey(["type", "sessionType"], default="interactive") - >> "session_type": tx.Enum(SessionTypes), + tx.AliasedKey(["arch", "architecture"], default=undefined) + >> "architecture": t.String | UndefChecker, + tx.AliasedKey(["type", "sessionType"], default=undefined) + >> "session_type": tx.Enum(SessionTypes) | UndefChecker, tx.AliasedKey(["group", "groupName", "group_name"], default=undefined): ( UndefChecker | t.Null | t.String ), @@ -470,7 +470,7 @@ async def create_from_template(request: web.Request, params: dict[str, Any]) -> param_from_template = { "image": template["spec"]["kernel"]["image"], - "architecture": template["spec"]["kernel"].get("architecture", DEFAULT_IMAGE_ARCH), + "architecture": template["spec"]["kernel"]["architecture"], } if "domain_name" in template_info: param_from_template["domain"] = template_info["domain_name"] diff --git a/src/ai/backend/manager/api/session_template.py b/src/ai/backend/manager/api/session_template.py index 527fb47dfe..9294c49f69 100644 --- a/src/ai/backend/manager/api/session_template.py +++ b/src/ai/backend/manager/api/session_template.py @@ -50,6 +50,7 @@ async def create(request: web.Request, params: Any) -> web.Response: owner_access_key if owner_access_key != requester_access_key else "*", ) root_ctx: RootContext = request.app["_root.context"] + resp = [] async with root_ctx.db.begin() as conn: user_uuid, group_id, _ = await query_userinfo(request, params, conn) log.debug("Params: {0}", params) @@ -57,10 +58,10 @@ async def create(request: web.Request, params: Any) -> web.Response: body = json.loads(params["payload"]) except json.JSONDecodeError: try: - body = yaml.safe_load(params["payload"]) + body = yaml.safe_load_all(params["payload"]) except (yaml.YAMLError, yaml.MarkedYAMLError): raise InvalidAPIParameters("Malformed payload") - for st in body["session_templates"]: + for st in body: template_data = check_task_template(st["template"]) template_id = uuid.uuid4().hex name = st["name"] if "name" in st else template_data["metadata"]["name"] @@ -81,10 +82,12 @@ async def create(request: web.Request, params: Any) -> web.Response: } ) result = await conn.execute(query) - resp = { - "id": template_id, - "user": user_uuid if isinstance(user_uuid, str) else user_uuid.hex, - } + resp.append( + { + "id": template_id, + "user": user_uuid if isinstance(user_uuid, str) else user_uuid.hex, + } + ) assert result.rowcount == 1 return web.json_response(resp)