Skip to content

Commit

Permalink
support job priority
Browse files Browse the repository at this point in the history
  • Loading branch information
zubenkoivan committed May 7, 2022
1 parent 27f8be4 commit 5d73c75
Show file tree
Hide file tree
Showing 19 changed files with 705 additions and 144 deletions.
1 change: 1 addition & 0 deletions charts/platform-api-poller/templates/priority-class.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ apiVersion: scheduling.k8s.io/v1
kind: PriorityClass
metadata:
name: {{ include "platformApiPoller.fullname" . }}-job
labels: {{ include "platformApiPoller.labels.standard" . | nindent 4 }}
value: 100
globalDefault: false
8 changes: 8 additions & 0 deletions platform_api/cluster_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class OrchestratorConfig:
job_schedule_scaleup_timeout: float = 15 * 60

allow_privileged_mode: bool = False
allow_job_priority: bool = False

@property
def tpu_resources(self) -> Sequence[TPUResource]:
Expand All @@ -35,6 +36,13 @@ def tpu_ipv4_cidr_block(self) -> Optional[str]:
return None
return tpus[0].ipv4_cidr_block

@property
def has_scheduler_enabled_presets(self) -> bool:
for preset in self.presets:
if preset.scheduler_enabled:
return True
return False


@dataclass(frozen=True)
class IngressConfig:
Expand Down
6 changes: 4 additions & 2 deletions platform_api/cluster_config_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,10 @@ def _create_orchestrator_config(
OrchestratorConfig.job_schedule_scaleup_timeout,
),
allow_privileged_mode=orchestrator.get(
"allow_privileged_mode",
OrchestratorConfig.allow_privileged_mode,
"allow_privileged_mode", OrchestratorConfig.allow_privileged_mode
),
allow_job_priority=orchestrator.get(
"allow_job_priority", OrchestratorConfig.allow_job_priority
),
)

Expand Down
9 changes: 9 additions & 0 deletions platform_api/handlers/jobs_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from platform_api.orchestrator.job import (
JOB_USER_NAMES_SEPARATOR,
Job,
JobPriority,
JobRestartPolicy,
JobStatusItem,
JobStatusReason,
Expand Down Expand Up @@ -121,6 +122,10 @@ def _take_first(data: dict[str, Any]) -> dict[str, Any]:
t.Key("pass_config", optional=True, default=False): t.Bool,
t.Key("wait_for_jobs_quota", optional=True, default=False): t.Bool,
t.Key("privileged", optional=True, default=False): t.Bool,
t.Key(
"priority", optional=True, default=JobPriority.NORMAL.to_name()
): t.Enum(*(p.to_name() for p in JobPriority))
>> JobPriority.from_name,
t.Key("schedule_timeout", optional=True): t.Float(gte=1, lt=30 * 24 * 3600),
t.Key("max_run_time_minutes", optional=True): t.Int(gte=1),
t.Key("cluster_name", default=cluster_name): t.Atom(cluster_name),
Expand Down Expand Up @@ -275,6 +280,7 @@ def create_job_response_validator() -> t.Trafaret:
t.Key("max_run_time_minutes", optional=True): t.Int,
"restart_policy": t.String,
"privileged": t.Bool,
"priority": t.String,
}
)

Expand Down Expand Up @@ -442,6 +448,7 @@ def convert_job_to_job_response(job: Job) -> dict[str, Any]:
"logs_removed": job.logs_removed,
"total_price_credits": str(job.total_price_credits),
"price_credits_per_hour": str(job.price_credits_per_hour),
"priority": job.priority.to_name(),
}
if job.name:
response_payload["name"] = job.name
Expand Down Expand Up @@ -671,6 +678,7 @@ async def create_job(self, request: aiohttp.web.Request) -> aiohttp.web.Response
schedule_timeout = request_payload.get("schedule_timeout")
max_run_time_minutes = request_payload.get("max_run_time_minutes")
wait_for_jobs_quota = request_payload.get("wait_for_jobs_quota")
priority = request_payload["priority"]
job_request = JobRequest.create(container, description)
job, _ = await self._jobs_service.create_job(
job_request,
Expand All @@ -688,6 +696,7 @@ async def create_job(self, request: aiohttp.web.Request) -> aiohttp.web.Response
schedule_timeout=schedule_timeout,
max_run_time_minutes=max_run_time_minutes,
restart_policy=request_payload["restart_policy"],
priority=priority,
)
response_payload = convert_job_to_job_response(job)
self._job_response_validator.check(response_payload)
Expand Down
29 changes: 29 additions & 0 deletions platform_api/orchestrator/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ def is_pending(self) -> bool:
def is_running(self) -> bool:
return self.status.is_running

@property
def is_suspended(self) -> bool:
return self.status.is_suspended

@property
def is_finished(self) -> bool:
return self.status.is_finished
Expand Down Expand Up @@ -216,6 +220,10 @@ def started_at_str(self) -> Optional[str]:
def is_running(self) -> bool:
return self.last.is_running

@property
def is_suspended(self) -> bool:
return self.last.is_suspended

@property
def is_finished(self) -> bool:
return self.last.is_finished
Expand Down Expand Up @@ -258,6 +266,20 @@ def __repr__(self) -> str:
return self.__str__().__repr__()


@enum.unique
class JobPriority(enum.IntEnum):
LOW = -1
NORMAL = 0
HIGH = 1

def to_name(self) -> str:
return self.name.lower()

@classmethod
def from_name(cls, name: str) -> "JobPriority":
return cls[name.upper()]


@dataclass
class JobRecord:
request: JobRequest
Expand All @@ -278,6 +300,7 @@ class JobRecord:
internal_hostname_named: Optional[str] = None
schedule_timeout: Optional[float] = None
restart_policy: JobRestartPolicy = JobRestartPolicy.NEVER
priority: JobPriority = JobPriority.NORMAL

# Billing in credits
fully_billed: bool = False # True if job has final price
Expand Down Expand Up @@ -453,6 +476,7 @@ def to_primitive(self) -> dict[str, Any]:
"restart_policy": str(self.restart_policy),
"fully_billed": self.fully_billed,
"total_price_credits": str(self.total_price_credits),
"priority": int(self.priority),
}
if self.schedule_timeout:
result["schedule_timeout"] = self.schedule_timeout
Expand Down Expand Up @@ -511,6 +535,7 @@ def from_primitive(
restart_policy=JobRestartPolicy(
payload.get("restart_policy", str(cls.restart_policy))
),
priority=JobPriority(payload.get("priority", int(cls.priority))),
fully_billed=payload.get("fully_billed", True), # Default for old jobs
total_price_credits=Decimal(payload.get("total_price_credits", "0")),
last_billed=datetime.fromisoformat(payload["last_billed"])
Expand Down Expand Up @@ -878,6 +903,10 @@ def total_price_credits(self) -> Decimal:
def org_name(self) -> Optional[str]:
return self._record.org_name

@property
def priority(self) -> JobPriority:
return self._record.priority

def to_primitive(self) -> dict[str, Any]:
return self._record.to_primitive()

Expand Down
4 changes: 4 additions & 0 deletions platform_api/orchestrator/job_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,10 @@ def is_pending(self) -> bool:
def is_running(self) -> bool:
return self == self.RUNNING

@property
def is_suspended(self) -> bool:
return self == self.SUSPENDED

@property
def is_finished(self) -> bool:
return self in (self.SUCCEEDED, self.FAILED, self.CANCELLED)
Expand Down
9 changes: 8 additions & 1 deletion platform_api/orchestrator/jobs_poller.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@
from yarl import URL

from ..cluster import SingleClusterUpdater
from .job import JobRecord, JobRestartPolicy, JobStatusHistory, JobStatusItem
from .job import (
JobPriority,
JobRecord,
JobRestartPolicy,
JobStatusHistory,
JobStatusItem,
)
from .job_request import (
Container,
ContainerHTTPServer,
Expand Down Expand Up @@ -139,6 +145,7 @@ def _parse_container(data: Mapping[str, Any]) -> Container:
internal_hostname_named=payload.get("internal_hostname_named"),
schedule_timeout=payload.get("schedule_timeout"),
restart_policy=JobRestartPolicy(payload["restart_policy"]),
priority=JobPriority.from_name(payload["priority"]),
)


Expand Down
11 changes: 11 additions & 0 deletions platform_api/orchestrator/jobs_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from .job import (
Job,
JobPriority,
JobRecord,
JobRestartPolicy,
JobStatusHistory,
Expand Down Expand Up @@ -232,6 +233,7 @@ async def create_job(
schedule_timeout: Optional[float] = None,
max_run_time_minutes: Optional[int] = None,
restart_policy: JobRestartPolicy = JobRestartPolicy.NEVER,
priority: JobPriority = JobPriority.NORMAL,
) -> tuple[Job, Status]:
base_name = user.name.split("/", 1)[
0
Expand Down Expand Up @@ -295,6 +297,7 @@ async def create_job(
max_run_time_minutes=max_run_time_minutes,
restart_policy=restart_policy,
privileged=privileged,
priority=priority,
)
job_id = job_request.job_id

Expand All @@ -309,6 +312,14 @@ async def create_job(
f"Cluster {cluster_name} does not allow privileged jobs"
)

if (
record.priority != JobPriority.NORMAL
and not cluster_config.orchestrator.allow_job_priority
):
raise JobsServiceException(
f"Cluster {cluster_name} does not allow specifying job priority"
)

async with self._create_job_in_storage(record) as record:
job = self._make_job(record, cluster_config)
await self._prepare_job_hostnames(job, cluster_config.orchestrator)
Expand Down
Loading

0 comments on commit 5d73c75

Please sign in to comment.