diff --git a/scripts/backtest_sites.py b/scripts/backtest_sites.py index 3572daa3..65d63be2 100644 --- a/scripts/backtest_sites.py +++ b/scripts/backtest_sites.py @@ -118,21 +118,38 @@ class PadForwardPVIterDataPipe(IterDataPipe): to run out of data to slice for the forecast part. """ - def __init__(self, pv_dp: IterDataPipe, forecast_duration: np.timedelta64): + def __init__( + self, + pv_dp: IterDataPipe, + forecast_duration: np.timedelta64, + history_duration: np.timedelta64, + time_resolution_minutes: np.timedelta64, + ): """Init""" super().__init__() self.pv_dp = pv_dp self.forecast_duration = forecast_duration + self.history_duration = history_duration + self.time_resolution_minutes = time_resolution_minutes + + self.min_seq_length = history_duration // time_resolution_minutes def __iter__(self): """Iter""" for xr_data in self.pv_dp: - t0 = xr_data.time_utc.data[int(xr_data.attrs["t0_idx"])] - pv_step = np.timedelta64(xr_data.attrs["sample_period_duration"]) - t_end = t0 + self.forecast_duration + pv_step - time_idx = np.arange(xr_data.time_utc.data[0], t_end, pv_step) + t_end = ( + xr_data.time_utc.data[0] + + self.history_duration + + self.forecast_duration + + self.time_resolution_minutes + ) + time_idx = np.arange(xr_data.time_utc.data[0], t_end, self.time_resolution_minutes) + + if len(xr_data.time_utc.data) < self.min_seq_length: + raise ValueError("Not enough PV data to predict") + yield xr_data.reindex(time_utc=time_idx, fill_value=-1) @@ -430,7 +447,9 @@ def get_datapipe(config_path: str) -> NumpyBatch: config = load_yaml_configuration(config_path) data_pipeline["pv"] = data_pipeline["pv"].pad_forward_pv( - forecast_duration=np.timedelta64(config.input_data.pv.forecast_minutes, "m") + forecast_duration=np.timedelta64(config.input_data.pv.forecast_minutes, "m"), + history_duration=np.timedelta64(config.input_data.pv.history_minutes, "m"), + time_resolution_minutes=np.timedelta64(config.input_data.pv.time_resolution_minutes, "m"), ) data_pipeline = DictDatasetIterDataPipe(