From 334c461322a4feb157131f7435fbe63c5766737b Mon Sep 17 00:00:00 2001 From: pciturri Date: Sun, 4 Aug 2024 23:42:40 +0200 Subject: [PATCH] ft: Create ForecastRepository class to handle different type of access/storage. Models have a flag to re-stage. refac: Renamed ModelTree as ForecastRegistry and PathTree as ExperimentRegistry. Moved .dir() method to ForecastRegistry. Decoupled the get_source from the directory creation, which is now inside the model class. Removed dataclasses from Registries. Removed __call__ override for registry classes and replaced them with get_path(). Created an abc BaseFileRegistry, from which ExperimentRegistry and ForecastRegistry inherits. Removed query_gcmt, as it was included in pycsep. Renamed iter_attr() from Experiment.as_dict() to parse_nested_dict, and moved to utils.py. sty: used black and pydocstringformatter in main modules. build: Added pydocstringformatter options to .toml. tests: updated all unit tests with the new classes. fix: conda manager prints output live to shell. --- docs/reference/api_reference.rst | 23 +- floatcsep/accessors.py | 215 +-------- floatcsep/cmd/main.py | 2 +- floatcsep/environments.py | 145 +++--- floatcsep/evaluation.py | 57 +-- floatcsep/experiment.py | 159 ++----- floatcsep/extras.py | 48 +- floatcsep/model.py | 338 ++++--------- floatcsep/readers.py | 13 +- floatcsep/registry.py | 254 +++++----- floatcsep/report.py | 2 +- floatcsep/repository.py | 162 +++++++ floatcsep/utils.py | 217 +++++---- pyproject.toml | 11 +- .../artifacts/models/td_model/input/args.txt | 5 +- tests/integration/test_model_interface.py | 308 ++++++++++-- tests/qa/test_data.py | 28 +- tests/unit/test_accessors.py | 63 +-- tests/unit/test_environments.py | 53 ++- tests/unit/test_experiment.py | 15 +- tests/unit/test_model.py | 445 +++++++----------- tests/unit/test_registry.py | 96 ++++ tests/unit/test_repositories.py | 155 ++++++ 23 files changed, 1506 insertions(+), 1308 deletions(-) create mode 100644 floatcsep/repository.py create mode 100644 tests/unit/test_registry.py create mode 100644 tests/unit/test_repositories.py diff --git a/docs/reference/api_reference.rst b/docs/reference/api_reference.rst index 0d00932..6525bd6 100644 --- a/docs/reference/api_reference.rst +++ b/docs/reference/api_reference.rst @@ -111,11 +111,32 @@ Accessors .. autosummary:: :toctree: generated - query_gcmt from_zenodo from_git +Environments +------------ + +.. :currentmodule:: floatcsep.environments + +.. automodule:: floatcsep.environments + +.. autosummary:: + :toctree: generated + + CondaManager + CondaManager.create_environment + CondaManager.env_exists + CondaManager.install_dependencies + CondaManager.run_command + + VenvManager + CondaManager.create_environment + CondaManager.env_exists + CondaManager.install_dependencies + CondaManager.run_command + Extras ------ diff --git a/floatcsep/accessors.py b/floatcsep/accessors.py index c62948a..312d2f5 100644 --- a/floatcsep/accessors.py +++ b/floatcsep/accessors.py @@ -14,40 +14,11 @@ TIMEOUT = 180 -def query_gcmt( - start_time, - end_time, - min_magnitude=5.0, - max_depth=None, - catalog_id=None, - min_latitude=None, - max_latitude=None, - min_longitude=None, - max_longitude=None, -): - - eventlist = _query_gcmt( - start_time=start_time, - end_time=end_time, - min_magnitude=min_magnitude, - min_latitude=min_latitude, - max_latitude=max_latitude, - min_longitude=min_longitude, - max_longitude=max_longitude, - max_depth=max_depth, - ) - - catalog = CSEPCatalog( - data=eventlist, name="gCMT", catalog_id=catalog_id, date_accessed=utc_now_datetime() - ) - return catalog - - def from_zenodo(record_id, folder, force=False): """ Download data from a Zenodo repository. - Downloads if file does not exist, checksum has changed in local respect to - url or force + + Downloads if file does not exist, checksum has changed in local respect to url or force Args: record_id: corresponding to the Zenodo repository @@ -55,7 +26,6 @@ def from_zenodo(record_id, folder, force=False): force: force download even if file exists and checksum passes Returns: - """ # Grab the urls and filenames and checksums r = requests.get(f"https://zenodo.org/api/records/{record_id}", timeout=3) @@ -87,8 +57,7 @@ def from_zenodo(record_id, folder, force=False): def from_git(url, path, branch=None, depth=1, **kwargs): """ - - Clones a shallow repository from a git url + Clones a shallow repository from a git url. Args: url (str): url of the repository @@ -115,185 +84,13 @@ def from_git(url, path, branch=None, depth=1, **kwargs): return repo -def _query_gcmt( - start_time, - end_time, - min_magnitude=3.50, - min_latitude=None, - max_latitude=None, - min_longitude=None, - max_longitude=None, - max_depth=1000, - extra_gcmt_params=None, -): - """ - Return GCMT eventlist from IRIS web service. - For details see "https://service.iris.edu/fdsnws/event/1/" - Args: - start_time (datetime.datetime): start time of catalog query - end_time (datetime.datetime): end time of catalog query - min_magnitude (float): minimum magnitude of query - min_latitude (float): minimum latitude of query - max_latitude (float): maximum latitude of query - min_longitude (float): minimum longitude of query - max_longitude (float): maximum longitude of query - max_depth (float): maximum depth of query - extra_gcmt_params (dict): additional parameters to pass to IRIS search - function - - Returns: - eventlist - """ - extra_gcmt_params = extra_gcmt_params or {} - - eventlist = gcmt_search( - minmagnitude=min_magnitude, - minlatitude=min_latitude, - maxlatitude=max_latitude, - minlongitude=min_longitude, - maxlongitude=max_longitude, - starttime=start_time.isoformat(), - endtime=end_time.isoformat(), - maxdepth=max_depth, - **extra_gcmt_params, - ) - - return eventlist - - -def gcmt_search( - format="text", - starttime=None, - endtime=None, - updatedafter=None, - minlatitude=None, - maxlatitude=None, - minlongitude=None, - maxlongitude=None, - latitude=None, - longitude=None, - maxradius=None, - catalog="GCMT", - contributor=None, - maxdepth=1000, - maxmagnitude=10.0, - mindepth=-100, - minmagnitude=0, - offset=1, - orderby="time-asc", - host=None, - verbose=False, -): - """Search the IRIS database for events matching input criteria. - This search function is a wrapper around the ComCat Web API described here: - https://service.iris.edu/fdsnws/event/1/ - - This function returns a list of SummaryEvent objects, described elsewhere in this package. - Args: - starttime (datetime): - Python datetime - Limit to events on or after the specified start time. - endtime (datetime): - Python datetime - Limit to events on or before the specified end time. - updatedafter (datetime): - Python datetime - Limit to events updated after the specified time. - minlatitude (float): - Limit to events with a latitude larger than the specified minimum. - maxlatitude (float): - Limit to events with a latitude smaller than the specified maximum. - minlongitude (float): - Limit to events with a longitude larger than the specified minimum. - maxlongitude (float): - Limit to events with a longitude smaller than the specified maximum. - latitude (float): - Specify the latitude to be used for a radius search. - longitude (float): - Specify the longitude to be used for a radius search. - maxradius (float): - Limit to events within the specified maximum number of degrees - from the geographic point defined by the latitude and longitude parameters. - catalog (str): - Limit to events from a specified catalog. - contributor (str): - Limit to events contributed by a specified contributor. - maxdepth (float): - Limit to events with depth less than the specified maximum. - maxmagnitude (float): - Limit to events with a magnitude smaller than the specified maximum. - mindepth (float): - Limit to events with depth more than the specified minimum. - minmagnitude (float): - Limit to events with a magnitude larger than the specified minimum. - offset (int): - Return results starting at the event count specified, starting at 1. - orderby (str): - Order the results. The allowed values are: - - time order by origin descending time - - time-asc order by origin ascending time - - magnitude order by descending magnitude - - magnitude-asc order by ascending magnitude - host (str): - Replace default ComCat host (earthquake.usgs.gov) with a custom host. - Returns: - list: List of SummaryEvent() objects. - """ - - # getting the inputargs must be the first line of the method! - inputargs = locals().copy() - newargs = {} - - for key, value in inputargs.items(): - if value is True: - newargs[key] = "true" - continue - if value is False: - newargs[key] = "false" - continue - if value is None: - continue - newargs[key] = value - - del newargs["verbose"] - - events = _search_gcmt(**newargs) - - return events - - -def _search_gcmt(**_newargs): - """ - Performs de-query at ISC API and returns event list and access date - - """ - paramstr = urlencode(_newargs) - url = HOST_CATALOG + paramstr - fh = request.urlopen(url, timeout=TIMEOUT) - data = fh.read().decode("utf8").split("\n") - fh.close() - eventlist = [] - for line in data[1:]: - line_ = line.split("|") - if len(line_) != 1: - id_ = line_[0] - time_ = datetime.fromisoformat(line_[1]) - dt = datetime_to_utc_epoch(time_) - lat = float(line_[2]) - lon = float(line_[3]) - depth = float(line_[4]) - mag = float(line_[10]) - eventlist.append((id_, dt, lat, lon, depth, mag)) - - return eventlist - - def _download_file(url: str, filename: str) -> None: """ - - Downloads files (from zenodo) + Downloads files (from zenodo). Args: url (str): the url where the file is located filename (str): the filename required. - """ progress_bar_length = 72 block_size = 1024 @@ -331,9 +128,7 @@ def _download_file(url: str, filename: str) -> None: def _check_hash(filename, checksum): - """ - Checks if existing file hash matches checksum from url - """ + """Checks if existing file hash matches checksum from url.""" algorithm, value = checksum.split(":") if not os.path.exists(filename): return value, "invalid" diff --git a/floatcsep/cmd/main.py b/floatcsep/cmd/main.py index 5429698..35bea66 100644 --- a/floatcsep/cmd/main.py +++ b/floatcsep/cmd/main.py @@ -52,7 +52,7 @@ def reproduce(config, **kwargs): log.info(f"floatCSEP v{__version__} | Reproduce") - reproduced_exp = Experiment.from_yml(config, reprdir="reproduced", **kwargs) + reproduced_exp = Experiment.from_yml(config, repr_dir="reproduced", **kwargs) reproduced_exp.stage_models() reproduced_exp.set_tasks() reproduced_exp.run() diff --git a/floatcsep/environments.py b/floatcsep/environments.py index 9d1c627..14cf095 100644 --- a/floatcsep/environments.py +++ b/floatcsep/environments.py @@ -15,9 +15,9 @@ class EnvironmentManager(ABC): """ - Abstract base class for managing different types of environments. - This class defines the interface for creating, checking existence, - running commands, and installing dependencies in various environment types. + Abstract base class for managing different types of environments. This class defines the + interface for creating, checking existence, running commands, and installing dependencies in + various environment types. """ @abstractmethod @@ -35,8 +35,8 @@ def __init__(self, base_name: str, model_directory: str): @abstractmethod def create_environment(self, force=False): """ - Creates the environment. If 'force' is True, it will remove any existing - environment with the same name before creating a new one. + Creates the environment. If 'force' is True, it will remove any existing environment + with the same name before creating a new one. Args: force (bool): Whether to forcefully remove an existing environment. @@ -66,15 +66,15 @@ def run_command(self, command): @abstractmethod def install_dependencies(self): """ - Installs the necessary dependencies for the environment based on the - specified configuration or requirements. + Installs the necessary dependencies for the environment based on the specified + configuration or requirements. """ pass def generate_env_name(self) -> str: """ - Generates a unique environment name by hashing the model directory - and appending it to the base name. + Generates a unique environment name by hashing the model directory and appending it + to the base name. Returns: str: A unique name for the environment. @@ -83,23 +83,23 @@ def generate_env_name(self) -> str: return f"{self.base_name}_{dir_hash}" -class CondaEnvironmentManager(EnvironmentManager): +class CondaManager(EnvironmentManager): """ - Manages a conda (or mamba) environment, providing methods to create, check, - and manipulate conda environments specifically. + Manages a conda (or mamba) environment, providing methods to create, check and manipulate + conda environments specifically. """ def __init__(self, base_name: str, model_directory: str): """ - Initializes the Conda environment manager with the specified base name - and model directory. It also generates the environment name and detects - the package manager (conda or mamba) to install dependencies.. + Initializes the Conda environment manager with the specified base name and model + directory. It also generates the environment name and detects the package manager (conda + or mamba) to install dependencies. Args: base_name (str): The base name, i.e., model name, for the conda environment. model_directory (str): The directory containing the model files. """ - self.base_name = base_name.replace(' ', '_') + self.base_name = base_name.replace(" ", "_") self.model_directory = model_directory self.env_name = self.generate_env_name() self.package_manager = self.detect_package_manager() @@ -120,10 +120,9 @@ def detect_package_manager(): def create_environment(self, force=False): """ - Creates a conda environment using either an environment.yml file or - the specified Python version in setup.py/setup.cfg or project/toml. - If 'force' is True, any existing environment with the same name will - be removed first. + Creates a conda environment using either an environment.yml file or the specified + Python version in setup.py/setup.cfg or project/toml. If 'force' is True, any existing + environment with the same name will be removed first. Args: force (bool): Whether to forcefully remove an existing environment. @@ -158,9 +157,7 @@ def create_environment(self, force=False): ) else: python_version = self.detect_python_version() - log.info( - f"Creating sub-conda environment {self.env_name} with Python {python_version}" - ) + log.info(f"Creating sub-conda env {self.env_name} with Python {python_version}") subprocess.run( [ self.package_manager, @@ -177,8 +174,8 @@ def create_environment(self, force=False): def env_exists(self) -> bool: """ - Checks if the conda environment exists by querying the list of - existing conda environments. + Checks if the conda environment exists by querying the list of existing conda + environments. Returns: bool: True if the conda environment exists, False otherwise. @@ -188,9 +185,9 @@ def env_exists(self) -> bool: def detect_python_version(self) -> str: """ - Determines the required Python version from setup files in the model directory. - It checks 'setup.py', 'pyproject.toml', and 'setup.cfg' (in that order), for - version specifications. + Determines the required Python version from setup files in the model directory. It + checks 'setup.py', 'pyproject.toml', and 'setup.cfg' (in that order), for version + specifications. Returns: str: The detected or default Python version. @@ -251,8 +248,8 @@ def is_version_compatible(requirement, current_version): def install_dependencies(self): """ - Installs dependencies in the conda environment using pip, based on the - model setup file + Installs dependencies in the conda environment using pip, based on the model setup + file. """ log.info(f"Installing dependencies in conda environment: {self.env_name}") cmd = [ @@ -269,15 +266,17 @@ def install_dependencies(self): def run_command(self, command): """ - Runs a specified command within the conda environment + Runs a specified command within the conda environment. + Args: command (str): The command to be executed in the conda environment. """ cmd = [ "bash", "-c", - f"{self.package_manager} run -n {self.env_name} {command}", + f"{self.package_manager} run --live-stream -n {self.env_name} {command}", ] + process = subprocess.Popen( cmd, stdout=subprocess.PIPE, @@ -285,35 +284,36 @@ def run_command(self, command): universal_newlines=True, ) for line in process.stdout: - log.info(f"[{self.base_name}]: {line[:-1]}") + stripped_line = line.strip() + log.info(f"[{self.base_name}]: " + stripped_line) process.wait() -class VenvEnvironmentManager(EnvironmentManager): +class VenvManager(EnvironmentManager): """ - Manages a virtual environment created using Python's venv module. - Provides methods to create, check, and manipulate virtual environments. + Manages a virtual environment created using Python's venv module. Provides methods to + create, check, and manipulate virtual environments. """ def __init__(self, base_name: str, model_directory: str): """ - Initializes the virtual environment manager with the specified base name - and model directory. + Initializes the virtual environment manager with the specified base name and model + directory. Args: base_name (str): The base name (i.e., model name) for the virtual environment. model_directory (str): The directory containing the model files. """ - self.base_name = base_name.replace(' ', '_') + self.base_name = base_name.replace(" ", "_") self.model_directory = model_directory self.env_name = self.generate_env_name() self.env_path = os.path.join(model_directory, self.env_name) def create_environment(self, force=False): """ - Creates a virtual environment in the specified model directory. If 'force' - is True, any existing virtual environment will be removed before creation. + Creates a virtual environment in the specified model directory. If 'force' is True, + any existing virtual environment will be removed before creation. Args: force (bool): Whether to forcefully remove an existing virtual environment. @@ -339,8 +339,8 @@ def env_exists(self) -> bool: def install_dependencies(self): """ - Installs dependencies in the virtual environment using pip, based on the - model directory's configuration. + Installs dependencies in the virtual environment using pip, based on the model + directory's configuration. """ log.info(f"Installing dependencies in virtual environment: {self.env_name}") pip_executable = os.path.join(self.env_path, "bin", "pip") @@ -375,14 +375,14 @@ def run_command(self, command): ) for line in process.stdout: stripped_line = line.strip() - log.info(f'[{self.base_name}]: ' + stripped_line) + log.info(f"[{self.base_name}]: " + stripped_line) process.wait() -class DockerEnvironmentManager(EnvironmentManager): +class DockerManager(EnvironmentManager): """ - Manages a Docker environment, providing methods to create, check, - and manipulate Docker containers for the environment. + Manages a Docker environment, providing methods to create, check and manipulate Docker + containers for the environment. """ def __init__(self, base_name: str, model_directory: str): @@ -403,18 +403,17 @@ def install_dependencies(self): class EnvironmentFactory: - """ - Factory class for creating instances of environment managers based on the specified type. - """ + """Factory class for creating instances of environment managers based on the specified + type.""" @staticmethod def get_env( build: str = None, model_name: str = "model", model_path: str = None ) -> EnvironmentManager: """ - Returns an instance of an environment manager based on the specified build type. - It checks the current environment type and can return a conda, venv, or Docker - environment manager. + Returns an instance of an environment manager based on the specified build type. It + checks the current environment type and can return a conda, venv, or Docker environment + manager. Args: build (str): The desired type of environment ('conda', 'venv', or 'docker'). @@ -433,18 +432,20 @@ def get_env( f"Selected build environment ({build}) for this model is different than that of" f" the experiment run. Consider selecting the same environment." ) - if build == "conda" or (not build and run_env == "conda"): - return CondaEnvironmentManager( + if build in ["conda", "micromamba"] or ( + not build and run_env in ["conda", "micromamba"] + ): + return CondaManager( base_name=f"{model_name}", model_directory=os.path.abspath(model_path), ) elif build == "venv" or (not build and run_env == "venv"): - return VenvEnvironmentManager( + return VenvManager( base_name=f"{model_name}", model_directory=os.path.abspath(model_path), ) elif build == "docker": - return DockerEnvironmentManager( + return DockerManager( base_name=f"{model_name}", model_directory=os.path.abspath(model_path), ) @@ -457,16 +458,40 @@ def get_env( @staticmethod def check_environment_type(): if "VIRTUAL_ENV" in os.environ: + log.info("Detected virtual environment.") return "venv" try: - subprocess.run( + result = subprocess.run( ["conda", "info"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) - return "conda" + if result.returncode == 0: + log.info("Detected conda environment.") + return "conda" + else: + log.warning( + "Conda command failed with return code: {}".format(result.returncode) + ) + except FileNotFoundError: + log.warning("Conda not found in PATH.") + + try: + result = subprocess.run( + ["micromamba", "info"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + if result.returncode == 0: + log.info("Detected micromamba environment.") + return "micromamba" + else: + log.warning( + "Micromamba command failed with return code: {}".format(result.returncode) + ) except FileNotFoundError: - pass + log.warning("Micromamba not found in PATH.") + return None diff --git a/floatcsep/evaluation.py b/floatcsep/evaluation.py index 20562f3..9e1d4da 100644 --- a/floatcsep/evaluation.py +++ b/floatcsep/evaluation.py @@ -1,3 +1,4 @@ +import datetime import json import os from typing import Dict, Callable, Union, Sequence, List @@ -9,15 +10,14 @@ from matplotlib import pyplot from floatcsep.model import Model -from floatcsep.registry import PathTree +from floatcsep.registry import ExperimentRegistry from floatcsep.utils import parse_csep_func, timewindow2str class Evaluation: """ - - Class representing a Scoring Test, which wraps the evaluation function, - its arguments, parameters and hyper-parameters. + Class representing a Scoring Test, which wraps the evaluation function, its arguments, + parameters and hyperparameters. Args: name (str): Name of the Test @@ -27,7 +27,6 @@ class Evaluation: plot_func (str, ~typing.Callable): Test's plotting function plot_args (list,dict): Positional arguments of the plotting function plot_kwargs (list,dict): Keyword arguments of the plotting function - """ _TYPES = { @@ -80,8 +79,7 @@ def __init__( @property def type(self): """ - Returns the type of the test, mapped from the class attribute - Evaluation._TYPES + Returns the type of the test, mapped from the class attribute Evaluation._TYPES. """ return self._type @@ -131,7 +129,6 @@ def prepare_args( region=None, ) -> tuple: """ - Prepares the positional argument for the Evaluation function. Args: @@ -146,7 +143,6 @@ def prepare_args( Returns: A tuple of the positional arguments required by the evaluation function :meth:`Evaluation.func`. - """ # Subtasks # ======== @@ -179,9 +175,8 @@ def get_catalog( forecast: Union[GriddedForecast, Sequence[GriddedForecast]], ) -> Union[CSEPCatalog, List[CSEPCatalog]]: """ - - Reads the catalog(s) from the given path(s). References the catalog - region to the forecast region. + Reads the catalog(s) from the given path(s). References the catalog region to the + forecast region. Args: catalog_path (str, list(str)): Path to the existing catalog @@ -189,7 +184,6 @@ def get_catalog( object, onto which the catalog will be confronted for testing. Returns: - """ if isinstance(catalog_path, str): eval_cat = CSEPCatalog.load_json(catalog_path) @@ -213,13 +207,12 @@ def compute( region=None, ) -> None: """ - Runs the test, structuring the arguments according to the - test-typology/function-signature + test-typology/function-signature Args: - timewindow (list[~datetime.datetime, ~datetime.datetime]): Pair of - datetime objects representing the testing time span + timewindow (list[~datetime.datetime, ~datetime.datetime]): A pair of datetime + objects representing the testing time span catalog (str): Path to the filtered catalog model (Model, list[Model]): Model(s) to be evaluated ref_model: Model to be used as reference @@ -227,7 +220,6 @@ def compute( region: region to filter a catalog forecast. Returns: - """ test_args = self.prepare_args( timewindow, catpath=catalog, model=model, ref_model=ref_model, region=region @@ -238,9 +230,7 @@ def compute( @staticmethod def write_result(result: EvaluationResult, path: str) -> None: - """ - Dumps a test result into a json file. - """ + """Dumps a test result into a json file.""" class NumpyEncoder(json.JSONEncoder): def default(self, obj): @@ -255,10 +245,15 @@ def default(self, obj): with open(path, "w") as _file: json.dump(result.to_dict(), _file, indent=4, cls=NumpyEncoder) - def read_results(self, window: str, models: List[Model], tree: PathTree) -> List: + def read_results( + self, + window: Union[str, Sequence[datetime.datetime]], + models: List[Model], + tree: ExperimentRegistry, + ) -> List: """ - Reads an Evaluation result for a given time window and returns a list - of the results for all tested models. + Reads an Evaluation result for a given time window and returns a list of the results for + all tested models. """ test_results = [] @@ -279,13 +274,12 @@ def plot_results( self, timewindow: Union[str, List], models: List[Model], - tree: PathTree, + tree: ExperimentRegistry, dpi: int = 300, show: bool = False, ) -> None: """ - - Plots all evaluation results + Plots all evaluation results. Args: timewindow: string representing the desired timewindow to plot @@ -293,7 +287,6 @@ def plot_results( tree: a :class:`floatcsep:models.PathTree` containing path of the results dpi: Figure resolution with which to save show: show in runtime - """ if isinstance(timewindow, str): timewindow = [timewindow] @@ -343,8 +336,8 @@ def plot_results( def as_dict(self) -> dict: """ - Represents an Evaluation instance as a dictionary, which can be - serialized and then parsed + Represents an Evaluation instance as a dictionary, which can be serialized and then + parsed """ out = {} included = ["model", "ref_model", "func_kwargs"] @@ -370,9 +363,7 @@ def __str__(self): @classmethod def from_dict(cls, record): - """ - Parses a dictionary and re-instantiate an Evaluation object - """ + """Parses a dictionary and re-instantiate an Evaluation object.""" if len(record) != 1: raise IndexError("A single test has not been passed") name = next(iter(record)) diff --git a/floatcsep/experiment.py b/floatcsep/experiment.py index 92212af..4f8dbc9 100644 --- a/floatcsep/experiment.py +++ b/floatcsep/experiment.py @@ -5,7 +5,7 @@ import shutil import warnings from os.path import join, abspath, relpath, dirname, isfile, split, exists -from typing import Union, List, Dict, Callable, Mapping, Sequence +from typing import Union, List, Dict, Callable, Sequence import csep import numpy @@ -18,8 +18,8 @@ from floatcsep import report from floatcsep.evaluation import Evaluation from floatcsep.logger import add_fhandler -from floatcsep.model import Model, ModelFactory, TimeDependentModel -from floatcsep.registry import PathTree +from floatcsep.model import Model, TimeDependentModel +from floatcsep.registry import ExperimentRegistry from floatcsep.utils import ( NoAliasLoader, parse_csep_func, @@ -30,6 +30,7 @@ timewindow2str, str2timewindow, magnitude_vs_time, + parse_nested_dicts, ) numpy.seterr(all="ignore") @@ -40,9 +41,8 @@ class Experiment: """ - - Main class that handles an Experiment's context. Contains all the - specifications, instructions and methods to carry out an experiment. + Main class that handles an Experiment's context. Contains all the specifications, + instructions and methods to carry out an experiment. Args: name (str): Experiment name @@ -96,7 +96,6 @@ class Experiment: be instantiated using these dicts as keywords. (e.g. ``Experiment( **time_config, **region_config)``, ``Experiment(start_date=start_date, intervals=1, region='csep-italy', ...)`` - """ """ @@ -145,7 +144,7 @@ def __init__( os.makedirs(os.path.join(workdir, rundir), exist_ok=True) self.name = name if name else "floatingExp" - self.path = PathTree(workdir, rundir) + self.path = ExperimentRegistry(workdir, rundir) self.config_file = kwargs.get("config_file", None) self.original_config = kwargs.get("original_config", None) self.original_rundir = kwargs.get("original_rundir", None) @@ -195,7 +194,7 @@ def __init__( def __getattr__(self, item: str) -> object: """ - Override built-in method to return attributes found within + Override built-in method to return attributes found within. :attr:`region_config` or :attr:`time_config` """ @@ -213,9 +212,7 @@ def __getattr__(self, item: str) -> object: ) from None def __dir__(self): - """ - Adds time and region configs keys to instance scope. - """ + """Adds time and region configs keys to instance scope.""" _dir = ( list(super().__dir__()) + list(self.time_config.keys()) + list(self.region_config) @@ -224,23 +221,20 @@ def __dir__(self): def set_models(self, model_config: Union[Dict, str, List], order: List = None) -> List: """ - - Parse the models' configuration file/dict. Instantiates all the models - as :class:`floatcsep.model.Model` and store them into - :attr:`Experiment.models`. + Parse the models' configuration file/dict. Instantiates all the models as + :class:`floatcsep.model.Model` and store them into :attr:`Experiment.models`. Args: model_config (dict, list, str): configuration file or dictionary containing the model initialization attributes, as defined in :meth:`~floatcsep.model.Model` order (list): desired order of models - """ models = [] if isinstance(model_config, str): modelcfg_path = self.path.abs(model_config) - _dir = self.path.absdir(model_config) + _dir = self.path.abs_dir(model_config) with open(modelcfg_path, "r") as file_: config_dict = yaml.load(file_, NoAliasLoader) elif isinstance(model_config, (dict, list)): @@ -262,7 +256,7 @@ def set_models(self, model_config: Union[Dict, str, List], order: List = None) - name_: {**element[name_], "model_path": path_, "workdir": self.path.workdir} } model_i[name_].pop("path") - models.append(ModelFactory.create_model(model_i)) + models.append(Model.factory(model_i)) else: model_flavours = list(element.values())[0]["flavours"].items() @@ -281,7 +275,7 @@ def set_models(self, model_config: Union[Dict, str, List], order: List = None) - } model_[name_flav].pop("path") model_[name_flav].pop("flavours") - models.append(ModelFactory.create_model(model_)) + models.append(Model.factory(model_)) # Checks if there is any repeated model. names_ = [i.name for i in models] @@ -299,15 +293,14 @@ def set_models(self, model_config: Union[Dict, str, List], order: List = None) - return models def get_model(self, name: str) -> Model: - """Returns a Model by its name string""" + """Returns a Model by its name string.""" for model in self.models: if model.name == name: return model def stage_models(self) -> None: """ - Stages all the experiment's models. See - :meth:`floatcsep.model.Model.stage` + Stages all the experiment's models. See :meth:`floatcsep.model.Model.stage` """ log.info("Staging models") for i in self.models: @@ -323,7 +316,6 @@ def set_tests(self, test_config: Union[str, Dict, List]) -> list: test_config (dict, list, str): configuration file or dictionary containing the evaluation initialization attributes, as defined in :meth:`~floatcsep.evaluation.Evaluation` - """ tests = [] @@ -343,8 +335,8 @@ def set_tests(self, test_config: Union[str, Dict, List]) -> list: @property def catalog(self) -> CSEPCatalog: """ - Returns a CSEP catalog loaded from the given query function or - a stored file if it exists. + Returns a CSEP catalog loaded from the given query function or a stored file if it + exists. """ cat_path = self.path.abs(self._catpath) @@ -406,14 +398,11 @@ def catalog(self, cat: Union[Callable, CSEPCatalog, str]) -> None: def get_test_cat(self, tstring: str = None) -> CSEPCatalog: """ - - Filters the complete experiment catalog to a test sub-catalog bounded - by the test time-window. Writes it to filepath defined in - :attr:`Experiment.filetree` + Filters the complete experiment catalog to a test sub-catalog bounded by the test + time-window. Writes it to filepath defined in :attr:`Experiment.registry` Args: tstring (str): Time window string - """ if tstring: @@ -437,16 +426,12 @@ def get_test_cat(self, tstring: str = None) -> CSEPCatalog: def set_test_cat(self, tstring: str) -> None: """ - - Filters the complete experiment catalog to a test sub-catalog bounded - by the test time-window. Writes it to filepath defined in - :attr:`Experiment.filetree` + Filters the complete experiment catalog to a test sub-catalog bounded by the test + time-window. Writes it to filepath defined in :attr:`Experiment.registry` Args: tstring (str): Time window string - """ - "" testcat_name = self.path(tstring, "catalog") if not exists(testcat_name): @@ -471,8 +456,8 @@ def set_test_cat(self, tstring: str) -> None: def set_input_cat(self, tstring: str, model: Model) -> None: """ + Filters the complete experiment catalog to a input sub-catalog filtered. - Filters the complete experiment catalog to a input sub-catalog filtered to the beginning of thetest time-window. Writes it to filepath defined in :attr:`Model.tree.catalog` @@ -480,11 +465,10 @@ def set_input_cat(self, tstring: str, model: Model) -> None: tstring (str): Time window string model (:class:`~floatcsep.model.Model`): Model to give the input catalog - """ start, end = str2timewindow(tstring) sub_cat = self.catalog.filter([f"origin_time < {start.timestamp() * 1000}"]) - sub_cat.write_ascii(filename=model.path("input_cat")) + sub_cat.write_ascii(filename=model.registry.get_path("input_cat")) def set_tasks(self): """ @@ -501,16 +485,14 @@ def set_tasks(self): * A sequential test requires the forecasts exist for all windows * A batch test requires all forecast exist for a given window. - Returns: - """ # Set the file path structure - self.path.build(self.timewindows, self.models, self.tests) + self.path.build_tree(self.timewindows, self.models, self.tests) log.info("Setting up experiment's tasks") - log.debug("Pre-run: results' paths\n" + yaml.dump(self.path.asdict())) + log.debug("Pre-run: results' paths\n" + yaml.dump(self.path.as_dict())) # Get the time windows strings tw_strings = timewindow2str(self.timewindows) @@ -658,14 +640,13 @@ def set_tasks(self): def run(self) -> None: """ - Run the task tree + Run the task tree. todo: - Cleanup forecast (perhaps add a clean task in self.prepare_tasks, after all test had been run for a given forecast) - Memory monitor? - Queuer? - """ log.info(f"Running {self.task_graph.ntasks} tasks") @@ -677,18 +658,14 @@ def run(self) -> None: def read_results(self, test: Evaluation, window: str) -> List: """ - Reads an Evaluation result for a given time window and returns a list - of the results for all tested models. + Reads an Evaluation result for a given time window and returns a list of the results + for all tested models. """ return test.read_results(window, self.models, self.path) def plot_results(self) -> None: - """ - - Plots all evaluation results - - """ + """Plots all evaluation results.""" log.info("Plotting evaluations") timewindows = timewindow2str(self.timewindows) @@ -697,13 +674,11 @@ def plot_results(self) -> None: def plot_catalog(self, dpi: int = 300, show: bool = False) -> None: """ - - Plots the evaluation catalogs + Plots the evaluation catalogs. Args: dpi: Figure resolution with which to save show: show in runtime - """ plot_args = { "basemap": "ESRI_terrain", @@ -715,7 +690,6 @@ def plot_catalog(self, dpi: int = 300, show: bool = False) -> None: "legend": True, } plot_args.update(self.postproc_config.get("plot_catalog", {})) - catalog = self.get_test_cat() if catalog.get_number_of_events() != 0: ax = catalog.plot(plot_args=plot_args, show=show) @@ -744,11 +718,7 @@ def plot_catalog(self, dpi: int = 300, show: bool = False) -> None: ) def plot_forecasts(self) -> None: - """ - - Plots and saves all the generated forecasts - - """ + """Plots and saves all the generated forecasts.""" plot_fc_config = self.postproc_config.get("plot_forecasts") if plot_fc_config: @@ -828,15 +798,11 @@ def plot_forecasts(self) -> None: pyplot.savefig(fig_path, dpi=300, facecolor=(0, 0, 0, 0)) def generate_report(self) -> None: - """ + """Creates a report summarizing the Experiment's results.""" - Creates a report summarizing the Experiment's results - - """ log.info(f"Saving report into {self.path.rundir}") - - self.path.build(self.timewindows, self.models, self.tests) - log.debug("Post-run: results' paths\n" + yaml.dump(self.path.asdict())) + self.path.build_tree(self.timewindows, self.models, self.tests) + log.debug("Post-run: results' paths\n" + yaml.dump(self.path.as_dict())) report.generate_report(self) @@ -846,7 +812,7 @@ def make_repr(self): repr_config = self.path("config") # Dropping region to results folder if it is a file - region_path = self.region_config.get("path", None) + region_path = self.region_config.get("path", False) if region_path: if isfile(region_path) and region_path: new_path = join(self.path.rundir, self.region_config["path"]) @@ -892,53 +858,16 @@ def as_dict( floatCSEP readable """ - def _get_value(x): - # For each element type, transforms to desired string/output - if hasattr(x, "as_dict"): - # e.g. model, test, etc. - o = x.as_dict() - else: - try: - try: - o = getattr(x, "__name__") - except AttributeError: - o = getattr(x, "name") - except AttributeError: - if isinstance(x, numpy.ndarray): - o = x.tolist() - else: - o = x - return o - - def iter_attr(val): - # recursive iter through nested dicts/lists - if isinstance(val, Mapping): - return { - item: iter_attr(val_) - for item, val_ in val.items() - if ((item not in exclude) and val_) or extended - } - elif isinstance(val, Sequence) and not isinstance(val, str): - return [iter_attr(i) for i in val] - else: - return _get_value(val) - listwalk = [(i, j) for i, j in self.__dict__.items() if not i.startswith("_") and j] listwalk.insert(6, ("catalog", self._catpath)) dictwalk = {i: j for i, j in listwalk} - # if self.model_config is None: - # dictwalk['models'] = iter_attr(self.models) - # if self.test_config is None: - # dictwalk['tests'] = iter_attr(self.tests) - return iter_attr(dictwalk) + return parse_nested_dicts(dictwalk, excluded=exclude, extended=extended) def to_yml(self, filename: str, **kwargs) -> None: """ - - Serializes the :class:`~floatcsep.experiment.Experiment` instance into - a .yml file. + Serializes the :class:`~floatcsep.experiment.Experiment` instance into a .yml file. Note: This instance can then be reinstantiated using @@ -949,7 +878,6 @@ def to_yml(self, filename: str, **kwargs) -> None: **kwargs: Pass to :meth:`~floatcsep.experiment.Experiment.as_dict` Returns: - """ class NoAliasDumper(yaml.Dumper): @@ -968,21 +896,20 @@ def ignore_aliases(self, data): ) @classmethod - def from_yml(cls, config_yml: str, reprdir=None, **kwargs): + def from_yml(cls, config_yml: str, repr_dir=None, **kwargs): """ + Initializes an experiment from a .yml file. It must contain the. - Initializes an experiment from a .yml file. It must contain the attributes described in the :class:`~floatcsep.experiment.Experiment`, :func:`~floatcsep.utils.read_time_config` and :func:`~floatcsep.utils.read_region_config` descriptions Args: config_yml (str): The path to the .yml file - reprdir (str): folder where to reproduce results + repr_dir (str): folder where to reproduce results Returns: An :class:`~floatcsep.experiment.Experiment` class instance - """ log.info("Initializing experiment from .yml file") with open(config_yml, "r") as yml: @@ -997,9 +924,9 @@ def from_yml(cls, config_yml: str, reprdir=None, **kwargs): _dict["path"] = abspath(join(_dir_yml, _dict.get("path", ""))) # replaces rundir case reproduce option is used - if reprdir: + if repr_dir: _dict["original_rundir"] = _dict.get("rundir", "results") - _dict["rundir"] = relpath(join(_dir_yml, reprdir), _dict["path"]) + _dict["rundir"] = relpath(join(_dir_yml, repr_dir), _dict["path"]) _dict["original_config"] = abspath(join(_dict["path"], _dict["config_file"])) else: diff --git a/floatcsep/extras.py b/floatcsep/extras.py index 506f5c0..cf66239 100644 --- a/floatcsep/extras.py +++ b/floatcsep/extras.py @@ -1,17 +1,18 @@ +from typing import Sequence + import numpy import scipy.stats -from matplotlib import pyplot -from csep.models import EvaluationResult +from csep.core.catalogs import CSEPCatalog +from csep.core.exceptions import CSEPCatalogException +from csep.core.forecasts import GriddedForecast from csep.core.poisson_evaluations import ( _simulate_catalog, paired_t_test, w_test, _poisson_likelihood_test, ) -from csep.core.exceptions import CSEPCatalogException -from typing import Sequence -from csep.core.forecasts import GriddedForecast -from csep.core.catalogs import CSEPCatalog +from csep.models import EvaluationResult +from matplotlib import pyplot def binomial_spatial_test( @@ -24,6 +25,7 @@ def binomial_spatial_test( ): """ Performs the binary spatial test on the Forecast using the Observed Catalogs. + Note: The forecast and the observations should be scaled to the same time period before calling this function. This increases transparency as no assumptions are being made about the length of the forecasts. This is particularly important for gridded forecasts that supply their forecasts as rates. @@ -204,8 +206,8 @@ def sequential_information_gain( random_numbers: Sequence = None, ): """ - Args: + gridded_forecasts: list csep.core.forecasts.GriddedForecast benchmark_forecasts: list csep.core.forecasts.GriddedForecast observed_catalogs: list csep.core.catalogs.Catalog @@ -274,8 +276,8 @@ def vector_poisson_t_w_test( catalog: CSEPCatalog, ): """ + Computes Student's t-test for the information gain per earthquake over. - Computes Student's t-test for the information gain per earthquake over a list of forecasts and w-test for normality Uses all ref_forecasts to perform pair-wise t-tests against the @@ -436,7 +438,9 @@ def negative_binomial_number_test(gridded_forecast, observed_catalog, variance): def binomial_joint_log_likelihood_ndarray(forecast, catalog): """ - Computes Bernoulli log-likelihood scores, assuming that earthquakes follow a binomial distribution. + Computes Bernoulli log-likelihood scores, assuming that earthquakes follow a binomial. + + distribution. Args: forecast: Forecast of a Model (Gridded) (Numpy Array) @@ -444,7 +448,7 @@ def binomial_joint_log_likelihood_ndarray(forecast, catalog): It can be anything greater than zero catalog: Observed (Gridded) seismicity (Numpy Array): An Observation has to be Number of Events in Each Bin - It has to be a either zero or positive integer only (No Floating Point) + It has to be either zero or positive integer only (No Floating Point) """ # First, we mask the forecast in cells where we could find log=0.0 singularities: forecast_masked = numpy.ma.masked_where(forecast.ravel() <= 0.0, forecast.ravel()) @@ -472,15 +476,18 @@ def _binomial_likelihood_test( normalize_likelihood=False, ): """ - Computes binary conditional-likelihood test from CSEP using an efficient simulation based approach. + Computes binary conditional likelihood test from CSEP using a simulation based. + + approach. + Args: + forecast_data (numpy.ndarray): nd array where [:, -1] are the magnitude bins. observed_data (numpy.ndarray): same format as observation. num_simulations: default number of simulations to use for likelihood based simulations seed: used for reproducibility of the prng random_numbers (numpy.ndarray): can supply an explicit list of random numbers, primarily used for software testing - use_observed_counts (bool): if true, will simulate catalogs using the observed events, if false will draw from poisson - distribution + use_observed_counts (bool): if true, will simulate catalogs using the observed events, if false will draw from poisson distribution """ # Array-masking that avoids log singularities: @@ -553,7 +560,9 @@ def binomial_conditional_likelihood_test( verbose=False, ): """ - Performs the binary conditional likelihood test on Gridded Forecast using an Observed Catalog. + Performs the binary conditional likelihood test on Gridded Forecast using an Observed. + + Catalog. Normalizes the forecast so the forecasted rate are consistent with the observations. This modification eliminates the strong impact differences in the number distribution have on the forecasted rates. @@ -674,10 +683,7 @@ def _binary_t_test_ndarray( def log_likelihood_point_process(observation, forecast, cell_area): - """ - Log-likelihood for point process - - """ + """Log-likelihood for point process.""" forecast_density = forecast / cell_area.reshape(-1, 1) observation = observation.ravel() forecast_density = forecast_density.ravel() @@ -703,11 +709,11 @@ def _standard_deviation( ): """ Calculate Variance using forecast 1 and forecast 2. + But It is calculated using the forecast values corresponding to the non-zero observations. The same process is repeated as repeated during calculation of Point Process LL. After we get forecast rates for non-zeros observations, then Pooled Variance is calculated. - Parameters ---------- gridded_forecast1 : forecast @@ -718,7 +724,6 @@ def _standard_deviation( Returns ------- Variance - """ N_obs = numpy.sum(gridded_observation1) @@ -767,6 +772,7 @@ def paired_ttest_point_process( ): """ Function for T test based on Point process LL. + Works for comparing forecasts for different grids Parameters @@ -995,7 +1001,7 @@ def plot_negbinom_consistency_test( def _get_marker_style(obs_stat, p, one_sided_lower): - """Returns matplotlib marker style as fmt string""" + """Returns matplotlib marker style as fmt string.""" if obs_stat < p[0] or obs_stat > p[1]: # red circle fmt = "ro" diff --git a/floatcsep/model.py b/floatcsep/model.py index 64906d5..f0e6e00 100644 --- a/floatcsep/model.py +++ b/floatcsep/model.py @@ -5,17 +5,16 @@ from datetime import datetime from typing import List, Callable, Union, Mapping, Sequence -import csep import git import numpy from csep.core.forecasts import GriddedForecast, CatalogForecast -from csep.utils.time_utils import decimal_year from floatcsep.accessors import from_zenodo, from_git from floatcsep.environments import EnvironmentFactory from floatcsep.readers import ForecastParsers, HDF5Serializer -from floatcsep.registry import ModelTree -from floatcsep.utils import timewindow2str, str2timewindow +from floatcsep.registry import ForecastRegistry +from floatcsep.repository import ForecastRepository +from floatcsep.utils import timewindow2str, str2timewindow, parse_nested_dicts log = logging.getLogger("floatLogger") @@ -39,7 +38,6 @@ class Model(ABC): def __init__( self, name: str, - model_path: str, zenodo_id: int = None, giturl: str = None, repo_hash: str = None, @@ -49,32 +47,21 @@ def __init__( ): self.name = name - self.model_path = model_path self.zenodo_id = zenodo_id self.giturl = giturl self.repo_hash = repo_hash self.authors = authors self.doi = doi - self.path = None + self.registry = None self.forecasts = {} + self.force_stage = False self.__dict__.update(**kwargs) - @property - def dir(self) -> str: - """ - Returns: - The directory containing the model source. - """ - if os.path.isdir(self.path("path")): - return self.path("path") - else: - return os.path.dirname(self.path("path")) - @abstractmethod def stage(self, timewindows=None) -> None: - """Prepares the stage for a model run. Can be""" + """Prepares the stage for a model run.""" pass @abstractmethod @@ -87,9 +74,7 @@ def create_forecast(self, tstring: str, **kwargs) -> None: """Creates a forecast based on the model's logic.""" pass - def get_source( - self, zenodo_id: int = None, giturl: str = None, force: bool = False, **kwargs - ) -> None: + def get_source(self, zenodo_id: int = None, giturl: str = None, **kwargs) -> None: """ Search, download or clone the model source in the filesystem, zenodo. @@ -101,22 +86,17 @@ def get_source( `https://zenodo.org/record/{zenodo_id}` giturl (str): git remote repository URL from which to clone the source - force (bool): Forces to re-query the model from a web repository **kwargs: see :func:`~floatcsep.utils.from_zenodo` and :func:`~floatcsep.utils.from_git` """ - if os.path.exists(self.path("path")) and not force: - return - - os.makedirs(self.dir, exist_ok=True) if zenodo_id: log.info(f"Retrieving model {self.name} from zenodo id: " f"{zenodo_id}") try: from_zenodo( zenodo_id, - self.dir if self.path.fmt else self.path("path"), - force=force, + self.registry.dir if self.registry.fmt else self.registry.get_path("path"), + force=True, ) except (KeyError, TypeError) as msg: raise KeyError(f"Zenodo identifier is not valid: {msg}") @@ -124,68 +104,44 @@ def get_source( elif giturl: log.info(f"Retrieving model {self.name} from git url: " f"{giturl}") try: - from_git(giturl, self.dir if self.path.fmt else self.path("path"), **kwargs) + from_git( + giturl, + self.registry.dir if self.registry.fmt else self.registry.get_path("path"), + **kwargs, + ) except (git.NoSuchPathError, git.CommandError) as msg: raise git.NoSuchPathError(f"git url was not found {msg}") else: raise FileNotFoundError("Model has no path or identified") - if not os.path.exists(self.dir) or not os.path.exists(self.path("path")): + if not os.path.exists(self.registry.dir) or not os.path.exists( + self.registry.get_path("path") + ): raise FileNotFoundError( - f"Directory '{self.dir}' or file {self.path}' do not exist. " + f"Directory '{self.registry.dir}' or file {self.registry}' do not exist. " f"Please check the specified 'path' matches the repo " f"structure" ) - def as_dict(self, excluded=("name", "forecasts", "workdir")): + def as_dict(self, excluded=("name", "repository", "workdir")): """ Returns: Dictionary with relevant attributes. Model can be re-instantiated from this dict """ - def _get_value(x): - # For each element type, transforms to desired string/output - if hasattr(x, "as_dict"): - # e.g. model, evaluation, filetree, etc. - o = x.as_dict() - else: - try: - try: - o = getattr(x, "__name__") - except AttributeError: - o = getattr(x, "name") - except AttributeError: - if isinstance(x, numpy.ndarray): - o = x.tolist() - else: - o = x - return o - - def iter_attr(val): - # recursive iter through nested dicts/lists - if isinstance(val, Mapping): - return { - item: iter_attr(val_) - for item, val_ in val.items() - if ((item not in excluded) and val_) - } - elif isinstance(val, Sequence) and not isinstance(val, str): - return [iter_attr(i) for i in val] - else: - return _get_value(val) - list_walk = [ (i, j) for i, j in sorted(self.__dict__.items()) if not i.startswith("_") and j ] dict_walk = {i: j for i, j in list_walk} + dict_walk["path"] = dict_walk.pop("registry").path - return {self.name: iter_attr(dict_walk)} + return {self.name: parse_nested_dicts(dict_walk, excluded=excluded)} @classmethod def from_dict(cls, record: dict, **kwargs): """ - Returns a Model instance from a dictionary containing the required atrributes. Can be + Returns a Model instance from a dictionary containing the required attributes. Can be used to quickly instantiate from a .yml file. Args: @@ -207,6 +163,30 @@ def from_dict(cls, record: dict, **kwargs): name = next(iter(record)) return cls(name=name, **record[name], **kwargs) + @classmethod + def factory(cls, model_cfg: dict) -> "Model": + """Factory method. Instantiate first on any explicit option provided in the model + configuration. + """ + model_path = [*model_cfg.values()][0]["model_path"] + workdir = [*model_cfg.values()][0].get("workdir", "") + model_class = [*model_cfg.values()][0].get("class", "") + + if model_class in ("ti", "time_independent"): + return TimeIndependentModel.from_dict(model_cfg) + + elif model_class in ("td", "time_dependent"): + return TimeDependentModel.from_dict(model_cfg) + + if os.path.isfile(os.path.join(workdir, model_path)): + return TimeIndependentModel.from_dict(model_cfg) + + elif "func" in [*model_cfg.values()][0]: + return TimeDependentModel.from_dict(model_cfg) + + else: + return TimeIndependentModel.from_dict(model_cfg) + class TimeIndependentModel(Model): """ @@ -220,18 +200,38 @@ class TimeIndependentModel(Model): """ def __init__(self, name: str, model_path: str, forecast_unit=1, store_db=False, **kwargs): - super().__init__(name, model_path, **kwargs) + super().__init__(name, **kwargs) + self.forecast_unit = forecast_unit self.store_db = store_db + self.registry = ForecastRegistry(kwargs.get("workdir", os.getcwd()), model_path) + self.repository = ForecastRepository.factory( + self.registry, model_class=self.__class__.__name__, **kwargs + ) - self.path = ModelTree(kwargs.get("workdir", os.getcwd()), model_path) + def stage(self, timewindows: Sequence[Sequence[datetime]] = None) -> None: + """ + Acquire the forecast data if it is not in the file system. Sets the paths internally + (or database pointers) to the forecast data. - def init_db(self, dbpath: str = "", force: bool = False) -> None: + Args: + timewindows (list): time_windows that the forecast data represents. """ - Initializes the database if `use_db` is True. - If the model source is a file, serializes the forecast into a HDF5 file. If source is a - generating function or code, creates an empty DB + if self.force_stage or not self.registry.file_exists("path"): + os.makedirs(self.registry.dir, exist_ok=True) + self.get_source(self.zenodo_id, self.giturl, branch=self.repo_hash) + + if self.store_db: + self.init_db() + + self.registry.build_tree(timewindows=timewindows, model_class=self.__class__.__name__) + + def init_db(self, dbpath: str = "", force: bool = False) -> None: + """ + Initializes the database if `use_db` is True. If the model source is a file, + serializes the forecast into a HDF5 file. If source is a generating function or code, + creates an empty DB. Args: dbpath (str): Path to drop the HDF5 database. Defaults to same path @@ -240,72 +240,32 @@ def init_db(self, dbpath: str = "", force: bool = False) -> None: exists """ - parser = getattr(ForecastParsers, self.path.fmt) - rates, region, mag = parser(self.path("path")) + parser = getattr(ForecastParsers, self.registry.fmt) + rates, region, mag = parser(self.registry.get_path("path")) db_func = HDF5Serializer.grid2hdf5 if not dbpath: - dbpath = self.path.path.replace(self.path.fmt, "hdf5") - self.path.database = dbpath + dbpath = self.registry.path.replace(self.registry.fmt, "hdf5") + self.registry.database = dbpath - if not os.path.isfile(self.path.abs(dbpath)) or force: + if not os.path.isfile(self.registry.abs(dbpath)) or force: log.info(f"Serializing model {self.name} into HDF5 format") db_func( rates, region, mag, - hdf5_filename=self.path.abs(dbpath), + hdf5_filename=self.registry.abs(dbpath), unit=self.forecast_unit, ) - def rm_db(self) -> None: - """Clean up the generated HDF5 File.""" - pass - - def stage(self, timewindows: Union[str, List[datetime]] = None) -> None: - """ - Acquire the forecast data if it is not in the file system. - Sets internally the paths (or database pointers) to the forecast data. - - Args: - timewindows (str, list): time_windows that the forecast data represents. - - """ - self.get_source(self.zenodo_id, self.giturl, branch=self.repo_hash) - if self.store_db: - self.init_db() - - self.path.build_tree( - timewindows=timewindows, - model_class="ti", - prefix=self.__dict__.get("prefix", self.name), - ) - def get_forecast( self, tstring: Union[str, list] = None, region=None - ) -> Union[GriddedForecast, CatalogForecast, List[GriddedForecast], List[CatalogForecast]]: - """ - Wrapper that just returns a forecast when requested. - """ + ) -> Union[GriddedForecast, List[GriddedForecast]]: + """Wrapper that just returns a forecast when requested.""" - if isinstance(tstring, str): - # If only one time_window string is passed - try: - # If they are retrieved from the Evaluation class - return self.forecasts[tstring] - except KeyError: - # In case they are called from postprocess - self.create_forecast(tstring) - return self.forecasts[tstring] - else: - # If multiple time_window strings are passed - forecasts = [] - for tw in tstring: - if tw in self.forecasts.keys(): - forecasts.append(self.forecasts[tw]) - if not forecasts: - raise KeyError(f"Forecasts {*tstring,} have not been created yet") - return forecasts + return self.repository.load_forecast( + tstring, name=self.name, region=region, forecast_unit=self.forecast_unit + ) def create_forecast(self, tstring: str, **kwargs) -> None: """ @@ -322,57 +282,13 @@ def create_forecast(self, tstring: str, **kwargs) -> None: formatted as 'YY1-MM1-DD1_YY2-MM2-DD2'. **kwargs: """ - start_date, end_date = str2timewindow(tstring) - self.forecast_from_file(start_date, end_date, **kwargs) - - def forecast_from_file(self, start_date: datetime, end_date: datetime, **kwargs) -> None: - """ - Generates a forecast from a file, by parsing and scaling it to. - - the desired time window. H - - Args: - start_date (~datetime.datetime): Start of the forecast - end_date (~datetime.datetime): End of the forecast - **kwargs: Keyword arguments for - :class:`csep.core.forecasts.GriddedForecast` - """ - - time_horizon = decimal_year(end_date) - decimal_year(start_date) - tstring = timewindow2str([start_date, end_date]) - - f_path = self.path("forecasts", tstring) - f_parser = getattr(ForecastParsers, self.path.fmt) - - rates, region, mags = f_parser(f_path) - - forecast = GriddedForecast( - name=f"{self.name}", - data=rates, - region=region, - magnitudes=mags, - start_time=start_date, - end_time=end_date, - ) - - scale = time_horizon / self.forecast_unit - if scale != 1.0: - forecast = forecast.scale(scale) - - log.debug( - f"Model {self.name}:\n" - f"\tForecast expected count: {forecast.event_count:.2f}" - f" with scaling parameter: {time_horizon:.1f}" - ) - - self.forecasts[tstring] = forecast + return class TimeDependentModel(Model): """ Model that creates varying forecasts depending on a time window. Requires either a collection of Forecasts or a function that returns a Forecast. - """ def __init__( @@ -384,18 +300,20 @@ def __init__( **kwargs, ) -> None: - super().__init__(name, model_path, **kwargs) + super().__init__(name, **kwargs) self.func = func self.func_kwargs = func_kwargs or {} - self.path = ModelTree(kwargs.get("workdir", os.getcwd()), model_path) + self.registry = ForecastRegistry(kwargs.get("workdir", os.getcwd()), model_path) + self.repository = ForecastRepository.factory( + self.registry, model_class=self.__class__.__name__, **kwargs + ) self.build = kwargs.get("build", None) - self.run_prefix = "" if self.func: self.environment = EnvironmentFactory.get_env( - self.build, self.name, self.path.abs(self.model_path) + self.build, self.name, self.registry.abs(model_path) ) def stage(self, timewindows=None) -> None: @@ -407,14 +325,16 @@ def stage(self, timewindows=None) -> None: - Initialize database - Run model quality assurance (unit tests, runnable from floatcsep) """ - self.get_source(self.zenodo_id, self.giturl, branch=self.repo_hash) + if self.force_stage or not self.registry.file_exists("path"): + os.makedirs(self.registry.dir, exist_ok=True) + self.get_source(self.zenodo_id, self.giturl, branch=self.repo_hash) if hasattr(self, "environment"): self.environment.create_environment() - self.path.build_tree( + self.registry.build_tree( timewindows=timewindows, - model_class="td", + model_class=self.__class__.__name__, prefix=self.__dict__.get("prefix", self.name), args_file=self.__dict__.get("args_file", None), input_cat=self.__dict__.get("input_cat", None), @@ -423,27 +343,8 @@ def stage(self, timewindows=None) -> None: def get_forecast( self, tstring: Union[str, list] = None, region=None ) -> Union[GriddedForecast, CatalogForecast, List[GriddedForecast], List[CatalogForecast]]: - """Wrapper that just returns a forecast, hiding the access method under the hood""" - - if isinstance(tstring, str): - # If one time window string is passed - fc_path = self.path("forecasts", tstring) - # A region must be given to the forecast - return csep.load_catalog_forecast( - fc_path, region=region, apply_filters=True, filter_spatial=True - ) - - else: - forecasts = [] - for t in tstring: - fc_path = self.path("forecasts", t) - # A region must be given to the forecast - forecasts.append( - csep.load_catalog_forecast( - fc_path, region=region, apply_filters=True, filter_spatial=True - ) - ) - return forecasts + """Wrapper that just returns a forecast, hiding the access method under the hood.""" + return self.repository.load_forecast(tstring, region=region) def create_forecast(self, tstring: str, **kwargs) -> None: """ @@ -463,26 +364,20 @@ def create_forecast(self, tstring: str, **kwargs) -> None: start_date, end_date = str2timewindow(tstring) # Model src is a func or binary - - fc_path = self.path("forecasts", tstring) - if kwargs.get("force") or not os.path.exists(fc_path): - self.forecast_from_func(start_date, end_date, **self.func_kwargs, **kwargs) - else: - log.info(f"Forecast of {tstring} of model {self.name} already " f"exists") - - def forecast_from_func(self, start_date: datetime, end_date: datetime, **kwargs) -> None: + if not kwargs.get("force") and self.registry.forecast_exists(tstring): + log.info(f"Forecast for {tstring} of model {self.name} already exists") + return self.prepare_args(start_date, end_date, **kwargs) log.info( f"Running {self.name} using {self.environment.__class__.__name__}:" f" {timewindow2str([start_date, end_date])}" ) - - self.run_model() + self.environment.run_command(f'{self.func} {self.registry.get_path("args_file")}') def prepare_args(self, start, end, **kwargs): - filepath = self.path("args_file") + filepath = self.registry.get_path("args_file") fmt = os.path.splitext(filepath)[1] if fmt == ".txt": @@ -506,6 +401,7 @@ def replace_arg(arg, val, fp): replace_arg("end_date", end.isoformat(), filepath) for i, j in kwargs.items(): replace_arg(i, j, filepath) + elif fmt == ".json": with open(filepath, "r") as file_: args = json.load(file_) @@ -516,31 +412,3 @@ def replace_arg(arg, val, fp): with open(filepath, "w") as file_: json.dump(args, file_, indent=2) - - def run_model(self): - - self.environment.run_command(f'{self.func} {self.path("args_file")}') - - -class ModelFactory: - @staticmethod - def create_model(model_cfg) -> Model: - - model_path = [*model_cfg.values()][0]["model_path"] - workdir = [*model_cfg.values()][0].get("workdir", "") - model_class = [*model_cfg.values()][0].get("class", "") - - if model_class == "ti": - return TimeIndependentModel.from_dict(model_cfg) - - elif model_class == "td": - return TimeDependentModel.from_dict(model_cfg) - - if os.path.isfile(os.path.join(workdir, model_path)): - return TimeIndependentModel.from_dict(model_cfg) - - elif "func" in [*model_cfg.values()][0]: - return TimeDependentModel.from_dict(model_cfg) - - else: - return TimeIndependentModel.from_dict(model_cfg) diff --git a/floatcsep/readers.py b/floatcsep/readers.py index 526a918..aa02804 100644 --- a/floatcsep/readers.py +++ b/floatcsep/readers.py @@ -1,13 +1,14 @@ +import argparse +import logging import os.path +import time +import xml.etree.ElementTree as eTree + import h5py -import pandas -import argparse import numpy -import xml.etree.ElementTree as eTree -from csep.models import Polygon +import pandas from csep.core.regions import QuadtreeGrid2D, CartesianGrid2D -import time -import logging +from csep.models import Polygon log = logging.getLogger(__name__) diff --git a/floatcsep/registry.py b/floatcsep/registry.py index b36c119..8437645 100644 --- a/floatcsep/registry.py +++ b/floatcsep/registry.py @@ -1,36 +1,19 @@ -import dataclasses +import logging import os +from abc import ABC, abstractmethod +from datetime import datetime from os.path import join, abspath, relpath, normpath, dirname, exists -from dataclasses import dataclass, field -from typing import Sequence -from floatcsep.utils import timewindow2str +from typing import Sequence, Union +from floatcsep.utils import timewindow2str -@dataclass -class ModelTree: - workdir: str - path: str - database: str = None - args_file: str = None - input_cat: str = None - forecasts: dict = field(default_factory=dict) - inventory: dict = field(default_factory=dict) +log = logging.getLogger("floatLogger") - def __call__(self, *args): - val = self.__dict__ - for i in args: - parsed_arg = self._parse_arg(i) - val = val[parsed_arg] +class BaseFileRegistry(ABC): - return self.abs(val) - - @property - def fmt(self) -> str: - if self.database: - return os.path.splitext(self.database)[1][1:] - else: - return os.path.splitext(self.path)[1][1:] + def __init__(self, workdir: str): + self.workdir = workdir @staticmethod def _parse_arg(arg): @@ -45,44 +28,129 @@ def _parse_arg(arg): else: raise Exception("Arg is not found") - def __eq__(self, other): - return self.path == other - + @abstractmethod def as_dict(self): - return self.path + pass - def asdict(self): - return dataclasses.asdict(self) + @abstractmethod + def build_tree(self, *args, **kwargs) -> None: + pass + + @abstractmethod + def get_path(self, *args): + pass def abs(self, *paths: Sequence[str]) -> str: - """Gets the absolute path of a file, when it was defined relative to - the experiment working dir.""" _path = normpath(abspath(join(self.workdir, *paths))) - _dir = dirname(_path) return _path - def absdir(self, *paths: Sequence[str]) -> str: - """Gets the absolute path of a file, when it was defined relative to - the experiment working dir.""" - + def abs_dir(self, *paths: Sequence[str]) -> str: _path = normpath(abspath(join(self.workdir, *paths))) _dir = dirname(_path) return _dir - def fileexists(self, *args): - file_abspath = self.__call__(*args) + def rel(self, *paths: Sequence[str]) -> str: + """Gets the relative path of a file, when it was defined relative to. + + the experiment working dir. + """ + + _abspath = normpath(abspath(join(self.workdir, *paths))) + _relpath = relpath(_abspath, self.workdir) + return _relpath + + def rel_dir(self, *paths: Sequence[str]) -> str: + """Gets the absolute path of a file, when it was defined relative to. + + the experiment working dir. + """ + + _path = normpath(abspath(join(self.workdir, *paths))) + _dir = dirname(_path) + + return relpath(_dir, self.workdir) + + def file_exists(self, *args): + file_abspath = self.get_path(*args) return exists(file_abspath) + +class ForecastRegistry(BaseFileRegistry): + def __init__( + self, + workdir: str, + path: str, + database: str = None, + args_file: str = None, + input_cat: str = None, + ): + super().__init__(workdir) + + self.path = path + self.database = database + self.args_file = args_file + self.input_cat = input_cat + self.forecasts = {} + self.inventory = {} + + def get_path(self, *args): + val = self.__dict__ + for i in args: + parsed_arg = self._parse_arg(i) + val = val[parsed_arg] + return self.abs(val) + + @property + def dir(self) -> str: + """ + Returns: + + The directory containing the model source. + """ + if os.path.isdir(self.get_path("path")): + return self.get_path("path") + else: + return os.path.dirname(self.get_path("path")) + + @property + def fmt(self) -> str: + if self.database: + return os.path.splitext(self.database)[1][1:] + else: + return os.path.splitext(self.path)[1][1:] + + def as_dict(self): + return { + "workdir": self.workdir, + "path": self.path, + "database": self.database, + "args_file": self.args_file, + "input_cat": self.input_cat, + "forecasts": self.forecasts, + "inventory": self.inventory, + } + + def forecast_exists(self, timewindow: Union[str, list]): + + if isinstance(timewindow, str): + return self.file_exists("forecasts", timewindow) + else: + return [self.file_exists("forecasts", i) for i in timewindow] + def build_tree( - self, timewindows=None, model_class="ti", prefix=None, args_file=None, input_cat=None + self, + timewindows: Sequence[Sequence[datetime]] = None, + model_class: str = "TimeIndependentModel", + prefix=None, + args_file=None, + input_cat=None, ) -> None: """ - - Creates the run directory, and reads the file structure inside + Creates the run directory, and reads the file structure inside. Args: timewindows (list(str)): List of time windows or strings. - model_class (str): Time-indendent (ti) or time-dependent (td) + model_class (str): Model's class name prefix (str): prefix of the model forecast filenames if TD args_file (str, bool): input arguments path of the model if TD input_cat (str, bool): input catalog path of the model if TD @@ -93,17 +161,17 @@ def build_tree( exist already target_paths: flag to each element of the gefe (catalog and evaluation results) - """ - if timewindows is None: - return + windows = timewindow2str(timewindows) - if model_class == "ti": + + if model_class == "TimeIndependentModel": fname = self.database if self.database else self.path - fc_files = {win: fname for win in windows} - fc_exists = {win: exists(fc_files[win]) for win in windows} + self.forecasts = {win: fname for win in windows} + self.inventory = {win: exists(self.forecasts[win]) for win in windows} + + elif model_class == "TimeDependentModel": - elif model_class == "td": args = args_file if args_file else join("input", "args.txt") self.args_file = join(self.path, args) input_cat = input_cat if input_cat else join("input", "catalog.csv") @@ -117,25 +185,25 @@ def build_tree( os.makedirs(folder_, exist_ok=True) # set forecast names - fc_files = { + self.forecasts = { win: join(dirtree["forecasts"], f"{prefix}_{win}.csv") for win in windows } - fc_exists = { + self.inventory = { win: any(file for file in list(os.listdir(dirtree["forecasts"]))) for win in windows } - self.forecasts = fc_files - self.inventory = fc_exists +class ExperimentRegistry(BaseFileRegistry): + def __init__(self, workdir: str, rundir: str = "results"): + super().__init__(workdir) + self.rundir = rundir + self.paths = {} + self.result_exists = {} -@dataclass -class PathTree: - workdir: str - rundir: str = "results" - paths: dict = field(default_factory=dict) - result_exists: dict = field(default_factory=dict) + def get_path(self, *args): + pass def __call__(self, *args): val = self.paths @@ -144,69 +212,12 @@ def __call__(self, *args): val = val[parsed_arg] return self.abs(self.rundir, val) - @staticmethod - def _parse_arg(arg): - if isinstance(arg, (list, tuple)): - return timewindow2str(arg) - elif isinstance(arg, str): - return arg - elif hasattr(arg, "name"): - return arg.name - elif hasattr(arg, "__name__"): - return arg.__name__ - else: - raise Exception("Arg is not found") - - def __eq__(self, other): - return self.workdir == other - def as_dict(self): return self.workdir - def asdict(self): - return dataclasses.asdict(self) - - def abs(self, *paths: Sequence[str]) -> str: - """Gets the absolute path of a file, when it was defined relative to - the experiment working dir.""" - - _path = normpath(abspath(join(self.workdir, *paths))) - return _path - - def rel(self, *paths: Sequence[str]) -> str: - """Gets the relative path of a file, when it was defined relative to - the experiment working dir.""" - - _abspath = normpath(abspath(join(self.workdir, *paths))) - _relpath = relpath(_abspath, self.workdir) - return _relpath - - def absdir(self, *paths: Sequence[str]) -> str: - """Gets the absolute path of a file, when it was defined relative to - the experiment working dir.""" - - _path = normpath(abspath(join(self.workdir, *paths))) - _dir = dirname(_path) - return _dir - - def reldir(self, *paths: Sequence[str]) -> str: - """Gets the absolute path of a file, when it was defined relative to - the experiment working dir.""" - - _path = normpath(abspath(join(self.workdir, *paths))) - _dir = dirname(_path) - _reldir = relpath(_dir, self.workdir) - return _reldir - - def fileexists(self, *args): - - file_abspath = self.__call__(*args) - return exists(file_abspath) - - def build(self, timewindows=None, models=None, tests=None) -> None: + def build_tree(self, timewindows=None, models=None, tests=None) -> None: """ - - Creates the run directory, and reads the file structure inside + Creates the run directory, and reads the file structure inside. Args: timewindows: List of time windows, or representing string. @@ -219,7 +230,6 @@ def build(self, timewindows=None, models=None, tests=None) -> None: exist already target_paths: flag to each element of the gefe (catalog and evaluation results) - """ # grab names for creating directories windows = timewindow2str(timewindows) diff --git a/floatcsep/report.py b/floatcsep/report.py index 2d5d364..1860229 100644 --- a/floatcsep/report.py +++ b/floatcsep/report.py @@ -1,7 +1,7 @@ from floatcsep.utils import MarkdownReport, timewindow2str """ -Use the MarkdownReport class to create output for the experiment +Use the MarkdownReport class to create output for the experiment. 1. string templates are stored for each evaluation 2. string templates are stored for each forecast diff --git a/floatcsep/repository.py b/floatcsep/repository.py new file mode 100644 index 0000000..be981bb --- /dev/null +++ b/floatcsep/repository.py @@ -0,0 +1,162 @@ +import logging +from abc import ABC, abstractmethod +from typing import Sequence, Union + +import csep +from csep.core.forecasts import GriddedForecast +from csep.utils.time_utils import decimal_year + +from floatcsep.readers import ForecastParsers +from floatcsep.registry import ForecastRegistry +from floatcsep.utils import str2timewindow +from floatcsep.utils import timewindow2str + +log = logging.getLogger("floatLogger") + + +class ForecastRepository(ABC): + + @abstractmethod + def __init__(self, registry: ForecastRegistry): + self.registry = registry + self.lazy_load = False + self.forecasts = {} + + @abstractmethod + def load_forecast(self, tstring: Union[str, Sequence[str]], **kwargs): + pass + + @abstractmethod + def _load_single_forecast(self, tstring: str, **kwargs): + pass + + @abstractmethod + def remove(self, tstring: Union[str, Sequence[str]]): + pass + + def __eq__(self, other) -> bool: + + if not isinstance(other, ForecastRepository): + return False + + if len(self.forecasts) != len(other.forecasts): + return False + + for key in self.forecasts.keys(): + if key not in other.forecasts.keys(): + return False + if self.forecasts[key] != other.forecasts[key]: + return False + return True + + @classmethod + def factory( + cls, registry: ForecastRegistry, model_class: str, forecast_type: str = None, **kwargs + ) -> "ForecastRepository": + """Factory method. Instantiate first on explicit option provided in the model + configuration. Then, defaults to gridded forecast for TimeIndependentModel and catalog + forecasts for TimeDependentModel + """ + + if forecast_type == "catalog": + return CatalogForecastRepository(registry, **kwargs) + elif forecast_type == "gridded": + return GriddedForecastRepository(registry, **kwargs) + + if model_class == "TimeIndependentModel": + return GriddedForecastRepository(registry, **kwargs) + elif model_class == "TimeDependentModel": + return CatalogForecastRepository(registry, **kwargs) + else: + raise ValueError(f"Unknown forecast type: {forecast_type}") + + +class CatalogForecastRepository(ForecastRepository): + + def __init__(self, registry: ForecastRegistry, **kwargs): + self.registry = registry + self.lazy_load = kwargs.get("lazy_load", True) + self.forecasts = {} + + def load_forecast(self, tstring: Union[str, list], region=None): + + if isinstance(tstring, str): + return self._load_single_forecast(tstring, region) + else: + return [self._load_single_forecast(t, region) for t in tstring] + + def _load_single_forecast(self, t: str, region=None): + fc_path = self.registry.get_path("forecasts", t) + return csep.load_catalog_forecast( + fc_path, region=region, apply_filters=True, filter_spatial=True + ) + + def remove(self, tstring: Union[str, Sequence[str]]): + pass + + +class GriddedForecastRepository(ForecastRepository): + + def __init__(self, registry: ForecastRegistry, **kwargs): + self.registry = registry + self.lazy_load = kwargs.get("lazy_load", False) + self.forecasts = {} + + def load_forecast( + self, tstring: Union[str, list] = None, name="", region=None, forecast_unit=1 + ) -> Union[GriddedForecast, Sequence[GriddedForecast]]: + """Returns a forecast when requested.""" + if isinstance(tstring, str): + return self._get_or_load_forecast(tstring, name, forecast_unit) + else: + return [self._get_or_load_forecast(tw, name, forecast_unit) for tw in tstring] + + def _get_or_load_forecast( + self, tstring: str, name: str, forecast_unit: int + ) -> GriddedForecast: + """Helper method to get or load a single forecast.""" + if tstring in self.forecasts: + log.debug(f"Loading {name} forecast for {tstring} from memory") + return self.forecasts[tstring] + else: + log.debug(f"Loading {name} forecast for {tstring} on the fly") + forecast = self._load_single_forecast(tstring, forecast_unit, name) + if not self.lazy_load: + self.forecasts[tstring] = forecast + return forecast + + def _load_single_forecast(self, tstring: str, fc_unit=1, name_=""): + + start_date, end_date = str2timewindow(tstring) + + time_horizon = decimal_year(end_date) - decimal_year(start_date) + tstring_ = timewindow2str([start_date, end_date]) + + f_path = self.registry.get_path("forecasts", tstring_) + f_parser = getattr(ForecastParsers, self.registry.fmt) + + rates, region, mags = f_parser(f_path) + + forecast_ = GriddedForecast( + name=f"{name_}", + data=rates, + region=region, + magnitudes=mags, + start_time=start_date, + end_time=end_date, + ) + + scale = time_horizon / fc_unit + if scale != 1.0: + forecast_ = forecast_.scale(scale) + + log.debug( + f"Model {name_}:\n" + f"\tForecast expected count: {forecast_.event_count:.2f}" + f" with scaling parameter: {time_horizon:.1f}" + ) + + return forecast_ + + def remove(self, tstring: Union[str, Sequence[str]]): + pass diff --git a/floatcsep/utils.py b/floatcsep/utils.py index 73fbfb6..7213acd 100644 --- a/floatcsep/utils.py +++ b/floatcsep/utils.py @@ -1,51 +1,56 @@ # python libraries import copy +import filecmp +import functools import hashlib - -import numpy -import re +import itertools +import logging import multiprocessing import os +import re +from collections import OrderedDict +from datetime import datetime, date +from functools import partial +from typing import Sequence, Union, Mapping + +import csep.core +import csep.utils import mercantile -import shapely.geometry -import scipy.stats -import itertools -import functools -import yaml +import numpy import pandas +import scipy.stats import seaborn -import filecmp -from datetime import datetime, date -from functools import partial -from typing import Sequence, Union -from matplotlib import pyplot -from matplotlib.lines import Line2D -from collections import OrderedDict +import shapely.geometry # pyCSEP libraries import six -import csep.core -import csep.utils +import yaml +from csep.core.forecasts import GriddedForecast from csep.core.regions import CartesianGrid2D, compute_vertices -from csep.utils.plots import plot_spatial_dataset -from csep.models import Polygon from csep.core.regions import QuadtreeGrid2D, geographical_area_from_bounds +from csep.models import Polygon from csep.utils.calc import cleaner_range - -# floatCSEP libraries +from csep.utils.plots import plot_spatial_dataset +from matplotlib import pyplot +from matplotlib.lines import Line2D import floatcsep.accessors import floatcsep.extras import floatcsep.readers +# floatCSEP libraries + _UNITS = ["years", "months", "weeks", "days"] _PD_FORMAT = ["YS", "MS", "W", "D"] +log = logging.getLogger("floatLogger") + + def parse_csep_func(func): """ + Searchs in pyCSEP and floatCSEP a function or method whose name matches the. - Searchs in pyCSEP and floatCSEP a function or method whose name matches the provided string. Args: @@ -56,7 +61,6 @@ def parse_csep_func(func): The callable function/method object. If it was already callable, returns the same input - """ def recgetattr(obj, attr, *args): @@ -94,8 +98,7 @@ def _getattr(obj_, attr_): def parse_timedelta_string(window, exp_class="ti"): """ - - Parses a float or string representing the testing time window length + Parses a float or string representing the testing time window length. Note: @@ -111,7 +114,6 @@ def parse_timedelta_string(window, exp_class="ti"): Formatted :py:class:`str` representing the length and unit (year, month, week, day) of the time window - """ if isinstance(window, str): @@ -134,7 +136,6 @@ def parse_timedelta_string(window, exp_class="ti"): def read_time_cfg(time_config, **kwargs): """ - Builds the temporal configuration of an experiment. Args: @@ -147,7 +148,6 @@ def read_time_cfg(time_config, **kwargs): Returns: A dictionary containing the experiment time attributes and the time windows to be evaluated - """ _attrs = ["start_date", "end_date", "intervals", "horizon", "offset", "growth", "exp_class"] time_config = copy.deepcopy(time_config) @@ -176,7 +176,6 @@ def read_time_cfg(time_config, **kwargs): def read_region_cfg(region_config, **kwargs): """ - Builds the region configuration of an experiment. Args: @@ -188,7 +187,6 @@ def read_region_cfg(region_config, **kwargs): Returns: A dictionary containing the region attributes of the experiment - """ region_config = copy.deepcopy(region_config) _attrs = ["region", "mag_min", "mag_max", "mag_bin", "magnitudes", "depth_min", "depth_max"] @@ -244,13 +242,13 @@ def read_region_cfg(region_config, **kwargs): def timewindow2str(datetimes: Union[Sequence[datetime], Sequence[Sequence[datetime]]]): """ - Converts a time window (list/tuple of datetimes) to a string that + Converts a time window (list/tuple of datetimes) to a string that. + represents it. Can be a single timewindow or a list of time windows. Args: datetimes: Returns: - """ if isinstance(datetimes[0], datetime): return "_".join([j.date().isoformat() for j in datetimes]) @@ -261,13 +259,13 @@ def timewindow2str(datetimes: Union[Sequence[datetime], Sequence[Sequence[dateti def str2timewindow(tw_string: Union[str, Sequence[str]]): """ - Converts a string representation of a time window into a list of datetimes + Converts a string representation of a time window into a list of datetimes. + representing the time window edges. Args: tw_string: Returns: - """ if isinstance(tw_string, str): start_date, end_date = [datetime.fromisoformat(i) for i in tw_string.split("_")] @@ -285,7 +283,6 @@ def timewindows_ti( start_date=None, end_date=None, intervals=None, horizon=None, growth="incremental", **_ ): """ - Creates the testing intervals for a time-independent experiment. Note: @@ -307,7 +304,6 @@ def timewindows_ti( List of tuples containing the lower and upper boundaries of each testing window, as :py:class:`datetime.datetime`. - """ frequency = None @@ -342,7 +338,6 @@ def timewindows_td( start_date=None, end_date=None, timeintervals=None, timehorizon=None, timeoffset=None, **_ ): """ - Creates the testing intervals for a time-dependent experiment. Note: @@ -365,7 +360,6 @@ def timewindows_td( Returns: List of tuples containing the lower and upper boundaries of each testing window, as :py:class:`datetime.datetime`. - """ frequency = None @@ -413,16 +407,55 @@ def timewindows_td( # return timewindows +def parse_nested_dicts( + nested_dict: dict, excluded: Sequence = (), extended: bool = False +) -> dict: + """ + Parses nested dictionaries to flatten them + """ + + def _get_value(x): + # For each element type, transforms to desired string/output + if hasattr(x, "as_dict"): + # e.g. model, test, etc. + o = x.as_dict() + else: + try: + try: + o = getattr(x, "__name__") + except AttributeError: + o = getattr(x, "name") + except AttributeError: + if isinstance(x, numpy.ndarray): + o = x.tolist() + else: + o = x + return o + + def iter_attr(val): + # recursive iter through nested dicts/lists + if isinstance(val, Mapping): + return { + item: iter_attr(val_) + for item, val_ in val.items() + if ((item not in excluded) and val_) or extended + } + elif isinstance(val, Sequence) and not isinstance(val, str): + return [iter_attr(i) for i in val] + else: + return _get_value(val) + + return iter_attr(nested_dict) + + class Task: def __init__(self, instance, method, **kwargs): """ - Base node of the workload distribution. - Wraps lazily objects, methods and their arguments for them to be - executed later. For instance, can wrap a floatcsep.Model, its method - 'create_forecast' and the argument 'time_window', which can be executed - later with Task.call() when, for example, task dependencies (parent - nodes) have been completed. + Base node of the workload distribution. Wraps lazily objects, methods and their + arguments for them to be executed later. For instance, can wrap a floatcsep.Model, its + method 'create_forecast' and the argument 'time_window', which can be executed later + with Task.call() when, for example, task dependencies (parent nodes) have been completed. Args: instance: can be floatcsep.Experiment, floatcsep.Model, floatcsep.Evaluation @@ -438,17 +471,16 @@ def __init__(self, instance, method, **kwargs): def sign_match(self, obj=None, met=None, kw_arg=None): """ - Checks if the Task matchs a given signature for simplicity. + Checks if the Task matches a given signature for simplicity. Purpose is to check from the outside if the Task is from a given object - (Model, Experiment, etc), matching either name or object or description + (Model, Experiment, etc.), matching either name or object or description Args: obj: Instance or instance's name str. Instance is preferred met: Name of the method kw_arg: Only the value (not key) of the kwargs dictionary Returns: - """ if self.obj == obj or obj == getattr(self.obj, "name", None): @@ -488,11 +520,11 @@ def check_exist(self): class TaskGraph: """ - Context manager of floatcsep workload distribution + Context manager of floatcsep workload distribution. + Assign tasks to a node and defines their dependencies (parent nodes). Contains a 'tasks' dictionary whose dict_keys are the Task to be executed with dict_values as the Task's dependencies. - """ def __init__(self): @@ -511,19 +543,21 @@ def ntasks(self, n): def add(self, task): """ - Simply adds a defined task to the graph + Simply adds a defined task to the graph. + Args: task: floatcsep.utils.Task Returns: - """ self.tasks[task] = [] self.ntasks += 1 def add_dependency(self, task, dinst=None, dmeth=None, dkw=None): """ - Adds a dependency to a task already inserted to the TaskGraph. Searchs + Adds a dependency to a task already inserted to the TaskGraph. + + Searchs within the pre-added tasks a signature match by their name/instance, method and keyword_args. @@ -534,7 +568,6 @@ def add_dependency(self, task, dinst=None, dmeth=None, dkw=None): dkw: keyword argument of the dependency Returns: - """ deps = [] for i, other_tasks in enumerate(self.tasks.keys()): @@ -546,8 +579,8 @@ def add_dependency(self, task, dinst=None, dmeth=None, dkw=None): def run(self): """ Iterates through all the graph tasks and runs them. - Returns: + Returns: """ for task, deps in self.tasks.items(): task.run() @@ -561,7 +594,7 @@ def check_exist(self): class MarkdownReport: - """Class to generate a Markdown report from a study""" + """Class to generate a Markdown report from a study.""" def __init__(self, outname="report.md"): self.outname = outname @@ -571,7 +604,7 @@ def __init__(self, outname="report.md"): self.markdown = [] def add_introduction(self, adict): - """Generate document header from dictionary""" + """Generate document header from dictionary.""" first = ( f"# CSEP Testing Results: {adict['simulation_name']} \n" f"**Forecast Name:** {adict['forecast_name']} \n" @@ -590,7 +623,8 @@ def add_introduction(self, adict): def add_text(self, text): """ - Text should be a list of strings where each string will be on its own + Text should be a list of strings where each string will be on its own. + line. Each add_text command represents a paragraph. Args: @@ -611,11 +645,18 @@ def add_figure( width=None, ): """ - This function expects a list of filepaths. If you want the output + This function expects a list of filepaths. + + If you want the output stacked, select a value of ncols. ncols should be divisible by filepaths. todo: modify formatted_paths to work when not divis. Args: + width: + caption: + text: + add_ext: + ncols: title: name of the figure level (int): value 1-6 depending on the heading relative_filepaths (str or List[Tuple[str]]): list of paths in @@ -681,7 +722,7 @@ def add_to_row(_row): self.toc.append((title, level, locator)) def add_heading(self, title, level=1, text="", add_toc=True): - # multipying char simply repeats it + # multiplying char simply repeats it if isinstance(text, str): text = [text] cell = [] @@ -711,7 +752,7 @@ def add_title(self, title, text): self.add_heading(title, 1, text, add_toc=False) def table_of_contents(self): - """generates table of contents based on contents of document.""" + """Generates table of contents based on contents of document.""" if len(self.toc) == 0: return toc = ["# Table of Contents"] @@ -725,7 +766,8 @@ def table_of_contents(self): def add_table(self, data, use_header=True): """ - Generates table from HTML and styles using bootstrap class + Generates table from HTML and styles using bootstrap class. + Args: data List[Tuple[str]]: should be (nrows, ncols) in size. all rows should be the same sizes @@ -736,8 +778,7 @@ def add_table(self, data, use_header=True): table = ['
', f""] def make_header(row): - header = [] - header.append("") + header = [""] for item in row: header.append(f"") header.append("") @@ -776,7 +817,7 @@ def ignore_aliases(self): class ExperimentComparison: def __init__(self, original, reproduced, **kwargs): - """ """ + """""" self.original = original self.reproduced = reproduced @@ -1084,7 +1125,7 @@ def magnitude_vs_time(catalog): def plot_matrix_comparative_test(evaluation_results, p=0.05, order=True, plot_args={}): - """Produces matrix plot for comparative tests for all models + """Produces matrix plot for comparative tests for all models. Args: evaluation_results (list of result objects): paired t-test results @@ -1162,7 +1203,8 @@ def plot_matrix_comparative_test(evaluation_results, p=0.05, order=True, plot_ar def forecast_mapping(forecast_gridded, target_grid, ncpu=None): """ - Aggregates conventional forecast onto quadtree region + Aggregates conventional forecast onto quadtree region. + This is generic function, which can map any forecast on to another grid. Wrapper function over "_forecat_mapping_generic" Forecast mapping onto Target Grid @@ -1174,7 +1216,6 @@ def forecast_mapping(forecast_gridded, target_grid, ncpu=None): both grids are Quadtree and Target grid is high-resolution at every level than the other grid. """ - from csep.core.forecasts import GriddedForecast bounds_target = target_grid.bounds bounds = forecast_gridded.region.bounds @@ -1189,6 +1230,7 @@ def forecast_mapping(forecast_gridded, target_grid, ncpu=None): def plot_quadtree_forecast(qtree_forecast): """ Currently, only a single-resolution plotting capability is available. + So we aggregate multi-resolution forecast on a single-resolution grid and then plot it @@ -1225,12 +1267,13 @@ def plot_quadtree_forecast(qtree_forecast): def plot_forecast_lowres(forecast, plot_args, k=4): """ - Plot a reduced resolution plot. The forecast values are kept the same, + Plot a reduced resolution plot. + + The forecast values are kept the same, but cells are enlarged :param forecast: GriddedForecast object :param plot_args: arguments to be passed to plot_spatial_dataset :param k: Resampling factor. Selects cells every k row and k columns. - """ print("\tPlotting Forecast") @@ -1244,7 +1287,8 @@ def plot_forecast_lowres(forecast, plot_args, k=4): def quadtree_csv_loader(csv_fname): - """Load quadtree forecasted stored as csv file + """Load quadtree forecasted stored as csv file. + The format expects forecast as a comma separated file, in which first column corresponds to quadtree grid cell (quadkey). The second and thrid columns indicate depth range. @@ -1275,16 +1319,15 @@ def quadtree_csv_loader(csv_fname): def geographical_area_from_qk(quadk): - """ - Wrapper around function geographical_area_from_bounds - """ + """Wrapper around function geographical_area_from_bounds.""" bounds = tile_bounds(quadk) return geographical_area_from_bounds(bounds[0], bounds[1], bounds[2], bounds[3]) def tile_bounds(quad_cell_id): """ - It takes in a single Quadkkey and returns lat,longs of two diagonal corners + It takes in a single Quadkkey and returns lat,longs of two diagonal corners. + using mercantile Parameters ---------- @@ -1295,7 +1338,6 @@ def tile_bounds(quad_cell_id): ------- bounds : Mercantile object Latitude and Longitude of bottom left AND top right corners. - """ bounds = mercantile.bounds(mercantile.quadkey_to_tile(quad_cell_id)) @@ -1303,24 +1345,21 @@ def tile_bounds(quad_cell_id): def create_polygon(fg): - """ - Required for parallel processing - """ + """Required for parallel processing.""" return shapely.geometry.Polygon( [(fg[0], fg[1]), (fg[2], fg[1]), (fg[2], fg[3]), (fg[0], fg[3])] ) def calc_cell_area(cell): - """ - Required for parallel processing - """ + """Required for parallel processing.""" return geographical_area_from_bounds(cell[0], cell[1], cell[2], cell[3]) def _map_overlapping_cells(fcst_grid_poly, fcst_cell_area, fcst_rate_poly, target_poly): # , """ - This functions work for Cells that do not directly conside with target + This functions work for Cells that do not directly conside with target. + polygon cells. This function uses 3 variables i.e. fcst_grid_poly, fcst_cell_area, fcst_rate_poly @@ -1358,7 +1397,9 @@ def _map_overlapping_cells(fcst_grid_poly, fcst_cell_area, fcst_rate_poly, targe def _map_exact_inside_cells(fcst_grid, fcst_rate, boundary): """ - Uses 2 Global variables. fcst_grid, fcst_rate + Uses 2 Global variables. + + fcst_grid, fcst_rate Takes a cell_boundary and finds all those fcst_grid cells that fit exactly inside it and then sum-up the rates of all those cells fitting inside it to get forecast rate for boundary_cell @@ -1382,7 +1423,8 @@ def _map_exact_inside_cells(fcst_grid, fcst_rate, boundary): def _forecast_mapping_generic(target_grid, fcst_grid, fcst_rate, ncpu=None): """ - This function can perofrmns both aggregation and de-aggregation/ + This function can perofrmns both aggregation and de-aggregation/. + It is a wrapper function that uses 4 functions in respective order i.e. _map_exact_cells, _map_overlapping_cells, calc_cell_area, create_polygon @@ -1501,7 +1543,8 @@ def _set_dockerfile(name): def _global_region(dh=0.1, name="global", magnitudes=None): - """Creates a global region used for evaluating gridded models on the + """Creates a global region used for evaluating gridded models on the. + global scale. Modified from csep.core.regions.global_region @@ -1543,7 +1586,7 @@ def _check_zero_bins(exp, catalog, test_date): "o", markersize=10, ) - pyplot.savefig(f"{model.path}/{model.name}.png", dpi=300) + pyplot.savefig(f"{model.registry}/{model.name}.png", dpi=300) for model in exp.models: forecast = model.create_forecast(exp.start_date, test_date) catalog.filter_spatial(forecast.region) @@ -1569,4 +1612,4 @@ def _check_zero_bins(exp, catalog, test_date): "o", markersize=10, ) - pyplot.savefig(f"{model.path}/{model.name}.png", dpi=300) + pyplot.savefig(f"{model.registry}/{model.name}.png", dpi=300) diff --git a/pyproject.toml b/pyproject.toml index 5eb4733..ef57c31 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,4 +11,13 @@ testpaths = [ [tool.black] line-length = 96 skip-string-normalization = false -target-version = ["py39", "py310", "py311"] \ No newline at end of file +target-version = ["py39", "py310", "py311"] + +[tool.pydocstringformatter] +write = true +exclude = ["examples/", "docs"] +strip-whitespaces = true +split-summary-body = false +style = "pep257" +max-line-length = 96 +linewrap-full-docstring = true \ No newline at end of file diff --git a/tests/artifacts/models/td_model/input/args.txt b/tests/artifacts/models/td_model/input/args.txt index e1cbf9d..cd87c0b 100644 --- a/tests/artifacts/models/td_model/input/args.txt +++ b/tests/artifacts/models/td_model/input/args.txt @@ -1,3 +1,2 @@ -start_date = 2000-01-01T00:00:00 -end_date = 2000-01-02T00:00:00 -n_sims = 200 +start_date = foo +end_date = bar \ No newline at end of file diff --git a/tests/integration/test_model_interface.py b/tests/integration/test_model_interface.py index af3c0fb..b8531e6 100644 --- a/tests/integration/test_model_interface.py +++ b/tests/integration/test_model_interface.py @@ -1,8 +1,136 @@ +import filecmp import os.path -import numpy.testing +import shutil +import tempfile from datetime import datetime from unittest import TestCase -from floatcsep.model import TimeIndependentModel +from unittest.mock import patch + +import numpy.testing +import csep.core.regions +from csep.core.forecasts import GriddedForecast + +from floatcsep.environments import EnvironmentManager +from floatcsep.model import TimeIndependentModel, TimeDependentModel +from floatcsep.utils import timewindow2str + + +class TestModelRegistryIntegration(TestCase): + + def setUp(self): + self.time_independent_model = TimeIndependentModel( + name="TestTIModel", + model_path=os.path.abspath( + os.path.join(os.path.dirname(__file__), "../artifacts/models/model.csv") + ), + forecast_unit=1, + store_db=False, + ) + self.time_dependent_model = TimeDependentModel( + name="mock", + model_path=os.path.abspath( + os.path.join(os.path.dirname(__file__), "../artifacts/models/td_model") + ), + func="run_model", + ) + + def test_time_independent_model_stage(self): + timewindows = [ + [datetime(2023, 1, 1), datetime(2023, 1, 2)], + ] + self.time_independent_model.stage(timewindows=timewindows) + print("a", self.time_independent_model.registry.as_dict()) + self.assertIn("2023-01-01_2023-01-02", self.time_independent_model.registry.forecasts) + self.assertIn("2023-01-01_2023-01-02", self.time_independent_model.registry.inventory) + + def test_time_independent_model_get_forecast(self): + tstring = "2023-01-01_2023-01-02" + self.time_independent_model.repository.forecasts[tstring] = "forecast" + forecast = self.time_independent_model.get_forecast(tstring) + self.assertEqual(forecast, "forecast") + + def test_time_independent_model_get_forecast_real(self): + tstring = "2023-01-01_2023-01-02" + timewindows = [ + [datetime(2023, 1, 1), datetime(2023, 1, 2)], + ] + self.time_independent_model.stage(timewindows=timewindows) + forecast = self.time_independent_model.get_forecast(tstring) + self.assertIsInstance(forecast, GriddedForecast) + self.assertAlmostEqual(forecast.data[0, 0], 0.002739726027357392) # 1 / 365 days + + @patch("floatcsep.environments.VenvManager.create_environment") + @patch("floatcsep.environments.CondaManager.create_environment") + def test_time_dependent_model_stage(self, mock_venv, mock_conda): + mock_venv.return_value = None + mock_conda.return_value = None + timewindows = [ + [datetime(2020, 1, 1), datetime(2020, 1, 2)], + [datetime(2020, 1, 2), datetime(2020, 1, 3)], + ] + tstrings = ["2020-01-01_2020-01-02", "2020-01-02_2020-01-03"] + self.time_dependent_model.stage(timewindows=timewindows) + + self.assertIn(tstrings[0], self.time_dependent_model.registry.forecasts) + self.assertIn(tstrings[1], self.time_dependent_model.registry.forecasts) + self.assertTrue(self.time_dependent_model.registry.inventory[tstrings[0]]) + self.assertTrue(self.time_dependent_model.registry.inventory[tstrings[1]]) + + @patch("floatcsep.environments.VenvManager.create_environment") + @patch("floatcsep.environments.CondaManager.create_environment") + def test_time_dependent_model_get_forecast(self, mock_venv, mock_conda): + mock_venv.return_value = None + mock_conda.return_value = None + timewindows = [ + [datetime(2020, 1, 1), datetime(2020, 1, 2)], + [datetime(2020, 1, 2), datetime(2020, 1, 3)], + ] + self.time_dependent_model.stage(timewindows) + tstring = "2020-01-01_2020-01-02" + forecast = self.time_dependent_model.get_forecast(tstring) + self.assertIsNotNone(forecast) + self.assertAlmostEqual(list(forecast.catalogs)[1].get_longitudes()[0], 1) + + +class TestModelRepositoryIntegration(TestCase): + + @classmethod + def setUpClass(cls) -> None: + path = os.path.dirname(__file__) + cls._path = path + cls._dir = os.path.normpath(os.path.join(path, "../artifacts", "models")) + + @staticmethod + def init_model(name, model_path, **kwargs): + """Instantiates a model without using the @register deco, + but mocks Model.Registry() attrs""" + + model = TimeIndependentModel(name=name, model_path=model_path, **kwargs) + + return model + + def test_get_forecast_from_repository(self): + """reads from file, scale in runtime""" + _rates = numpy.array([[1.0, 0.1], [1.0, 0.1]]) + _mags = numpy.array([5.0, 5.1]) + origins = numpy.array([[0.0, 0.0], [0.0, 1.0]]) + _region = csep.core.regions.CartesianGrid2D.from_origins(origins) + + def forecast_(_): + return _rates, _region, _mags + + start = datetime(1900, 1, 1) + end = datetime(2000, 1, 1) + timestring = timewindow2str([start, end]) + + name = "mock" + fname = os.path.join(self._dir, "model.csv") + + with patch("floatcsep.readers.ForecastParsers.csv", forecast_): + model = self.init_model(name, fname) + model.registry.build_tree([[start, end]]) + forecast = model.get_forecast(timestring) + numpy.testing.assert_almost_equal(220.0, forecast.data.sum()) class TestModelFromFile(TestCase): @@ -22,70 +150,55 @@ def setUpClass(cls) -> None: @staticmethod def init_model(name, path, **kwargs): - model = TimeIndependentModel(name, path, **kwargs) - return model + def run_forecast_test(self, name, fname, start, end, expected_sum, use_db=False): + model = self.init_model(name=name, path=fname, use_db=use_db) + model.stage([[start, end]]) + model.get_forecast(timewindow2str([start, end])) + numpy.testing.assert_almost_equal( + expected_sum, + model.repository.forecasts[ + f"{start.strftime('%Y-%m-%d')}_{end.strftime('%Y-%m-%d')}" + ].data.sum(), + ) + def test_forecast_ti_from_csv(self): """Parses forecast from csv file""" name = "mock" fname = os.path.join(self._dir, "model.csv") - model = self.init_model(name, fname) start = datetime(1900, 1, 1) end = datetime(2000, 1, 1) - model.stage([[start, end]]) - model.forecast_from_file(start, end) - numpy.testing.assert_almost_equal( - 440.0, model.forecasts["1900-01-01_2000-01-01"].data.sum() - ) + expected_sum = 440.0 + self.run_forecast_test(name, fname, start, end, expected_sum) def test_forecast_ti_from_xml(self): """Parses forecast from XML file""" - name = "ALM" fname = self._alm_fn - numpy.seterr(all="ignore") - model = self.init_model(name, fname) start = datetime(1900, 1, 1) end = datetime(2000, 1, 1) - model.stage([[start, end]]) - model.forecast_from_file(start, end) - - numpy.testing.assert_almost_equal( - 1618.5424321406535, model.forecasts["1900-01-01_2000-01-01"].data.sum() - ) + expected_sum = 1618.5424321406535 + self.run_forecast_test(name, fname, start, end, expected_sum) def test_forecast_ti_from_xml2hdf5(self): """reads from xml, drops to db, makes forecast from db""" name = "ALM" fname = self._alm_fn - numpy.seterr(all="ignore") - - model = self.init_model(name=name, path=fname, use_db=True) start = datetime(1900, 1, 1) end = datetime(2000, 1, 1) - model.stage([[start, end]]) - model.forecast_from_file(start, end) - - numpy.testing.assert_almost_equal( - 1618.5424321406535, model.forecasts["1900-01-01_2000-01-01"].data.sum() - ) + expected_sum = 1618.5424321406535 + self.run_forecast_test(name, fname, start, end, expected_sum, use_db=True) def test_forecast_ti_from_hdf5(self): """reads from hdf5, scale in runtime""" name = "mock" fname = os.path.join(self._dir, "model_h5.hdf5") - model = self.init_model(name=name, path=fname, use_db=True) - model.stage() - start = datetime(2020, 1, 1) end = datetime(2023, 1, 1) - model.stage([[start, end]]) - model.forecast_from_file(start, end) - numpy.testing.assert_almost_equal( - 13.2, model.forecasts["2020-01-01_2023-01-01"].data.sum() - ) + expected_sum = 13.2 + self.run_forecast_test(name, fname, start, end, expected_sum, use_db=True) @classmethod def tearDownClass(cls) -> None: @@ -98,3 +211,126 @@ def tearDownClass(cls) -> None: ) if os.path.isfile(alm_db): os.remove(alm_db) + + +class TestModelFromGit(TestCase): + + @classmethod + def setUpClass(cls) -> None: + path = os.path.dirname(__file__) + cls._path = path + cls._dir = os.path.normpath(os.path.join(path, "../artifacts", "models")) + + @staticmethod + def init_model(name, model_path, **kwargs): + """Instantiates a model without using the @register deco, + but mocks Model.Registry() attrs""" + + model = TimeDependentModel(name=name, model_path=model_path, **kwargs) + + return model + + @patch.object(EnvironmentManager, "create_environment") + @patch("floatcsep.registry.ForecastRegistry.build_tree") + def test_from_git(self, mock_build_tree, mock_create_environment): + """clones model from git, checks with test artifacts""" + mock_build_tree.return_value = None + mock_create_environment.return_value = None + name = "mock_git" + _dir = "git_template" + path_ = os.path.join(tempfile.tempdir, _dir) + if os.path.exists(path_): + shutil.rmtree(path_) + giturl = ( + "https://git.gfz-potsdam.de/csep-group/" "rise_italy_experiment/models/template.git" + ) + model_a = self.init_model(name=name, model_path=path_, giturl=giturl) + model_a.stage() + + path = os.path.join(self._dir, "template") + model_b = self.init_model(name=name, model_path=path) + model_b.stage() + self.assertEqual(model_a.name, model_b.name) + dircmp = filecmp.dircmp(model_a.registry.dir, model_b.registry.dir).common + self.assertGreater(len(dircmp), 8) + shutil.rmtree(path_) + + def test_fail_git(self): + name = "mock_git" + filename_ = "attr.c" + dir_ = os.path.join(tempfile.tempdir, "git_notreal") + if os.path.isdir(dir_): + shutil.rmtree(dir_) + path_ = os.path.join(dir_, filename_) + + # Initialize from git url + model = self.init_model( + name=name, model_path=path_, giturl="https://github.com/github/testrepo" + ) + + with self.assertRaises(FileNotFoundError): + model.get_source(model.zenodo_id, model.giturl, branch="master") + + +class TestModelFromZenodo(TestCase): + + @classmethod + def setUpClass(cls) -> None: + path = os.path.dirname(__file__) + cls._path = path + cls._dir = os.path.normpath(os.path.join(path, "../artifacts", "models")) + + @staticmethod + def init_model(name, model_path, **kwargs): + """Instantiates a model without using the @register deco, + but mocks Model.Registry() attrs""" + + model = TimeIndependentModel(name=name, model_path=model_path, **kwargs) + return model + + @patch("floatcsep.registry.ForecastRegistry.build_tree") + def test_zenodo(self, mock_buildtree): + """downloads model from zenodo, checks with test artifacts""" + mock_buildtree.return_value = None + + name = "mock_zenodo" + filename_ = "dummy.txt" + dir_ = os.path.join(tempfile.tempdir, "mock") + + if os.path.isdir(dir_): + shutil.rmtree(dir_) + path_ = os.path.join(dir_, filename_) + + zenodo_id = 13117711 + # Initialize from zenodo id + model_a = self.init_model(name=name, model_path=path_, zenodo_id=zenodo_id) + model_a.stage() + + # Initialize from the artifact files (same as downloaded) + dir_art = os.path.join(self._path, "../artifacts", "models", "zenodo_test") + path = os.path.join(dir_art, filename_) + model_b = self.init_model(name=name, model_path=path, zenodo_id=zenodo_id) + model_b.stage() + + self.assertEqual( + os.path.basename(model_a.registry.get_path("path")), + os.path.basename(model_b.registry.get_path("path")), + ) + self.assertEqual(model_a.name, model_b.name) + self.assertTrue( + filecmp.cmp(model_a.registry.get_path("path"), model_b.registry.get_path("path")) + ) + + def test_zenodo_fail(self): + name = "mock_zenodo" + filename_ = "model_notreal.csv" # File not found in repository + dir_ = os.path.join(tempfile.tempdir, "zenodo_notreal") + if os.path.isdir(dir_): + shutil.rmtree(dir_) + path_ = os.path.join(dir_, filename_) + + # Initialize from zenodo id + model = self.init_model(name=name, model_path=path_, zenodo_id=13117711) + + with self.assertRaises(FileNotFoundError): + model.get_source(model.zenodo_id, model.giturl) diff --git a/tests/qa/test_data.py b/tests/qa/test_data.py index 96dd1d2..0e984ab 100644 --- a/tests/qa/test_data.py +++ b/tests/qa/test_data.py @@ -10,17 +10,15 @@ class DataTest(unittest.TestCase): @staticmethod def get_runpath(case): return os.path.abspath( - os.path.join(__file__, '../../..', - 'examples', - f'case_{case}', - f'config.yml') + os.path.join(__file__, "../../..", "examples", f"case_{case}", f"config.yml") ) @staticmethod def get_rerunpath(case): return os.path.abspath( - os.path.join(__file__, '../../..', 'examples', f'case_{case}', - 'results', f'repr_config.yml') + os.path.join( + __file__, "../../..", "examples", f"case_{case}", "results", f"repr_config.yml" + ) ) @staticmethod @@ -41,37 +39,37 @@ def get_eval_dist(self): class RunExamples(DataTest): def test_case_a(self, *args): - cfg = self.get_runpath('a') + cfg = self.get_runpath("a") self.run_evaluation(cfg) self.assertEqual(1, 1) def test_case_b(self, *args): - cfg = self.get_runpath('b') + cfg = self.get_runpath("b") self.run_evaluation(cfg) self.assertEqual(1, 1) def test_case_c(self, *args): - cfg = self.get_runpath('c') + cfg = self.get_runpath("c") self.run_evaluation(cfg) self.assertEqual(1, 1) def test_case_d(self, *args): - cfg = self.get_runpath('d') + cfg = self.get_runpath("d") self.run_evaluation(cfg) self.assertEqual(1, 1) def test_case_e(self, *args): - cfg = self.get_runpath('e') + cfg = self.get_runpath("e") self.run_evaluation(cfg) self.assertEqual(1, 1) def test_case_f(self, *args): - cfg = self.get_runpath('f') + cfg = self.get_runpath("f") self.run_evaluation(cfg) self.assertEqual(1, 1) def test_case_g(self, *args): - cfg = self.get_runpath('g') + cfg = self.get_runpath("g") self.run_evaluation(cfg) self.assertEqual(1, 1) @@ -82,11 +80,11 @@ def test_case_g(self, *args): class ReproduceExamples(DataTest): def test_case_c(self, *args): - cfg = self.get_rerunpath('c') + cfg = self.get_rerunpath("c") self.repr_evaluation(cfg) self.assertEqual(1, 1) def test_case_f(self, *args): - cfg = self.get_rerunpath('f') + cfg = self.get_rerunpath("f") self.repr_evaluation(cfg) self.assertEqual(1, 1) diff --git a/tests/unit/test_accessors.py b/tests/unit/test_accessors.py index 260009b..48823bc 100644 --- a/tests/unit/test_accessors.py +++ b/tests/unit/test_accessors.py @@ -1,79 +1,18 @@ import os.path import vcr from datetime import datetime -from floatcsep.accessors import query_gcmt, _query_gcmt, from_zenodo, from_git, _check_hash +from floatcsep.accessors import from_zenodo, from_git, _check_hash import unittest from unittest import mock root_dir = os.path.dirname(os.path.abspath(__file__)) -def gcmt_dir(): - data_dir = os.path.join(root_dir, "../artifacts", "gcmt") - return data_dir - - def zenodo_dir(): data_dir = os.path.join(root_dir, "../artifacts", "zenodo") return data_dir -class TestCatalogGetter(unittest.TestCase): - - @classmethod - def setUpClass(cls) -> None: - os.makedirs(gcmt_dir(), exist_ok=True) - cls._fname = os.path.join(gcmt_dir(), "test_cat") - - def test_gcmt_search(self): - tape_file = os.path.join(gcmt_dir(), "vcr_search.yaml") - with vcr.use_cassette(tape_file): - # Maule, Chile - eventlist = _query_gcmt( - start_time=datetime(2010, 2, 26), end_time=datetime(2010, 3, 2), min_magnitude=6 - ) - event = eventlist[0] - assert event[0] == "2844986" - - def test_gcmt_summary(self): - tape_file = os.path.join(gcmt_dir(), "vcr_summary.yaml") - with vcr.use_cassette(tape_file): - eventlist = _query_gcmt( - start_time=datetime(2010, 2, 26), end_time=datetime(2010, 3, 2), min_magnitude=7 - ) - event = eventlist[0] - cmp = "('2844986', 1267252514000, -35.98, -73.15, 23.2, 8.8)" - assert str(event) == cmp - assert event[0] == "2844986" - assert datetime.fromtimestamp(event[1] / 1000.0) == datetime.fromtimestamp( - 1267252514 - ) - assert event[2] == -35.98 - assert event[3] == -73.15 - assert event[4] == 23.2 - assert event[5] == 8.8 - - def test_catalog_query_plot(self): - start_datetime = datetime(2020, 1, 1) - end_datetime = datetime(2020, 3, 1) - catalog = query_gcmt( - start_time=start_datetime, end_time=end_datetime, min_magnitude=5.95 - ) - catalog.plot( - set_global=True, plot_args={"filename": self._fname, "basemap": "stock_img"} - ) - assert os.path.isfile(self._fname + ".png") - assert os.path.isfile(self._fname + ".pdf") - - @classmethod - def tearDownClass(cls) -> None: - try: - os.remove(os.path.join(gcmt_dir(), cls._fname + ".pdf")) - os.remove(os.path.join(gcmt_dir(), cls._fname + ".png")) - except OSError: - pass - - class TestZenodoGetter(unittest.TestCase): @classmethod diff --git a/tests/unit/test_environments.py b/tests/unit/test_environments.py index cef9511..dfe90f6 100644 --- a/tests/unit/test_environments.py +++ b/tests/unit/test_environments.py @@ -7,10 +7,10 @@ import hashlib import logging from floatcsep.environments import ( - CondaEnvironmentManager, + CondaManager, EnvironmentFactory, - VenvEnvironmentManager, - DockerEnvironmentManager, + VenvManager, + DockerManager, ) @@ -22,9 +22,7 @@ def setUpClass(cls): raise unittest.SkipTest("Conda is not available in the environment.") def setUp(self): - self.manager = CondaEnvironmentManager( - base_name="test_env", model_directory="/tmp/test_model" - ) + self.manager = CondaManager(base_name="test_env", model_directory="/tmp/test_model") os.makedirs("/tmp/test_model", exist_ok=True) with open("/tmp/test_model/environment.yml", "w") as f: f.write("name: test_env\ndependencies:\n - python=3.8\n - numpy") @@ -43,7 +41,7 @@ def tearDown(self): @patch("subprocess.run") @patch("shutil.which", return_value="conda") def test_generate_env_name(self, mock_which, mock_run): - manager = CondaEnvironmentManager("test_base", "/path/to/model") + manager = CondaManager("test_base", "/path/to/model") expected_name = "test_base_" + hashlib.md5("/path/to/model".encode()).hexdigest()[:8] print(expected_name) self.assertEqual(manager.generate_env_name(), expected_name) @@ -53,13 +51,13 @@ def test_env_exists(self, mock_run): hashed = hashlib.md5("/path/to/model".encode()).hexdigest()[:8] mock_run.return_value.stdout = f"test_base_{hashed}\n".encode() - manager = CondaEnvironmentManager("test_base", "/path/to/model") + manager = CondaManager("test_base", "/path/to/model") self.assertTrue(manager.env_exists()) @patch("subprocess.run") @patch("os.path.exists", return_value=True) def test_create_environment(self, mock_exists, mock_run): - manager = CondaEnvironmentManager("test_base", "/path/to/model") + manager = CondaManager("test_base", "/path/to/model") manager.create_environment(force=False) package_manager = manager.detect_package_manager() expected_calls = [ @@ -97,15 +95,15 @@ def test_create_environment(self, mock_exists, mock_run): @patch("subprocess.run") def test_create_environment_force(self, mock_run): - manager = CondaEnvironmentManager("test_base", "/path/to/model") + manager = CondaManager("test_base", "/path/to/model") manager.env_exists = MagicMock(return_value=True) manager.create_environment(force=True) self.assertEqual(mock_run.call_count, 2) # One for remove, one for create @patch("subprocess.run") - @patch.object(CondaEnvironmentManager, "detect_package_manager", return_value="conda") + @patch.object(CondaManager, "detect_package_manager", return_value="conda") def test_install_dependencies(self, mock_detect_package_manager, mock_run): - manager = CondaEnvironmentManager("test_base", "/path/to/model") + manager = CondaManager("test_base", "/path/to/model") manager.install_dependencies() mock_run.assert_called_once_with( [ @@ -126,10 +124,11 @@ def test_install_dependencies(self, mock_detect_package_manager, mock_run): @patch( "builtins.open", new_callable=mock_open, - read_data="[metadata]\nname = test\n\n[options]\ninstall_requires =\n numpy\npython_requires = >=3.9,<3.12\n", + read_data="[metadata]\nname = test\n\n[options]\ninstall_requires =\n " + "numpy\npython_requires = >=3.9,<3.12\n", ) def test_detect_python_version_setup_cfg(self, mock_open, mock_exists, mock_which): - manager = CondaEnvironmentManager("test_base", "../artifacts/models/td_model") + manager = CondaManager("test_base", "../artifacts/models/td_model") python_version = manager.detect_python_version() # Extract major and minor version parts @@ -180,7 +179,7 @@ def test_get_env_conda(self, mock_check_env, mock_abspath): env_manager = EnvironmentFactory.get_env( build="conda", model_name="test_model", model_path="/path/to/model" ) - self.assertIsInstance(env_manager, CondaEnvironmentManager) + self.assertIsInstance(env_manager, CondaManager) self.assertEqual(env_manager.base_name, "test_model") self.assertEqual(env_manager.model_directory, "/absolute/path/to/model") @@ -190,7 +189,19 @@ def test_get_env_venv(self, mock_check_env, mock_abspath): env_manager = EnvironmentFactory.get_env( build="venv", model_name="test_model", model_path="/path/to/model" ) - self.assertIsInstance(env_manager, VenvEnvironmentManager) + self.assertIsInstance(env_manager, VenvManager) + self.assertEqual(env_manager.base_name, "test_model") + self.assertEqual(env_manager.model_directory, "/absolute/path/to/model") + + @patch("os.path.abspath", return_value="/absolute/path/to/model") + @patch.object(EnvironmentFactory, "check_environment_type", return_value="micromamba") + def test_get_env_micromamba(self, mock_check_env, mock_abspath): + env_manager = EnvironmentFactory.get_env( + build="micromamba", model_name="test_model", model_path="/path/to/model" + ) + self.assertIsInstance( + env_manager, CondaManager + ) # Assuming Micromamba uses CondaManager self.assertEqual(env_manager.base_name, "test_model") self.assertEqual(env_manager.model_directory, "/absolute/path/to/model") @@ -200,7 +211,7 @@ def test_get_env_docker(self, mock_check_env, mock_abspath): env_manager = EnvironmentFactory.get_env( build="docker", model_name="test_model", model_path="/path/to/model" ) - self.assertIsInstance(env_manager, DockerEnvironmentManager) + self.assertIsInstance(env_manager, DockerManager) self.assertEqual(env_manager.base_name, "test_model") self.assertEqual(env_manager.model_directory, "/absolute/path/to/model") @@ -210,7 +221,7 @@ def test_get_env_default_conda(self, mock_check_env, mock_abspath): env_manager = EnvironmentFactory.get_env( build=None, model_name="test_model", model_path="/path/to/model" ) - self.assertIsInstance(env_manager, CondaEnvironmentManager) + self.assertIsInstance(env_manager, CondaManager) self.assertEqual(env_manager.base_name, "test_model") self.assertEqual(env_manager.model_directory, "/absolute/path/to/model") @@ -220,7 +231,7 @@ def test_get_env_default_venv(self, mock_check_env, mock_abspath): env_manager = EnvironmentFactory.get_env( build=None, model_name="test_model", model_path="/path/to/model" ) - self.assertIsInstance(env_manager, VenvEnvironmentManager) + self.assertIsInstance(env_manager, VenvManager) self.assertEqual(env_manager.base_name, "test_model") self.assertEqual(env_manager.model_directory, "/absolute/path/to/model") @@ -256,9 +267,7 @@ def setUpClass(cls): def setUp(self): self.model_directory = "/tmp/test_model" - self.manager = VenvEnvironmentManager( - base_name="test_env", model_directory=self.model_directory - ) + self.manager = VenvManager(base_name="test_env", model_directory=self.model_directory) os.makedirs(self.model_directory, exist_ok=True) with open(os.path.join(self.model_directory, "setup.py"), "w") as f: f.write("from setuptools import setup\nsetup(name='test_model', version='0.1')") diff --git a/tests/unit/test_experiment.py b/tests/unit/test_experiment.py index 2a1ebc0..f311dc9 100644 --- a/tests/unit/test_experiment.py +++ b/tests/unit/test_experiment.py @@ -2,11 +2,9 @@ import tempfile import numpy from unittest import TestCase -from unittest.mock import patch from datetime import datetime from floatcsep.experiment import Experiment from csep.core import poisson_evaluations -from csep.core.catalogs import CSEPCatalog _dir = os.path.dirname(__file__) _model_cfg = os.path.normpath(os.path.join(_dir, "../artifacts", "models", "model_cfg.yml")) @@ -31,8 +29,8 @@ def setUpClass(cls) -> None: def assertEqualExperiment(self, exp_a, exp_b): self.assertEqual(exp_a.name, exp_b.name) - self.assertEqual(exp_a.path, os.getcwd()) - self.assertEqual(exp_a.path, exp_b.path) + self.assertEqual(exp_a.path.workdir, os.getcwd()) + self.assertEqual(exp_a.path.workdir, exp_b.path.workdir) self.assertEqual(exp_a.start_date, exp_b.start_date) self.assertEqual(exp_a.timewindows, exp_b.timewindows) self.assertEqual(exp_a.exp_class, exp_b.exp_class) @@ -41,13 +39,6 @@ def assertEqualExperiment(self, exp_a, exp_b): numpy.testing.assert_equal(exp_a.depths, exp_b.depths) self.assertEqual(exp_a.catalog, exp_b.catalog) - @staticmethod - def init_no_wrap(name, path, **kwargs): - - model = Experiment.__new__(Experiment) - Experiment.__init__.__wrapped__(self=model, name=name, path=path, **kwargs) - return Experiment - def test_init(self): exp_a = Experiment(**_time_config, **_region_config, catalog=_cat) exp_b = Experiment(time_config=_time_config, region_config=_region_config, catalog=_cat) @@ -139,7 +130,7 @@ def test_stage_models(self): exp.stage_models() dbpath = os.path.relpath(os.path.join(_dir, "../artifacts", "models", "model.hdf5")) - self.assertEqual(exp.models[0].path.database, dbpath) + self.assertEqual(exp.models[0].registry.database, dbpath) def test_set_tests(self): test_cfg = os.path.normpath( diff --git a/tests/unit/test_model.py b/tests/unit/test_model.py index f63a2d9..5f8f41f 100644 --- a/tests/unit/test_model.py +++ b/tests/unit/test_model.py @@ -1,17 +1,12 @@ -import filecmp import os.path -import shutil -import tempfile -from datetime import datetime from unittest import TestCase -from unittest.mock import patch, MagicMock, mock_open - -import csep.core.regions -import numpy.testing -from floatcsep.environments import EnvironmentManager -from floatcsep.model import TimeIndependentModel, TimeDependentModel -from floatcsep.utils import str2timewindow +from floatcsep.model import TimeIndependentModel +from floatcsep.registry import ForecastRegistry +from floatcsep.repository import GriddedForecastRepository +from unittest.mock import patch, MagicMock, mock_open +from floatcsep.model import TimeDependentModel +from datetime import datetime class TestModel(TestCase): @@ -23,27 +18,30 @@ def setUpClass(cls) -> None: cls._dir = os.path.normpath(os.path.join(path, "../artifacts", "models")) @staticmethod - def assertEqualModel(exp_a, exp_b): + def assertEqualModel(model_a, model_b): - keys_a = list(exp_a.__dict__.keys()) - keys_b = list(exp_a.__dict__.keys()) + keys_a = list(model_a.__dict__.keys()) + keys_b = list(model_a.__dict__.keys()) if keys_a != keys_b: raise AssertionError("Models are not equal") for i in keys_a: - if not (getattr(exp_a, i) == getattr(exp_b, i)): + if isinstance(getattr(model_a, i), ForecastRegistry): + continue + if not (getattr(model_a, i) == getattr(model_b, i)): + print(getattr(model_a, i), getattr(model_b, i)) raise AssertionError("Models are not equal") class TestTimeIndependentModel(TestModel): @staticmethod - def init_model(name, path, **kwargs): + def init_model(name, model_path, **kwargs): """Instantiates a model without using the @register deco, but mocks Model.Registry() attrs""" - model = TimeIndependentModel(name, path, **kwargs) + model = TimeIndependentModel(name=name, model_path=model_path, **kwargs) return model @@ -53,54 +51,22 @@ def test_from_filesystem(self): fname = os.path.join(self._dir, "model.csv") # Initialize without Registry - model = self.init_model(name=name, path=fname) + model = self.init_model(name=name, model_path=fname) self.assertEqual(name, model.name) - self.assertEqual(fname, model.path) + self.assertEqual(fname, model.registry.path) self.assertEqual(1, model.forecast_unit) - def test_zenodo(self): - """downloads model from zenodo, checks with test artifacts""" - - name = "mock_zenodo" - filename_ = "dummy.txt" - dir_ = os.path.join(tempfile.tempdir, "mock") - - if os.path.isdir(dir_): - shutil.rmtree(dir_) - path_ = os.path.join(dir_, filename_) - - zenodo_id = 13117711 - # Initialize from zenodo id - model_a = self.init_model(name=name, path=path_, zenodo_id=zenodo_id) - model_a.stage() - - # Initialize from the artifact files (same as downloaded) - dir_art = os.path.join(self._path, "../artifacts", "models", "zenodo_test") - path = os.path.join(dir_art, filename_) - model_b = self.init_model(name=name, path=path, zenodo_id=zenodo_id) - model_b.stage() - - self.assertEqual( - os.path.basename(model_a.path("path")), - os.path.basename(model_b.path("path")), - ) - self.assertEqual(model_a.name, model_b.name) - self.assertTrue(filecmp.cmp(model_a.path("path"), model_b.path("path"))) - - def test_zenodo_fail(self): - name = "mock_zenodo" - filename_ = "model_notreal.csv" # File not found in repository - dir_ = os.path.join(tempfile.tempdir, "zenodo_notreal") - if os.path.isdir(dir_): - shutil.rmtree(dir_) - path_ = os.path.join(dir_, filename_) - - # Initialize from zenodo id - model = self.init_model(name=name, path=path_, zenodo_id=13117711) - - with self.assertRaises(FileNotFoundError): - model.get_source(model.zenodo_id, model.giturl) + @patch("os.makedirs") + @patch("floatcsep.model.TimeIndependentModel.get_source") + @patch("floatcsep.registry.ForecastRegistry.build_tree") + def test_stage_creates_directory(self, mock_build_tree, mock_get_source, mock_makedirs): + """Test stage method creates directory.""" + model = self.init_model("mock", "mockfile.csv") + model.force_stage = True # Simulate forcing the staging process + model.stage() + mock_makedirs.assert_called_once() + mock_get_source.assert_called_once() def test_from_dict(self): """test that '__init__' and 'from_dict' instantiates @@ -128,10 +94,11 @@ def test_from_dict(self): model_b = TimeIndependentModel.from_dict(py_dict) self.assertEqual(name, model_a.name) - self.assertEqual(fname, model_a.path.path) - self.assertEqual("csv", model_a.path.fmt) - self.assertEqual(self._dir, model_a.dir) + self.assertEqual(fname, model_a.registry.path) + self.assertEqual("csv", model_a.registry.fmt) + self.assertEqual(self._dir, model_a.registry.dir) + # print(model_a.__dict__, model_b.__dict__) self.assertEqualModel(model_a, model_b) with self.assertRaises(IndexError): @@ -141,61 +108,27 @@ def test_from_dict(self): {"model_1": {"name": "quack"}, "model_2": {"name": "moo"}} ) - @patch("floatcsep.model.TimeIndependentModel.forecast_from_file") - def test_create_forecast(self, mock_file): - - model = self.init_model("mock", "mockfile.csv") - model.create_forecast("2020-01-01_2021-01-01") - self.assertTrue(mock_file.called) - - def test_forecast_from_file(self): - """reads from file, scale in runtime""" - _rates = numpy.array([[1.0, 0.1], [1.0, 0.1]]) - _mags = numpy.array([5.0, 5.1]) - origins = numpy.array([[0.0, 0.0], [0.0, 1.0]]) - _region = csep.core.regions.CartesianGrid2D.from_origins(origins) - - def forecast_(_): - return _rates, _region, _mags - - name = "mock" - fname = os.path.join(self._dir, "model.csv") - - with patch("floatcsep.readers.ForecastParsers.csv", forecast_): - model = self.init_model(name, fname) - start = datetime(1900, 1, 1) - end = datetime(2000, 1, 1) - model.path.build_tree([[start, end]]) - model.forecast_from_file(start, end) - numpy.testing.assert_almost_equal( - 220.0, model.forecasts["1900-01-01_2000-01-01"].data.sum() - ) - - def test_get_forecast(self): + @patch.object(GriddedForecastRepository, "load_forecast") + def test_get_forecast(self, repo_mock): + repo_mock.return_value = 1 model = self.init_model("mock", "mockfile.csv") - model.forecasts = {"a": 1, "moo": 1, "cuack": 1} - - self.assertEqual(1, model.get_forecast("a")) - self.assertEqual([1, 1], model.get_forecast(["moo", "cuack"])) - with self.assertRaises(KeyError): - model.get_forecast(["woof"]) - with self.assertRaises(ValueError): - model.get_forecast("meaow") + self.assertEqual(1, model.get_forecast("1900-01-01_2000-01-01")) + repo_mock.assert_called_once_with( + "1900-01-01_2000-01-01", name="mock", region=None, forecast_unit=1 + ) def test_todict(self): fname = os.path.join(self._dir, "model.csv") dict_ = { - "path": fname, "forecast_unit": 5, "authors": ["Darwin, C.", "Bell, J.", "Et, Al."], "doi": "10.1010/10101010", "giturl": "should not be accessed, bc filesystem exists", "zenodo_id": "should not be accessed, bc filesystem exists", - "model_class": "ti", } - model = self.init_model(name="mock", **dict_) + model = self.init_model(name="mock", model_path=fname, **dict_) model_dict = model.as_dict() eq = True @@ -207,187 +140,171 @@ def test_todict(self): eq = False excl = ["path", "giturl", "forecast_unit"] keys = list(model.as_dict(excluded=excl).keys()) + for i in excl: if i in keys and i != "path": # path always gets printed eq = False self.assertTrue(eq) - def test_init_db(self): - pass - - def test_rm_db(self): - pass + @patch("os.path.isfile", return_value=False) + @patch("floatcsep.model.HDF5Serializer.grid2hdf5") + def test_init_db(self, mock_grid2hdf5, mock_isfile): + """Test init_db method creates database.""" + filepath = os.path.join(self._dir, "model.csv") + model = self.init_model("mock", filepath) + model.init_db(force=True) class TestTimeDependentModel(TestModel): - @staticmethod - def init_model(name, path, **kwargs): - """Instantiates a model without using the @register deco, - but mocks Model.Registry() attrs""" + def setUp(self): + # Patches + self.patcher_registry = patch("floatcsep.model.ForecastRegistry") + self.patcher_repository = patch("floatcsep.model.ForecastRepository.factory") + self.patcher_environment = patch("floatcsep.model.EnvironmentFactory.get_env") + self.patcher_get_source = patch( + "floatcsep.model.Model.get_source" + ) # Patch the get_source method on Model + + # Start patches + self.mock_registry = self.patcher_registry.start() + self.mock_repository_factory = self.patcher_repository.start() + self.mock_environment = self.patcher_environment.start() + self.mock_get_source = self.patcher_get_source.start() + + # Mock instances + self.mock_registry_instance = MagicMock() + self.mock_registry.return_value = self.mock_registry_instance + + self.mock_repository_instance = MagicMock() + self.mock_repository_factory.return_value = self.mock_repository_instance + + self.mock_environment_instance = MagicMock() + self.mock_environment.return_value = self.mock_environment_instance + + # Set attributes on the mock objects + self.mock_registry_instance.workdir = "/path/to/workdir" + self.mock_registry_instance.path = "/path/to/model" + self.mock_registry_instance.get_path.return_value = ( + "/path/to/args_file.txt" # Mocking the return of the registry call + ) - model = TimeDependentModel(name, path, **kwargs) + # Test data + self.name = "TestModel" + self.model_path = "/path/to/model" + self.func = "run_forecast" - return model + # Instantiate the model + self.model = TimeDependentModel( + name=self.name, model_path=self.model_path, func=self.func + ) - @patch.object(EnvironmentManager, "create_environment") - def test_from_git(self, mock_create_environment): - """clones model from git, checks with test artifacts""" - name = "mock_git" - _dir = "git_template" - path_ = os.path.join(tempfile.tempdir, _dir) - giturl = ( - "https://git.gfz-potsdam.de/csep-group/" "rise_italy_experiment/models/template.git" + def tearDown(self): + patch.stopall() + + def test_init(self): + # Assertions to check if the components were instantiated correctly + self.mock_registry.assert_called_once_with( + os.getcwd(), self.model_path + ) # Ensure the registry is initialized correctly + self.mock_repository_factory.assert_called_once_with( + self.mock_registry_instance, model_class="TimeDependentModel" ) - model_a = self.init_model(name=name, path=path_, giturl=giturl) - model_a.stage() - path = os.path.join(self._dir, "template") - model_b = self.init_model(name=name, path=path) - model_b.stage() - self.assertEqual(model_a.name, model_b.name) - dircmp = filecmp.dircmp(model_a.dir, model_b.dir).common - self.assertGreater(len(dircmp), 8) - shutil.rmtree(path_) - - def test_fail_git(self): - name = "mock_git" - filename_ = "attr.c" - dir_ = os.path.join(tempfile.tempdir, "git_notreal") - if os.path.isdir(dir_): - shutil.rmtree(dir_) - path_ = os.path.join(dir_, filename_) - - # Initialize from git url - model = self.init_model( - name=name, path=path_, giturl="https://github.com/github/testrepo" + self.mock_environment.assert_called_once_with( + None, self.name, self.mock_registry_instance.abs(self.model_path) ) - with self.assertRaises(FileNotFoundError): - model.get_source(model.zenodo_id, model.giturl, branch="master") - - @patch("floatcsep.model.TimeDependentModel.forecast_from_func") - def test_create_forecast(self, mock_func): + self.assertEqual(self.model.name, self.name) + self.assertEqual(self.model.func, self.func) + self.assertEqual(self.model.registry, self.mock_registry_instance) + self.assertEqual(self.model.repository, self.mock_repository_instance) + self.assertEqual(self.model.environment, self.mock_environment_instance) - model = self.init_model("mock", "mockbins", model_class="td") + def test_stage(self): + self.model.force_stage = True # Force staging to occur - model.path.build_tree([str2timewindow("2020-01-01_2021-01-01")]) - model.create_forecast("2020-01-01_2021-01-01") - self.assertTrue(mock_func.called) + self.model.stage(timewindows=["2020-01-01_2020-12-31"]) - @patch("csep.load_catalog_forecast") # Mocking the load_catalog_forecast function - def test_get_forecast_single(self, mock_load_forecast): - # Arrange - model_path = os.path.join(self._dir, "td_model") - model = self.init_model(name="mock", path=model_path) - tstring = "2020-01-01_2020-01-02" # Example time window string - model.stage([str2timewindow(tstring)]) + self.mock_get_source.assert_called_once_with( + self.model.zenodo_id, self.model.giturl, branch=self.model.repo_hash + ) + self.mock_registry_instance.build_tree.assert_called_once_with( + timewindows=["2020-01-01_2020-12-31"], + model_class="TimeDependentModel", + prefix=self.model.__dict__.get("prefix", self.name), + args_file=self.model.__dict__.get("args_file", None), + input_cat=self.model.__dict__.get("input_cat", None), + ) + self.mock_environment_instance.create_environment.assert_called_once() - region = "TestRegion" + def test_get_forecast(self): + tstring = "2020-01-01_2020-12-31" + self.model.get_forecast(tstring) - # Mock the return value of load_catalog_forecast - mock_forecast = MagicMock() - mock_load_forecast.return_value = mock_forecast + self.mock_repository_instance.load_forecast.assert_called_once_with( + tstring, region=None + ) - # Act - result = model.get_forecast(tstring=tstring, region=region) + @patch("floatcsep.model.TimeDependentModel.prepare_args") + def test_create_forecast(self, prep_args_mock): + tstring = "2020-01-01_2020-12-31" + prep_args_mock.return_value = None + self.model.registry.forecast_exists.return_value = False + self.model.create_forecast(tstring, force=True) - # Assert - mock_load_forecast.assert_called_once_with( - model.path("forecasts", tstring), - region=region, - apply_filters=True, - filter_spatial=True, + self.mock_environment_instance.run_command.assert_called_once_with( + f'{self.func} {self.model.registry.get_path("args_file")}' ) - self.assertEqual(result, mock_forecast) - - @patch("csep.load_catalog_forecast") - def test_get_forecast_multiple(self, mock_load_forecast): - - model_path = os.path.join(self._dir, "td_model") - model = self.init_model(name="mock", path=model_path) - tstrings = [ - "2020-01-01_2020-01-02", - "2020-01-02_2020-01-03", - ] # Example list of time window strings - region = "TestRegion" - model.stage(str2timewindow(tstrings)) - # Mock the return values of load_catalog_forecast for each forecast - mock_forecast1 = MagicMock() - mock_forecast2 = MagicMock() - mock_load_forecast.side_effect = [mock_forecast1, mock_forecast2] - - # Act - result = model.get_forecast(tstring=tstrings, region=region) - - # Assert - self.assertEqual(len(result), 2) - mock_load_forecast.assert_any_call( - model.path("forecasts", tstrings[0]), - region=region, - apply_filters=True, - filter_spatial=True, - ) - mock_load_forecast.assert_any_call( - model.path("forecasts", tstrings[1]), - region=region, - apply_filters=True, - filter_spatial=True, - ) - self.assertEqual(result[0], mock_forecast1) - self.assertEqual(result[1], mock_forecast2) - - def test_argprep(self): - model_path = os.path.join(self._dir, "td_model") - with open(os.path.join(model_path, "input", "args.txt"), "w") as args: - args.write("start_date = foo\nend_date = bar") - - model = self.init_model("a", model_path, func="func", build='docker') - start = datetime(2000, 1, 1) - end = datetime(2000, 1, 2) - model.stage([[start, end]]) - model.prepare_args(start, end) - - with open(os.path.join(model_path, "input", "args.txt"), "r") as args: - self.assertEqual(args.readline(), f"start_date = {start.isoformat()}\n") - self.assertEqual(args.readline(), f"end_date = {end.isoformat()}\n") - model.prepare_args(start, end, n_sims=400) - with open(os.path.join(model_path, "input", "args.txt"), "r") as args: - self.assertEqual(args.readlines()[2], f"n_sims = 400\n") - - model.prepare_args(start, end, n_sims=200) - with open(os.path.join(model_path, "input", "args.txt"), "r") as args: - self.assertEqual(args.readlines()[2], f"n_sims = 200\n") - - @patch("floatcsep.model.open", new_callable=mock_open, read_data='{"key": "value"}') + + @patch("builtins.open", new_callable=mock_open) + @patch("json.load") @patch("json.dump") - def test_argprep_json(self, mock_json_dump, mock_file): - model = self.init_model(name="TestModel", path=os.path.join(self._dir, "td_model")) - model.path = MagicMock(return_value="path/to/model/args.json") - start = MagicMock() - end = MagicMock() - start.isoformat.return_value = "2023-01-01" - end.isoformat.return_value = "2023-01-31" - - kwargs = {"key1": "value1", "key2": "value2"} - - # Act - model.prepare_args(start, end, **kwargs) - - # Assert - # Check that the file was opened for reading - mock_file.assert_any_call("path/to/model/args.json", "r") - - # Check that the file was opened for writing - mock_file.assert_any_call("path/to/model/args.json", "w") - - # Check that the JSON data was updated correctly - expected_data = { - "key": "value", - "start_date": "2023-01-01", - "end_date": "2023-01-31", - "key1": "value1", - "key2": "value2", + def test_prepare_args(self, mock_json_dump, mock_json_load, mock_open_file): + start_date = datetime(2020, 1, 1) + end_date = datetime(2020, 12, 31) + + # Mock json.load to return a dictionary + mock_json_load.return_value = { + "start_date": "2020-01-01T00:00:00", + "end_date": "2020-12-31T00:00:00", + "custom_arg": "value", } - # Check that json.dump was called with the expected data - mock_json_dump.assert_called_with(expected_data, mock_file(), indent=2) + # Simulate reading a .txt file + mock_open_file().readlines.return_value = [ + "start_date = 2020-01-01T00:00:00\n", + "end_date = 2020-12-31T00:00:00\n", + "custom_arg = value\n", + ] + + # Call the method + args_file_path = self.model.registry.get_path("args_file") + self.model.prepare_args(start_date, end_date, custom_arg="value") + + mock_open_file.assert_any_call(args_file_path, "r") + mock_open_file.assert_any_call(args_file_path, "w") + handle = mock_open_file() + handle.writelines.assert_any_call( + [ + "start_date = 2020-01-01T00:00:00\n", + "end_date = 2020-12-31T00:00:00\n", + "custom_arg = value\n", + ] + ) + + json_file_path = "/path/to/args_file.json" + self.model.registry.get_path.return_value = json_file_path + self.model.prepare_args(start_date, end_date, custom_arg="value") + + mock_open_file.assert_any_call(json_file_path, "r") + mock_json_load.assert_called_once() + mock_open_file.assert_any_call(json_file_path, "w") + mock_json_dump.assert_called_once_with( + { + "start_date": "2020-01-01T00:00:00", + "end_date": "2020-12-31T00:00:00", + "custom_arg": "value", + }, + mock_open_file(), + indent=2, + ) diff --git a/tests/unit/test_registry.py b/tests/unit/test_registry.py new file mode 100644 index 0000000..0bb8e14 --- /dev/null +++ b/tests/unit/test_registry.py @@ -0,0 +1,96 @@ +import unittest +from datetime import datetime +from unittest.mock import patch, MagicMock +from floatcsep.registry import ForecastRegistry + + +class TestForecastRegistry(unittest.TestCase): + + def setUp(self): + self.registry_file = ForecastRegistry( + workdir="/test/workdir", path="/test/workdir/model.txt" + ) + self.registry_folder = ForecastRegistry( + workdir="/test/workdir", path="/test/workdir/model" + ) + + def test_call(self): + self.registry_file._parse_arg = MagicMock(return_value="path") + result = self.registry_file.get_path("path") + self.assertEqual(result, "/test/workdir/model.txt") + + @patch("os.path.isdir") + def test_dir(self, mock_isdir): + mock_isdir.return_value = False + self.assertEqual(self.registry_file.dir, "/test/workdir") + + mock_isdir.return_value = True + self.assertEqual(self.registry_folder.dir, "/test/workdir/model") + + def test_fmt(self): + self.registry_file.database = "test.db" + self.assertEqual(self.registry_file.fmt, "db") + self.registry_file.database = None + self.assertEqual(self.registry_file.fmt, "txt") + + def test_parse_arg(self): + self.assertEqual(self.registry_file._parse_arg("arg"), "arg") + self.assertRaises(Exception, self.registry_file._parse_arg, 123) + + def test_as_dict(self): + self.assertEqual( + self.registry_file.as_dict(), + { + "args_file": None, + "database": None, + "forecasts": {}, + "input_cat": None, + "inventory": {}, + "path": "/test/workdir/model.txt", + "workdir": "/test/workdir", + }, + ) + + def test_abs(self): + result = self.registry_file.abs("file.txt") + self.assertTrue(result.endswith("/test/workdir/file.txt")) + + def test_absdir(self): + result = self.registry_file.abs_dir("model.txt") + self.assertTrue(result.endswith("/test/workdir")) + + @patch("floatcsep.registry.exists") + def test_fileexists(self, mock_exists): + mock_exists.return_value = True + self.registry_file.get_path = MagicMock(return_value="/test/path/file.txt") + self.assertTrue(self.registry_file.file_exists("file.txt")) + + @patch("os.makedirs") + @patch("os.listdir") + def test_build_tree_time_independent(self, mock_listdir, mock_makedirs): + timewindows = [[datetime(2023, 1, 1), datetime(2023, 1, 2)]] + self.registry_file.build_tree( + timewindows=timewindows, model_class="TimeIndependentModel" + ) + self.assertIn("2023-01-01_2023-01-02", self.registry_file.forecasts) + self.assertIn("2023-01-01_2023-01-02", self.registry_file.inventory) + + @patch("os.makedirs") + @patch("os.listdir") + def test_build_tree_time_dependent(self, mock_listdir, mock_makedirs): + mock_listdir.return_value = ["forecast_1.csv"] + timewindows = [ + [datetime(2023, 1, 1), datetime(2023, 1, 2)], + [datetime(2023, 1, 2), datetime(2023, 1, 3)], + ] + self.registry_folder.build_tree( + timewindows=timewindows, model_class="TimeDependentModel", prefix="forecast" + ) + self.assertIn("2023-01-01_2023-01-02", self.registry_folder.forecasts) + self.assertTrue(self.registry_folder.inventory["2023-01-01_2023-01-02"]) + self.assertIn("2023-01-02_2023-01-03", self.registry_folder.forecasts) + self.assertTrue(self.registry_folder.inventory["2023-01-02_2023-01-03"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_repositories.py b/tests/unit/test_repositories.py new file mode 100644 index 0000000..3b5029d --- /dev/null +++ b/tests/unit/test_repositories.py @@ -0,0 +1,155 @@ +import datetime +import unittest +from unittest.mock import MagicMock, patch, PropertyMock +from floatcsep.registry import ForecastRegistry +from floatcsep.repository import CatalogForecastRepository, GriddedForecastRepository +from floatcsep.readers import ForecastParsers +from csep.core.forecasts import GriddedForecast + + +class TestCatalogForecastRepository(unittest.TestCase): + + def setUp(self): + self.registry = MagicMock(spec=ForecastRegistry) + self.registry.__call__ = MagicMock(return_value="a_duck") + + @patch("csep.load_catalog_forecast") + def test_initialization(self, mock_load_catalog_forecast): + repo = CatalogForecastRepository(self.registry, lazy_load=True) + self.assertTrue(repo.lazy_load) + + @patch("csep.load_catalog_forecast") + def test_load_forecast(self, mock_load_catalog_forecast): + repo = CatalogForecastRepository(self.registry) + mock_load_catalog_forecast.return_value = "forecatto" + forecast = repo.load_forecast("2023-01-01_2023-01-02") + self.assertEqual(forecast, "forecatto") + + # Test load_forecast with list + forecasts = repo.load_forecast(["2023-01-01_2023-01-01", "2023-01-02_2023-01-03"]) + self.assertEqual(forecasts, ["forecatto", "forecatto"]) + + @patch("csep.load_catalog_forecast") + def test_load_single_forecast(self, mock_load_catalog_forecast): + # Test _load_single_forecast + repo = CatalogForecastRepository(self.registry) + mock_load_catalog_forecast.return_value = "forecatto" + forecast = repo._load_single_forecast("2023-01-01_2023-01-01") + self.assertEqual(forecast, "forecatto") + + +class TestGriddedForecastRepository(unittest.TestCase): + + def setUp(self): + self.registry = MagicMock(spec=ForecastRegistry) + self.registry.fmt = "hdf5" + self.registry.__call__ = MagicMock(return_value="a_duck") + + def test_initialization(self): + repo = GriddedForecastRepository(self.registry, lazy_load=False) + self.assertFalse(repo.lazy_load) + + @patch.object(ForecastParsers, "hdf5") + def test_load_forecast(self, mock_parser): + # Mock parser return values + mock_parser.return_value = ("rates", "region", "mags") + + repo = GriddedForecastRepository(self.registry) + with patch.object( + repo, "_get_or_load_forecast", return_value="forecatto" + ) as mock_method: + forecast = repo.load_forecast("2023-01-01_2023-01-02") + self.assertEqual(forecast, "forecatto") + mock_method.assert_called_once_with("2023-01-01_2023-01-02", "", 1) + + # Test load_forecast with list + with patch.object( + repo, "_get_or_load_forecast", return_value="forecatto" + ) as mock_method: + forecasts = repo.load_forecast(["2023-01-01_2023-01-02", "2023-01-02_2023-01-03"]) + self.assertEqual(forecasts, ["forecatto", "forecatto"]) + self.assertEqual(mock_method.call_count, 2) + + @patch.object(ForecastParsers, "hdf5") + def test_get_or_load_forecast(self, mock_parser): + mock_parser.return_value = ("rates", "region", "mags") + repo = GriddedForecastRepository(self.registry, lazy_load=False) + with patch.object( + repo, "_load_single_forecast", return_value="forecatta" + ) as mock_method: + # Test when forecast is not in memory + forecast = repo._get_or_load_forecast("2023-01-01_2023-01-02", "test_name", 1) + self.assertEqual(forecast, "forecatta") + mock_method.assert_called_once_with("2023-01-01_2023-01-02", 1, "test_name") + self.assertIn("2023-01-01_2023-01-02", repo.forecasts) + + # Test when forecast is in memory + forecast = repo._get_or_load_forecast("2023-01-01_2023-01-02", "test_name", 1) + self.assertEqual(forecast, "forecatta") + mock_method.assert_called_once() # Should not be called again + + @patch.object(GriddedForecast, "__init__", return_value=None) + @patch.object(GriddedForecast, "event_count", new_callable=PropertyMock) + @patch.object(GriddedForecast, "scale") + @patch.object(ForecastParsers, "hdf5") + def test_load_single_forecast(self, mock_parser, mock_scale, mock_count, mock_init): + # Mock parser return values + mock_count.return_value = 2 + mock_parser.return_value = ("rates", "region", "mags") + mock_scale.return_value = mock_scale + + # Test _load_single_forecast + repo = GriddedForecastRepository(self.registry, lazy_load=False) + with patch("csep.utils.time_utils.decimal_year", side_effect=[2023.0, 2024.0]): + forecast = repo._load_single_forecast("2023-01-01_2024-01-01", 1, "axe") + self.assertIsInstance(forecast, GriddedForecast) + mock_init.assert_called_once_with( + name="axe", + data="rates", + region="region", + magnitudes="mags", + start_time=datetime.datetime(2023, 1, 1), + end_time=datetime.datetime(2024, 1, 1), + ) + + @patch.object(ForecastParsers, "hdf5") + def test_lazy_load_behavior(self, mock_parser): + mock_parser.return_value = ("rates", "region", "mags") + # Test lazy_load behavior + repo = GriddedForecastRepository(self.registry, lazy_load=False) + with patch.object( + repo, "_load_single_forecast", return_value="forecatto" + ) as mock_method: + # Load forecast and check if it is stored + forecast = repo.load_forecast("2023-01-01_2023-01-02") + self.assertEqual(forecast, "forecatto") + self.assertIn("2023-01-01_2023-01-02", repo.forecasts) + + # Change to lazy_load=True and check if forecast is not stored + repo.lazy_load = True + forecast = repo.load_forecast("2023-01-02_2023-01-03") + self.assertEqual(forecast, "forecatto") + self.assertNotIn("2023-01-02_2023-01-03", repo.forecasts) + + @patch("floatcsep.registry.ForecastRegistry") + def test_equal(self, MockForecastRegistry): + + self.registry = MockForecastRegistry() + + self.repo1 = CatalogForecastRepository(self.registry) + self.repo2 = CatalogForecastRepository(self.registry) + self.repo3 = CatalogForecastRepository(self.registry) + self.repo4 = CatalogForecastRepository(self.registry) + + self.repo1.forecasts = {"1": 1, "2": 2} + self.repo2.forecasts = {"1": 1, "2": 2} + self.repo3.forecasts = {"1": 2, "2": 2} + self.repo4.forecasts = {"3": 1, "2": 2} + + self.assertEqual(self.repo1, self.repo2) + self.assertNotEqual(self.repo1, self.repo3) + self.assertNotEqual(self.repo1, self.repo3) + + +if __name__ == "__main__": + unittest.main()
{item}