diff --git a/scripts/backtest_sites.py b/scripts/backtest_sites.py index e764abf8..3572daa3 100644 --- a/scripts/backtest_sites.py +++ b/scripts/backtest_sites.py @@ -23,6 +23,7 @@ except RuntimeError: pass +import json import logging import os import sys @@ -32,6 +33,8 @@ import pandas as pd import torch import xarray as xr +from huggingface_hub import hf_hub_download +from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME from ocf_datapipes.batch import ( BatchKey, NumpyBatch, @@ -50,7 +53,7 @@ ) from ocf_datapipes.utils.consts import ELEVATION_MEAN, ELEVATION_STD from omegaconf import DictConfig -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, IterDataPipe, functional_datapipe from torch.utils.data.datapipes.iter import IterableWrapper from tqdm import tqdm @@ -67,6 +70,10 @@ # checkpoint on the val set model_chckpoint_dir = "PLACEHOLDER" +hf_revision = None +hf_token = None +hf_model_id = None + # Forecasts will be made for all available init times between these start_datetime = "2022-05-08 00:00" end_datetime = "2022-05-08 00:30" @@ -101,11 +108,70 @@ # FUNCTIONS +@functional_datapipe("pad_forward_pv") +class PadForwardPVIterDataPipe(IterDataPipe): + """ + Pads forecast pv. + + Sun position is calculated based off of pv time index + and for t0's close to end of pv data can have wrong shape as pv starts + to run out of data to slice for the forecast part. + """ + + def __init__(self, pv_dp: IterDataPipe, forecast_duration: np.timedelta64): + """Init""" + + super().__init__() + self.pv_dp = pv_dp + self.forecast_duration = forecast_duration + + def __iter__(self): + """Iter""" + + for xr_data in self.pv_dp: + t0 = xr_data.time_utc.data[int(xr_data.attrs["t0_idx"])] + pv_step = np.timedelta64(xr_data.attrs["sample_period_duration"]) + t_end = t0 + self.forecast_duration + pv_step + time_idx = np.arange(xr_data.time_utc.data[0], t_end, pv_step) + yield xr_data.reindex(time_utc=time_idx, fill_value=-1) + + +def load_model_from_hf(model_id: str, revision: str, token: str): + """ + Loads model from HuggingFace + """ + + model_file = hf_hub_download( + repo_id=model_id, + filename=PYTORCH_WEIGHTS_NAME, + revision=revision, + token=token, + ) + + # load config file + config_file = hf_hub_download( + repo_id=model_id, + filename=CONFIG_NAME, + revision=revision, + token=token, + ) + + with open(config_file, "r", encoding="utf-8") as f: + config = json.load(f) + + model = hydra.utils.instantiate(config) + + state_dict = torch.load(model_file, map_location=torch.device("cuda")) + model.load_state_dict(state_dict) # type: ignore + model.eval() # type: ignore + + return model + + def preds_to_dataarray(preds, model, valid_times, site_ids): """Put numpy array of predictions into a dataarray""" if model.use_quantile_regression: - output_labels = model.output_quantiles output_labels = [f"forecast_mw_plevel_{int(q*100):02}" for q in model.output_quantiles] output_labels[output_labels.index("forecast_mw_plevel_50")] = "forecast_mw" else: @@ -333,7 +399,7 @@ def predict_batch(self, batch: NumpyBatch) -> xr.Dataset: da_abs_site = da_abs_site.where(~da_sundown_mask).fillna(0.0) da_abs_site = da_abs_site.expand_dims(dim="init_time_utc", axis=0).assign_coords( - init_time_utc=[t0] + init_time_utc=np.array([t0], dtype="datetime64[ns]") ) return da_abs_site @@ -362,6 +428,11 @@ def get_datapipe(config_path: str) -> NumpyBatch: t0_datapipe, ) + config = load_yaml_configuration(config_path) + data_pipeline["pv"] = data_pipeline["pv"].pad_forward_pv( + forecast_duration=np.timedelta64(config.input_data.pv.forecast_minutes, "m") + ) + data_pipeline = DictDatasetIterDataPipe( {k: v for k, v in data_pipeline.items() if k != "config"}, ).map(split_dataset_dict_dp) @@ -412,7 +483,13 @@ def main(config: DictConfig): # Create a dataloader for the concurrent batches and use multiprocessing dataloader = DataLoader(batch_pipe, **dataloader_kwargs) # Load the PVNet model - model, *_ = get_model_from_checkpoints([model_chckpoint_dir], val_best=True) + if model_chckpoint_dir: + model, *_ = get_model_from_checkpoints([model_chckpoint_dir], val_best=True) + elif hf_model_id: + model = load_model_from_hf(hf_model_id, hf_revision, hf_token) + else: + raise ValueError("Provide a model checkpoint or a HuggingFace model") + model = model.eval().to(device) # Create object to make predictions for each input batch @@ -426,13 +503,13 @@ def main(config: DictConfig): t0 = ds_abs_all.init_time_utc.values[0] - # Save the predictioons + # Save the predictions filename = f"{output_dir}/{t0}.nc" ds_abs_all.to_netcdf(filename) pbar.update() except Exception as e: - print(f"Exception {e} at {i}") + print(f"Exception {e} at batch {i}") pass # Close down