Skip to content

Commit

Permalink
Added ClusterConfig.energy (#2032)
Browse files Browse the repository at this point in the history
* Added ClusterConfig.energy

* add tests
  • Loading branch information
dalazx authored Jan 30, 2023
1 parent 9c0cb78 commit 43dafbb
Show file tree
Hide file tree
Showing 3 changed files with 210 additions and 1 deletion.
57 changes: 57 additions & 0 deletions platform_api/cluster_config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,66 @@
from collections.abc import Sequence
from dataclasses import dataclass
from datetime import time, tzinfo
from typing import Optional
from zoneinfo import ZoneInfo

from yarl import URL

from .resource import Preset, ResourcePoolType, TPUResource

UTC = ZoneInfo("UTC")


@dataclass(frozen=True)
class EnergySchedulePeriod:
# ISO 8601 weekday number (1-7)
weekday: int
start_time: time
end_time: time

def __post_init__(self) -> None:
if not self.start_time.tzinfo or not self.end_time.tzinfo:
raise ValueError("start_time and end_time must have tzinfo")
if self.end_time == time.min.replace(tzinfo=self.end_time.tzinfo):
object.__setattr__(
self, "end_time", time.max.replace(tzinfo=self.end_time.tzinfo)
)
if not 1 <= self.weekday <= 7:
raise ValueError("weekday must be in range 1-7")
if self.start_time >= self.end_time:
raise ValueError("start_time must be less than end_time")

@classmethod
def create_full_day(
cls, *, weekday: int, timezone: tzinfo
) -> "EnergySchedulePeriod":
return cls(
weekday=weekday,
start_time=time.min.replace(tzinfo=timezone),
end_time=time.max.replace(tzinfo=timezone),
)


@dataclass(frozen=True)
class EnergySchedule:
name: str
periods: Sequence[EnergySchedulePeriod] = ()

@classmethod
def create_default(cls, *, timezone: tzinfo) -> "EnergySchedule":
return cls(
name="default",
periods=[
EnergySchedulePeriod.create_full_day(weekday=weekday, timezone=timezone)
for weekday in range(1, 8)
],
)


@dataclass(frozen=True)
class EnergyConfig:
schedules: Sequence[EnergySchedule] = (EnergySchedule.create_default(timezone=UTC),)


@dataclass(frozen=True)
class OrchestratorConfig:
Expand Down Expand Up @@ -67,3 +122,5 @@ class ClusterConfig:
name: str
orchestrator: OrchestratorConfig
ingress: IngressConfig
timezone: tzinfo = UTC
energy: EnergyConfig = EnergyConfig()
55 changes: 54 additions & 1 deletion platform_api/cluster_config_factory.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
import logging
from collections.abc import Sequence
from datetime import time, tzinfo
from decimal import Decimal
from typing import Any, Optional
from zoneinfo import ZoneInfo

import trafaret as t
from yarl import URL

from .cluster_config import ClusterConfig, IngressConfig, OrchestratorConfig
from .cluster_config import (
ClusterConfig,
EnergyConfig,
EnergySchedule,
EnergySchedulePeriod,
IngressConfig,
OrchestratorConfig,
)
from .resource import Preset, ResourcePoolType, TPUPreset, TPUResource

_cluster_config_validator = t.Dict({"name": t.String}).allow_extra("*")
Expand All @@ -31,10 +40,13 @@ def create_cluster_configs(
def create_cluster_config(self, payload: dict[str, Any]) -> Optional[ClusterConfig]:
try:
_cluster_config_validator.check(payload)
timezone = self._create_timezone(payload.get("timezone"))
return ClusterConfig(
name=payload["name"],
orchestrator=self._create_orchestrator_config(payload),
ingress=self._create_ingress_config(payload),
timezone=timezone,
energy=self._create_energy_config(payload, timezone=timezone),
)
except t.DataError as err:
logging.warning(f"failed to parse cluster config: {err}")
Expand Down Expand Up @@ -145,3 +157,44 @@ def _create_tpu_resource(
types=tuple(payload["types"]),
software_versions=tuple(payload["software_versions"]),
)

def _create_timezone(self, name: Optional[str]) -> tzinfo:
if not name:
return ClusterConfig.timezone
try:
return ZoneInfo(name)
except Exception:
raise ValueError(f"invalid timezone: {name}")

def _create_energy_schedule_period(
self, payload: dict[str, Any], *, timezone: tzinfo
) -> EnergySchedulePeriod:
start_time = time.fromisoformat(payload["start_time"]).replace(tzinfo=timezone)
end_time = time.fromisoformat(payload["end_time"]).replace(tzinfo=timezone)
return EnergySchedulePeriod(
weekday=payload["weekday"],
start_time=start_time,
end_time=end_time,
)

def _create_energy_schedule(
self, payload: dict[str, Any], timezone: tzinfo
) -> EnergySchedule:
return EnergySchedule(
name=payload["name"],
periods=[
self._create_energy_schedule_period(p, timezone=timezone)
for p in payload["periods"]
],
)

def _create_energy_config(
self, payload: dict[str, Any], *, timezone: tzinfo
) -> EnergyConfig:
schedules = {
schedule.name: schedule
for s in payload.get("energy", {}).get("schedules", [])
if (schedule := self._create_energy_schedule(s, timezone=timezone))
}
schedules["default"] = EnergySchedule.create_default(timezone=timezone)
return EnergyConfig(schedules=list(schedules.values()))
99 changes: 99 additions & 0 deletions tests/unit/test_cluster_config_factory.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
from collections.abc import Sequence
from datetime import time
from decimal import Decimal
from typing import Any
from unittest import mock
from zoneinfo import ZoneInfo

import pytest
from yarl import URL

from platform_api.cluster_config import (
UTC,
EnergyConfig,
EnergySchedule,
EnergySchedulePeriod,
)
from platform_api.cluster_config_factory import ClusterConfigFactory
from platform_api.resource import GKEGPUModels, Preset, TPUPreset, TPUResource

Expand Down Expand Up @@ -299,6 +308,11 @@ def test_valid_cluster_config(
)
assert orchestrator.tpu_ipv4_cidr_block == "1.1.1.1/32"

assert cluster.timezone == UTC
assert cluster.energy == EnergyConfig(
schedules=[EnergySchedule.create_default(timezone=UTC)]
)

def test_orchestrator_resource_presets_default(
self, clusters_payload: Sequence[dict[str, Any]]
) -> None:
Expand Down Expand Up @@ -390,3 +404,88 @@ def test_factory_skips_invalid_cluster_configs(
clusters = factory.create_cluster_configs(clusters_payload)

assert len(clusters) == 1

def test_energy(self, clusters_payload: Sequence[dict[str, Any]]) -> None:
clusters_payload[0]["timezone"] = "Europe/Kyiv"
clusters_payload[0]["energy"] = {
"schedules": [
{
"name": "default",
"periods": [
{
"weekday": 1,
"start_time": "00:00",
"end_time": "06:00",
},
],
},
{
"name": "green",
"periods": [
{
"weekday": 1,
"start_time": "00:00",
"end_time": "06:00",
},
],
},
]
}
factory = ClusterConfigFactory()
clusters = factory.create_cluster_configs(clusters_payload)

assert clusters[0].timezone == ZoneInfo("Europe/Kyiv")
assert clusters[0].energy == EnergyConfig(schedules=mock.ANY)
assert clusters[0].energy.schedules == [
EnergySchedule.create_default(timezone=ZoneInfo("Europe/Kyiv")),
EnergySchedule(
name="green",
periods=[
EnergySchedulePeriod(
weekday=1,
start_time=time(0, 0, tzinfo=ZoneInfo("Europe/Kyiv")),
end_time=time(6, 0, tzinfo=ZoneInfo("Europe/Kyiv")),
)
],
),
]


class TestEnergySchedulePeriod:
def test__post_init__missing_start_time_tzinfo(self) -> None:
with pytest.raises(
ValueError, match="start_time and end_time must have tzinfo"
):
EnergySchedulePeriod(weekday=1, start_time=time(0, 0), end_time=time(6, 0))

def test__post_init__missing_end_time_tzinfo(self) -> None:
with pytest.raises(
ValueError, match="start_time and end_time must have tzinfo"
):
EnergySchedulePeriod(
weekday=1, start_time=time(0, 0, tzinfo=UTC), end_time=time(6, 0)
)

def test__post_init__end_time_before_start_time(self) -> None:
with pytest.raises(ValueError, match="start_time must be less than end_time"):
EnergySchedulePeriod(
weekday=1,
start_time=time(6, 0, tzinfo=UTC),
end_time=time(6, 0, tzinfo=UTC),
)

def test__post_init__invalid_weekday(self) -> None:
with pytest.raises(ValueError, match="weekday must be in range 1-7"):
EnergySchedulePeriod(
weekday=0,
start_time=time(0, 0, tzinfo=UTC),
end_time=time(6, 0, tzinfo=UTC),
)

def test__post_init__end_time_00_00_turns_into_time_max(self) -> None:
period = EnergySchedulePeriod(
weekday=1,
start_time=time(0, 0, tzinfo=UTC),
end_time=time(0, 0, tzinfo=UTC),
)
assert period.end_time == time.max.replace(tzinfo=UTC)

0 comments on commit 43dafbb

Please sign in to comment.