diff --git a/pyproject.toml b/pyproject.toml index 7e240390..30f8df85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,6 +101,8 @@ ignore = [ "E501", # consider `[meta, header, *data]` instead of concatenation "RUF005", + # Use `X | Y` in `isinstance` call instead of `(X, Y)` + "UP038", # multi-line-summary-first-line "D212", # one-blank-line-before-class diff --git a/src/virtualship/expedition/ship_config.py b/src/virtualship/expedition/ship_config.py index 7bb41e84..5609c182 100644 --- a/src/virtualship/expedition/ship_config.py +++ b/src/virtualship/expedition/ship_config.py @@ -8,6 +8,8 @@ import pydantic import yaml +from virtualship.utils import _validate_numeric_mins_to_timedelta + class ArgoFloatConfig(pydantic.BaseModel): """Configuration for argos floats.""" @@ -37,6 +39,10 @@ class ADCPConfig(pydantic.BaseModel): def _serialize_period(self, value: timedelta, _info): return value.total_seconds() / 60.0 + @pydantic.field_validator("period", mode="before") + def _validate_period(cls, value: int | float | timedelta) -> timedelta: + return _validate_numeric_mins_to_timedelta(value) + class CTDConfig(pydantic.BaseModel): """Configuration for CTD instrument.""" @@ -55,6 +61,10 @@ class CTDConfig(pydantic.BaseModel): def _serialize_stationkeeping_time(self, value: timedelta, _info): return value.total_seconds() / 60.0 + @pydantic.field_validator("stationkeeping_time", mode="before") + def _validate_stationkeeping_time(cls, value: int | float | timedelta) -> timedelta: + return _validate_numeric_mins_to_timedelta(value) + class ShipUnderwaterSTConfig(pydantic.BaseModel): """Configuration for underwater ST.""" @@ -71,6 +81,10 @@ class ShipUnderwaterSTConfig(pydantic.BaseModel): def _serialize_period(self, value: timedelta, _info): return value.total_seconds() / 60.0 + @pydantic.field_validator("period", mode="before") + def _validate_period(cls, value: int | float | timedelta) -> timedelta: + return _validate_numeric_mins_to_timedelta(value) + class DrifterConfig(pydantic.BaseModel): """Configuration for drifters.""" @@ -88,6 +102,10 @@ class DrifterConfig(pydantic.BaseModel): def _serialize_lifetime(self, value: timedelta, _info): return value.total_seconds() / 60.0 + @pydantic.field_validator("lifetime", mode="before") + def _validate_lifetime(cls, value: int | float | timedelta) -> timedelta: + return _validate_numeric_mins_to_timedelta(value) + class XBTConfig(pydantic.BaseModel): """Configuration for xbt instrument.""" diff --git a/src/virtualship/utils.py b/src/virtualship/utils.py index 43608b60..ec97321e 100644 --- a/src/virtualship/utils.py +++ b/src/virtualship/utils.py @@ -1,3 +1,4 @@ +from datetime import timedelta from functools import lru_cache from importlib.resources import files from typing import TextIO @@ -151,3 +152,10 @@ def mfp_to_yaml(excel_file_path: str, yaml_output_path: str): # noqa: D417 # Save to YAML file schedule.to_yaml(yaml_output_path) + + +def _validate_numeric_mins_to_timedelta(value: int | float | timedelta) -> timedelta: + """Convert minutes to timedelta when reading.""" + if isinstance(value, timedelta): + return value + return timedelta(minutes=value) diff --git a/tests/expedition/test_simulate_schedule.py b/tests/expedition/test_simulate_schedule.py index 01544c42..ef090b4e 100644 --- a/tests/expedition/test_simulate_schedule.py +++ b/tests/expedition/test_simulate_schedule.py @@ -46,3 +46,11 @@ def test_simulate_schedule_too_far() -> None: result = simulate_schedule(projection, ship_config, schedule) assert isinstance(result, ScheduleProblem) + + +def test_time_in_minutes_in_ship_schedule() -> None: + """Test whether the pydantic serializer picks up the time *in minutes* in the ship schedule.""" + ship_config = ShipConfig.from_yaml("expedition_dir/ship_config.yaml") + assert ship_config.adcp_config.period == timedelta(minutes=5) + assert ship_config.ctd_config.stationkeeping_time == timedelta(minutes=20) + assert ship_config.ship_underwater_st_config.period == timedelta(minutes=5)