diff --git a/changes/1662.fix.md b/changes/1662.fix.md new file mode 100644 index 0000000000..8333b460be --- /dev/null +++ b/changes/1662.fix.md @@ -0,0 +1 @@ +Change the type of `status_history` from a mapping of status and timestamps to a list of log entries containing status and timestamps, to preserve timestamps when revisiting session/kernel statuses (e.g., after session restarts). \ No newline at end of file diff --git a/src/ai/backend/client/cli/session/lifecycle.py b/src/ai/backend/client/cli/session/lifecycle.py index d0c3185df1..9a6666c925 100644 --- a/src/ai/backend/client/cli/session/lifecycle.py +++ b/src/ai/backend/client/cli/session/lifecycle.py @@ -16,7 +16,6 @@ import inquirer import treelib from async_timeout import timeout -from dateutil.parser import isoparse from dateutil.tz import tzutc from faker import Faker from humanize import naturalsize @@ -25,6 +24,8 @@ from ai.backend.cli.main import main from ai.backend.cli.params import CommaSeparatedListType, OptionalType from ai.backend.cli.types import ExitCode, Undefined, undefined +from ai.backend.client.cli.extensions import pass_ctx_obj +from ai.backend.client.cli.types import CLIContext from ai.backend.common.arch import DEFAULT_IMAGE_ARCH from ai.backend.common.types import ClusterMode @@ -34,6 +35,7 @@ from ...output.fields import session_fields from ...output.types import FieldSpec from ...session import AsyncSession, Session +from ...utils import get_first_timestamp_for_status from .. import events from ..pretty import ( ProgressViewer, @@ -778,8 +780,9 @@ def logs(session_id, kernel: str | None): @session.command("status-history") +@pass_ctx_obj @click.argument("session_id", metavar="SESSID") -def status_history(session_id): +def status_history(ctx: CLIContext, session_id): """ Shows the status transition history of the compute session. @@ -791,31 +794,33 @@ def status_history(session_id): kernel = session.ComputeSession(session_id) try: status_history = kernel.get_status_history().get("result") - print_info(f"status_history: {status_history}") - if (preparing := status_history.get("preparing")) is None: - result = { - "result": { - "seconds": 0, - "microseconds": 0, - }, - } - elif (terminated := status_history.get("terminated")) is None: - alloc_time_until_now: timedelta = datetime.now(tzutc()) - isoparse(preparing) - result = { - "result": { - "seconds": alloc_time_until_now.seconds, - "microseconds": alloc_time_until_now.microseconds, - }, - } + + prev_time = None + + for status_record in status_history: + timestamp = datetime.fromisoformat(status_record["timestamp"]) + + if prev_time: + time_diff = timestamp - prev_time + status_record["time_elapsed"] = str(time_diff) + + prev_time = timestamp + + ctx.output.print_list( + status_history, + [FieldSpec("status"), FieldSpec("timestamp"), FieldSpec("time_elapsed")], + ) + + if (preparing := get_first_timestamp_for_status(status_history, "PREPARING")) is None: + elapsed = timedelta() + elif ( + terminated := get_first_timestamp_for_status(status_history, "TERMINATED") + ) is None: + elapsed = datetime.now(tzutc()) - preparing else: - alloc_time: timedelta = isoparse(terminated) - isoparse(preparing) - result = { - "result": { - "seconds": alloc_time.seconds, - "microseconds": alloc_time.microseconds, - }, - } - print_done(f"Actual Resource Allocation Time: {result}") + elapsed = terminated - preparing + + print_done(f"Actual Resource Allocation Time: {elapsed.total_seconds()}") except Exception as e: print_error(e) sys.exit(ExitCode.FAILURE) diff --git a/src/ai/backend/client/output/fields.py b/src/ai/backend/client/output/fields.py index 4a1905fc39..f0d9c4fa1a 100644 --- a/src/ai/backend/client/output/fields.py +++ b/src/ai/backend/client/output/fields.py @@ -185,6 +185,8 @@ FieldSpec("created_user_id"), FieldSpec("status"), FieldSpec("status_info"), + FieldSpec("status_history"), + FieldSpec("status_history_log"), FieldSpec("status_data", formatter=nested_dict_formatter), FieldSpec("status_changed", "Last Updated"), FieldSpec("created_at"), diff --git a/src/ai/backend/client/utils.py b/src/ai/backend/client/utils.py index b95fc5c9b8..6561eacedf 100644 --- a/src/ai/backend/client/utils.py +++ b/src/ai/backend/client/utils.py @@ -1,6 +1,10 @@ +from __future__ import annotations + import io import os +from datetime import datetime +from dateutil.parser import parse as dtparse from tqdm import tqdm @@ -48,3 +52,13 @@ def readinto1(self, *args, **kwargs): count = super().readinto1(*args, **kwargs) self.tqdm.set_postfix(file=self._filename, refresh=False) self.tqdm.update(count) + + +def get_first_timestamp_for_status( + status_history: list[dict[str, str]], + status: str, +) -> datetime | None: + for rec in status_history: + if rec["status"] == status: + return dtparse(rec["timestamp"]) + return None diff --git a/src/ai/backend/manager/api/resource.py b/src/ai/backend/manager/api/resource.py index 156b8d9b66..3e5d4e4035 100644 --- a/src/ai/backend/manager/api/resource.py +++ b/src/ai/backend/manager/api/resource.py @@ -465,7 +465,7 @@ async def _pipe_builder(r: Redis) -> RedisPipeline: "status": row["status"].name, "status_info": row["status_info"], "status_changed": str(row["status_changed"]), - "status_history": row["status_history"] or {}, + "status_history": row["status_history"], "cluster_mode": row["cluster_mode"], } if group_id not in objs_per_group: diff --git a/src/ai/backend/manager/api/schema.graphql b/src/ai/backend/manager/api/schema.graphql index 76df1d941e..0e875c44a8 100644 --- a/src/ai/backend/manager/api/schema.graphql +++ b/src/ai/backend/manager/api/schema.graphql @@ -583,7 +583,10 @@ type ComputeSession implements Item { status_changed: DateTime status_info: String status_data: JSONString - status_history: JSONString + status_history: JSONString @deprecated(reason: "Deprecated since 24.09.0; use `status_history_log`") + + """Added in 24.09.0""" + status_history_log: JSONString created_at: DateTime terminated_at: DateTime starts_at: DateTime diff --git a/src/ai/backend/manager/api/session.py b/src/ai/backend/manager/api/session.py index 9446381372..873edc1ec2 100644 --- a/src/ai/backend/manager/api/session.py +++ b/src/ai/backend/manager/api/session.py @@ -2186,6 +2186,36 @@ async def get_container_logs( return web.json_response(resp, status=200) +@server_status_required(READ_ALLOWED) +@auth_required +@check_api_params( + t.Dict({ + tx.AliasedKey(["session_name", "sessionName", "task_id", "taskId"]) >> "kernel_id": tx.UUID, + t.Key("owner_access_key", default=None): t.Null | t.String, + }) +) +async def get_status_history(request: web.Request, params: Any) -> web.Response: + root_ctx: RootContext = request.app["_root.context"] + session_name: str = request.match_info["session_name"] + requester_access_key, owner_access_key = await get_access_key_scopes(request, params) + log.info( + "GET_STATUS_HISTORY (ak:{}/{}, s:{})", requester_access_key, owner_access_key, session_name + ) + resp: dict[str, Mapping] = {"result": {}} + + async with root_ctx.db.begin_readonly_session() as db_sess: + compute_session = await SessionRow.get_session( + db_sess, + session_name, + owner_access_key, + allow_stale=True, + kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, + ) + resp["result"] = compute_session.status_history + + return web.json_response(resp, status=200) + + @server_status_required(READ_ALLOWED) @auth_required @check_api_params( @@ -2321,6 +2351,7 @@ def create_app( app.router.add_route("GET", "/{session_name}/direct-access-info", get_direct_access_info) ) cors.add(app.router.add_route("GET", "/{session_name}/logs", get_container_logs)) + cors.add(app.router.add_route("GET", "/{session_name}/status-history", get_status_history)) cors.add(app.router.add_route("POST", "/{session_name}/rename", rename_session)) cors.add(app.router.add_route("POST", "/{session_name}/interrupt", interrupt)) cors.add(app.router.add_route("POST", "/{session_name}/complete", complete)) diff --git a/src/ai/backend/manager/models/alembic/versions/8c8e90aebacd_replace_status_history_to_list.py b/src/ai/backend/manager/models/alembic/versions/8c8e90aebacd_replace_status_history_to_list.py new file mode 100644 index 0000000000..64b25e9781 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/8c8e90aebacd_replace_status_history_to_list.py @@ -0,0 +1,109 @@ +"""Replace sessions, kernels's status_history's type map with list + +Revision ID: 8c8e90aebacd +Revises: 59a622c31820 +Create Date: 2024-01-26 11:19:23.075014 + +""" + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "8c8e90aebacd" +down_revision = "59a622c31820" +branch_labels = None +depends_on = None + + +def upgrade(): + op.execute( + """ + WITH data AS ( + SELECT id, + (jsonb_each(status_history)).key AS status, + (jsonb_each(status_history)).value AS timestamp + FROM kernels + WHERE jsonb_typeof(status_history) = 'object' + ) + UPDATE kernels + SET status_history = ( + SELECT jsonb_agg( + jsonb_build_object('status', status, 'timestamp', timestamp) + ) + FROM data + WHERE data.id = kernels.id + AND jsonb_typeof(kernels.status_history) = 'object' + ); + """ + ) + op.execute("UPDATE kernels SET status_history = '[]'::jsonb WHERE status_history IS NULL;") + op.alter_column("kernels", "status_history", nullable=False, default=[]) + + op.execute( + """ + WITH data AS ( + SELECT id, + (jsonb_each(status_history)).key AS status, + (jsonb_each(status_history)).value AS timestamp + FROM sessions + WHERE jsonb_typeof(status_history) = 'object' + ) + UPDATE sessions + SET status_history = ( + SELECT jsonb_agg( + jsonb_build_object('status', status, 'timestamp', timestamp) + ) + FROM data + WHERE data.id = sessions.id + AND jsonb_typeof(sessions.status_history) = 'object' + ); + """ + ) + op.execute("UPDATE sessions SET status_history = '[]'::jsonb WHERE status_history IS NULL;") + op.alter_column("sessions", "status_history", nullable=False, default=[]) + + +def downgrade(): + op.execute( + """ + WITH data AS ( + SELECT id, + jsonb_object_agg( + elem->>'status', elem->>'timestamp' + ) AS new_status_history + FROM kernels, + jsonb_array_elements(status_history) AS elem + WHERE jsonb_typeof(status_history) = 'array' + GROUP BY id + ) + UPDATE kernels + SET status_history = data.new_status_history + FROM data + WHERE data.id = kernels.id + AND jsonb_typeof(kernels.status_history) = 'array'; + """ + ) + op.alter_column("kernels", "status_history", nullable=True, default=None) + op.execute("UPDATE kernels SET status_history = NULL WHERE status_history = '[]'::jsonb;") + + op.execute( + """ + WITH data AS ( + SELECT id, + jsonb_object_agg( + elem->>'status', elem->>'timestamp' + ) AS new_status_history + FROM sessions, + jsonb_array_elements(status_history) AS elem + WHERE jsonb_typeof(status_history) = 'array' + GROUP BY id + ) + UPDATE sessions + SET status_history = data.new_status_history + FROM data + WHERE data.id = sessions.id + AND jsonb_typeof(sessions.status_history) = 'array'; + """ + ) + op.alter_column("sessions", "status_history", nullable=True, default=None) + op.execute("UPDATE sessions SET status_history = NULL WHERE status_history = '[]'::jsonb;") diff --git a/src/ai/backend/manager/models/kernel.py b/src/ai/backend/manager/models/kernel.py index d78ef7c92f..a31b06458c 100644 --- a/src/ai/backend/manager/models/kernel.py +++ b/src/ai/backend/manager/models/kernel.py @@ -75,11 +75,17 @@ ) from .group import groups from .image import ImageNode, ImageRow -from .minilang import JSONFieldItem from .minilang.ordering import ColumnMapType, QueryOrderParser from .minilang.queryfilter import FieldSpecType, QueryFilterParser, enum_field_getter from .user import users -from .utils import ExtendedAsyncSAEngine, JSONCoalesceExpr, execute_with_retry, sql_json_merge +from .utils import ( + ExtendedAsyncSAEngine, + JSONCoalesceExpr, + execute_with_retry, + get_first_timestamp_for_status, + sql_append_dict_to_list, + sql_json_merge, +) if TYPE_CHECKING: from .gql import GraphQueryContext @@ -538,7 +544,14 @@ class KernelRow(Base): # // used to prevent duplication of SessionTerminatedEvent # } # } - status_history = sa.Column("status_history", pgsql.JSONB(), nullable=True, default=sa.null()) + status_history = sa.Column("status_history", pgsql.JSONB(), nullable=False, default=[]) + # status_history records all status changes + # e.g) + # [ + # {"status: "PENDING", "timestamp": "2022-10-22T10:22:30"}, + # {"status: "SCHEDULED", "timestamp": "2022-10-22T11:40:30"}, + # {"status: "PREPARING", "timestamp": "2022-10-25T10:22:30"} + # ] callback_url = sa.Column("callback_url", URLColumn, nullable=True, default=sa.null()) startup_command = sa.Column("startup_command", sa.Text, nullable=True) result = sa.Column( @@ -723,12 +736,9 @@ async def set_kernel_status( data = { "status": status, "status_changed": now, - "status_history": sql_json_merge( - kernels.c.status_history, - (), - { - status.name: now.isoformat(), # ["PULLING", "PREPARING"] - }, + "status_history": sql_append_dict_to_list( + KernelRow.status_history, + {"status": status.name, "timestamp": now.isoformat()}, ), } if status_data is not None: @@ -774,12 +784,9 @@ async def _update() -> bool: if update_data is None: update_values = { "status": new_status, - "status_history": sql_json_merge( + "status_history": sql_append_dict_to_list( KernelRow.status_history, - (), - { - new_status.name: now.isoformat(), - }, + {"status": new_status.name, "timestamp": now.isoformat()}, ), } else: @@ -921,7 +928,9 @@ def parse_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Mapping[str, Any]: hide_agents = False else: hide_agents = ctx.local_config["manager"]["hide-agents"] - status_history = row.status_history or {} + status_history = cast(list[dict[str, str]], row.status_history) + scheduled_at = get_first_timestamp_for_status(status_history, KernelStatus.SCHEDULED) + return { # identity "id": row.id, @@ -947,7 +956,7 @@ def parse_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Mapping[str, Any]: "created_at": row.created_at, "terminated_at": row.terminated_at, "starts_at": row.starts_at, - "scheduled_at": status_history.get(KernelStatus.SCHEDULED.name), + "scheduled_at": scheduled_at, "occupied_slots": row.occupied_slots.to_json(), # resources "agent": row.agent if not hide_agents else None, @@ -1001,7 +1010,7 @@ async def resolve_abusing_report( "created_at": ("created_at", dtparse), "status_changed": ("status_changed", dtparse), "terminated_at": ("terminated_at", dtparse), - "scheduled_at": (JSONFieldItem("status_history", KernelStatus.SCHEDULED.name), dtparse), + "scheduled_at": ("scheduled_at", None), } _queryorder_colmap: ColumnMapType = { @@ -1018,7 +1027,7 @@ async def resolve_abusing_report( "status_changed": ("status_info", None), "created_at": ("created_at", None), "terminated_at": ("terminated_at", None), - "scheduled_at": (JSONFieldItem("status_history", KernelStatus.SCHEDULED.name), None), + "scheduled_at": ("scheduled_at", None), } @classmethod diff --git a/src/ai/backend/manager/models/resource_usage.py b/src/ai/backend/manager/models/resource_usage.py index d8997867d7..0916cadf3c 100644 --- a/src/ai/backend/manager/models/resource_usage.py +++ b/src/ai/backend/manager/models/resource_usage.py @@ -1,8 +1,9 @@ from __future__ import annotations +import json from datetime import datetime from enum import Enum -from typing import Any, Mapping, Optional, Sequence +from typing import Any, Mapping, Optional, Sequence, cast from uuid import UUID import attrs @@ -14,6 +15,7 @@ from sqlalchemy.orm import joinedload, load_only from ai.backend.common import redis_helper +from ai.backend.common.json import ExtendedJSONEncoder from ai.backend.common.types import RedisConnectionInfo from ai.backend.common.utils import nmget @@ -21,7 +23,7 @@ from .kernel import LIVE_STATUS, RESOURCE_USAGE_KERNEL_STATUSES, KernelRow, KernelStatus from .session import SessionRow from .user import UserRow -from .utils import ExtendedAsyncSAEngine +from .utils import ExtendedAsyncSAEngine, get_first_timestamp_for_status __all__: Sequence[str] = ( "ResourceGroupUnit", @@ -516,7 +518,11 @@ async def _pipe_builder(r: Redis) -> RedisPipeline: session_row=kern.session, created_at=kern.created_at, terminated_at=kern.terminated_at, - scheduled_at=kern.status_history.get(KernelStatus.SCHEDULED.name), + scheduled_at=str( + get_first_timestamp_for_status( + cast(list[dict[str, str]], kern.status_history), KernelStatus.SCHEDULED + ) + ), used_time=kern.used_time, used_days=kern.get_used_days(local_tz), last_stat=stat_map[kern.id], @@ -534,7 +540,7 @@ async def _pipe_builder(r: Redis) -> RedisPipeline: images={kern.image}, agents={kern.agent}, status=kern.status.name, - status_history=kern.status_history, + status_history=json.dumps(kern.status_history, cls=ExtendedJSONEncoder), cluster_mode=kern.cluster_mode, status_info=kern.status_info, group_unit=ResourceGroupUnit.KERNEL, diff --git a/src/ai/backend/manager/models/session.py b/src/ai/backend/manager/models/session.py index 42c7967a2c..8cbdd06afb 100644 --- a/src/ai/backend/manager/models/session.py +++ b/src/ai/backend/manager/models/session.py @@ -70,7 +70,7 @@ ) from .group import GroupRow from .kernel import ComputeContainer, KernelRow, KernelStatus -from .minilang import ArrayFieldItem, JSONFieldItem +from .minilang import ArrayFieldItem from .minilang.ordering import ColumnMapType, QueryOrderParser from .minilang.queryfilter import FieldSpecType, QueryFilterParser, enum_field_getter from .user import UserRow @@ -79,6 +79,8 @@ JSONCoalesceExpr, agg_to_array, execute_with_retry, + get_first_timestamp_for_status, + sql_append_dict_to_list, sql_json_merge, ) @@ -679,7 +681,14 @@ class SessionRow(Base): # // used to prevent duplication of SessionTerminatedEvent # } # } - status_history = sa.Column("status_history", pgsql.JSONB(), nullable=True, default=sa.null()) + status_history = sa.Column("status_history", pgsql.JSONB(), nullable=False, default=[]) + # status_history records all status changes + # e.g) + # [ + # {"status: "PENDING", "timestamp": "2022-10-22T10:22:30"}, + # {"status: "SCHEDULED", "timestamp": "2022-10-22T11:40:30"}, + # {"status: "PREPARING", "timestamp": "2022-10-25T10:22:30"} + # ] callback_url = sa.Column("callback_url", URLColumn, nullable=True, default=sa.null()) startup_command = sa.Column("startup_command", sa.Text, nullable=True) @@ -725,13 +734,8 @@ def main_kernel(self) -> KernelRow: return kerns[0] @property - def status_changed(self) -> Optional[datetime]: - if self.status_history is None: - return None - try: - return datetime.fromisoformat(self.status_history[self.status.name]) - except KeyError: - return None + def status_changed(self) -> datetime | None: + return get_first_timestamp_for_status(self.status_history, self.status.name) @property def resource_opts(self) -> dict[str, Any]: @@ -805,12 +809,9 @@ async def _check_and_update() -> SessionStatus | None: update_values = { "status": determined_status, - "status_history": sql_json_merge( + "status_history": sql_append_dict_to_list( SessionRow.status_history, - (), - { - determined_status.name: now.isoformat(), - }, + {"status": determined_status.name, "timestamp": now.isoformat()}, ), } if determined_status in (SessionStatus.CANCELLED, SessionStatus.TERMINATED): @@ -911,12 +912,9 @@ async def set_session_status( now = status_changed_at data = { "status": status, - "status_history": sql_json_merge( + "status_history": sql_append_dict_to_list( SessionRow.status_history, - (), - { - status.name: datetime.now(tzutc()).isoformat(), - }, + {"status": status.name, "timestamp": datetime.now(tzutc()).isoformat()}, ), } if status_data is not None: @@ -1283,7 +1281,10 @@ class Meta: status_changed = GQLDateTime() status_info = graphene.String() status_data = graphene.JSONString() - status_history = graphene.JSONString() + status_history = graphene.JSONString( + deprecation_reason="Deprecated since 24.09.0; use `status_history_log`" + ) + status_history_log = graphene.JSONString(description="Added in 24.09.0") created_at = GQLDateTime() terminated_at = GQLDateTime() starts_at = GQLDateTime() @@ -1324,8 +1325,8 @@ def parse_row(cls, ctx: GraphQueryContext, row: Row) -> Mapping[str, Any]: full_name = getattr(row, "full_name") group_name = getattr(row, "group_name") row = row.SessionRow - status_history = row.status_history or {} - raw_scheduled_at = status_history.get(SessionStatus.SCHEDULED.name) + scheduled_at = get_first_timestamp_for_status(row.status_history, SessionStatus.SCHEDULED) + return { # identity "id": row.id, @@ -1357,13 +1358,11 @@ def parse_row(cls, ctx: GraphQueryContext, row: Row) -> Mapping[str, Any]: "status_changed": row.status_changed, "status_info": row.status_info, "status_data": row.status_data, - "status_history": status_history, + "status_history_log": row.status_history, "created_at": row.created_at, "terminated_at": row.terminated_at, "starts_at": row.starts_at, - "scheduled_at": ( - datetime.fromisoformat(raw_scheduled_at) if raw_scheduled_at is not None else None - ), + "scheduled_at": scheduled_at, "startup_command": row.startup_command, "result": row.result.name, # resources @@ -1443,6 +1442,10 @@ async def resolve_idle_checks(self, info: graphene.ResolveInfo) -> Mapping[str, graph_ctx: GraphQueryContext = info.context return await graph_ctx.idle_checker_host.get_idle_check_report(self.session_id) + # legacy + async def resolve_status_history(self, _info: graphene.ResolveInfo) -> Mapping[str, Any]: + return {item["status"]: item["timestamp"] for item in self.status_history_log} + _queryfilter_fieldspec: FieldSpecType = { "id": ("sessions_id", None), "type": ("sessions_session_type", enum_field_getter(SessionTypes)), @@ -1466,10 +1469,7 @@ async def resolve_idle_checks(self, info: graphene.ResolveInfo) -> Mapping[str, "created_at": ("sessions_created_at", dtparse), "terminated_at": ("sessions_terminated_at", dtparse), "starts_at": ("sessions_starts_at", dtparse), - "scheduled_at": ( - JSONFieldItem("sessions_status_history", SessionStatus.SCHEDULED.name), - dtparse, - ), + "scheduled_at": ("scheduled_at", None), "startup_command": ("sessions_startup_command", None), } @@ -1497,10 +1497,7 @@ async def resolve_idle_checks(self, info: graphene.ResolveInfo) -> Mapping[str, "created_at": ("sessions_created_at", None), "terminated_at": ("sessions_terminated_at", None), "starts_at": ("sessions_starts_at", None), - "scheduled_at": ( - JSONFieldItem("sessions_status_history", SessionStatus.SCHEDULED.name), - None, - ), + "scheduled_at": ("scheduled_at", None), } @classmethod diff --git a/src/ai/backend/manager/models/utils.py b/src/ai/backend/manager/models/utils.py index 4ad6d2a195..7756576263 100644 --- a/src/ai/backend/manager/models/utils.py +++ b/src/ai/backend/manager/models/utils.py @@ -6,6 +6,7 @@ import logging from contextlib import AbstractAsyncContextManager as AbstractAsyncCtxMgr from contextlib import asynccontextmanager as actxmgr +from datetime import datetime from typing import ( TYPE_CHECKING, Any, @@ -23,6 +24,7 @@ from urllib.parse import quote_plus as urlquote import sqlalchemy as sa +from dateutil.parser import parse as dtparse from sqlalchemy.dialects import postgresql as psql from sqlalchemy.engine import create_engine as _create_engine from sqlalchemy.exc import DBAPIError @@ -44,6 +46,10 @@ if TYPE_CHECKING: from ..config import LocalConfig + from . import ( + KernelStatus, + SessionStatus, + ) from ..defs import LockID from ..types import Sentinel @@ -452,6 +458,16 @@ def sql_json_merge( return expr +def sql_append_dict_to_list(col, arg: dict): + """ + Generate an SQLAlchemy column update expression that appends a dictionary to + the existing JSONB array. + """ + new_item_str = json.dumps(arg).replace("'", '"') + expr = col.op("||")(sa.text(f"'[{new_item_str}]'::jsonb")) + return expr + + def sql_json_increment( col, key: Tuple[str, ...], @@ -526,3 +542,17 @@ async def vacuum_db( vacuum_sql = "VACUUM FULL" if vacuum_full else "VACUUM" log.info(f"Perfoming {vacuum_sql} operation...") await conn.exec_driver_sql(vacuum_sql) + + +def get_first_timestamp_for_status( + status_history_records: list[dict[str, str]], + status: KernelStatus | SessionStatus, +) -> datetime | None: + """ + Get the first occurrence time of the given status from the status history records. + """ + + for status_history in status_history_records: + if status_history["status"] == status.name: + return dtparse(status_history["timestamp"]) + return None diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index 283b9f3877..9062305038 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -179,6 +179,7 @@ is_db_retry_error, reenter_txn, reenter_txn_session, + sql_append_dict_to_list, sql_json_merge, ) from .types import UserScope @@ -1006,9 +1007,12 @@ async def enqueue_session( session_data = { "id": session_id, "status": SessionStatus.PENDING, - "status_history": { - SessionStatus.PENDING.name: datetime.now(tzutc()).isoformat(), - }, + "status_history": [ + { + "status": SessionStatus.PENDING.name, + "timestamp": datetime.now(tzutc()).isoformat(), + } + ], "creation_id": session_creation_id, "name": session_name, "session_type": session_type, @@ -1029,9 +1033,12 @@ async def enqueue_session( kernel_shared_data = { "status": KernelStatus.PENDING, - "status_history": { - KernelStatus.PENDING.name: datetime.now(tzutc()).isoformat(), - }, + "status_history": [ + { + "status": KernelStatus.PENDING.name, + "timestamp": datetime.now(tzutc()).isoformat(), + }, + ], "session_creation_id": session_creation_id, "session_id": session_id, "session_name": session_name, @@ -1583,6 +1590,7 @@ async def finalize_running( created_info["resource_spec"]["allocations"] ) new_status = KernelStatus.RUNNING + update_data = { "occupied_slots": actual_allocs, "scaling_group": created_info["scaling_group"], @@ -1595,14 +1603,12 @@ async def finalize_running( "stdin_port": created_info["stdin_port"], "stdout_port": created_info["stdout_port"], "service_ports": service_ports, - "status_history": sql_json_merge( - kernels.c.status_history, - (), - { - new_status.name: datetime.now(tzutc()).isoformat(), - }, + "status_history": sql_append_dict_to_list( + KernelRow.status_history, + {"status": new_status.name, "timestamp": datetime.now(tzutc()).isoformat()}, ), } + self._kernel_actual_allocated_resources[kernel_id] = actual_allocs async def _update_session_occupying_slots(db_session: AsyncSession) -> None: @@ -1787,7 +1793,6 @@ async def _update_kernel() -> None: log.warning("_create_kernels_in_one_agent(s:{}) cancelled", scheduled_session.id) except Exception as e: ex = e - err_info = convert_to_status_data(ex, self.debug) # The agent has already cancelled or issued the destruction lifecycle event # for this batch of kernels. @@ -1805,16 +1810,14 @@ async def _update_failure() -> None: status_info=f"other-error ({ex!r})", status_changed=now, terminated_at=now, - status_history=sql_json_merge( + status_history=sql_append_dict_to_list( KernelRow.status_history, - (), { - KernelStatus.ERROR.name: ( - now.isoformat() - ), # ["PULLING", "PREPARING"] + "status": KernelStatus.ERROR.name, + "timestamp": now.isoformat(), }, ), - status_data=err_info, + status_data=convert_to_status_data(ex, self.debug), ) ) await db_sess.execute(query) @@ -2222,21 +2225,21 @@ async def _destroy(db_session: AsyncSession) -> SessionRow: kern.status = kernel_target_status kern.terminated_at = current_time kern.status_info = destroy_reason - kern.status_history = sql_json_merge( + kern.status_history = sql_append_dict_to_list( KernelRow.status_history, - (), { - kernel_target_status.name: current_time.isoformat(), + "status": kernel_target_status.name, + "timestamp": now.isoformat(), }, ) session_row.status = target_status session_row.terminated_at = current_time session_row.status_info = destroy_reason - session_row.status_history = sql_json_merge( + session_row.status_history = sql_append_dict_to_list( SessionRow.status_history, - (), { - target_status.name: current_time.isoformat(), + "status": target_status.name, + "timestamp": current_time.isoformat(), }, ) return session_row @@ -2416,11 +2419,11 @@ async def _update() -> None: "status_info": reason, "status_changed": now, "terminated_at": now, - "status_history": sql_json_merge( + "status_history": sql_append_dict_to_list( KernelRow.status_history, - (), { - KernelStatus.TERMINATED.name: now.isoformat(), + "status": KernelStatus.TERMINATED.name, + "timestamp": now.isoformat(), }, ), } @@ -2463,11 +2466,11 @@ async def _update() -> None: "kernel": {"exit_code": None}, "session": {"status": "terminating"}, }, - "status_history": sql_json_merge( + "status_history": sql_append_dict_to_list( KernelRow.status_history, - (), { - KernelStatus.TERMINATING.name: now.isoformat(), + "status": KernelStatus.TERMINATING.name, + "timestamp": now.isoformat(), }, ), } @@ -2633,11 +2636,11 @@ async def _restarting_session() -> None: sa.update(SessionRow) .values( status=SessionStatus.RESTARTING, - status_history=sql_json_merge( + status_history=sql_append_dict_to_list( SessionRow.status_history, - (), { - SessionStatus.RESTARTING.name: datetime.now(tzutc()).isoformat(), + "status": KernelStatus.RESTARTING.name, + "timestamp": datetime.now(tzutc()).isoformat(), }, ), ) @@ -2672,12 +2675,9 @@ async def _restart_kernel(kernel: KernelRow) -> None: "stdin_port": kernel_info["stdin_port"], "stdout_port": kernel_info["stdout_port"], "service_ports": kernel_info.get("service_ports", []), - "status_history": sql_json_merge( + "status_history": sql_append_dict_to_list( KernelRow.status_history, - (), - { - KernelStatus.RUNNING.name: now.isoformat(), - }, + {"status": KernelStatus.RUNNING.name, "timestamp": now.isoformat()}, ), } await KernelRow.update_kernel( @@ -3237,11 +3237,11 @@ async def _update_kernel() -> tuple[AccessKey, AgentId] | None: ("kernel",), {"exit_code": exit_code}, ), - "status_history": sql_json_merge( + "status_history": sql_append_dict_to_list( KernelRow.status_history, - (), { - KernelStatus.TERMINATED.name: now.isoformat(), + "status": KernelStatus.TERMINATED.name, + "timestamp": now.isoformat(), }, ), "terminated_at": now, diff --git a/src/ai/backend/manager/scheduler/dispatcher.py b/src/ai/backend/manager/scheduler/dispatcher.py index e65ae6c694..6c54527a01 100644 --- a/src/ai/backend/manager/scheduler/dispatcher.py +++ b/src/ai/backend/manager/scheduler/dispatcher.py @@ -92,7 +92,12 @@ recalc_concurrency_used, ) from ..models.utils import ExtendedAsyncSAEngine as SAEngine -from ..models.utils import execute_with_retry, sql_json_increment, sql_json_merge +from ..models.utils import ( + execute_with_retry, + sql_append_dict_to_list, + sql_json_increment, + sql_json_merge, +) from .predicates import ( check_concurrency, check_dependencies, @@ -371,12 +376,9 @@ async def _apply_cancellation( status=KernelStatus.CANCELLED, status_info=reason, terminated_at=now, - status_history=sql_json_merge( + status_history=sql_append_dict_to_list( KernelRow.status_history, - (), - { - KernelStatus.CANCELLED.name: now.isoformat(), - }, + {"status": KernelStatus.CANCELLED.name, "timestamp": now.isoformat()}, ), ) .where(KernelRow.session_id.in_(session_ids)) @@ -388,12 +390,9 @@ async def _apply_cancellation( status=SessionStatus.CANCELLED, status_info=reason, terminated_at=now, - status_history=sql_json_merge( + status_history=sql_append_dict_to_list( SessionRow.status_history, - (), - { - SessionStatus.CANCELLED.name: now.isoformat(), - }, + {"status": KernelStatus.CANCELLED.name, "timestamp": now.isoformat()}, ), ) .where(SessionRow.id.in_(session_ids)) @@ -962,11 +961,11 @@ async def _finalize_scheduled() -> None: status_info="scheduled", status_data={}, status_changed=now, - status_history=sql_json_merge( + status_history=sql_append_dict_to_list( KernelRow.status_history, - (), { - KernelStatus.SCHEDULED.name: now.isoformat(), + "status": KernelStatus.SCHEDULED.name, + "timestamp": now.isoformat(), }, ), ) @@ -984,12 +983,9 @@ async def _finalize_scheduled() -> None: status=SessionStatus.SCHEDULED, status_info="scheduled", status_data={}, - status_history=sql_json_merge( + status_history=sql_append_dict_to_list( SessionRow.status_history, - (), - { - SessionStatus.SCHEDULED.name: now.isoformat(), - }, + {"status": KernelStatus.SCHEDULED.name, "timestamp": now.isoformat()}, ), ) .where(SessionRow.id == sess_ctx.id) @@ -1199,11 +1195,11 @@ async def _finalize_scheduled() -> None: status_info="scheduled", status_data={}, status_changed=now, - status_history=sql_json_merge( + status_history=sql_append_dict_to_list( KernelRow.status_history, - (), { - KernelStatus.SCHEDULED.name: now.isoformat(), + "status": KernelStatus.SCHEDULED.name, + "timestamp": now.isoformat(), }, ), ) @@ -1222,12 +1218,9 @@ async def _finalize_scheduled() -> None: status_info="scheduled", status_data={}, # status_changed=now, - status_history=sql_json_merge( + status_history=sql_append_dict_to_list( SessionRow.status_history, - (), - { - SessionStatus.SCHEDULED.name: now.isoformat(), - }, + {"status": KernelStatus.SCHEDULED.name, "timestamp": now.isoformat()}, ), ) .where(SessionRow.id == sess_ctx.id) @@ -1288,11 +1281,11 @@ async def _mark_session_preparing() -> Sequence[SessionRow]: status_changed=now, status_info="", status_data={}, - status_history=sql_json_merge( + status_history=sql_append_dict_to_list( KernelRow.status_history, - (), { - KernelStatus.PREPARING.name: now.isoformat(), + "status": KernelStatus.PREPARING.name, + "timestamp": now.isoformat(), }, ), ) @@ -1308,11 +1301,11 @@ async def _mark_session_preparing() -> Sequence[SessionRow]: # status_changed=now, status_info="", status_data={}, - status_history=sql_json_merge( + status_history=sql_append_dict_to_list( SessionRow.status_history, - (), { - SessionStatus.PREPARING.name: now.isoformat(), + "status": KernelStatus.PREPARING.name, + "timestamp": now.isoformat(), }, ), ) @@ -1612,11 +1605,11 @@ async def _mark_session_cancelled() -> None: status_info="failed-to-start", status_data=status_data, terminated_at=now, - status_history=sql_json_merge( + status_history=sql_append_dict_to_list( KernelRow.status_history, - (), { - KernelStatus.CANCELLED.name: now.isoformat(), + "status": KernelStatus.CANCELLED.name, + "timestamp": now.isoformat(), }, ), ) @@ -1631,11 +1624,11 @@ async def _mark_session_cancelled() -> None: status_info="failed-to-start", status_data=status_data, terminated_at=now, - status_history=sql_json_merge( + status_history=sql_append_dict_to_list( SessionRow.status_history, - (), { - SessionStatus.CANCELLED.name: now.isoformat(), + "status": KernelStatus.CANCELLED.name, + "timestamp": now.isoformat(), }, ), ) diff --git a/src/ai/backend/manager/server.py b/src/ai/backend/manager/server.py index 84f7f7fdc5..93c275c840 100644 --- a/src/ai/backend/manager/server.py +++ b/src/ai/backend/manager/server.py @@ -555,7 +555,6 @@ async def hanging_session_scanner_ctx(root_ctx: RootContext) -> AsyncIterator[No import sqlalchemy as sa from dateutil.relativedelta import relativedelta - from dateutil.tz import tzutc from sqlalchemy.orm import load_only, noload from .config import session_hang_tolerance_iv @@ -573,13 +572,19 @@ async def _fetch_hanging_sessions( sa.select(SessionRow) .where(SessionRow.status == status) .where( - ( - datetime.now(tz=tzutc()) - - SessionRow.status_history[status.name].astext.cast( - sa.types.DateTime(timezone=True) + sa.text( + """ + EXISTS ( + SELECT 1 + FROM jsonb_array_elements(status_history) AS session_history + WHERE + session_history->>'status' = :status_name AND + ( + now() - CAST(session_history->>'timestamp' AS TIMESTAMP WITH TIME ZONE) + ) > :threshold ) - ) - > threshold + """ + ).bindparams(status_name=status.name, threshold=threshold) ) .options( noload("*"), diff --git a/src/ai/backend/manager/utils.py b/src/ai/backend/manager/utils.py index 091b9a767c..8a4c070a9b 100644 --- a/src/ai/backend/manager/utils.py +++ b/src/ai/backend/manager/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Optional from uuid import UUID diff --git a/tests/manager/models/test_utils.py b/tests/manager/models/test_utils.py index 233fd795c6..4331c1baab 100644 --- a/tests/manager/models/test_utils.py +++ b/tests/manager/models/test_utils.py @@ -1,14 +1,15 @@ import uuid -from datetime import datetime -from typing import Any, Dict, Optional, Union +from typing import Union import pytest import sqlalchemy import sqlalchemy as sa -from dateutil.tz import tzutc from ai.backend.manager.models import KernelRow, SessionRow, kernels -from ai.backend.manager.models.utils import agg_to_array, agg_to_str, sql_json_merge +from ai.backend.manager.models.utils import ( + agg_to_array, + agg_to_str, +) async def _select_kernel_row( @@ -20,225 +21,6 @@ async def _select_kernel_row( return kernel -@pytest.mark.asyncio -async def test_sql_json_merge__default(session_info): - session_id, conn = session_info - expected: Optional[Dict[str, Any]] = None - kernel = await _select_kernel_row(conn, session_id) - assert kernel is not None - assert kernel.status_history == expected - - -@pytest.mark.asyncio -async def test_sql_json_merge__deeper_object(session_info): - session_id, conn = session_info - timestamp = datetime.now(tzutc()).isoformat() - expected = { - "kernel": { - "session": { - "PENDING": timestamp, - "PREPARING": timestamp, - }, - }, - } - query = ( - kernels.update() - .values({ - "status_history": sql_json_merge( - kernels.c.status_history, - ("kernel", "session"), - { - "PENDING": timestamp, - "PREPARING": timestamp, - }, - ), - }) - .where(kernels.c.session_id == session_id) - ) - await conn.execute(query) - kernel = await _select_kernel_row(conn, session_id) - assert kernel is not None - assert kernel.status_history == expected - - -@pytest.mark.asyncio -async def test_sql_json_merge__append_values(session_info): - session_id, conn = session_info - timestamp = datetime.now(tzutc()).isoformat() - expected = { - "kernel": { - "session": { - "PENDING": timestamp, - "PREPARING": timestamp, - "TERMINATED": timestamp, - "TERMINATING": timestamp, - }, - }, - } - query = ( - kernels.update() - .values({ - "status_history": sql_json_merge( - kernels.c.status_history, - ("kernel", "session"), - { - "PENDING": timestamp, - "PREPARING": timestamp, - }, - ), - }) - .where(kernels.c.session_id == session_id) - ) - await conn.execute(query) - query = ( - kernels.update() - .values({ - "status_history": sql_json_merge( - kernels.c.status_history, - ("kernel", "session"), - { - "TERMINATING": timestamp, - "TERMINATED": timestamp, - }, - ), - }) - .where(kernels.c.session_id == session_id) - ) - await conn.execute(query) - kernel = await _select_kernel_row(conn, session_id) - assert kernel is not None - assert kernel.status_history == expected - - -@pytest.mark.asyncio -async def test_sql_json_merge__kernel_status_history(session_info): - session_id, conn = session_info - timestamp = datetime.now(tzutc()).isoformat() - expected = { - "PENDING": timestamp, - "PREPARING": timestamp, - "TERMINATING": timestamp, - "TERMINATED": timestamp, - } - query = ( - kernels.update() - .values({ - # "status_history": sqlalchemy.func.coalesce(sqlalchemy.text("'{}'::jsonb")).concat( - # sqlalchemy.func.cast( - # {"PENDING": timestamp, "PREPARING": timestamp}, - # sqlalchemy.dialects.postgresql.JSONB, - # ), - # ), - "status_history": sql_json_merge( - kernels.c.status_history, - (), - { - "PENDING": timestamp, - "PREPARING": timestamp, - }, - ), - }) - .where(kernels.c.session_id == session_id) - ) - await conn.execute(query) - query = ( - kernels.update() - .values({ - "status_history": sql_json_merge( - kernels.c.status_history, - (), - { - "TERMINATING": timestamp, - "TERMINATED": timestamp, - }, - ), - }) - .where(kernels.c.session_id == session_id) - ) - await conn.execute(query) - kernel = await _select_kernel_row(conn, session_id) - assert kernel is not None - assert kernel.status_history == expected - - -@pytest.mark.asyncio -async def test_sql_json_merge__mixed_formats(session_info): - session_id, conn = session_info - timestamp = datetime.now(tzutc()).isoformat() - expected = { - "PENDING": timestamp, - "kernel": { - "PREPARING": timestamp, - }, - } - query = ( - kernels.update() - .values({ - "status_history": sql_json_merge( - kernels.c.status_history, - (), - { - "PENDING": timestamp, - }, - ), - }) - .where(kernels.c.session_id == session_id) - ) - await conn.execute(query) - kernel = await _select_kernel_row(conn, session_id) - query = ( - kernels.update() - .values({ - "status_history": sql_json_merge( - kernels.c.status_history, - ("kernel",), - { - "PREPARING": timestamp, - }, - ), - }) - .where(kernels.c.session_id == session_id) - ) - await conn.execute(query) - kernel = await _select_kernel_row(conn, session_id) - assert kernel is not None - assert kernel.status_history == expected - - -@pytest.mark.asyncio -async def test_sql_json_merge__json_serializable_types(session_info): - session_id, conn = session_info - expected = { - "boolean": True, - "integer": 10101010, - "float": 1010.1010, - "string": "10101010", - # "bytes": b"10101010", - "list": [ - 10101010, - "10101010", - ], - "dict": { - "10101010": 10101010, - }, - } - query = ( - kernels.update() - .values({ - "status_history": sql_json_merge( - kernels.c.status_history, - (), - expected, - ), - }) - .where(kernels.c.session_id == session_id) - ) - await conn.execute(query) - kernel = await _select_kernel_row(conn, session_id) - assert kernel is not None - assert kernel.status_history == expected - - @pytest.mark.asyncio async def test_agg_to_str(session_info): session_id, conn = session_info