diff --git a/platform_api/cluster_config.py b/platform_api/cluster_config.py index b17f31e06..aa4b8adc0 100644 --- a/platform_api/cluster_config.py +++ b/platform_api/cluster_config.py @@ -1,11 +1,66 @@ from collections.abc import Sequence from dataclasses import dataclass +from datetime import time, tzinfo from typing import Optional +from zoneinfo import ZoneInfo from yarl import URL from .resource import Preset, ResourcePoolType, TPUResource +UTC = ZoneInfo("UTC") + + +@dataclass(frozen=True) +class EnergySchedulePeriod: + # ISO 8601 weekday number (1-7) + weekday: int + start_time: time + end_time: time + + def __post_init__(self) -> None: + if not self.start_time.tzinfo or not self.end_time.tzinfo: + raise ValueError("start_time and end_time must have tzinfo") + if self.end_time == time.min.replace(tzinfo=self.end_time.tzinfo): + object.__setattr__( + self, "end_time", time.max.replace(tzinfo=self.end_time.tzinfo) + ) + if not 1 <= self.weekday <= 7: + raise ValueError("weekday must be in range 1-7") + if self.start_time >= self.end_time: + raise ValueError("start_time must be less than end_time") + + @classmethod + def create_full_day( + cls, *, weekday: int, timezone: tzinfo + ) -> "EnergySchedulePeriod": + return cls( + weekday=weekday, + start_time=time.min.replace(tzinfo=timezone), + end_time=time.max.replace(tzinfo=timezone), + ) + + +@dataclass(frozen=True) +class EnergySchedule: + name: str + periods: Sequence[EnergySchedulePeriod] = () + + @classmethod + def create_default(cls, *, timezone: tzinfo) -> "EnergySchedule": + return cls( + name="default", + periods=[ + EnergySchedulePeriod.create_full_day(weekday=weekday, timezone=timezone) + for weekday in range(1, 8) + ], + ) + + +@dataclass(frozen=True) +class EnergyConfig: + schedules: Sequence[EnergySchedule] = (EnergySchedule.create_default(timezone=UTC),) + @dataclass(frozen=True) class OrchestratorConfig: @@ -67,3 +122,5 @@ class ClusterConfig: name: str orchestrator: OrchestratorConfig ingress: IngressConfig + timezone: tzinfo = UTC + energy: EnergyConfig = EnergyConfig() diff --git a/platform_api/cluster_config_factory.py b/platform_api/cluster_config_factory.py index 315c6ff4c..b1c8ffe53 100644 --- a/platform_api/cluster_config_factory.py +++ b/platform_api/cluster_config_factory.py @@ -1,12 +1,21 @@ import logging from collections.abc import Sequence +from datetime import time, tzinfo from decimal import Decimal from typing import Any, Optional +from zoneinfo import ZoneInfo import trafaret as t from yarl import URL -from .cluster_config import ClusterConfig, IngressConfig, OrchestratorConfig +from .cluster_config import ( + ClusterConfig, + EnergyConfig, + EnergySchedule, + EnergySchedulePeriod, + IngressConfig, + OrchestratorConfig, +) from .resource import Preset, ResourcePoolType, TPUPreset, TPUResource _cluster_config_validator = t.Dict({"name": t.String}).allow_extra("*") @@ -31,10 +40,13 @@ def create_cluster_configs( def create_cluster_config(self, payload: dict[str, Any]) -> Optional[ClusterConfig]: try: _cluster_config_validator.check(payload) + timezone = self._create_timezone(payload.get("timezone")) return ClusterConfig( name=payload["name"], orchestrator=self._create_orchestrator_config(payload), ingress=self._create_ingress_config(payload), + timezone=timezone, + energy=self._create_energy_config(payload, timezone=timezone), ) except t.DataError as err: logging.warning(f"failed to parse cluster config: {err}") @@ -145,3 +157,44 @@ def _create_tpu_resource( types=tuple(payload["types"]), software_versions=tuple(payload["software_versions"]), ) + + def _create_timezone(self, name: Optional[str]) -> tzinfo: + if not name: + return ClusterConfig.timezone + try: + return ZoneInfo(name) + except Exception: + raise ValueError(f"invalid timezone: {name}") + + def _create_energy_schedule_period( + self, payload: dict[str, Any], *, timezone: tzinfo + ) -> EnergySchedulePeriod: + start_time = time.fromisoformat(payload["start_time"]).replace(tzinfo=timezone) + end_time = time.fromisoformat(payload["end_time"]).replace(tzinfo=timezone) + return EnergySchedulePeriod( + weekday=payload["weekday"], + start_time=start_time, + end_time=end_time, + ) + + def _create_energy_schedule( + self, payload: dict[str, Any], timezone: tzinfo + ) -> EnergySchedule: + return EnergySchedule( + name=payload["name"], + periods=[ + self._create_energy_schedule_period(p, timezone=timezone) + for p in payload["periods"] + ], + ) + + def _create_energy_config( + self, payload: dict[str, Any], *, timezone: tzinfo + ) -> EnergyConfig: + schedules = { + schedule.name: schedule + for s in payload.get("energy", {}).get("schedules", []) + if (schedule := self._create_energy_schedule(s, timezone=timezone)) + } + schedules["default"] = EnergySchedule.create_default(timezone=timezone) + return EnergyConfig(schedules=list(schedules.values())) diff --git a/tests/unit/test_cluster_config_factory.py b/tests/unit/test_cluster_config_factory.py index 888b9416b..3c60c4509 100644 --- a/tests/unit/test_cluster_config_factory.py +++ b/tests/unit/test_cluster_config_factory.py @@ -1,10 +1,19 @@ from collections.abc import Sequence +from datetime import time from decimal import Decimal from typing import Any +from unittest import mock +from zoneinfo import ZoneInfo import pytest from yarl import URL +from platform_api.cluster_config import ( + UTC, + EnergyConfig, + EnergySchedule, + EnergySchedulePeriod, +) from platform_api.cluster_config_factory import ClusterConfigFactory from platform_api.resource import GKEGPUModels, Preset, TPUPreset, TPUResource @@ -299,6 +308,11 @@ def test_valid_cluster_config( ) assert orchestrator.tpu_ipv4_cidr_block == "1.1.1.1/32" + assert cluster.timezone == UTC + assert cluster.energy == EnergyConfig( + schedules=[EnergySchedule.create_default(timezone=UTC)] + ) + def test_orchestrator_resource_presets_default( self, clusters_payload: Sequence[dict[str, Any]] ) -> None: @@ -390,3 +404,88 @@ def test_factory_skips_invalid_cluster_configs( clusters = factory.create_cluster_configs(clusters_payload) assert len(clusters) == 1 + + def test_energy(self, clusters_payload: Sequence[dict[str, Any]]) -> None: + clusters_payload[0]["timezone"] = "Europe/Kyiv" + clusters_payload[0]["energy"] = { + "schedules": [ + { + "name": "default", + "periods": [ + { + "weekday": 1, + "start_time": "00:00", + "end_time": "06:00", + }, + ], + }, + { + "name": "green", + "periods": [ + { + "weekday": 1, + "start_time": "00:00", + "end_time": "06:00", + }, + ], + }, + ] + } + factory = ClusterConfigFactory() + clusters = factory.create_cluster_configs(clusters_payload) + + assert clusters[0].timezone == ZoneInfo("Europe/Kyiv") + assert clusters[0].energy == EnergyConfig(schedules=mock.ANY) + assert clusters[0].energy.schedules == [ + EnergySchedule.create_default(timezone=ZoneInfo("Europe/Kyiv")), + EnergySchedule( + name="green", + periods=[ + EnergySchedulePeriod( + weekday=1, + start_time=time(0, 0, tzinfo=ZoneInfo("Europe/Kyiv")), + end_time=time(6, 0, tzinfo=ZoneInfo("Europe/Kyiv")), + ) + ], + ), + ] + + +class TestEnergySchedulePeriod: + def test__post_init__missing_start_time_tzinfo(self) -> None: + with pytest.raises( + ValueError, match="start_time and end_time must have tzinfo" + ): + EnergySchedulePeriod(weekday=1, start_time=time(0, 0), end_time=time(6, 0)) + + def test__post_init__missing_end_time_tzinfo(self) -> None: + with pytest.raises( + ValueError, match="start_time and end_time must have tzinfo" + ): + EnergySchedulePeriod( + weekday=1, start_time=time(0, 0, tzinfo=UTC), end_time=time(6, 0) + ) + + def test__post_init__end_time_before_start_time(self) -> None: + with pytest.raises(ValueError, match="start_time must be less than end_time"): + EnergySchedulePeriod( + weekday=1, + start_time=time(6, 0, tzinfo=UTC), + end_time=time(6, 0, tzinfo=UTC), + ) + + def test__post_init__invalid_weekday(self) -> None: + with pytest.raises(ValueError, match="weekday must be in range 1-7"): + EnergySchedulePeriod( + weekday=0, + start_time=time(0, 0, tzinfo=UTC), + end_time=time(6, 0, tzinfo=UTC), + ) + + def test__post_init__end_time_00_00_turns_into_time_max(self) -> None: + period = EnergySchedulePeriod( + weekday=1, + start_time=time(0, 0, tzinfo=UTC), + end_time=time(0, 0, tzinfo=UTC), + ) + assert period.end_time == time.max.replace(tzinfo=UTC)