diff --git a/ocf_datapipes/training/metnet/metnet_gsp_national.py b/ocf_datapipes/training/metnet/metnet_gsp_national.py deleted file mode 100644 index 50cdb6216..000000000 --- a/ocf_datapipes/training/metnet/metnet_gsp_national.py +++ /dev/null @@ -1,188 +0,0 @@ -"""Create the training/validation datapipe for training the national MetNet/-2 Model""" - -import datetime -import logging -from pathlib import Path -from typing import Union - -import xarray -from torch.utils.data.datapipes.datapipe import IterDataPipe - -from ocf_datapipes.convert import ConvertGSPToNumpy -from ocf_datapipes.select import FilterGSPIDs, PickLocations -from ocf_datapipes.training.common import ( - add_selected_time_slices_from_datapipes, - get_and_return_overlapping_time_periods_and_t0, - open_and_return_datapipes, -) -from ocf_datapipes.training.metnet.metnet_preprocessor import ( - PreProcessMetNetIterDataPipe as PreProcessMetNet, -) -from ocf_datapipes.utils.consts import RSS_MEAN, RSS_STD, UKV_MEAN, UKV_STD - -xarray.set_options(keep_attrs=True) -logger = logging.getLogger("metnet_datapipe") -logger.setLevel(logging.DEBUG) - - -def normalize_gsp(x): # So it can be pickled - """ - Normalize the GSP data - - Args: - x: Input DataArray - - Returns: - Normalized DataArray - """ - return x / x.nominal_capacity_mwp - - -def normalize_pv(x): # So it can be pickled - """ - Normalize the GSP data - - Args: - x: Input DataArray - - Returns: - Normalized DataArray - """ - return x / x.observed_capacity_wp - - -def _select_non_nan_times(x): - return x.fillna(0.0) - - -def metnet_national_datapipe( - configuration_filename: Union[Path, str], - use_sun: bool = True, - use_nwp: bool = True, - use_sat: bool = True, - use_hrv: bool = True, - use_pv: bool = False, - use_gsp: bool = True, - use_topo: bool = True, - output_size: int = 256, - gsp_in_image: bool = False, - start_time: datetime.datetime = datetime.datetime(2014, 1, 1), - end_time: datetime.datetime = datetime.datetime(2023, 1, 1), -) -> IterDataPipe: - """ - Make GSP national data pipe - - Currently only has GSP and NWP's in them - - Args: - configuration_filename: the configruation filename for the pipe - use_sun: Whether to add sun features or not - use_pv: Whether to use PV input or not - use_hrv: Whether to use HRV Satellite or not - use_sat: Whether to use non-HRV Satellite or not - use_nwp: Whether to use NWP or not - use_topo: Whether to use topographic map or not - use_gsp: Whether to use GSP history - start_time: Start time to select on - end_time: End time to select from - output_size: Size, in pixels, of the output image - gsp_in_image: Add GSP history as channels in MetNet image - - Returns: datapipe - """ - - # load datasets - used_datapipes = open_and_return_datapipes( - configuration_filename=configuration_filename, - use_nwp=use_nwp, - use_topo=use_topo, - use_sat=use_sat, - use_hrv=use_hrv, - use_gsp=use_gsp, - use_pv=use_pv, - ) - # Load GSP national data - used_datapipes["gsp"] = used_datapipes["gsp"].filter_times(start_time, end_time) - - # Now get overlapping time periods - used_datapipes = get_and_return_overlapping_time_periods_and_t0(used_datapipes) - - # And now get time slices - used_datapipes = add_selected_time_slices_from_datapipes(used_datapipes) - - # Now do the extra processing - gsp_history = used_datapipes["gsp"].normalize(normalize_fn=normalize_gsp) - gsp_datapipe = used_datapipes["gsp_future"].normalize(normalize_fn=normalize_gsp) - # Split into GSP for target, only national, and one for history - gsp_datapipe = FilterGSPIDs(gsp_datapipe, gsps_to_keep=[0]) - - if "nwp" in used_datapipes.keys(): - # take nwp time slices - logger.debug("Take NWP time slices") - nwp_datapipe = used_datapipes["nwp"].normalize(mean=UKV_MEAN, std=UKV_STD) - - if "sat" in used_datapipes.keys(): - logger.debug("Take Satellite time slices") - # take sat time slices - sat_datapipe = used_datapipes["sat"].normalize(mean=RSS_MEAN, std=RSS_STD) - - if "hrv" in used_datapipes.keys(): - logger.debug("Take HRV Satellite time slices") - sat_hrv_datapipe = used_datapipes["hrv"].normalize( - mean=RSS_MEAN.sel(channel="HRV"), std=RSS_STD.sel(channel="HRV") - ) - - if "topo" in used_datapipes.keys(): - topo_datapipe = used_datapipes["topo"].map(_select_non_nan_times) - - # Now combine in the MetNet format - modalities = [] - if gsp_in_image and "hrv" in used_datapipes.keys(): - sat_hrv_datapipe, sat_gsp_datapipe = sat_hrv_datapipe.fork(2) - gsp_history = gsp_history.filter_gsp_ids(gsps_to_keep=[0]).create_gsp_image( - image_datapipe=sat_gsp_datapipe - ) - elif gsp_in_image and "sat" in used_datapipes.keys(): - sat_datapipe, sat_gsp_datapipe = sat_datapipe.fork(2) - gsp_history = gsp_history.filter_gsp_ids(gsps_to_keep=[0]).create_gsp_image( - image_datapipe=sat_gsp_datapipe - ) - elif gsp_in_image and "nwp" in used_datapipes.keys(): - nwp_datapipe, nwp_gsp_datapipe = nwp_datapipe.fork(2) - gsp_history = gsp_history.filter_gsp_ids(gsps_to_keep=[0]).create_gsp_image( - image_datapipe=nwp_gsp_datapipe, image_dim="osgb" - ) - if "nwp" in used_datapipes.keys(): - modalities.append(nwp_datapipe) - if "hrv" in used_datapipes.keys(): - modalities.append(sat_hrv_datapipe) - if "sat" in used_datapipes.keys(): - modalities.append(sat_datapipe) - if "topo" in used_datapipes.keys(): - modalities.append(topo_datapipe) - if gsp_in_image: - modalities.append(gsp_history) - - gsp_datapipe, gsp_loc_datapipe = gsp_datapipe.fork(2, buffer_size=5) - - location_datapipe = PickLocations(gsp_loc_datapipe) - - metnet_datapipe = PreProcessMetNet( - modalities, - location_datapipe=location_datapipe, - center_width=500_000, - center_height=1_000_000, - context_height=10_000_000, - context_width=10_000_000, - output_width_pixels=output_size, - output_height_pixels=output_size, - add_sun_features=use_sun, - ) - gsp_datapipe = ConvertGSPToNumpy(gsp_datapipe) - - if not gsp_in_image: - gsp_history = gsp_history.map(_select_non_nan_times) - gsp_history = ConvertGSPToNumpy(gsp_history, return_id=True) - return metnet_datapipe.zip_ocf(gsp_history, gsp_datapipe) # Makes (Inputs, Label) tuples - else: - return metnet_datapipe.zip(gsp_datapipe) diff --git a/ocf_datapipes/training/metnet/metnet_national.md b/ocf_datapipes/training/metnet/metnet_national.md deleted file mode 100644 index eaf210986..000000000 --- a/ocf_datapipes/training/metnet/metnet_national.md +++ /dev/null @@ -1,35 +0,0 @@ -# MetNet National Pipeline - -metnet_national.py is a training pipeline for loading NWP,PV,Satellite,and Topographic data and transforming it as in -the MetNet paper. - -The location is chosen using the center of the National GSP shape. Only the modalities wanted are loaded. -Then a time is chosen, and PV and NWP examples are made. - -```mermaid -graph TD - A[Load GSP] -->|Select Train/Test Times| B(Drop Regional GSP) --> A1 - C[Load NWP] --> CA[Filter] --> A1 - D[Load Satellite] --> DA[Filter] --> A1 - E[Load PV] --> EA[Filter] --> A1 - F[Load Topo] - A1[Select Joint Time Periods] - B1[Select T0 Time] - A1 --> B1 - B1 --> C1 - A1 --> C1 - B1 --> CAA - CA --> CAA[Convert to Target Time] - DA --> C1 - EA --> C1 - C1[Select Time Slice] - AA[Get Location] - B --> AA - A11[PreProcess MetNet] - C1 --> A11 - CAA --> A11 - F --> A11 - AA --> A11 - A111[Return Example] - A11 --> A111 -``` diff --git a/ocf_datapipes/training/metnet/metnet_national.py b/ocf_datapipes/training/metnet/metnet_national.py deleted file mode 100644 index 60c3c867d..000000000 --- a/ocf_datapipes/training/metnet/metnet_national.py +++ /dev/null @@ -1,341 +0,0 @@ -"""Create the training/validation datapipe for training the national MetNet/-2 Model""" - -import datetime -import logging -from datetime import timedelta -from pathlib import Path -from typing import Union - -import xarray -from torch.utils.data.datapipes.datapipe import IterDataPipe - -from ocf_datapipes.config.model import Configuration -from ocf_datapipes.convert import ConvertGSPToNumpy -from ocf_datapipes.load import ( - OpenConfiguration, - OpenGSP, - OpenNWP, - OpenPVFromNetCDF, - OpenSatellite, - OpenTopography, -) -from ocf_datapipes.select import FilterGSPIDs, PickLocations -from ocf_datapipes.training.metnet.metnet_preprocessor import ( - PreProcessMetNetIterDataPipe as PreProcessMetNet, -) -from ocf_datapipes.utils.consts import RSS_MEAN, RSS_STD, UKV_MEAN, UKV_STD - -xarray.set_options(keep_attrs=True) -logger = logging.getLogger("metnet_datapipe") -logger.setLevel(logging.DEBUG) - - -def normalize_gsp(x): # So it can be pickled - """ - Normalize the GSP data - - Args: - x: Input DataArray - - Returns: - Normalized DataArray - """ - return x / x.nominal_capacity_mwp - - -def _select_non_nan_times(x): - return x.fillna(0.0) - - -def metnet_national_datapipe( - configuration_filename: Union[Path, str], - use_sun: bool = True, - use_nwp: bool = True, - use_sat: bool = True, - use_hrv: bool = True, - use_pv: bool = True, - use_topo: bool = True, - mode: str = "train", - max_num_pv_systems: int = -1, - start_time: datetime.datetime = datetime.datetime(2014, 1, 1), - end_time: datetime.datetime = datetime.datetime(2023, 1, 1), -) -> IterDataPipe: - """ - Make GSP national data pipe - - Currently only has GSP and NWP's in them - - Args: - configuration_filename: the configruation filename for the pipe - use_sun: Whether to add sun features or not - use_pv: Whether to use PV input or not - use_hrv: Whether to use HRV Satellite or not - use_sat: Whether to use non-HRV Satellite or not - use_nwp: Whether to use NWP or not - use_topo: Whether to use topographic map or not - mode: Either 'train', where random times are selected, - or 'test' or 'val' where times are sequential - max_num_pv_systems: max number of PV systems to include, <= 0 if no sampling - start_time: Start time to select on - end_time: End time to select from - - Returns: datapipe - """ - - # load configuration - config_datapipe = OpenConfiguration(configuration_filename) - configuration: Configuration = next(iter(config_datapipe)) - - # Check which modalities to use - if use_nwp: - use_nwp = True if configuration.input_data.nwp.nwp_zarr_path != "" else False - if use_pv: - use_pv = True if configuration.input_data.pv.pv_files_groups[0].pv_filename != "" else False - if use_sat: - use_sat = True if configuration.input_data.satellite.satellite_zarr_path != "" else False - if use_hrv: - use_hrv = ( - True if configuration.input_data.hrvsatellite.hrvsatellite_zarr_path != "" else False - ) - if use_topo: - use_topo = ( - True if configuration.input_data.topographic.topographic_filename != "" else False - ) - print( - f"NWP: {use_nwp} Sat: {use_sat}, HRV: {use_hrv} " - f"PV: {use_pv} Sun: {use_sun} Topo: {use_topo}" - ) - # Load GSP national data - logger.debug("Opening GSP Data") - gsp_datapipe = OpenGSP( - gsp_pv_power_zarr_path=configuration.input_data.gsp.gsp_zarr_path - ).filter_times(start_time, end_time) - - gsp_datapipe = FilterGSPIDs(gsp_datapipe, gsps_to_keep=[0]) - - logger.debug("Add t0 idx and normalize") - - gsp_datapipe, gsp_time_periods_datapipe, gsp_t0_datapipe = ( - gsp_datapipe.normalize(normalize_fn=normalize_gsp) - .add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(minutes=30), - history_duration=timedelta(minutes=configuration.input_data.gsp.history_minutes), - ) - .fork(3) - ) - # get time periods - # get contiguous time periods - logger.debug("Getting contiguous time periods") - gsp_time_periods_datapipe = gsp_time_periods_datapipe.find_contiguous_t0_time_periods( - sample_period_duration=timedelta(minutes=30), - history_duration=timedelta(minutes=configuration.input_data.gsp.history_minutes), - forecast_duration=timedelta(minutes=configuration.input_data.gsp.forecast_minutes), - ) - - secondary_datapipes = [] - - # Load NWP data - if use_nwp: - logger.debug("Opening NWP Data") - nwp_datapipe, nwp_time_periods_datapipe = ( - OpenNWP(configuration.input_data.nwp.nwp_zarr_path) - .filter_channels(configuration.input_data.nwp.nwp_channels) - .add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(hours=1), - history_duration=timedelta(minutes=configuration.input_data.nwp.history_minutes), - ) - .fork(2) - ) - - nwp_time_periods_datapipe = nwp_time_periods_datapipe.find_contiguous_t0_time_periods( - sample_period_duration=timedelta(hours=3), # Init times are 3 hours apart - history_duration=timedelta(minutes=configuration.input_data.nwp.history_minutes), - forecast_duration=timedelta(minutes=configuration.input_data.nwp.forecast_minutes), - time_dim="init_time_utc", - ) - secondary_datapipes.append(nwp_time_periods_datapipe) - - if use_sat: - logger.debug("Opening Satellite Data") - sat_datapipe, sat_time_periods_datapipe = ( - OpenSatellite(configuration.input_data.satellite.satellite_zarr_path) - .filter_channels(configuration.input_data.satellite.satellite_channels) - .add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta( - minutes=configuration.input_data.satellite.history_minutes - ), - ) - .fork(2) - ) - - sat_time_periods_datapipe = sat_time_periods_datapipe.find_contiguous_t0_time_periods( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta(minutes=configuration.input_data.satellite.history_minutes), - forecast_duration=timedelta(minutes=1), - ) - secondary_datapipes.append(sat_time_periods_datapipe) - - if use_hrv: - logger.debug("Opening HRV Satellite Data") - sat_hrv_datapipe, sat_hrv_time_periods_datapipe = ( - OpenSatellite(configuration.input_data.hrvsatellite.hrvsatellite_zarr_path) - .add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta( - minutes=configuration.input_data.hrvsatellite.history_minutes - ), - ) - .fork(2) - ) - - sat_hrv_time_periods_datapipe = ( - sat_hrv_time_periods_datapipe.find_contiguous_t0_time_periods( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta( - minutes=configuration.input_data.hrvsatellite.history_minutes - ), - forecast_duration=timedelta(minutes=1), - ) - ) - secondary_datapipes.append(sat_hrv_time_periods_datapipe) - - if use_pv: - logger.debug("Opening PV") - pv_datapipe, pv_time_periods_datapipe = ( - OpenPVFromNetCDF(pv=configuration.input_data.pv) - .add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta(minutes=configuration.input_data.pv.history_minutes), - ) - .fork(2) - ) - - pv_time_periods_datapipe = pv_time_periods_datapipe.find_contiguous_t0_time_periods( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta(minutes=configuration.input_data.pv.history_minutes), - forecast_duration=timedelta(minutes=1), - ) - secondary_datapipes.append(pv_time_periods_datapipe) - - # find joint overlapping timer periods - logger.debug("Getting joint time periods") - overlapping_datapipe = gsp_time_periods_datapipe.filter_to_overlapping_time_periods( - secondary_datapipes=secondary_datapipes, - ) - - # select time periods - gsp_t0_datapipe = gsp_t0_datapipe.filter_time_periods(time_periods=overlapping_datapipe) - - # select t0 periods - logger.debug("Select t0 joint") - num_t0_datapipes = ( - 1 + len(secondary_datapipes) if mode == "train" else 2 + len(secondary_datapipes) - ) - t0_datapipes = gsp_t0_datapipe.pick_t0_times().fork(num_t0_datapipes) - - # take pv time slices - logger.debug("Take GSP time slices") - gsp_datapipe = gsp_datapipe.select_time_slice( - t0_datapipe=t0_datapipes[0], - history_duration=timedelta(minutes=0), - forecast_duration=timedelta(minutes=configuration.input_data.gsp.forecast_minutes), - sample_period_duration=timedelta(minutes=30), - ) - - if use_nwp: - # take nwp time slices - logger.debug("Take NWP time slices") - nwp_datapipe = nwp_datapipe.select_time_slice_nwp( - t0_datapipe=t0_datapipes[1], - sample_period_duration=timedelta(hours=1), - history_duration=timedelta(minutes=configuration.input_data.nwp.history_minutes), - forecast_duration=timedelta(minutes=configuration.input_data.nwp.forecast_minutes), - ).normalize(mean=UKV_MEAN, std=UKV_STD) - - if use_sat: - logger.debug("Take Satellite time slices") - # take sat time slices - sat_datapipe = sat_datapipe.select_time_slice( - t0_datapipe=t0_datapipes[sum([use_nwp, use_sat])], - history_duration=timedelta(minutes=configuration.input_data.satellite.history_minutes), - forecast_duration=timedelta(minutes=0), - sample_period_duration=timedelta(minutes=5), - ).normalize(mean=RSS_MEAN, std=RSS_STD) - - if use_hrv: - logger.debug("Take HRV Satellite time slices") - sat_hrv_datapipe = sat_hrv_datapipe.select_time_slice( - t0_datapipe=t0_datapipes[sum([use_nwp, use_sat, use_hrv])], - history_duration=timedelta( - minutes=configuration.input_data.hrvsatellite.history_minutes - ), - forecast_duration=timedelta(minutes=0), - sample_period_duration=timedelta(minutes=5), - ).normalize(mean=RSS_MEAN.sel(channel="HRV"), std=RSS_STD.sel(channel="HRV")) - - if use_pv: - logger.debug("Take PV Time Slices") - # take pv time slices - pv_datapipe = pv_datapipe.select_time_slice( - t0_datapipe=t0_datapipes[sum([use_nwp, use_sat, use_hrv, use_pv])], - history_duration=timedelta(minutes=configuration.input_data.pv.history_minutes), - forecast_duration=timedelta(minutes=0), - sample_period_duration=timedelta(minutes=5), - ) - - if use_hrv: - image_datapipe = OpenSatellite( - configuration.input_data.hrvsatellite.hrvsatellite_zarr_path - ) - elif use_sat: - image_datapipe = OpenSatellite(configuration.input_data.satellite.satellite_zarr_path) - elif use_nwp: - image_datapipe = OpenNWP(configuration.input_data.nwp.nwp_zarr_path) - - pv_datapipe = pv_datapipe.create_pv_image( - image_datapipe, - normalize=True, - max_num_pv_systems=max_num_pv_systems, - ) - - if use_topo: - topo_datapipe = OpenTopography( - configuration.input_data.topographic.topographic_filename - ).map(_select_non_nan_times) - - # Now combine in the MetNet format - modalities = [] - if use_nwp: - modalities.append(nwp_datapipe) - if use_hrv: - modalities.append(sat_hrv_datapipe) - if use_sat: - modalities.append(sat_datapipe) - if use_pv: - modalities.append(pv_datapipe) - if use_topo: - modalities.append(topo_datapipe) - - gsp_datapipe, gsp_loc_datapipe = gsp_datapipe.fork(2) - - location_datapipe = PickLocations(gsp_loc_datapipe) - - combined_datapipe = PreProcessMetNet( - modalities, - location_datapipe=location_datapipe, - center_width=500_000, - center_height=1_000_000, - context_height=10_000_000, - context_width=10_000_000, - output_width_pixels=256, - output_height_pixels=256, - add_sun_features=use_sun, - ) - - gsp_datapipe = ConvertGSPToNumpy(gsp_datapipe) - if mode == "train": - return combined_datapipe.zip_ocf(gsp_datapipe) # Makes (Inputs, Label) tuples - else: - start_time_datapipe = t0_datapipes[len(t0_datapipes) - 1] # The one extra one - return combined_datapipe.zip_ocf(gsp_datapipe, start_time_datapipe) diff --git a/ocf_datapipes/training/metnet/metnet_national_class.py b/ocf_datapipes/training/metnet/metnet_national_class.py deleted file mode 100644 index 69bcbd4e8..000000000 --- a/ocf_datapipes/training/metnet/metnet_national_class.py +++ /dev/null @@ -1,330 +0,0 @@ -"""Create the training/validation datapipe for training the national MetNet/-2 Model""" - -import datetime -import logging -from datetime import timedelta -from pathlib import Path -from typing import Union - -import xarray -from torch.utils.data.datapipes.datapipe import IterDataPipe - -from ocf_datapipes.config.model import Configuration -from ocf_datapipes.convert import ConvertGSPToNumpy -from ocf_datapipes.load import ( - OpenConfiguration, - OpenGSP, - OpenNWP, - OpenPVFromNetCDF, - OpenSatellite, - OpenTopography, -) -from ocf_datapipes.select import FilterGSPIDs, PickLocations -from ocf_datapipes.training.metnet.metnet_preprocessor import ( - PreProcessMetNetIterDataPipe as PreProcessMetNet, -) -from ocf_datapipes.utils.consts import RSS_MEAN, RSS_STD, UKV_MEAN, UKV_STD - -xarray.set_options(keep_attrs=True) -logger = logging.getLogger("metnet_datapipe") -logger.setLevel(logging.DEBUG) - - -def _select_non_nan_times(x): - return x.fillna(0.0) - - -def metnet_national_datapipe( - configuration_filename: Union[Path, str], - use_sun: bool = True, - use_nwp: bool = True, - use_sat: bool = True, - use_hrv: bool = True, - use_pv: bool = True, - use_topo: bool = True, - mode: str = "train", - max_num_pv_systems: int = -1, - start_time: datetime.datetime = datetime.datetime(2014, 1, 1), - end_time: datetime.datetime = datetime.datetime(2023, 1, 1), -) -> IterDataPipe: - """ - Make GSP national data pipe - - Currently only has GSP and NWP's in them - - Args: - configuration_filename: the configruation filename for the pipe - use_sun: Whether to add sun features or not - use_pv: Whether to use PV input or not - use_hrv: Whether to use HRV Satellite or not - use_sat: Whether to use non-HRV Satellite or not - use_nwp: Whether to use NWP or not - use_topo: Whether to use topographic map or not - mode: Either 'train', where random times are selected, - or 'test' or 'val' where times are sequential - max_num_pv_systems: max number of PV systems to include, <= 0 if no sampling - start_time: Start time to select on - end_time: End time to select from - - Returns: datapipe - """ - - # load configuration - config_datapipe = OpenConfiguration(configuration_filename) - configuration: Configuration = next(iter(config_datapipe)) - - # Check which modalities to use - if use_nwp: - use_nwp = True if configuration.input_data.nwp.nwp_zarr_path != "" else False - if use_pv: - use_pv = True if configuration.input_data.pv.pv_files_groups[0].pv_filename != "" else False - if use_sat: - use_sat = True if configuration.input_data.satellite.satellite_zarr_path != "" else False - if use_hrv: - use_hrv = ( - True if configuration.input_data.hrvsatellite.hrvsatellite_zarr_path != "" else False - ) - if use_topo: - use_topo = ( - True if configuration.input_data.topographic.topographic_filename != "" else False - ) - print( - f"NWP: {use_nwp} Sat: {use_sat}, HRV: {use_hrv} " - f"PV: {use_pv} Sun: {use_sun} Topo: {use_topo}" - ) - # Load GSP national data - logger.debug("Opening GSP Data") - gsp_datapipe = OpenGSP( - gsp_pv_power_zarr_path=configuration.input_data.gsp.gsp_zarr_path - ).filter_times(start_time, end_time) - - gsp_datapipe = FilterGSPIDs(gsp_datapipe, gsps_to_keep=[0]) - - logger.debug("Add t0 idx and normalize") - - ( - gsp_datapipe, - gsp_time_periods_datapipe, - gsp_t0_datapipe, - ) = gsp_datapipe.add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(minutes=30), - history_duration=timedelta(minutes=configuration.input_data.gsp.history_minutes), - ).fork( - 3 - ) - # get time periods - # get contiguous time periods - logger.debug("Getting contiguous time periods") - gsp_time_periods_datapipe = gsp_time_periods_datapipe.find_contiguous_t0_time_periods( - sample_period_duration=timedelta(minutes=30), - history_duration=timedelta(minutes=configuration.input_data.gsp.history_minutes), - forecast_duration=timedelta(minutes=configuration.input_data.gsp.forecast_minutes), - ) - - secondary_datapipes = [] - - # Load NWP data - if use_nwp: - logger.debug("Opening NWP Data") - nwp_datapipe, nwp_time_periods_datapipe = ( - OpenNWP(configuration.input_data.nwp.nwp_zarr_path) - .filter_channels(configuration.input_data.nwp.nwp_channels) - .add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(hours=1), - history_duration=timedelta(minutes=configuration.input_data.nwp.history_minutes), - ) - .fork(2) - ) - - nwp_time_periods_datapipe = nwp_time_periods_datapipe.find_contiguous_t0_time_periods( - sample_period_duration=timedelta(hours=3), # Init times are 3 hours apart - history_duration=timedelta(minutes=configuration.input_data.nwp.history_minutes), - forecast_duration=timedelta(minutes=configuration.input_data.nwp.forecast_minutes), - time_dim="init_time_utc", - ) - secondary_datapipes.append(nwp_time_periods_datapipe) - - if use_sat: - logger.debug("Opening Satellite Data") - sat_datapipe, sat_time_periods_datapipe = ( - OpenSatellite(configuration.input_data.satellite.satellite_zarr_path) - .filter_channels(configuration.input_data.satellite.satellite_channels) - .add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta( - minutes=configuration.input_data.satellite.history_minutes - ), - ) - .fork(2) - ) - - sat_time_periods_datapipe = sat_time_periods_datapipe.find_contiguous_t0_time_periods( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta(minutes=configuration.input_data.satellite.history_minutes), - forecast_duration=timedelta(minutes=1), - ) - secondary_datapipes.append(sat_time_periods_datapipe) - - if use_hrv: - logger.debug("Opening HRV Satellite Data") - sat_hrv_datapipe, sat_hrv_time_periods_datapipe = ( - OpenSatellite(configuration.input_data.hrvsatellite.hrvsatellite_zarr_path) - .add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta( - minutes=configuration.input_data.hrvsatellite.history_minutes - ), - ) - .fork(2) - ) - - sat_hrv_time_periods_datapipe = ( - sat_hrv_time_periods_datapipe.find_contiguous_t0_time_periods( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta( - minutes=configuration.input_data.hrvsatellite.history_minutes - ), - forecast_duration=timedelta(minutes=1), - ) - ) - secondary_datapipes.append(sat_hrv_time_periods_datapipe) - - if use_pv: - logger.debug("Opening PV") - pv_datapipe, pv_time_periods_datapipe = ( - OpenPVFromNetCDF(pv=configuration.input_data.pv) - .add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta(minutes=configuration.input_data.pv.history_minutes), - ) - .fork(2) - ) - - pv_time_periods_datapipe = pv_time_periods_datapipe.find_contiguous_t0_time_periods( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta(minutes=configuration.input_data.pv.history_minutes), - forecast_duration=timedelta(minutes=1), - ) - secondary_datapipes.append(pv_time_periods_datapipe) - - # find joint overlapping timer periods - logger.debug("Getting joint time periods") - overlapping_datapipe = gsp_time_periods_datapipe.filter_to_overlapping_time_periods( - secondary_datapipes=secondary_datapipes, - ) - - # select time periods - gsp_t0_datapipe = gsp_t0_datapipe.filter_time_periods(time_periods=overlapping_datapipe) - - # select t0 periods - logger.debug("Select t0 joint") - num_t0_datapipes = ( - 1 + len(secondary_datapipes) if mode == "train" else 2 + len(secondary_datapipes) - ) - t0_datapipes = gsp_t0_datapipe.pick_t0_times().fork(num_t0_datapipes) - - # take pv time slices - logger.debug("Take GSP time slices") - gsp_datapipe = gsp_datapipe.select_time_slice( - t0_datapipe=t0_datapipes[0], - history_duration=timedelta(minutes=0), - forecast_duration=timedelta(minutes=configuration.input_data.gsp.forecast_minutes), - sample_period_duration=timedelta(minutes=30), - ) - - if use_nwp: - # take nwp time slices - logger.debug("Take NWP time slices") - nwp_datapipe = nwp_datapipe.select_time_slice_nwp( - t0_datapipe=t0_datapipes[1], - sample_period_duration=timedelta(hours=1), - history_duration=timedelta(minutes=configuration.input_data.nwp.history_minutes), - forecast_duration=timedelta(minutes=configuration.input_data.nwp.forecast_minutes), - ).normalize(mean=UKV_MEAN, std=UKV_STD) - - if use_sat: - logger.debug("Take Satellite time slices") - # take sat time slices - sat_datapipe = sat_datapipe.select_time_slice( - t0_datapipe=t0_datapipes[sum([use_nwp, use_sat])], - history_duration=timedelta(minutes=configuration.input_data.satellite.history_minutes), - forecast_duration=timedelta(minutes=0), - sample_period_duration=timedelta(minutes=5), - ).normalize(mean=RSS_MEAN, std=RSS_STD) - - if use_hrv: - logger.debug("Take HRV Satellite time slices") - sat_hrv_datapipe = sat_hrv_datapipe.select_time_slice( - t0_datapipe=t0_datapipes[sum([use_nwp, use_sat, use_hrv])], - history_duration=timedelta( - minutes=configuration.input_data.hrvsatellite.history_minutes - ), - forecast_duration=timedelta(minutes=0), - sample_period_duration=timedelta(minutes=5), - ).normalize(mean=RSS_MEAN.sel(channel="HRV"), std=RSS_STD.sel(channel="HRV")) - - if use_pv: - logger.debug("Take PV Time Slices") - # take pv time slices - pv_datapipe = pv_datapipe.select_time_slice( - t0_datapipe=t0_datapipes[sum([use_nwp, use_sat, use_hrv, use_pv])], - history_duration=timedelta(minutes=configuration.input_data.pv.history_minutes), - forecast_duration=timedelta(minutes=0), - sample_period_duration=timedelta(minutes=5), - ) - - if use_hrv: - image_datapipe = OpenSatellite( - configuration.input_data.hrvsatellite.hrvsatellite_zarr_path - ) - elif use_sat: - image_datapipe = OpenSatellite(configuration.input_data.satellite.satellite_zarr_path) - elif use_nwp: - image_datapipe = OpenNWP(configuration.input_data.nwp.nwp_zarr_path) - - pv_datapipe = pv_datapipe.create_pv_image( - image_datapipe, - normalize=True, - max_num_pv_systems=max_num_pv_systems, - ) - - if use_topo: - topo_datapipe = OpenTopography( - configuration.input_data.topographic.topographic_filename - ).map(_select_non_nan_times) - - # Now combine in the MetNet format - modalities = [] - if use_nwp: - modalities.append(nwp_datapipe) - if use_hrv: - modalities.append(sat_hrv_datapipe) - if use_sat: - modalities.append(sat_datapipe) - if use_pv: - modalities.append(pv_datapipe) - if use_topo: - modalities.append(topo_datapipe) - - gsp_datapipe, gsp_loc_datapipe = gsp_datapipe.fork(2) - - location_datapipe = PickLocations(gsp_loc_datapipe) - - combined_datapipe = PreProcessMetNet( - modalities, - location_datapipe=location_datapipe, - center_width=500_000, - center_height=1_000_000, - context_height=10_000_000, - context_width=10_000_000, - output_width_pixels=256, - output_height_pixels=256, - add_sun_features=use_sun, - ) - - gsp_datapipe = ConvertGSPToNumpy(gsp_datapipe) - if mode == "train": - return combined_datapipe.zip_ocf(gsp_datapipe) # Makes (Inputs, Label) tuples - else: - start_time_datapipe = t0_datapipes[len(t0_datapipes) - 1] # The one extra one - return combined_datapipe.zip_ocf(gsp_datapipe, start_time_datapipe) diff --git a/ocf_datapipes/training/metnet/metnet_preprocessor.py b/ocf_datapipes/training/metnet/metnet_preprocessor.py deleted file mode 100644 index cb474cc7f..000000000 --- a/ocf_datapipes/training/metnet/metnet_preprocessor.py +++ /dev/null @@ -1,351 +0,0 @@ -"""Preprocessing for MetNet-type inputs""" - -import itertools -from typing import List - -import numpy as np -import pvlib -import xarray as xr -from torch.utils.data import IterDataPipe, functional_datapipe - -from ocf_datapipes.select.select_spatial_slice import convert_coords_to_match_xarray -from ocf_datapipes.utils import Zipper -from ocf_datapipes.utils.consts import ( - AZIMUTH_MEAN, - AZIMUTH_STD, - ELEVATION_MEAN, - ELEVATION_STD, -) -from ocf_datapipes.utils.geospatial import ( - geostationary_area_coords_to_lonlat, - move_lon_lat_by_meters, - osgb_to_lon_lat, - spatial_coord_type, -) -from ocf_datapipes.utils.parallel import run_with_threadpool -from ocf_datapipes.utils.utils import trigonometric_datetime_transformation - - -@functional_datapipe("preprocess_metnet") -class PreProcessMetNetIterDataPipe(IterDataPipe): - """Preprocess set of Xarray datasets similar to MetNet-1""" - - def __init__( - self, - source_datapipes: List[IterDataPipe], - location_datapipe: IterDataPipe, - context_width: int, - context_height: int, - center_width: int, - center_height: int, - output_height_pixels: int, - output_width_pixels: int, - add_sun_features: bool = False, - only_sun: bool = False, - ): - """ - - Processes set of Xarray datasets similar to MetNet - - In terms of taking all available source datapipes: - 1. selecting the same context area of interest - 2. Creating a center crop of the center_height, center_width - 3. Downsampling the context area of interest to the same shape as the center crop - 4. Stacking those context images on the center crop. - 5. Add Month, Day, Hour channels for each input time - 6. Add Sun position as well? - - This would be designed originally for NWP+Satellite+Topographic data sources. - To add the PV power for lots of sites, the PV power would - need to be able to be on a grid for the context/center - crops and then for the downsample - - This also appends Lat/Lon coordinates to the stack, - and returns a new Numpy array with the stacked data - - Args: - source_datapipes: Datapipes that emit xarray datasets - with latitude/longitude coordinates included - location_datapipe: Datapipe emitting location coordinate for center of example - context_width: Width of the context area - context_height: Height of the context area - center_width: Center width of the area of interest - center_height: Center height of the area of interest - output_height_pixels: Output height in pixels - output_width_pixels: Output width in pixels - add_sun_features: Whether to calculate and - add Sun elevation and azimuth for each center pixel - only_sun: Whether to only output sun features - Assumes only one input to give the coordinates - """ - self.source_datapipes = source_datapipes - self.location_datapipe = location_datapipe - self.context_width = context_width - self.context_height = context_height - self.center_width = center_width - self.center_height = center_height - self.output_height_pixels = output_height_pixels - self.output_width_pixels = output_width_pixels - self.add_sun_features = add_sun_features - self.only_sun = only_sun - - def __iter__(self) -> np.ndarray: - for xr_datas, location in Zipper(Zipper(*self.source_datapipes), self.location_datapipe): - # TODO Use the Lat/Long coordinates of the center array for the lat/lon stuff - # Do the resampling and cropping in parallel - xr_datas = run_with_threadpool( - zip( - _bicycle(xr_datas), - itertools.repeat(location), - itertools.chain.from_iterable( - zip( - itertools.repeat(self.center_width), - itertools.repeat(self.context_width), - ) - ), - itertools.chain.from_iterable( - zip( - itertools.repeat(self.center_height), - itertools.repeat(self.context_height), - ) - ), - itertools.repeat(self.output_height_pixels), - itertools.repeat(self.output_width_pixels), - ), - _crop_and_resample_wrapper, - max_workers=8, - scheduled_tasks=int(len(xr_datas) * 2), # One for center, one for context - ) - xr_datas = list(xr_datas) - # Output is then list of center, context, center, context, etc. - # So we need to split the list into two lists of the same length, - # one with centers, one with contexts - centers = xr_datas[::2] - contexts = xr_datas[1::2] - # Now do the first one for the sun and other features - xr_center = centers[0] - _extra_time_dim = ( - "target_time_utc" if "target_time_utc" in xr_center.dims else "time_utc" - ) - # Add in time features for each timestep - time_image = _create_time_image( - xr_center, - time_dim=_extra_time_dim, - output_height_pixels=self.output_height_pixels, - output_width_pixels=self.output_width_pixels, - ) - contexts.append(time_image) - # Need to add sun features - if self.add_sun_features: - sun_image = _create_sun_image( - image_xr=xr_center, - x_dim="x_osgb" if "x_osgb" in xr_center.dims else "x_geostationary", - y_dim="y_osgb" if "y_osgb" in xr_center.dims else "y_geostationary", - time_dim=_extra_time_dim, - normalize=True, - ) - if self.only_sun: - contexts = [time_image, sun_image] - else: - contexts.append(sun_image) - for xr_index in range(len(centers)): - xr_center = centers[xr_index] - xr_context = contexts[xr_index] - xr_center = xr_center.to_numpy() - xr_context = xr_context.to_numpy() - if len(xr_center.shape) == 2: # Need to add channel dimension - xr_center = np.expand_dims(xr_center, axis=0) - xr_context = np.expand_dims(xr_context, axis=0) - if len(xr_center.shape) == 3: # Need to add channel dimension - xr_center = np.expand_dims(xr_center, axis=1) - xr_context = np.expand_dims(xr_context, axis=1) - centers[xr_index] = xr_center - contexts[xr_index] = xr_context - # Pad out time dimension to be the same, using the largest one - # All should have 4 dimensions at this point - max_time_len = max( - np.max([c.shape[0] for c in centers]), np.max([c.shape[0] for c in contexts]) - ) - for i in range(len(centers)): - centers[i] = np.pad( - centers[i], - pad_width=( - (0, max_time_len - centers[i].shape[0]), - (0, 0), - (0, 0), - (0, 0), - ), - mode="constant", - constant_values=0.0, - ) - for i in range(len(contexts)): - contexts[i] = np.pad( - contexts[i], - pad_width=( - (0, max_time_len - contexts[i].shape[0]), - (0, 0), - (0, 0), - (0, 0), - ), - mode="constant", - constant_values=0.0, - ) - stacked_data = np.concatenate([*centers, *contexts], axis=1) - yield stacked_data - - -def _crop_and_resample_wrapper(args): - return _crop_and_resample(*args) - - -def _bicycle(xr_datas): - for xr_data in xr_datas: - yield xr_data - yield xr_data - - -def _crop_and_resample( - xr_data: xr.Dataset, - location, - context_width, - context_height, - output_height_pixels, - output_width_pixels, -): - xr_context: xr.Dataset = _get_spatial_crop( - xr_data, - location=location, - roi_width_meters=context_width, - roi_height_meters=context_height, - ) - - # Resamples to the same number of pixels for both center and contexts - xr_context = _resample_to_pixel_size(xr_context, output_height_pixels, output_width_pixels) - return xr_context - - -def _get_spatial_crop(xr_data, location, roi_height_meters: int, roi_width_meters: int): - xr_coords, xr_x_dim, xr_y_dim = spatial_coord_type(xr_data) - - # Compute the index for left and right: - half_height = roi_height_meters // 2 - half_width = roi_width_meters // 2 - - # Find the bounding box values for the location in either lon-lat or OSGB coord systems - if location.coordinate_system == "lon_lat": - right, top = move_lon_lat_by_meters( - location.x, - location.y, - half_width, - half_height, - ) - left, bottom = move_lon_lat_by_meters( - location.x, - location.y, - -half_width, - -half_height, - ) - - elif location.coordinate_system == "osgb": - left = location.x - half_width - right = location.x + half_width - bottom = location.y - half_height - top = location.y + half_height - - else: - raise ValueError(f"Location coord system not recognized: {location.coordinate_system}") - - (left, right), (bottom, top) = convert_coords_to_match_xarray( - x=np.array([left, right], dtype=np.float32), - y=np.array([bottom, top], dtype=np.float32), - from_coords=location.coordinate_system, - xr_data=xr_data, - ) - - # Select a patch from the xarray data - x_mask = (left <= xr_data[xr_x_dim]) & (xr_data[xr_x_dim] <= right) - y_mask = (bottom <= xr_data[xr_y_dim]) & (xr_data[xr_y_dim] <= top) - selected = xr_data.isel({xr_x_dim: x_mask, xr_y_dim: y_mask}) - - return selected - - -def _resample_to_pixel_size(xr_data, height_pixels, width_pixels) -> np.ndarray: - if "x_geostationary" in xr_data.dims: - x_coords = xr_data["x_geostationary"].values - y_coords = xr_data["y_geostationary"].values - elif "x_osgb" in xr_data.dims: - x_coords = xr_data["x_osgb"].values - y_coords = xr_data["y_osgb"].values - else: - x_coords = xr_data["x"].values - y_coords = xr_data["y"].values - # Resample down to the number of pixels wanted - x_coords = np.linspace(x_coords[0], x_coords[-1], num=width_pixels) - y_coords = np.linspace(y_coords[0], y_coords[-1], num=height_pixels) - if "x_geostationary" in xr_data.dims: - xr_data = xr_data.interp( - x_geostationary=x_coords, y_geostationary=y_coords, method="linear" - ) - elif "x_osgb" in xr_data.dims: - xr_data = xr_data.interp(x_osgb=x_coords, y_osgb=y_coords, method="linear") - else: - xr_data = xr_data.interp(x=x_coords, y=y_coords, method="linear") - # Extract just the data now - return xr_data - - -def _create_time_image(xr_data, time_dim: str, output_height_pixels: int, output_width_pixels: int): - # Create trig decomposition of datetime values, tiled over output height and width - datetimes = xr_data[time_dim].values - trig_decomposition = trigonometric_datetime_transformation(datetimes) - tiled_data = np.expand_dims(trig_decomposition, (2, 3)) - tiled_data = np.tile(tiled_data, (1, 1, output_height_pixels, output_width_pixels)) - return tiled_data - - -def _create_sun_image(image_xr, x_dim, y_dim, time_dim, normalize): - # Create empty image to use for the PV Systems, assumes image has x and y coordinates - sun_image = np.zeros( - ( - 2, # Azimuth and elevation - len(image_xr[y_dim]), - len(image_xr[x_dim]), - len(image_xr[time_dim]), - ), - dtype=np.float32, - ) - if "geostationary" in x_dim: - lons, lats = geostationary_area_coords_to_lonlat( - x=image_xr[x_dim].values, y=image_xr[y_dim].values, xr_data=image_xr - ) - else: - lons, lats = osgb_to_lon_lat(x=image_xr.x_osgb.values, y=image_xr.y_osgb.values) - time_utc = image_xr[time_dim].values - - # Loop round each example to get the Sun's elevation and azimuth: - # Go through each time on its own, lat lons still in order of image - # TODO Make this faster - # dt = pd.DatetimeIndex(dt) # pvlib expects a `pd.DatetimeIndex`. - for example_idx, (lat, lon) in enumerate(zip(lats, lons)): - solpos = pvlib.solarposition.get_solarposition( - time=time_utc, - latitude=lat, - longitude=lon, - # Which `method` to use? - # pyephem seemed to be a good mix between speed and ease but causes segfaults! - # nrel_numba doesn't work when using multiple worker processes. - # nrel_c is probably fastest but requires C code to be manually compiled: - # https://midcdmz.nrel.gov/spa/ - ) - sun_image[0][:][example_idx] = solpos["azimuth"] - sun_image[1][example_idx][:] = solpos["elevation"] - - # Flip back to normal ordering - sun_image = np.transpose(sun_image, [3, 0, 1, 2]) - - # Normalize. - if normalize: - sun_image[:, 0] = (sun_image[:, 0] - AZIMUTH_MEAN) / AZIMUTH_STD - sun_image[:, 1] = (sun_image[:, 1] - ELEVATION_MEAN) / ELEVATION_STD - return sun_image diff --git a/ocf_datapipes/training/metnet/metnet_pv_national.py b/ocf_datapipes/training/metnet/metnet_pv_national.py deleted file mode 100644 index ae088e2cb..000000000 --- a/ocf_datapipes/training/metnet/metnet_pv_national.py +++ /dev/null @@ -1,346 +0,0 @@ -"""Create the training/validation datapipe for training the national MetNet/-2 Model""" - -import datetime -import logging -from datetime import timedelta -from pathlib import Path -from typing import Union - -import xarray -from torch.utils.data.datapipes.datapipe import IterDataPipe - -from ocf_datapipes.config.model import Configuration -from ocf_datapipes.convert import ConvertGSPToNumpy -from ocf_datapipes.load import ( - OpenConfiguration, - OpenGSP, - OpenNWP, - OpenPVFromNetCDF, - OpenSatellite, - OpenTopography, -) -from ocf_datapipes.select import FilterGSPIDs, PickLocations -from ocf_datapipes.training.metnet.metnet_preprocessor import ( - PreProcessMetNetIterDataPipe as PreProcessMetNet, -) -from ocf_datapipes.utils.consts import RSS_MEAN, RSS_STD, UKV_MEAN, UKV_STD - -xarray.set_options(keep_attrs=True) -logger = logging.getLogger("metnet_datapipe") -logger.setLevel(logging.DEBUG) - - -def normalize_gsp(x): # So it can be pickled - """ - Normalize the GSP data - - Args: - x: Input DataArray - - Returns: - Normalized DataArray - """ - return x / x.nominal_capacity_mwp - - -def normalize_pv(x): # So it can be pickled - """ - Normalize the GSP data - - Args: - x: Input DataArray - - Returns: - Normalized DataArray - """ - return x / x.observed_capacity_wp - - -def _select_non_nan_times(x): - return x.fillna(0.0) - - -def metnet_national_datapipe( - configuration_filename: Union[Path, str], - use_sun: bool = True, - use_nwp: bool = True, - use_sat: bool = True, - use_hrv: bool = True, - use_pv: bool = True, - use_topo: bool = True, - mode: str = "train", - max_num_pv_systems: int = -1, - start_time: datetime.datetime = datetime.datetime(2014, 1, 1), - end_time: datetime.datetime = datetime.datetime(2023, 1, 1), -) -> IterDataPipe: - """ - Make GSP national data pipe - - Currently only has GSP and NWP's in them - - Args: - configuration_filename: the configruation filename for the pipe - use_sun: Whether to add sun features or not - use_pv: Whether to use PV input or not - use_hrv: Whether to use HRV Satellite or not - use_sat: Whether to use non-HRV Satellite or not - use_nwp: Whether to use NWP or not - use_topo: Whether to use topographic map or not - mode: Either 'train', where random times are selected, - or 'test' or 'val' where times are sequential - max_num_pv_systems: max number of PV systems to include, <= 0 if no sampling - start_time: Start time to select on - end_time: End time to select from - - Returns: datapipe - """ - - # load configuration - config_datapipe = OpenConfiguration(configuration_filename) - configuration: Configuration = next(iter(config_datapipe)) - - # Check which modalities to use - if use_nwp: - use_nwp = True if configuration.input_data.nwp.nwp_zarr_path != "" else False - if use_pv: - use_pv = True if configuration.input_data.pv.pv_files_groups[0].pv_filename != "" else False - if use_sat: - use_sat = True if configuration.input_data.satellite.satellite_zarr_path != "" else False - if use_hrv: - use_hrv = ( - True if configuration.input_data.hrvsatellite.hrvsatellite_zarr_path != "" else False - ) - if use_topo: - use_topo = ( - True if configuration.input_data.topographic.topographic_filename != "" else False - ) - print( - f"NWP: {use_nwp} Sat: {use_sat}, HRV: {use_hrv}" - f" PV: {use_pv} Sun: {use_sun} Topo: {use_topo}" - ) - # Load GSP national data - logger.debug("Opening GSP Data") - gsp_datapipe = OpenGSP( - gsp_pv_power_zarr_path=configuration.input_data.gsp.gsp_zarr_path - ).filter_times(start_time, end_time) - - gsp_datapipe = FilterGSPIDs(gsp_datapipe, gsps_to_keep=[0]) - - logger.debug("Add t0 idx and normalize") - - gsp_datapipe, gsp_time_periods_datapipe, gsp_t0_datapipe = ( - gsp_datapipe.normalize(normalize_fn=normalize_gsp) - .add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(minutes=30), - history_duration=timedelta(minutes=configuration.input_data.gsp.history_minutes), - ) - .fork(3) - ) - # get time periods - # get contiguous time periods - logger.debug("Getting contiguous time periods") - gsp_time_periods_datapipe = gsp_time_periods_datapipe.find_contiguous_t0_time_periods( - sample_period_duration=timedelta(minutes=30), - history_duration=timedelta(minutes=configuration.input_data.gsp.history_minutes), - forecast_duration=timedelta(minutes=configuration.input_data.gsp.forecast_minutes), - ) - - secondary_datapipes = [] - - # Load NWP data - if use_nwp: - logger.debug("Opening NWP Data") - nwp_datapipe, nwp_time_periods_datapipe = ( - OpenNWP(configuration.input_data.nwp.nwp_zarr_path) - .filter_channels(configuration.input_data.nwp.nwp_channels) - .add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(hours=1), - history_duration=timedelta(minutes=configuration.input_data.nwp.history_minutes), - ) - .fork(2) - ) - - nwp_time_periods_datapipe = nwp_time_periods_datapipe.find_contiguous_t0_time_periods( - sample_period_duration=timedelta(hours=3), # Init times are 3 hours apart - history_duration=timedelta(minutes=configuration.input_data.nwp.history_minutes), - forecast_duration=timedelta(minutes=configuration.input_data.nwp.forecast_minutes), - time_dim="init_time_utc", - ) - secondary_datapipes.append(nwp_time_periods_datapipe) - - if use_sat: - logger.debug("Opening Satellite Data") - sat_datapipe, sat_time_periods_datapipe = ( - OpenSatellite(configuration.input_data.satellite.satellite_zarr_path) - .filter_channels(configuration.input_data.satellite.satellite_channels) - .add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta( - minutes=configuration.input_data.satellite.history_minutes - ), - ) - .fork(2) - ) - - sat_time_periods_datapipe = sat_time_periods_datapipe.find_contiguous_t0_time_periods( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta(minutes=configuration.input_data.satellite.history_minutes), - forecast_duration=timedelta(minutes=1), - ) - secondary_datapipes.append(sat_time_periods_datapipe) - - if use_hrv: - logger.debug("Opening HRV Satellite Data") - sat_hrv_datapipe, sat_hrv_time_periods_datapipe = ( - OpenSatellite(configuration.input_data.hrvsatellite.hrvsatellite_zarr_path) - .add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta( - minutes=configuration.input_data.hrvsatellite.history_minutes - ), - ) - .fork(2) - ) - - sat_hrv_time_periods_datapipe = ( - sat_hrv_time_periods_datapipe.find_contiguous_t0_time_periods( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta( - minutes=configuration.input_data.hrvsatellite.history_minutes - ), - forecast_duration=timedelta(minutes=1), - ) - ) - secondary_datapipes.append(sat_hrv_time_periods_datapipe) - - if use_pv: - logger.debug("Opening PV") - pv_datapipe, pv_time_periods_datapipe = ( - OpenPVFromNetCDF(pv=configuration.input_data.pv) - .add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta(minutes=configuration.input_data.pv.history_minutes), - ) - .fork(2) - ) - - pv_time_periods_datapipe = pv_time_periods_datapipe.find_contiguous_t0_time_periods( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta(minutes=configuration.input_data.pv.history_minutes), - forecast_duration=timedelta(minutes=1), - ) - secondary_datapipes.append(pv_time_periods_datapipe) - - # find joint overlapping timer periods - logger.debug("Getting joint time periods") - overlapping_datapipe = gsp_time_periods_datapipe.filter_to_overlapping_time_periods( - secondary_datapipes=secondary_datapipes, - ) - - # select time periods - gsp_t0_datapipe = gsp_t0_datapipe.filter_time_periods(time_periods=overlapping_datapipe) - - # select t0 periods - logger.debug("Select t0 joint") - num_t0_datapipes = ( - 1 + len(secondary_datapipes) if mode == "train" else 2 + len(secondary_datapipes) - ) - t0_datapipes = gsp_t0_datapipe.pick_t0_times( - return_all_times=False # if mode == "train" else True - ).fork(num_t0_datapipes) - - # take pv time slices - logger.debug("Take GSP time slices") - gsp_datapipe = gsp_datapipe.select_time_slice( - t0_datapipe=t0_datapipes[0], - history_duration=timedelta(minutes=0), - forecast_duration=timedelta(minutes=configuration.input_data.gsp.forecast_minutes), - sample_period_duration=timedelta(minutes=30), - ) - - if use_nwp: - # take nwp time slices - logger.debug("Take NWP time slices") - nwp_datapipe = nwp_datapipe.select_time_slice_nwp( - t0_datapipe=t0_datapipes[1], - sample_period_duration=timedelta(hours=1), - history_duration=timedelta(minutes=configuration.input_data.nwp.history_minutes), - forecast_duration=timedelta(minutes=configuration.input_data.nwp.forecast_minutes), - ).normalize(mean=UKV_MEAN, std=UKV_STD) - - if use_sat: - logger.debug("Take Satellite time slices") - # take sat time slices - sat_datapipe = sat_datapipe.select_time_slice( - t0_datapipe=t0_datapipes[sum([use_nwp, use_sat])], - history_duration=timedelta(minutes=configuration.input_data.satellite.history_minutes), - forecast_duration=timedelta(minutes=0), - sample_period_duration=timedelta(minutes=5), - ).normalize(mean=RSS_MEAN, std=RSS_STD) - - if use_hrv: - logger.debug("Take HRV Satellite time slices") - sat_hrv_datapipe = sat_hrv_datapipe.select_time_slice( - t0_datapipe=t0_datapipes[sum([use_nwp, use_sat, use_hrv])], - history_duration=timedelta( - minutes=configuration.input_data.hrvsatellite.history_minutes - ), - forecast_duration=timedelta(minutes=0), - sample_period_duration=timedelta(minutes=5), - ).normalize(mean=RSS_MEAN.sel(channel="HRV"), std=RSS_STD.sel(channel="HRV")) - - if use_pv: - logger.debug("Take PV Time Slices") - # take pv time slices - pv_datapipe = pv_datapipe.normalize(normalize_fn=normalize_pv) - pv_datapipe = pv_datapipe.select_time_slice( - t0_datapipe=t0_datapipes[sum([use_nwp, use_sat, use_hrv, use_pv])], - history_duration=timedelta(minutes=configuration.input_data.pv.history_minutes), - forecast_duration=timedelta(minutes=0), - sample_period_duration=timedelta(minutes=5), - ) - - if use_topo: - topo_datapipe = OpenTopography( - configuration.input_data.topographic.topographic_filename - ).map(_select_non_nan_times) - - # Now combine in the MetNet format - modalities = [] - if use_nwp: - modalities.append(nwp_datapipe) - if use_hrv: - modalities.append(sat_hrv_datapipe) - if use_sat: - modalities.append(sat_datapipe) - if use_topo: - modalities.append(topo_datapipe) - - gsp_datapipe, gsp_loc_datapipe = gsp_datapipe.fork(2) - - location_datapipe = PickLocations(gsp_loc_datapipe) - - metnet_datapipe = PreProcessMetNet( - modalities, - location_datapipe=location_datapipe, - center_width=500_000, - center_height=1_000_000, - context_height=10_000_000, - context_width=10_000_000, - output_width_pixels=256, - output_height_pixels=256, - add_sun_features=use_sun, - ) - - pv_datapipe = ( - pv_datapipe.ensure_n_pv_systems_per_example(n_pv_systems_per_example=max_num_pv_systems) - .map(_select_non_nan_times) - .convert_pv_to_numpy(return_pv_system_row=True) - ) - combined_datapipe = metnet_datapipe.zip_ocf(pv_datapipe) - gsp_datapipe = ConvertGSPToNumpy(gsp_datapipe) - if mode == "train": - return combined_datapipe.zip_ocf(gsp_datapipe) # Makes (Inputs, Label) tuples - else: - start_time_datapipe = t0_datapipes[len(t0_datapipes) - 1] # The one extra one - return combined_datapipe.zip_ocf(gsp_datapipe, start_time_datapipe) diff --git a/ocf_datapipes/training/metnet/metnet_pv_site.py b/ocf_datapipes/training/metnet/metnet_pv_site.py deleted file mode 100644 index 7878a5aa2..000000000 --- a/ocf_datapipes/training/metnet/metnet_pv_site.py +++ /dev/null @@ -1,213 +0,0 @@ -"""Create the training/validation datapipe for training the national MetNet/-2 Model""" - -import datetime -import logging -from pathlib import Path -from typing import Union - -import xarray -from torch.utils.data.datapipes.datapipe import IterDataPipe - -from ocf_datapipes.convert import ConvertPVToNumpy -from ocf_datapipes.select import PickLocations -from ocf_datapipes.training.common import ( - add_selected_time_slices_from_datapipes, - get_and_return_overlapping_time_periods_and_t0, - open_and_return_datapipes, -) -from ocf_datapipes.training.metnet.metnet_preprocessor import ( - PreProcessMetNetIterDataPipe as PreProcessMetNet, -) -from ocf_datapipes.utils.consts import RSS_MEAN, RSS_STD, UKV_MEAN, UKV_STD -from ocf_datapipes.utils.future import ThreadPoolMapperIterDataPipe as ThreadPoolMapper - -xarray.set_options(keep_attrs=True) -logger = logging.getLogger("metnet_datapipe") -logger.setLevel(logging.DEBUG) - - -def normalize_pv(x): # So it can be pickled - """ - Normalize the PV data - - Args: - x: Input DataArray - - Returns: - Normalized DataArray - """ - return x / x.observed_capacity_wp - - -def _select_non_nan_times(x): - return x.fillna(0.0) - - -def _load_xarray_values(x): - return x.load() - - -def metnet_site_datapipe( - configuration_filename: Union[Path, str], - use_sun: bool = True, - use_nwp: bool = True, - use_sat: bool = True, - use_hrv: bool = True, - use_pv: bool = True, - use_topo: bool = True, - output_size: int = 256, - pv_in_image: bool = False, - start_time: datetime.datetime = datetime.datetime(2014, 1, 1), - end_time: datetime.datetime = datetime.datetime(2023, 1, 1), - center_size_meters: int = 64_000, - context_size_meters: int = 512_000, - batch_size: int = 1, -) -> IterDataPipe: - """ - Make PV data pipe - - Args: - configuration_filename: the configruation filename for the pipe - use_sun: Whether to add sun features or not - use_pv: Whether to use PV input or not - use_hrv: Whether to use HRV Satellite or not - use_sat: Whether to use non-HRV Satellite or not - use_nwp: Whether to use NWP or not - use_topo: Whether to use topographic map or not - start_time: Start time to select on - end_time: End time to select from - output_size: Size, in pixels, of the output image - pv_in_image: Add PV history as channels in MetNet image - center_size_meters: Center size for MeNet cutouts, in meters - context_size_meters: Context area size in meters - batch_size: Batch size for the datapipe - - Returns: datapipe - """ - - # load datasets - used_datapipes = open_and_return_datapipes( - configuration_filename=configuration_filename, - use_nwp=use_nwp, - use_topo=use_topo, - use_sat=use_sat, - use_hrv=use_hrv, - use_gsp=False, - use_pv=use_pv, - ) - # Load PV data - used_datapipes["pv"] = ( - used_datapipes["pv"].filter_times(start_time, end_time).pv_interpolate_infill() - ) - - # Now get overlapping time periods - used_datapipes = get_and_return_overlapping_time_periods_and_t0(used_datapipes, key_for_t0="pv") - - # And now get time slices - used_datapipes = add_selected_time_slices_from_datapipes(used_datapipes) - - # Now do the extra processing - pv_history = used_datapipes["pv"].normalize(normalize_fn=normalize_pv) - pv_datapipe = used_datapipes["pv_future"].normalize(normalize_fn=normalize_pv) - # Split into PV for target, and one for history - pv_datapipe, pv_loc_datapipe = pv_datapipe.fork(2) - pv_loc_datapipe, pv_id_datapipe = PickLocations(pv_loc_datapipe).fork(2) - pv_history = pv_history.select_id(pv_id_datapipe, data_source_name="pv") - - if "nwp" in used_datapipes.keys(): - # take nwp time slices - logger.debug("Take NWP time slices") - nwp_datapipe = used_datapipes["nwp"].normalize(mean=UKV_MEAN, std=UKV_STD) - pv_loc_datapipe, pv_nwp_image_loc_datapipe = pv_loc_datapipe.fork(2) - # context_size is the largest it would need - nwp_datapipe = nwp_datapipe.select_spatial_slice_meters( - pv_nwp_image_loc_datapipe, - roi_height_meters=context_size_meters, - roi_width_meters=context_size_meters, - dim_name=None, - ) - # Multithread the data - nwp_datapipe = ThreadPoolMapper( - nwp_datapipe, _load_xarray_values, max_workers=8, scheduled_tasks=batch_size - ) - - if "sat" in used_datapipes.keys(): - logger.debug("Take Satellite time slices") - # take sat time slices - sat_datapipe = used_datapipes["sat"].normalize(mean=RSS_MEAN, std=RSS_STD) - pv_loc_datapipe, pv_sat_image_loc_datapipe = pv_loc_datapipe.fork(2) - sat_datapipe = sat_datapipe.select_spatial_slice_meters( - pv_sat_image_loc_datapipe, - roi_height_meters=context_size_meters, - roi_width_meters=context_size_meters, - dim_name=None, - ) - sat_datapipe = ThreadPoolMapper( - sat_datapipe, _load_xarray_values, max_workers=8, scheduled_tasks=batch_size - ) - - if "hrv" in used_datapipes.keys(): - logger.debug("Take HRV Satellite time slices") - sat_hrv_datapipe = used_datapipes["hrv"].normalize(mean=RSS_MEAN, std=RSS_STD) - pv_loc_datapipe, pv_hrv_image_loc_datapipe = pv_loc_datapipe.fork(2) - sat_hrv_datapipe = sat_hrv_datapipe.select_spatial_slice_meters( - pv_hrv_image_loc_datapipe, - roi_height_meters=context_size_meters, - roi_width_meters=context_size_meters, - dim_name=None, - ) - sat_hrv_datapipe = ThreadPoolMapper( - sat_hrv_datapipe, _load_xarray_values, max_workers=8, scheduled_tasks=batch_size - ) - - if "topo" in used_datapipes.keys(): - topo_datapipe = used_datapipes["topo"].map(_select_non_nan_times) - - # Now combine in the MetNet format - modalities = [] - - if pv_in_image and "hrv" in used_datapipes.keys(): - sat_hrv_datapipe, sat_pv_datapipe = sat_hrv_datapipe.fork(2) - pv_history = pv_history.create_pv_history_image(image_datapipe=sat_pv_datapipe) - elif pv_in_image and "sat" in used_datapipes.keys(): - sat_datapipe, sat_pv_datapipe = sat_datapipe.fork(2) - pv_history = pv_history.create_pv_history_image(image_datapipe=sat_pv_datapipe) - elif pv_in_image and "nwp" in used_datapipes.keys(): - nwp_datapipe, nwp_pv_datapipe = nwp_datapipe.fork(2) - pv_history = pv_history.create_pv_history_image( - image_datapipe=nwp_pv_datapipe, image_dim="osgb" - ) - - if "nwp" in used_datapipes.keys(): - modalities.append(nwp_datapipe) - if "hrv" in used_datapipes.keys(): - modalities.append(sat_hrv_datapipe) - if "sat" in used_datapipes.keys(): - modalities.append(sat_datapipe) - if "topo" in used_datapipes.keys(): - modalities.append(topo_datapipe) - if pv_in_image: - modalities.append(pv_history) - - metnet_datapipe = PreProcessMetNet( - modalities, - location_datapipe=pv_loc_datapipe, - center_width=center_size_meters, - center_height=center_size_meters, # 64km - context_height=context_size_meters, - context_width=context_size_meters, # 512km - output_width_pixels=output_size, - output_height_pixels=output_size, - add_sun_features=use_sun, - ) - - pv_datapipe = ConvertPVToNumpy(pv_datapipe) - - if not pv_in_image: - pv_history = pv_history.map(_select_non_nan_times) - pv_history = ConvertPVToNumpy(pv_history, return_pv_id=True) - return metnet_datapipe.batch(batch_size).zip_ocf( - pv_history.batch(batch_size), pv_datapipe.batch(batch_size) - ) - else: - return metnet_datapipe.batch(batch_size).zip_ocf(pv_datapipe.batch(batch_size)) diff --git a/tests/end2end/test_metnet_training.py b/tests/end2end/test_metnet_training.py deleted file mode 100644 index 86752bd8b..000000000 --- a/tests/end2end/test_metnet_training.py +++ /dev/null @@ -1,179 +0,0 @@ -import numpy as np -import torch -import xarray -from torch.utils.data.datapipes._decorator import functional_datapipe -from torch.utils.data.datapipes.iter import IterableWrapper - - -xarray.set_options(keep_attrs=True) - -from datetime import timedelta - -from ocf_datapipes.select import ( - FilterGSPIDs, - PickLocations, - SelectSpatialSliceMeters, - SelectTimeSliceNWP, -) - -from ocf_datapipes.transform.xarray import ( - AddT0IdxAndSamplePeriodDuration, - CreatePVImage, - Downsample, - Normalize, - ReprojectTopography, -) -from ocf_datapipes.training.metnet.metnet_preprocessor import ( - PreProcessMetNetIterDataPipe as PreProcessMetNet, -) - -from ocf_datapipes.utils.consts import UKV_MEAN, UKV_STD, RSS_MEAN, RSS_STD - -import pytest - - -def last_time(ds, time_dim="time_utc"): - return ds[time_dim].values[-1] - - -# N.B First change which broke this test was changing the NWP data in the test directory to include -# more forecast steps -@pytest.mark.skip(reason="Not maintained for the moment") -def test_metnet_production( - sat_hrv_datapipe, sat_datapipe, passiv_datapipe, topo_datapipe, gsp_datapipe, nwp_datapipe -): - #################################### - # - # Equivalent to PP's loading and filtering methods - # - ##################################### - # Normalize GSP and PV on whole dataset here - pv_datapipe = passiv_datapipe - gsp_datapipe, gsp_loc_datapipe = FilterGSPIDs(gsp_datapipe, gsps_to_keep=[0]).fork(2) - gsp_datapipe = Normalize(gsp_datapipe, normalize_fn=lambda x: x / x.installedcapacity_mwp) - topo_datapipe = ReprojectTopography(topo_datapipe) - sat_hrv_datapipe = AddT0IdxAndSamplePeriodDuration( - sat_hrv_datapipe, - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta(minutes=60), - ) - sat_datapipe = AddT0IdxAndSamplePeriodDuration( - sat_datapipe, - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta(minutes=60), - ) - pv_datapipe = AddT0IdxAndSamplePeriodDuration( - pv_datapipe, - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta(minutes=60), - ) - gsp_datapipe, gsp_t0_datapipe = AddT0IdxAndSamplePeriodDuration( - gsp_datapipe, - sample_period_duration=timedelta(minutes=30), - history_duration=timedelta(hours=2), - ).fork(2) - nwp_datapipe = AddT0IdxAndSamplePeriodDuration( - nwp_datapipe, sample_period_duration=timedelta(hours=1), history_duration=timedelta(hours=2) - ) - - #################################### - # - # Equivalent to PP's xr_batch_processors and normal loading/selecting - # - ##################################### - - ( - location_datapipe1, - location_datapipe2, - location_datapipe3, - location_datapipe4, - location_datapipe5, - ) = PickLocations(gsp_loc_datapipe, return_all_locations=True).fork( - 5 - ) # Its in order then - pv_datapipe, pv_t0_datapipe = SelectSpatialSliceMeters( - pv_datapipe, - location_datapipe=location_datapipe1, - roi_width_meters=100_000, - roi_height_meters=100_000, - ).fork( - 2 - ) # Has to be large as test PV systems aren't in first 20 GSPs it seems - nwp_datapipe, nwp_t0_datapipe = Downsample(nwp_datapipe, y_coarsen=16, x_coarsen=16).fork(2) - nwp_t0_datapipe = nwp_t0_datapipe.map(lambda x: last_time(x, "init_time_utc")) - nwp_datapipe = SelectTimeSliceNWP( - nwp_datapipe, - t0_datapipe=nwp_t0_datapipe, - sample_period_duration=timedelta(hours=1), - history_duration=timedelta(hours=2), - forecast_duration=timedelta(hours=3), - ) - gsp_t0_datapipe = gsp_t0_datapipe.map(last_time) - gsp_datapipe = SelectLiveTimeSlice( - gsp_datapipe, - t0_datapipe=gsp_t0_datapipe, - history_duration=timedelta(hours=2), - ) - sat_t0_datapipe = sat_datapipe.map(last_time) - sat_datapipe, image_datapipe = SelectLiveTimeSlice( - sat_datapipe, - t0_datapipe=sat_t0_datapipe, - history_duration=timedelta(hours=1), - ).fork(2) - sat_hrv_t0_datapipe = sat_hrv_datapipe.map(last_time) - sat_hrv_datapipe = SelectLiveTimeSlice( - sat_hrv_datapipe, - t0_datapipe=sat_hrv_t0_datapipe, - history_duration=timedelta(hours=1), - ) - passiv_t0_datapipe = pv_t0_datapipe.map(last_time) - sat_hrv_t0_datapipe - pv_datapipe = SelectLiveTimeSlice( - pv_datapipe, - t0_datapipe=passiv_t0_datapipe, - history_duration=timedelta(hours=1), - ) - gsp_datapipe = SelectSpatialSliceMeters( - gsp_datapipe, - location_datapipe=location_datapipe4, - dim_name="gsp_id", - roi_width_meters=10, - roi_height_meters=10, - ) - - pv_datapipe = CreatePVImage(pv_datapipe, image_datapipe) - - sat_hrv_datapipe = Normalize( - sat_hrv_datapipe, mean=RSS_MEAN.sel(channel="HRV") / 4, std=RSS_STD.sel(channel="HRV") / 4 - ).map( - lambda x: x.resample(time_utc="5min").interpolate("linear") - ) # Interplate to 5 minutes incase its 15 minutes - sat_datapipe = Normalize(sat_datapipe, mean=RSS_MEAN, std=RSS_STD).map( - lambda x: x.resample(time_utc="5min").interpolate("linear") - ) # Interplate to 5 minutes incase its 15 minutes - nwp_datapipe = Normalize(nwp_datapipe, mean=UKV_MEAN, std=UKV_STD) - topo_datapipe = Normalize(topo_datapipe, calculate_mean_std_from_example=True) - - # Now combine in the MetNet format - combined_datapipe = PreProcessMetNet( - [ - nwp_datapipe, - sat_hrv_datapipe, - sat_datapipe, - pv_datapipe, - ], - location_datapipe=location_datapipe5, - center_width=500_000, - center_height=1_000_000, - context_height=10_000_000, - context_width=10_000_000, - output_width_pixels=512, - output_height_pixels=512, - add_sun_features=True, - ) - - batch = next(iter(combined_datapipe)) - assert ~np.isnan(batch).any() - print(batch.shape) - batch = next(iter(gsp_datapipe)) - print(batch.shape) diff --git a/tests/training/metnet/test_metnet_gsp_national.py b/tests/training/metnet/test_metnet_gsp_national.py deleted file mode 100644 index 62e2ef8eb..000000000 --- a/tests/training/metnet/test_metnet_gsp_national.py +++ /dev/null @@ -1,30 +0,0 @@ -import os - -import numpy as np -import pytest -from torch.utils.data import DataLoader - -import ocf_datapipes -from ocf_datapipes.training.metnet.metnet_gsp_national import metnet_national_datapipe - - -@pytest.mark.skip("Failing at the moment") -def test_metnet_gsp_national_datapipe(): - filename = os.path.join(os.path.dirname(ocf_datapipes.__file__), "../tests/config/test.yaml") - datapipe = metnet_national_datapipe(filename, use_pv=False) - dataloader = DataLoader(datapipe) - for i, batch in enumerate(dataloader): - _ = batch - if i + 1 % 50000 == 0: - break - - -@pytest.mark.skip("Failing at the moment") -def test_metnet_gsp_national_image_datapipe(): - filename = os.path.join(os.path.dirname(ocf_datapipes.__file__), "../tests/config/test.yaml") - datapipe = metnet_national_datapipe(filename, use_pv=False, gsp_in_image=True, output_size=128) - dataloader = iter(datapipe) - batch = next(dataloader) - x, y = batch - assert np.isfinite(x).all() - assert np.isfinite(y).all() diff --git a/tests/training/metnet/test_metnet_national.py b/tests/training/metnet/test_metnet_national.py deleted file mode 100644 index 29329b4c6..000000000 --- a/tests/training/metnet/test_metnet_national.py +++ /dev/null @@ -1,18 +0,0 @@ -import os - -import numpy as np -import pytest - -import ocf_datapipes -from ocf_datapipes.training.metnet.metnet_national import metnet_national_datapipe - - -@pytest.mark.skip("Failing at the moment") -def test_metnet_national_datapipe(): - filename = os.path.join(os.path.dirname(ocf_datapipes.__file__), "../tests/config/test.yaml") - - datapipe = metnet_national_datapipe(filename, max_num_pv_systems=1).set_length(2) - - batch = next(iter(datapipe)) - assert np.isfinite(batch[0]).all() - assert np.isfinite(batch[1]).all() diff --git a/tests/training/metnet/test_metnet_preprocessor.py b/tests/training/metnet/test_metnet_preprocessor.py deleted file mode 100644 index e8f9957c2..000000000 --- a/tests/training/metnet/test_metnet_preprocessor.py +++ /dev/null @@ -1,161 +0,0 @@ -from ocf_datapipes.select import FilterGSPIDs, PickLocations -from ocf_datapipes.transform.xarray import CreatePVImage -from ocf_datapipes.training.metnet.metnet_preprocessor import ( - PreProcessMetNetIterDataPipe as PreProcessMetNet, -) - - -def test_metnet_preprocess_no_sun(sat_datapipe, gsp_datapipe): - gsp_datapipe = FilterGSPIDs(gsp_datapipe, gsps_to_keep=[0]) - gsp_datapipe = PickLocations(gsp_datapipe) - datapipe = PreProcessMetNet( - [sat_datapipe], - location_datapipe=gsp_datapipe, - center_width=100_000, - center_height=100_000, - context_height=1_000_000, - context_width=1_000_000, - output_width_pixels=100, - output_height_pixels=100, - add_sun_features=False, - ) - data = next(iter(datapipe)) - print(data.shape) - - -def test_metnet_preprocess(sat_datapipe, gsp_datapipe): - gsp_datapipe = FilterGSPIDs(gsp_datapipe, gsps_to_keep=[0]) - gsp_datapipe = PickLocations(gsp_datapipe) - datapipe = PreProcessMetNet( - [sat_datapipe], - location_datapipe=gsp_datapipe, - center_width=100_000, - center_height=100_000, - context_height=1_000_000, - context_width=1_000_000, - output_width_pixels=100, - output_height_pixels=100, - add_sun_features=True, - ) - data = next(iter(datapipe)) - print(data.shape) - - -def test_metnet_preprocess_both_sat(sat_datapipe, sat_hrv_datapipe, gsp_datapipe): - gsp_datapipe = FilterGSPIDs(gsp_datapipe, gsps_to_keep=[0]) - gsp_datapipe = PickLocations(gsp_datapipe) - datapipe = PreProcessMetNet( - [sat_datapipe, sat_hrv_datapipe], - location_datapipe=gsp_datapipe, - center_width=100_000, - center_height=100_000, - context_height=1_000_000, - context_width=1_000_000, - output_width_pixels=100, - output_height_pixels=100, - add_sun_features=False, - ) - data = next(iter(datapipe)) - print(data.shape) - - -def test_metnet_preprocess_both_sat_other_order(sat_datapipe, sat_hrv_datapipe, gsp_datapipe): - gsp_datapipe = FilterGSPIDs(gsp_datapipe, gsps_to_keep=[0]) - gsp_datapipe = PickLocations(gsp_datapipe) - datapipe = PreProcessMetNet( - [sat_hrv_datapipe, sat_datapipe], - location_datapipe=gsp_datapipe, - center_width=100_000, - center_height=100_000, - context_height=1_000_000, - context_width=1_000_000, - output_width_pixels=100, - output_height_pixels=100, - add_sun_features=True, - ) - data = next(iter(datapipe)) - print(data.shape) - - -def test_metnet_preprocess_both_sat_pv( - sat_datapipe, sat_hrv_datapipe, gsp_datapipe, passiv_datapipe -): - gsp_datapipe = FilterGSPIDs(gsp_datapipe, gsps_to_keep=[0]) - gsp_datapipe = PickLocations(gsp_datapipe) - sat_datapipe, image_datapipe = sat_datapipe.fork(2) - passiv_datapipe = CreatePVImage(passiv_datapipe, image_datapipe=image_datapipe, normalize=True) - datapipe = PreProcessMetNet( - [sat_hrv_datapipe, sat_datapipe, passiv_datapipe], - location_datapipe=gsp_datapipe, - center_width=100_000, - center_height=100_000, - context_height=1_000_000, - context_width=1_000_000, - output_width_pixels=100, - output_height_pixels=100, - add_sun_features=True, - ) - data = next(iter(datapipe)) - assert data.shape == (289, 14, 100, 100) - - -def test_metnet_preprocess_sat_hrv_pv_nwp( - sat_datapipe, sat_hrv_datapipe, gsp_datapipe, passiv_datapipe, nwp_datapipe -): - gsp_datapipe = FilterGSPIDs(gsp_datapipe, gsps_to_keep=[0]) - gsp_datapipe = PickLocations(gsp_datapipe) - sat_datapipe, image_datapipe = sat_datapipe.fork(2) - passiv_datapipe = CreatePVImage(passiv_datapipe, image_datapipe=image_datapipe, normalize=True) - datapipe = PreProcessMetNet( - [sat_hrv_datapipe, sat_datapipe, passiv_datapipe], - location_datapipe=gsp_datapipe, - center_width=100_000, - center_height=100_000, - context_height=1_000_000, - context_width=1_000_000, - output_width_pixels=100, - output_height_pixels=100, - add_sun_features=True, - ) - data = next(iter(datapipe)) - assert data.shape == (289, 14, 100, 100) - - -def test_metnet_preprocess_sat_topo(sat_datapipe, gsp_datapipe, topo_datapipe): - gsp_datapipe = FilterGSPIDs(gsp_datapipe, gsps_to_keep=[0]) - gsp_datapipe = PickLocations(gsp_datapipe) - datapipe = PreProcessMetNet( - [sat_datapipe, topo_datapipe], - location_datapipe=gsp_datapipe, - center_width=100_000, - center_height=100_000, - context_height=1_000_000, - context_width=1_000_000, - output_width_pixels=100, - output_height_pixels=100, - add_sun_features=True, - ) - data = next(iter(datapipe)) - assert data.shape == (25, 12, 100, 100) - - -def test_metnet_preprocess_sat_hrv_pv_nwp_topo( - sat_datapipe, sat_hrv_datapipe, gsp_datapipe, passiv_datapipe, nwp_datapipe, topo_datapipe -): - gsp_datapipe = FilterGSPIDs(gsp_datapipe, gsps_to_keep=[0]) - gsp_datapipe = PickLocations(gsp_datapipe) - sat_datapipe, image_datapipe = sat_datapipe.fork(2) - passiv_datapipe = CreatePVImage(passiv_datapipe, image_datapipe=image_datapipe, normalize=True) - datapipe = PreProcessMetNet( - [sat_hrv_datapipe, sat_datapipe, passiv_datapipe, topo_datapipe], - location_datapipe=gsp_datapipe, - center_width=100_000, - center_height=100_000, - context_height=1_000_000, - context_width=1_000_000, - output_width_pixels=100, - output_height_pixels=100, - add_sun_features=True, - ) - data = next(iter(datapipe)) - assert data.shape == (289, 16, 100, 100) diff --git a/tests/training/metnet/test_metnet_pv_national.py b/tests/training/metnet/test_metnet_pv_national.py deleted file mode 100644 index c4392ce16..000000000 --- a/tests/training/metnet/test_metnet_pv_national.py +++ /dev/null @@ -1,19 +0,0 @@ -import os - -import numpy as np -import pytest - -import ocf_datapipes -from ocf_datapipes.training.metnet.metnet_pv_national import metnet_national_datapipe - - -@pytest.mark.skip("Failing at the moment") -def test_metnet_pv_national_datapipe(): - filename = os.path.join(os.path.dirname(ocf_datapipes.__file__), "../tests/config/test.yaml") - datapipe = metnet_national_datapipe(filename, use_nwp=False, max_num_pv_systems=1) - - batch = next(iter(datapipe)) - assert np.isfinite(batch[0]).all() - assert np.isfinite(batch[1]).all() - assert np.isfinite(batch[2]).all() - assert np.isfinite(batch[3]).all() diff --git a/tests/training/metnet/test_metnet_pv_site.py b/tests/training/metnet/test_metnet_pv_site.py deleted file mode 100644 index 1d2f3ef6b..000000000 --- a/tests/training/metnet/test_metnet_pv_site.py +++ /dev/null @@ -1,17 +0,0 @@ -import os - -import numpy as np -import pytest - -import ocf_datapipes -from ocf_datapipes.training.metnet.metnet_pv_site import metnet_site_datapipe - - -@pytest.mark.skip("Failing at the moment") -def test_metnet_site_datapipe(): - filename = os.path.join(os.path.dirname(ocf_datapipes.__file__), "../tests/config/test.yaml") - datapipe = metnet_site_datapipe(filename, use_nwp=False, pv_in_image=True) - - batch = next(iter(datapipe)) - assert np.isfinite(batch[0]).all() - assert np.isfinite(batch[1]).all()