Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Config refactor #61

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 65 additions & 91 deletions ocf_data_sampler/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from typing import Dict, List, Optional
from typing_extensions import Self

from pydantic import BaseModel, Field, RootModel, field_validator, ValidationInfo, model_validator
from ocf_data_sampler.constants import NWP_PROVIDERS
from pydantic import BaseModel, Field, RootModel, field_validator, model_validator
from ocf_datapipes.utils.consts import NWP_PROVIDERS

logger = logging.getLogger(__name__)

Expand All @@ -34,27 +34,12 @@ class Config:
class General(Base):
"""General pydantic model"""

name: str = Field("example", description="The name of this configuration file.")
name: str = Field("example", description="The name of this configuration file")
description: str = Field(
"example configuration", description="Description of this configuration file"
)


class DataSourceMixin(Base):
"""Mixin class, to add forecast and history minutes"""

forecast_minutes: int = Field(
...,
ge=0,
description="how many minutes to forecast in the future. ",
)
history_minutes: int = Field(
...,
ge=0,
description="how many historic minutes to use. ",
)


# noinspection PyMethodParameters
class DropoutMixin(Base):
"""Mixin class, to add dropout minutes"""
Expand All @@ -65,7 +50,12 @@ class DropoutMixin(Base):
"negative or zero.",
)

dropout_fraction: float = Field(0, description="Chance of dropout being applied to each sample")
dropout_fraction: float = Field(
default=0,
description="Chance of dropout being applied to each sample",
ge=0,
le=1,
)

@field_validator("dropout_timedeltas_minutes")
def dropout_timedeltas_minutes_negative(cls, v: List[int]) -> List[int]:
Expand All @@ -75,12 +65,6 @@ def dropout_timedeltas_minutes_negative(cls, v: List[int]) -> List[int]:
assert m <= 0, "Dropout timedeltas must be negative"
return v

@field_validator("dropout_fraction")
def dropout_fraction_valid(cls, v: float) -> float:
"""Validate 'dropout_fraction'"""
assert 0 <= v <= 1, "Dropout fraction must be between 0 and 1"
return v

@model_validator(mode="after")
def dropout_instructions_consistent(self) -> Self:
if self.dropout_fraction == 0:
Expand All @@ -93,36 +77,66 @@ def dropout_instructions_consistent(self) -> Self:


# noinspection PyMethodParameters
class TimeResolutionMixin(Base):
class TimeWindowMixin(Base):
"""Time resolution mix in"""

time_resolution_minutes: int = Field(
...,
gt=0,
description="The temporal resolution of the data in minutes",
)

forecast_minutes: int = Field(
...,
ge=0,
description="how many minutes to forecast in the future",
)
history_minutes: int = Field(
...,
ge=0,
description="how many historic minutes to use",
)

class Satellite(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
"""Satellite configuration model"""
@field_validator("forecast_minutes")
def forecast_minutes_divide_by_time_resolution(cls, v, values) -> int:
if v % values.data["time_resolution_minutes"] != 0:
message = "Forecast duration must be divisible by time resolution"
logger.error(message)
raise Exception(message)
return v

# Todo: remove 'satellite' from names
satellite_zarr_path: str | tuple[str] | list[str] = Field(
@field_validator("history_minutes")
def history_minutes_divide_by_time_resolution(cls, v, values) -> int:
if v % values.data["time_resolution_minutes"] != 0:
message = "History duration must be divisible by time resolution"
logger.error(message)
raise Exception(message)
return v


class DataSourceBase(TimeWindowMixin, DropoutMixin):
"""Mixin class, to add path and image size"""
zarr_path: str | tuple[str] | list[str] = Field(
...,
description="The path or list of paths which hold the satellite zarr",
)
satellite_channels: list[str] = Field(
..., description="the satellite channels that are used"
description="The path or list of paths which hold the data zarr",
)
satellite_image_size_pixels_height: int = Field(

image_size_pixels_height: int = Field(
...,
description="The number of pixels of the height of the region of interest"
" for non-HRV satellite channels.",
description="The number of pixels of the height of the region of interest",
)

satellite_image_size_pixels_width: int = Field(
image_size_pixels_width: int = Field(
...,
description="The number of pixels of the width of the region "
"of interest for non-HRV satellite channels.",
description="The number of pixels of the width of the region of interest",
)


class Satellite(DataSourceBase):
"""Satellite configuration model"""

channels: list[str] = Field(
..., description="the satellite channels that are used"
)

live_delay_minutes: int = Field(
Expand All @@ -131,21 +145,16 @@ class Satellite(DataSourceMixin, TimeResolutionMixin, DropoutMixin):


# noinspection PyMethodParameters
class NWP(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
class NWP(DataSourceBase):
"""NWP configuration model"""

nwp_zarr_path: str | tuple[str] | list[str] = Field(
...,
description="The path which holds the NWP zarr",
)
nwp_channels: list[str] = Field(
channels: list[str] = Field(
..., description="the channels used in the nwp data"
)
nwp_accum_channels: list[str] = Field([], description="the nwp channels which need to be diffed")
nwp_image_size_pixels_height: int = Field(..., description="The size of NWP spacial crop in pixels")
nwp_image_size_pixels_width: int = Field(..., description="The size of NWP spacial crop in pixels")

nwp_provider: str = Field(..., description="The provider of the NWP data")
provider: str = Field(..., description="The provider of the NWP data")

accum_channels: list[str] = Field([], description="the nwp channels which need to be diffed")

max_staleness_minutes: Optional[int] = Field(
None,
Expand All @@ -154,33 +163,15 @@ class NWP(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
" the maximum forecast horizon of the NWP and the requested forecast length.",
)


@field_validator("nwp_provider")
def validate_nwp_provider(cls, v: str) -> str:
"""Validate 'nwp_provider'"""
@field_validator("provider")
def validate_provider(cls, v: str) -> str:
"""Validate 'provider'"""
if v.lower() not in NWP_PROVIDERS:
message = f"NWP provider {v} is not in {NWP_PROVIDERS}"
logger.warning(message)
raise Exception(message)
return v

# Todo: put into time mixin when moving intervals there
@field_validator("forecast_minutes")
def forecast_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
if v % info.data["time_resolution_minutes"] != 0:
message = "Forecast duration must be divisible by time resolution"
logger.error(message)
raise Exception(message)
return v

@field_validator("history_minutes")
def history_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
if v % info.data["time_resolution_minutes"] != 0:
message = "History duration must be divisible by time resolution"
logger.error(message)
raise Exception(message)
return v


class MultiNWP(RootModel):
"""Configuration for multiple NWPs"""
Expand Down Expand Up @@ -208,27 +199,10 @@ def items(self):
return self.root.items()


# noinspection PyMethodParameters
class GSP(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
class GSP(TimeWindowMixin, DropoutMixin):
"""GSP configuration model"""

gsp_zarr_path: str = Field(..., description="The path which holds the GSP zarr")

@field_validator("forecast_minutes")
def forecast_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
if v % info.data["time_resolution_minutes"] != 0:
message = "Forecast duration must be divisible by time resolution"
logger.error(message)
raise Exception(message)
return v

@field_validator("history_minutes")
def history_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
if v % info.data["time_resolution_minutes"] != 0:
message = "History duration must be divisible by time resolution"
logger.error(message)
raise Exception(message)
return v
zarr_path: str = Field(..., description="The path which holds the GSP zarr")


# noinspection PyPep8Naming
Expand All @@ -246,4 +220,4 @@ class Configuration(Base):
"""Configuration model for the dataset"""

general: General = General()
input_data: InputData = InputData()
input_data: InputData = InputData()
26 changes: 13 additions & 13 deletions ocf_data_sampler/torch_datasets/pvnet_uk_regional.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def get_dataset_dict(config: Configuration) -> dict[xr.DataArray, dict[xr.DataAr
datasets_dict = {}

# Load GSP data unless the path is None
if in_config.gsp.gsp_zarr_path:
da_gsp = open_gsp(zarr_path=in_config.gsp.gsp_zarr_path).compute()
if in_config.gsp.zarr_path:
da_gsp = open_gsp(zarr_path=in_config.gsp.zarr_path).compute()

# Remove national GSP
datasets_dict["gsp"] = da_gsp.sel(gsp_id=slice(1, None))
Expand All @@ -76,19 +76,19 @@ def get_dataset_dict(config: Configuration) -> dict[xr.DataArray, dict[xr.DataAr
datasets_dict["nwp"] = {}
for nwp_source, nwp_config in in_config.nwp.items():

da_nwp = open_nwp(nwp_config.nwp_zarr_path, provider=nwp_config.nwp_provider)
da_nwp = open_nwp(nwp_config.zarr_path, provider=nwp_config.provider)

da_nwp = da_nwp.sel(channel=list(nwp_config.nwp_channels))
da_nwp = da_nwp.sel(channel=list(nwp_config.channels))

datasets_dict["nwp"][nwp_source] = da_nwp

# Load satellite data if in config
if in_config.satellite:
sat_config = config.input_data.satellite

da_sat = open_sat_data(sat_config.satellite_zarr_path)
da_sat = open_sat_data(sat_config.zarr_path)

da_sat = da_sat.sel(channel=list(sat_config.satellite_channels))
da_sat = da_sat.sel(channel=list(sat_config.channels))

datasets_dict["sat"] = da_sat

Expand Down Expand Up @@ -127,7 +127,7 @@ def find_valid_t0_times(
max_staleness = minutes(nwp_config.max_staleness_minutes)

# The last step of the forecast is lost if we have to diff channels
if len(nwp_config.nwp_accum_channels) > 0:
if len(nwp_config.accum_channels) > 0:
end_buffer = minutes(nwp_config.time_resolution_minutes)
else:
end_buffer = minutes(0)
Expand Down Expand Up @@ -229,8 +229,8 @@ def slice_datasets_by_space(
sliced_datasets_dict["nwp"][nwp_key] = select_spatial_slice_pixels(
datasets_dict["nwp"][nwp_key],
location,
height_pixels=nwp_config.nwp_image_size_pixels_height,
width_pixels=nwp_config.nwp_image_size_pixels_width,
height_pixels=nwp_config.image_size_pixels_height,
width_pixels=nwp_config.image_size_pixels_width,
)

if "sat" in datasets_dict:
Expand All @@ -239,8 +239,8 @@ def slice_datasets_by_space(
sliced_datasets_dict["sat"] = select_spatial_slice_pixels(
datasets_dict["sat"],
location,
height_pixels=sat_config.satellite_image_size_pixels_height,
width_pixels=sat_config.satellite_image_size_pixels_width,
height_pixels=sat_config.image_size_pixels_height,
width_pixels=sat_config.image_size_pixels_width,
)

if "gsp" in datasets_dict:
Expand Down Expand Up @@ -280,7 +280,7 @@ def slice_datasets_by_time(
forecast_duration=minutes(nwp_config.forecast_minutes),
dropout_timedeltas=minutes(nwp_config.dropout_timedeltas_minutes),
dropout_frac=nwp_config.dropout_fraction,
accum_channels=nwp_config.nwp_accum_channels,
accum_channels=nwp_config.accum_channels,
)

if "sat" in datasets_dict:
Expand Down Expand Up @@ -383,7 +383,7 @@ def process_and_combine_datasets(

for nwp_key, da_nwp in dataset_dict["nwp"].items():
# Standardise
provider = config.input_data.nwp[nwp_key].nwp_provider
provider = config.input_data.nwp[nwp_key].provider
da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
# Convert to NumpyBatch
nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_batch(da_nwp)
Expand Down
15 changes: 9 additions & 6 deletions tests/config/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
)


def test_default():
def test_default_configuration():
"""Test default pydantic class"""

_ = Configuration()


def test_yaml_load_test_config(test_config_filename):
def test_load_yaml_configuration(test_config_filename):
"""
Test that yaml loading works for 'test_config.yaml'
and fails for an empty .yaml file
Expand Down Expand Up @@ -56,7 +56,7 @@ def test_yaml_save(test_config_filename):
assert test_config == tmp_config


def test_extra_field():
def test_extra_field_error():
"""
Check an extra parameters in config causes error
"""
Expand Down Expand Up @@ -99,10 +99,11 @@ def test_incorrect_nwp_provider(test_config_filename):

configuration = load_yaml_configuration(test_config_filename)

configuration.input_data.nwp['ukv'].nwp_provider = "unexpected_provider"
configuration.input_data.nwp['ukv'].provider = "unexpected_provider"
with pytest.raises(Exception, match="NWP provider"):
_ = Configuration(**configuration.model_dump())


def test_incorrect_dropout(test_config_filename):
"""
Check a dropout timedelta over 0 causes error and 0 doesn't
Expand All @@ -119,6 +120,7 @@ def test_incorrect_dropout(test_config_filename):
configuration.input_data.nwp['ukv'].dropout_timedeltas_minutes = [0]
_ = Configuration(**configuration.model_dump())


def test_incorrect_dropout_fraction(test_config_filename):
"""
Check dropout fraction outside of range causes error
Expand All @@ -127,11 +129,12 @@ def test_incorrect_dropout_fraction(test_config_filename):
configuration = load_yaml_configuration(test_config_filename)

configuration.input_data.nwp['ukv'].dropout_fraction= 1.1
with pytest.raises(Exception, match="Dropout fraction must be between 0 and 1"):

with pytest.raises(ValidationError, match="Input should be less than or equal to 1"):
_ = Configuration(**configuration.model_dump())

configuration.input_data.nwp['ukv'].dropout_fraction= -0.1
with pytest.raises(Exception, match="Dropout fraction must be between 0 and 1"):
with pytest.raises(ValidationError, match="Input should be greater than or equal to 0"):
_ = Configuration(**configuration.model_dump())


Expand Down
Loading
Loading