Skip to content

Commit

Permalink
add pv padding and hf model handling to backtest_sites.py (#262)
Browse files Browse the repository at this point in the history
* added pv padding and hf model handling to backtest_sites.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply suggestions from code review

Co-authored-by: Sukhil Patel <42407101+Sukh-P@users.noreply.github.com>

* Update scripts/backtest_sites.py

* Update pyproject.toml

* undo Update pyproject.toml

* linting

* docstring scripts/backtest_sites.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* docstring scripts/backtest_sites.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sukhil Patel <42407101+Sukh-P@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 18, 2024
1 parent aa76d4a commit eb8b445
Showing 1 changed file with 83 additions and 6 deletions.
89 changes: 83 additions & 6 deletions scripts/backtest_sites.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
except RuntimeError:
pass

import json
import logging
import os
import sys
Expand All @@ -32,6 +33,8 @@
import pandas as pd
import torch
import xarray as xr
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME
from ocf_datapipes.batch import (
BatchKey,
NumpyBatch,
Expand All @@ -50,7 +53,7 @@
)
from ocf_datapipes.utils.consts import ELEVATION_MEAN, ELEVATION_STD
from omegaconf import DictConfig
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, IterDataPipe, functional_datapipe
from torch.utils.data.datapipes.iter import IterableWrapper
from tqdm import tqdm

Expand All @@ -67,6 +70,10 @@
# checkpoint on the val set
model_chckpoint_dir = "PLACEHOLDER"

hf_revision = None
hf_token = None
hf_model_id = None

# Forecasts will be made for all available init times between these
start_datetime = "2022-05-08 00:00"
end_datetime = "2022-05-08 00:30"
Expand Down Expand Up @@ -101,11 +108,70 @@
# FUNCTIONS


@functional_datapipe("pad_forward_pv")
class PadForwardPVIterDataPipe(IterDataPipe):
"""
Pads forecast pv.
Sun position is calculated based off of pv time index
and for t0's close to end of pv data can have wrong shape as pv starts
to run out of data to slice for the forecast part.
"""

def __init__(self, pv_dp: IterDataPipe, forecast_duration: np.timedelta64):
"""Init"""

super().__init__()
self.pv_dp = pv_dp
self.forecast_duration = forecast_duration

def __iter__(self):
"""Iter"""

for xr_data in self.pv_dp:
t0 = xr_data.time_utc.data[int(xr_data.attrs["t0_idx"])]
pv_step = np.timedelta64(xr_data.attrs["sample_period_duration"])
t_end = t0 + self.forecast_duration + pv_step
time_idx = np.arange(xr_data.time_utc.data[0], t_end, pv_step)
yield xr_data.reindex(time_utc=time_idx, fill_value=-1)


def load_model_from_hf(model_id: str, revision: str, token: str):
"""
Loads model from HuggingFace
"""

model_file = hf_hub_download(
repo_id=model_id,
filename=PYTORCH_WEIGHTS_NAME,
revision=revision,
token=token,
)

# load config file
config_file = hf_hub_download(
repo_id=model_id,
filename=CONFIG_NAME,
revision=revision,
token=token,
)

with open(config_file, "r", encoding="utf-8") as f:
config = json.load(f)

model = hydra.utils.instantiate(config)

state_dict = torch.load(model_file, map_location=torch.device("cuda"))
model.load_state_dict(state_dict) # type: ignore
model.eval() # type: ignore

return model


def preds_to_dataarray(preds, model, valid_times, site_ids):
"""Put numpy array of predictions into a dataarray"""

if model.use_quantile_regression:
output_labels = model.output_quantiles
output_labels = [f"forecast_mw_plevel_{int(q*100):02}" for q in model.output_quantiles]
output_labels[output_labels.index("forecast_mw_plevel_50")] = "forecast_mw"
else:
Expand Down Expand Up @@ -333,7 +399,7 @@ def predict_batch(self, batch: NumpyBatch) -> xr.Dataset:
da_abs_site = da_abs_site.where(~da_sundown_mask).fillna(0.0)

da_abs_site = da_abs_site.expand_dims(dim="init_time_utc", axis=0).assign_coords(
init_time_utc=[t0]
init_time_utc=np.array([t0], dtype="datetime64[ns]")
)

return da_abs_site
Expand Down Expand Up @@ -362,6 +428,11 @@ def get_datapipe(config_path: str) -> NumpyBatch:
t0_datapipe,
)

config = load_yaml_configuration(config_path)
data_pipeline["pv"] = data_pipeline["pv"].pad_forward_pv(
forecast_duration=np.timedelta64(config.input_data.pv.forecast_minutes, "m")
)

data_pipeline = DictDatasetIterDataPipe(
{k: v for k, v in data_pipeline.items() if k != "config"},
).map(split_dataset_dict_dp)
Expand Down Expand Up @@ -412,7 +483,13 @@ def main(config: DictConfig):
# Create a dataloader for the concurrent batches and use multiprocessing
dataloader = DataLoader(batch_pipe, **dataloader_kwargs)
# Load the PVNet model
model, *_ = get_model_from_checkpoints([model_chckpoint_dir], val_best=True)
if model_chckpoint_dir:
model, *_ = get_model_from_checkpoints([model_chckpoint_dir], val_best=True)
elif hf_model_id:
model = load_model_from_hf(hf_model_id, hf_revision, hf_token)
else:
raise ValueError("Provide a model checkpoint or a HuggingFace model")

model = model.eval().to(device)

# Create object to make predictions for each input batch
Expand All @@ -426,13 +503,13 @@ def main(config: DictConfig):

t0 = ds_abs_all.init_time_utc.values[0]

# Save the predictioons
# Save the predictions
filename = f"{output_dir}/{t0}.nc"
ds_abs_all.to_netcdf(filename)

pbar.update()
except Exception as e:
print(f"Exception {e} at {i}")
print(f"Exception {e} at batch {i}")
pass

# Close down
Expand Down

0 comments on commit eb8b445

Please sign in to comment.