Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(BA-432): Enforce VFolder name length restriction through the API schema #3363

Merged
merged 17 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3363.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Enforce VFolder name length restriction through the API schema, not by the DB column constraint
29 changes: 29 additions & 0 deletions src/ai/backend/common/defs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from typing import Final

# Redis database IDs depending on purposes
Expand All @@ -10,3 +11,31 @@


DEFAULT_FILE_IO_TIMEOUT: Final = 10

_RESERVED_VFOLDER_PATTERNS = [r"^\.[a-z0-9]+rc$", r"^\.[a-z0-9]+_profile$"]
RESERVED_VFOLDERS = [
".terminfo",
".jupyter",
".tmux.conf",
".ssh",
"/bin",
"/boot",
"/dev",
"/etc",
"/lib",
"/lib64",
"/media",
"/mnt",
"/opt",
"/proc",
"/root",
"/run",
"/sbin",
"/srv",
"/sys",
"/tmp",
"/usr",
"/var",
"/home",
]
RESERVED_VFOLDER_PATTERNS = [re.compile(x) for x in _RESERVED_VFOLDER_PATTERNS]
25 changes: 25 additions & 0 deletions src/ai/backend/common/typed_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import core_schema

from .defs import RESERVED_VFOLDER_PATTERNS, RESERVED_VFOLDERS

TVariousDelta: TypeAlias = datetime.timedelta | relativedelta


Expand Down Expand Up @@ -156,3 +158,26 @@ def session_name_validator(s: str) -> str:

SessionName = Annotated[str, AfterValidator(session_name_validator)]
"""Validator with extended re.ASCII option to match session name string literal"""


def _vfolder_name_validator(name: str) -> str:
"""
Although the length constraint of the `vfolders.name` column is 128,
we limit the length to 64 in the create/rename API
because we append a timestamp of deletion to the name when VFolders are deleted.
"""
if (name_len := len(name)) > 64:
fregataa marked this conversation as resolved.
Show resolved Hide resolved
raise AssertionError(
f"The length of VFolder name should be shorter than 64. (len: {name_len})"
)
if name in RESERVED_VFOLDERS:
raise AssertionError(f"VFolder name '{name}' is reserved for internal operations")
for pattern in RESERVED_VFOLDER_PATTERNS:
if pattern.match(name):
raise AssertionError(
f"VFolder name '{name}' matches a reserved pattern (pattern: {pattern})"
)
return name


VFolderName = Annotated[str, AfterValidator(_vfolder_name_validator)]
104 changes: 61 additions & 43 deletions src/ai/backend/manager/api/vfolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from sqlalchemy.orm import load_only, selectinload

from ai.backend.common import msgpack, redis_helper
from ai.backend.common import typed_validators as tv
from ai.backend.common import validators as tx
from ai.backend.common.types import (
QuotaScopeID,
Expand Down Expand Up @@ -344,39 +345,52 @@ async def _wrapped(request: web.Request, *args: P.args, **kwargs: P.kwargs) -> w
return _wrapped


class CreateRequestModel(BaseModel):
name: tv.VFolderName = Field(
description="Name of the vfolder",
)
folder_host: str | None = Field(
validation_alias=AliasChoices("host", "folder_host"),
default=None,
)
usage_mode: VFolderUsageMode = Field(default=VFolderUsageMode.GENERAL)
permission: VFolderPermission = Field(default=VFolderPermission.READ_WRITE)
unmanaged_path: str | None = Field(
validation_alias=AliasChoices("unmanaged_path", "unmanagedPath"),
default=None,
)
group: str | uuid.UUID | None = Field(
validation_alias=AliasChoices("group", "groupId", "group_id"),
default=None,
)
cloneable: bool = Field(
default=False,
)


@auth_required
@server_status_required(ALL_ALLOWED)
@check_api_params(
t.Dict({
t.Key("name"): tx.Slug(allow_dot=True),
t.Key("host", default=None) >> "folder_host": t.String | t.Null,
t.Key("usage_mode", default="general"): tx.Enum(VFolderUsageMode) | t.Null,
t.Key("permission", default="rw"): tx.Enum(VFolderPermission) | t.Null,
tx.AliasedKey(["unmanaged_path", "unmanagedPath"], default=None): t.String | t.Null,
tx.AliasedKey(["group", "groupId", "group_id"], default=None): tx.UUID | t.String | t.Null,
t.Key("cloneable", default=False): t.Bool,
}),
)
async def create(request: web.Request, params: Any) -> web.Response:
@pydantic_params_api_handler(CreateRequestModel)
async def create(request: web.Request, params: CreateRequestModel) -> web.Response:
resp: Dict[str, Any] = {}
root_ctx: RootContext = request.app["_root.context"]
access_key = request["keypair"]["access_key"]
user_role = request["user"]["role"]
user_uuid: uuid.UUID = request["user"]["uuid"]
keypair_resource_policy = request["keypair"]["resource_policy"]
domain_name = request["user"]["domain_name"]
group_id_or_name = params["group"]
group_id_or_name = params.group
log.info(
"VFOLDER.CREATE (email:{}, ak:{}, vf:{}, vfh:{}, umod:{}, perm:{})",
request["user"]["email"],
access_key,
params["name"],
params["folder_host"],
params["usage_mode"].value,
params["permission"].value,
params.name,
params.folder_host,
params.usage_mode.value,
params.permission.value,
)
folder_host = params["folder_host"]
unmanaged_path = params["unmanaged_path"]
folder_host = params.folder_host
unmanaged_path = params.unmanaged_path
# Check if user is trying to created unmanaged vFolder
if unmanaged_path:
# Approve only if user is Admin or Superadmin
Expand All @@ -393,10 +407,8 @@ async def create(request: web.Request, params: Any) -> web.Response:

allowed_vfolder_types = await root_ctx.shared_config.get_vfolder_types()

if not verify_vfolder_name(params["name"]):
raise InvalidAPIParameters(f"{params['name']} is reserved for internal operations.")
if params["name"].startswith(".") and params["name"] != ".local":
if params["group"] is not None:
if params.name.startswith(".") and params.name != ".local":
if params.group is not None:
raise InvalidAPIParameters("dot-prefixed vfolders cannot be a group folder.")

group_uuid: uuid.UUID | None = None
Expand Down Expand Up @@ -486,17 +498,18 @@ async def create(request: web.Request, params: Any) -> web.Response:
)

if group_type == ProjectType.MODEL_STORE:
if params["permission"] != VFolderPermission.READ_WRITE:
if params.permission != VFolderPermission.READ_WRITE:
raise InvalidAPIParameters(
"Setting custom permission is not supported for model store vfolder"
)
if params["usage_mode"] != VFolderUsageMode.MODEL:
if params.usage_mode != VFolderUsageMode.MODEL:
raise InvalidAPIParameters(
"Only Model VFolder can be created under the model store project"
)

async with root_ctx.db.begin() as conn:
if not unmanaged_path:
assert folder_host is not None
await ensure_host_permission_allowed(
conn,
folder_host,
Expand Down Expand Up @@ -542,7 +555,7 @@ async def create(request: web.Request, params: Any) -> web.Response:

# Prevent creation of vfolder with duplicated name on all hosts.
extra_vf_conds = [
(vfolders.c.name == params["name"]),
(vfolders.c.name == params.name),
(vfolders.c.status.not_in(HARD_DELETED_VFOLDER_STATUSES)),
]
entries = await query_accessible_vfolders(
Expand All @@ -554,7 +567,7 @@ async def create(request: web.Request, params: Any) -> web.Response:
extra_vf_conds=(sa.and_(*extra_vf_conds)),
)
if len(entries) > 0:
raise VFolderAlreadyExists(extra_data=params["name"])
raise VFolderAlreadyExists(extra_data=params.name)
try:
folder_id = uuid.uuid4()
vfid = VFolderID(quota_scope_id, folder_id)
Expand All @@ -575,6 +588,7 @@ async def create(request: web.Request, params: Any) -> web.Response:
# },
# ):
# pass
assert folder_host is not None
options = {}
if max_quota_scope_size and max_quota_scope_size > 0:
options["initial_max_size_for_quota_scope"] = max_quota_scope_size
Expand All @@ -594,40 +608,40 @@ async def create(request: web.Request, params: Any) -> web.Response:

# By default model store VFolder should be considered as read only for every users but without the creator
if group_type == ProjectType.MODEL_STORE:
params["permission"] = VFolderPermission.READ_ONLY
params.permission = VFolderPermission.READ_ONLY

# TODO: include quota scope ID in the database
# TODO: include quota scope ID in the API response
insert_values = {
"id": vfid.folder_id.hex,
"name": params["name"],
"name": params.name,
"domain_name": domain_name,
"quota_scope_id": str(quota_scope_id),
"usage_mode": params["usage_mode"],
"permission": params["permission"],
"usage_mode": params.usage_mode,
"permission": params.permission,
"last_used": None,
"host": folder_host,
"creator": request["user"]["email"],
"ownership_type": VFolderOwnershipType(ownership_type),
"user": user_uuid if ownership_type == "user" else None,
"group": group_uuid if ownership_type == "group" else None,
"unmanaged_path": "",
"cloneable": params["cloneable"],
"cloneable": params.cloneable,
"status": VFolderOperationStatus.READY,
}
resp = {
"id": vfid.folder_id.hex,
"name": params["name"],
"name": params.name,
"quota_scope_id": str(quota_scope_id),
"host": folder_host,
"usage_mode": params["usage_mode"].value,
"permission": params["permission"].value,
"usage_mode": params.usage_mode.value,
"permission": params.permission.value,
"max_size": 0, # migrated to quota scopes, no longer valid
"creator": request["user"]["email"],
"ownership_type": ownership_type,
"user": str(user_uuid) if ownership_type == "user" else None,
"group": str(group_uuid) if ownership_type == "group" else None,
"cloneable": params["cloneable"],
"cloneable": params.cloneable,
"status": VFolderOperationStatus.READY,
}
if unmanaged_path:
Expand Down Expand Up @@ -1177,16 +1191,20 @@ async def get_used_bytes(request: web.Request, params: Any) -> web.Response:
return web.json_response(usage, status=200)


class RenameRequestModel(BaseModel):
new_name: tv.VFolderName = Field(
description="Name of the vfolder",
)


@auth_required
@server_status_required(ALL_ALLOWED)
@pydantic_params_api_handler(RenameRequestModel)
@with_vfolder_rows_resolved(VFolderPermission.OWNER_PERM)
@check_api_params(
t.Dict({
t.Key("new_name"): tx.Slug(allow_dot=True),
})
)
async def rename_vfolder(
request: web.Request, params: Any, row: Sequence[VFolderRow]
request: web.Request,
row: Sequence[VFolderRow],
params: RenameRequestModel,
) -> web.Response:
root_ctx: RootContext = request.app["_root.context"]
old_name = request.match_info["name"]
Expand All @@ -1195,7 +1213,7 @@ async def rename_vfolder(
user_role = request["user"]["role"]
user_uuid = request["user"]["uuid"]
resource_policy = request["keypair"]["resource_policy"]
new_name = params["new_name"]
new_name = params.new_name
allowed_vfolder_types = await root_ctx.shared_config.get_vfolder_types()
log.info(
"VFOLDER.RENAME (email:{}, ak:{}, vf.old:{}, vf.new:{})",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""extend length of vfolders.name column

Revision ID: ef9a7960d234
Revises: 0bb88d5a46bf
Create Date: 2025-01-03 16:07:11.407081

"""

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "ef9a7960d234"
down_revision = "0bb88d5a46bf"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.alter_column(
"vfolders",
"name",
existing_type=sa.VARCHAR(length=64),
type_=sa.String(length=128),
existing_nullable=False,
)


def downgrade() -> None:
op.alter_column(
"vfolders",
"name",
existing_type=sa.String(length=128),
type_=sa.VARCHAR(length=64),
existing_nullable=False,
)
19 changes: 17 additions & 2 deletions src/ai/backend/manager/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,7 +1107,13 @@ async def delete_vfolders(

:return: number of deleted rows
"""
from . import VFolderDeletionInfo, initiate_vfolder_deletion, vfolder_permissions, vfolders
from . import (
VFolderDeletionInfo,
VFolderOperationStatus,
initiate_vfolder_deletion,
vfolder_permissions,
vfolders,
)

async with engine.begin_session() as conn:
await conn.execute(
Expand All @@ -1116,7 +1122,16 @@ async def delete_vfolders(
result = await conn.execute(
sa.select([vfolders.c.id, vfolders.c.host, vfolders.c.quota_scope_id])
.select_from(vfolders)
.where(vfolders.c.user == user_uuid),
.where(
sa.and_(
vfolders.c.user == user_uuid,
vfolders.c.status.not_in(
VFolderOperationStatus.DELETE_ONGOING,
VFolderOperationStatus.DELETE_COMPLETE,
VFolderOperationStatus.DELETE_ERROR,
),
)
),
)
target_vfs = result.fetchall()

Expand Down
Loading
Loading