Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dalazx committed Jan 27, 2023
1 parent 2a69de8 commit 2c5c4f0
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 1 deletion.
12 changes: 12 additions & 0 deletions platform_api/cluster_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@ class EnergySchedulePeriod:
start_time: time
end_time: time

def __post_init__(self) -> None:
if not self.start_time.tzinfo or not self.end_time.tzinfo:
raise ValueError("start_time and end_time must have tzinfo")
if self.end_time == time.min.replace(tzinfo=self.end_time.tzinfo):
object.__setattr__(
self, "end_time", time.max.replace(tzinfo=self.end_time.tzinfo)
)
if not 1 <= self.weekday <= 7:
raise ValueError("weekday must be in range 1-7")
if self.start_time >= self.end_time:
raise ValueError("start_time must be less than end_time")

@classmethod
def create_full_day(
cls, *, weekday: int, timezone: tzinfo
Expand Down
2 changes: 1 addition & 1 deletion platform_api/cluster_config_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def _create_energy_config(
) -> EnergyConfig:
schedules = {
schedule.name: schedule
for s in payload["schedules"]
for s in payload.get("energy", {}).get("schedules", [])
if (schedule := self._create_energy_schedule(s, timezone=timezone))
}
schedules["default"] = EnergySchedule.create_default(timezone=timezone)
Expand Down
99 changes: 99 additions & 0 deletions tests/unit/test_cluster_config_factory.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
from collections.abc import Sequence
from datetime import time
from decimal import Decimal
from typing import Any
from unittest import mock
from zoneinfo import ZoneInfo

import pytest
from yarl import URL

from platform_api.cluster_config import (
UTC,
EnergyConfig,
EnergySchedule,
EnergySchedulePeriod,
)
from platform_api.cluster_config_factory import ClusterConfigFactory
from platform_api.resource import GKEGPUModels, Preset, TPUPreset, TPUResource

Expand Down Expand Up @@ -299,6 +308,11 @@ def test_valid_cluster_config(
)
assert orchestrator.tpu_ipv4_cidr_block == "1.1.1.1/32"

assert cluster.timezone == UTC
assert cluster.energy == EnergyConfig(
schedules=[EnergySchedule.create_default(timezone=UTC)]
)

def test_orchestrator_resource_presets_default(
self, clusters_payload: Sequence[dict[str, Any]]
) -> None:
Expand Down Expand Up @@ -390,3 +404,88 @@ def test_factory_skips_invalid_cluster_configs(
clusters = factory.create_cluster_configs(clusters_payload)

assert len(clusters) == 1

def test_energy(self, clusters_payload: Sequence[dict[str, Any]]) -> None:
clusters_payload[0]["timezone"] = "Europe/Kyiv"
clusters_payload[0]["energy"] = {
"schedules": [
{
"name": "default",
"periods": [
{
"weekday": 1,
"start_time": "00:00",
"end_time": "06:00",
},
],
},
{
"name": "green",
"periods": [
{
"weekday": 1,
"start_time": "00:00",
"end_time": "06:00",
},
],
},
]
}
factory = ClusterConfigFactory()
clusters = factory.create_cluster_configs(clusters_payload)

assert clusters[0].timezone == ZoneInfo("Europe/Kyiv")
assert clusters[0].energy == EnergyConfig(schedules=mock.ANY)
assert clusters[0].energy.schedules == [
EnergySchedule.create_default(timezone=ZoneInfo("Europe/Kyiv")),
EnergySchedule(
name="green",
periods=[
EnergySchedulePeriod(
weekday=1,
start_time=time(0, 0, tzinfo=ZoneInfo("Europe/Kyiv")),
end_time=time(6, 0, tzinfo=ZoneInfo("Europe/Kyiv")),
)
],
),
]


class TestEnergySchedulePeriod:
def test__post_init__missing_start_time_tzinfo(self) -> None:
with pytest.raises(
ValueError, match="start_time and end_time must have tzinfo"
):
EnergySchedulePeriod(weekday=1, start_time=time(0, 0), end_time=time(6, 0))

def test__post_init__missing_end_time_tzinfo(self) -> None:
with pytest.raises(
ValueError, match="start_time and end_time must have tzinfo"
):
EnergySchedulePeriod(
weekday=1, start_time=time(0, 0, tzinfo=UTC), end_time=time(6, 0)
)

def test__post_init__end_time_before_start_time(self) -> None:
with pytest.raises(ValueError, match="start_time must be less than end_time"):
EnergySchedulePeriod(
weekday=1,
start_time=time(6, 0, tzinfo=UTC),
end_time=time(6, 0, tzinfo=UTC),
)

def test__post_init__invalid_weekday(self) -> None:
with pytest.raises(ValueError, match="weekday must be in range 1-7"):
EnergySchedulePeriod(
weekday=0,
start_time=time(0, 0, tzinfo=UTC),
end_time=time(6, 0, tzinfo=UTC),
)

def test__post_init__end_time_00_00_turns_into_time_max(self) -> None:
period = EnergySchedulePeriod(
weekday=1,
start_time=time(0, 0, tzinfo=UTC),
end_time=time(0, 0, tzinfo=UTC),
)
assert period.end_time == time.max.replace(tzinfo=UTC)

0 comments on commit 2c5c4f0

Please sign in to comment.