Skip to content

Commit

Permalink
feat(session): added new feature to delete sessions with name prefix,…
Browse files Browse the repository at this point in the history
… kind and status

fixes #37
  • Loading branch information
shinybrar committed Oct 25, 2024
1 parent 41e1886 commit 056254b
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 72 deletions.
65 changes: 13 additions & 52 deletions skaha/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from base64 import b64encode
from os import environ
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Literal, Optional

from pydantic import BaseModel, Field, field_validator, model_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from typing_extensions import Self

KINDS: List[str] = ["desktop", "notebook", "carta", "headless"]
STATUS: List[str] = ["Pending", "Running", "Terminating", "Succeeded", "Error"]
VIEW: List[str] = ["all"]
KINDS = Literal["desktop", "notebook", "carta", "headless"]
STATUS = Literal["Pending", "Running", "Terminating", "Succeeded", "Error"]
VIEW = Literal["all"]


class CreateSpec(BaseModel):
Expand All @@ -32,7 +32,7 @@ class CreateSpec(BaseModel):
)
cores: int = Field(1, description="Number of cores.", ge=1, le=256)
ram: int = Field(4, description="Amount of RAM (GB).", ge=1, le=512)
kind: str = Field(
kind: KINDS = Field(
..., description="Type of skaha session.", examples=["headless", "notebook"]
)
gpus: Optional[int] = Field(None, description="Number of GPUs.", ge=1, le=28)
Expand All @@ -47,6 +47,8 @@ class CreateSpec(BaseModel):
1, description="Number of sessions to launch.", ge=1, le=256, exclude=True
)

model_config = ConfigDict(validate_assignment=True, populate_by_name=True)

# Validate that cmd, args and env are only used with headless sessions.
@model_validator(mode="after")
def validate_headless(self) -> Self:
Expand All @@ -58,7 +60,6 @@ def validate_headless(self) -> Self:
Returns:
Dict[str, Any]: Validated values.
"""
assert self.kind in KINDS, f"kind must be one of: {KINDS}"
if self.cmd or self.args or self.env:
assert (
self.kind == "headless"
Expand All @@ -76,55 +77,15 @@ class FetchSpec(BaseModel):
object: Pydantic BaseModel object.
"""

kind: Optional[str] = Field(
None, description="Type of skaha session.", examples=["headless"]
kind: Optional[KINDS] = Field(
None, description="Type of skaha session.", examples=["headless"], alias="type"
)
status: Optional[str] = Field(
status: Optional[STATUS] = Field(
None, description="Status of the session.", examples=["Running"]
)
view: Optional[str] = Field(None, description="Number of views.", examples=["all"])

@field_validator("kind")
@classmethod
def validate_kind(cls, value: str) -> str:
"""Validate kind.
Args:
value (str): Value to validate.
Returns:
str: Validated value.
"""
assert value in KINDS, f"kind must be one of: {KINDS}"
return value

@field_validator("status")
@classmethod
def validate_status(cls, value: str) -> str:
"""Validate status.
Args:
value (str): Value to validate.
Returns:
str: Validated value.
"""
assert value in STATUS, f"status must be one of: {STATUS}"
return value

@field_validator("view")
@classmethod
def validate_view(cls, value: str) -> str:
"""Validate view.
view: Optional[VIEW] = Field(None, description="Number of views.", examples=["all"])

Args:
value (str): Value to validate.
Returns:
str: Validated value.
"""
assert value in VIEW, f"views must be one of: {VIEW}"
return value
model_config = ConfigDict(validate_assignment=True, populate_by_name=True)


class ContainerRegistry(BaseModel):
Expand Down
65 changes: 50 additions & 15 deletions skaha/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing_extensions import Self

from skaha.client import SkahaClient
from skaha.models import CreateSpec, FetchSpec
from skaha.models import KINDS, STATUS, VIEW, CreateSpec, FetchSpec
from skaha.utils import convert, logs
from skaha.utils.threaded import scale

Expand Down Expand Up @@ -37,21 +37,21 @@ def _set_server(self) -> Self:
"""Sets the server path after validation."""
suffix = "session"
self.server = f"{self.server}/{self.version}/{suffix}" # type: ignore
log.debug(f"Server set to {self.server}")
log.debug("Server set to %s", self.server)
return self

def fetch(
self,
kind: Optional[str] = None,
status: Optional[str] = None,
view: Optional[str] = None,
kind: Optional[KINDS] = None,
status: Optional[STATUS] = None,
view: Optional[VIEW] = None,
) -> List[Dict[str, str]]:
"""List open sessions for the user.
Args:
kind (str, optional): Session kind. Defaults to None.
status (str, optional): Session status. Defaults to None.
view (str, optional): Session view level. Defaults to None.
kind (Optional[KINDS], optional): Session kind. Defaults to None.
status (Optional[STATUS], optional): Session status. Defaults to None.
view (Optional[VIEW], optional): Session view level. Defaults to None.
Notes:
By default, only the calling user's sessions are listed. If views is
Expand Down Expand Up @@ -84,12 +84,15 @@ def fetch(
'startTime': '2222-12-07T05:45:58Z'},
...]
"""
values: Dict[str, str] = {}
values: Dict[str, Any] = {}
for key, value in {"kind": kind, "status": status, "view": view}.items():
if value:
values[key] = value
spec = FetchSpec(**values)
parameters = spec.model_dump(exclude_none=True)
# Kind is an alias for type in the API. It is renamed in the Python Client
# to avoid conflicts with the built-in type function. by_alias true,
# returns, {"type": "headless"} instead of {"kind": "headless"}
parameters = spec.model_dump(exclude_none=True, by_alias=True)
log.debug(parameters)
response: Response = self.session.get(url=self.server, params=parameters) # type: ignore # noqa: E501
response.raise_for_status()
Expand Down Expand Up @@ -117,7 +120,7 @@ def stats(self) -> Dict[str, Any]:
response.raise_for_status()
return response.json()

def info(self, id: Union[List[str], str]) -> List[Dict[str, Any]]:
def info(self, ids: Union[List[str], str]) -> List[Dict[str, Any]]:
"""Get information about session[s].
Args:
Expand All @@ -131,11 +134,11 @@ def info(self, id: Union[List[str], str]) -> List[Dict[str, Any]]:
>>> session.info(id=["hjko98yghj", "ikvp1jtp"])
"""
# Convert id to list if it is a string
if isinstance(id, str):
id = [id]
if isinstance(ids, str):
ids = [ids]
parameters: Dict[str, str] = {"view": "event"}
arguments: List[Any] = []
for value in id:
for value in ids:
arguments.append({"url": f"{self.server}/{value}", "params": parameters})
loop = get_event_loop()
results = loop.run_until_complete(scale(self.session.get, arguments))
Expand Down Expand Up @@ -185,7 +188,7 @@ def create(
image: str,
cores: int = 2,
ram: int = 4,
kind: str = "headless",
kind: KINDS = "headless",
gpu: Optional[int] = None,
cmd: Optional[str] = None,
args: Optional[str] = None,
Expand Down Expand Up @@ -295,3 +298,35 @@ def destroy(self, ids: Union[str, List[str]]) -> Dict[str, bool]:
log.error(err)
responses[identity] = False
return responses

def destroy_with(
self, prefix: str, kind: KINDS = "headless", status: STATUS = "Succeeded"
) -> Dict[str, bool]:
"""Destroy skaha session[s] matching search criteria.
Args:
prefix (str): Prefix to match in the session name.
kind (KINDS): Type of skaha session. Defaults to "headless".
status (STATUS): Status of the session. Defaults to "Succeeded".
Returns:
Dict[str, bool]: A dictionary of session IDs
and a bool indicating if the session was destroyed.
Notes:
The prefix is case-sensitive.
This method is useful for destroying multiple sessions at once.
Examples:
>>> session.destroy_with(prefix="test")
>>> session.destroy_with(prefix="test", kind="desktop")
>>> session.destroy_with(prefix="test", kind="headless", status="Running")
"""
sessions = self.fetch(kind=kind, status=status)
ids: List[str] = []
for session in sessions:
if session["name"].startswith(prefix):
ids.append(session["id"])
return self.destroy(ids)
10 changes: 5 additions & 5 deletions tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,19 @@ def test_fetch_with_kind(session: Session):
def test_fetch_malformed_kind(session: Session):
"""Test fetching images with malformed kind."""
with pytest.raises(ValidationError):
session.fetch(kind="invalid")
session.fetch(kind="invalid") # type: ignore


def test_fetch_with_malformed_view(session: Session):
"""Test fetching images with malformed view."""
with pytest.raises(ValidationError):
session.fetch(view="invalid")
session.fetch(view="invalid") # type: ignore


def test_fetch_with_malformed_status(session: Session):
"""Test fetching images with malformed status."""
with pytest.raises(ValidationError):
session.fetch(status="invalid")
session.fetch(status="invalid") # type: ignore


def test_session_stats(session: Session):
Expand All @@ -59,7 +59,7 @@ def test_create_session_with_malformed_kind(session: Session, name: str):
with pytest.raises(ValidationError):
session.create(
name=name,
kind="invalid",
kind="invalid", # type: ignore
image="ubuntu:latest",
cmd="bash",
replicas=1,
Expand Down Expand Up @@ -130,5 +130,5 @@ def test_session_logs(session: Session, name: str):
def test_delete_session(session: Session, name: str):
"""Test deleting a session."""
# Delete the session
deletion = session.destroy(pytest.IDENTITY[0]) # type: ignore
deletion = session.destroy_with(prefix=name) # type: ignore
assert deletion == {pytest.IDENTITY[0]: True} # type: ignore

0 comments on commit 056254b

Please sign in to comment.