diff --git a/floatcsep/cmd/main.py b/floatcsep/cmd/main.py index 35bea66..c20e64b 100644 --- a/floatcsep/cmd/main.py +++ b/floatcsep/cmd/main.py @@ -58,7 +58,7 @@ def reproduce(config, **kwargs): reproduced_exp.run() original_config = reproduced_exp.original_config - original_exp = Experiment.from_yml(original_config, rundir=reproduced_exp.original_rundir) + original_exp = Experiment.from_yml(original_config, rundir=reproduced_exp.original_run_dir) original_exp.stage_models() original_exp.set_tasks() diff --git a/floatcsep/evaluation.py b/floatcsep/evaluation.py index 9e1d4da..912d40b 100644 --- a/floatcsep/evaluation.py +++ b/floatcsep/evaluation.py @@ -1,17 +1,14 @@ import datetime -import json import os from typing import Dict, Callable, Union, Sequence, List -import numpy from csep.core.catalogs import CSEPCatalog from csep.core.forecasts import GriddedForecast -from csep.models import EvaluationResult from matplotlib import pyplot from floatcsep.model import Model from floatcsep.registry import ExperimentRegistry -from floatcsep.utils import parse_csep_func, timewindow2str +from floatcsep.utils import parse_csep_func class Evaluation: @@ -76,6 +73,9 @@ def __init__( self.markdown = markdown self.type = Evaluation._TYPES.get(self.func.__name__) + self.results_repo = None + self.catalog_repo = None + @property def type(self): """ @@ -123,7 +123,6 @@ def parse_plots(self, plot_func, plot_args, plot_kwargs): def prepare_args( self, timewindow: Union[str, list], - catpath: Union[str, list], model: Union[Model, Sequence[Model]], ref_model: Union[Model, Sequence] = None, region=None, @@ -153,7 +152,7 @@ def prepare_args( # Prepare argument tuple forecast = model.get_forecast(timewindow, region) - catalog = self.get_catalog(catpath, forecast) + catalog = self.get_catalog(timewindow, forecast) if isinstance(ref_model, Model): # Args: (Fc, RFc, Cat) @@ -169,9 +168,9 @@ def prepare_args( return test_args - @staticmethod def get_catalog( - catalog_path: Union[str, Sequence[str]], + self, + timewindow: Union[str, Sequence[str]], forecast: Union[GriddedForecast, Sequence[GriddedForecast]], ) -> Union[CSEPCatalog, List[CSEPCatalog]]: """ @@ -179,19 +178,22 @@ def get_catalog( forecast region. Args: - catalog_path (str, list(str)): Path to the existing catalog + timewindow (str): Time window of the testing catalog forecast (:class:`~csep.core.forecasts.GriddedForecast`): Forecast object, onto which the catalog will be confronted for testing. Returns: """ - if isinstance(catalog_path, str): - eval_cat = CSEPCatalog.load_json(catalog_path) + + if isinstance(timewindow, str): + # eval_cat = CSEPCatalog.load_json(catalog_path) + eval_cat = self.catalog_repo.get_test_cat(timewindow) eval_cat.region = getattr(forecast, "region") + else: - eval_cat = [CSEPCatalog.load_json(i) for i in catalog_path] + eval_cat = [self.catalog_repo.get_test_cat(i) for i in timewindow] if (len(forecast) != len(eval_cat)) or (not isinstance(forecast, Sequence)): - raise IndexError("Amount of passed catalogs and forecats must " "be the same") + raise IndexError("Amount of passed catalogs and forecasts must " "be the same") for cat, fc in zip(eval_cat, forecast): cat.region = getattr(fc, "region", None) @@ -202,7 +204,6 @@ def compute( timewindow: Union[str, list], catalog: str, model: Model, - path: str, ref_model: Union[Model, Sequence[Model]] = None, region=None, ) -> None: @@ -216,57 +217,30 @@ def compute( catalog (str): Path to the filtered catalog model (Model, list[Model]): Model(s) to be evaluated ref_model: Model to be used as reference - path: Path to store the Evaluation result region: region to filter a catalog forecast. Returns: """ test_args = self.prepare_args( - timewindow, catpath=catalog, model=model, ref_model=ref_model, region=region + timewindow, model=model, ref_model=ref_model, region=region ) evaluation_result = self.func(*test_args, **self.func_kwargs) - self.write_result(evaluation_result, path) - - @staticmethod - def write_result(result: EvaluationResult, path: str) -> None: - """Dumps a test result into a json file.""" - class NumpyEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, numpy.integer): - return int(obj) - if isinstance(obj, numpy.floating): - return float(obj) - if isinstance(obj, numpy.ndarray): - return obj.tolist() - return json.JSONEncoder.default(self, obj) - - with open(path, "w") as _file: - json.dump(result.to_dict(), _file, indent=4, cls=NumpyEncoder) + if self.type in ["sequential", "sequential_comparative"]: + self.results_repo.write_result(evaluation_result, self, model, timewindow[-1]) + else: + self.results_repo.write_result(evaluation_result, self, model, timewindow) def read_results( - self, - window: Union[str, Sequence[datetime.datetime]], - models: List[Model], - tree: ExperimentRegistry, + self, window: Union[str, Sequence[datetime.datetime]], models: List[Model] ) -> List: """ Reads an Evaluation result for a given time window and returns a list of the results for all tested models. """ - test_results = [] - - if not isinstance(window, str): - wstr_ = timewindow2str(window) - else: - wstr_ = window - for i in models: - eval_path = tree(wstr_, "evaluations", self, i.name) - with open(eval_path, "r") as file_: - model_eval = EvaluationResult.from_dict(json.load(file_)) - test_results.append(model_eval) + test_results = self.results_repo.load_results(self, window, models) return test_results @@ -274,7 +248,7 @@ def plot_results( self, timewindow: Union[str, List], models: List[Model], - tree: ExperimentRegistry, + registry: ExperimentRegistry, dpi: int = 300, show: bool = False, ) -> None: @@ -284,7 +258,7 @@ def plot_results( Args: timewindow: string representing the desired timewindow to plot models: a list of :class:`floatcsep:models.Model` - tree: a :class:`floatcsep:models.PathTree` containing path of the results + registry: a :class:`floatcsep:models.PathTree` containing path of the results dpi: Figure resolution with which to save show: show in runtime """ @@ -296,8 +270,8 @@ def plot_results( try: for time_str in timewindow: - fig_path = tree(time_str, "figures", self.name) - results = self.read_results(time_str, models, tree) + fig_path = registry.get(time_str, "figures", self.name) + results = self.read_results(time_str, models) ax = func(results, plot_args=fargs, **fkwargs) if "code" in fargs: exec(fargs["code"]) @@ -308,14 +282,14 @@ def plot_results( except AttributeError as msg: if self.type in ["consistency", "comparative"]: for time_str in timewindow: - results = self.read_results(time_str, models, tree) + results = self.read_results(time_str, models) for result, model in zip(results, models): fig_name = f"{self.name}_{model.name}" - tree.paths[time_str]["figures"][fig_name] = os.path.join( + registry.paths[time_str]["figures"][fig_name] = os.path.join( time_str, "figures", fig_name ) - fig_path = tree(time_str, "figures", fig_name) + fig_path = registry.get(time_str, "figures", fig_name) ax = func(result, plot_args=fargs, **fkwargs, show=False) if "code" in fargs: exec(fargs["code"]) @@ -324,8 +298,8 @@ def plot_results( pyplot.show() elif self.type in ["sequential", "sequential_comparative", "batch"]: - fig_path = tree(timewindow[-1], "figures", self.name) - results = self.read_results(timewindow[-1], models, tree) + fig_path = registry.get(timewindow[-1], "figures", self.name) + results = self.read_results(timewindow[-1], models) ax = func(results, plot_args=fargs, **fkwargs) if "code" in fargs: diff --git a/floatcsep/experiment.py b/floatcsep/experiment.py index 4f8dbc9..80c3487 100644 --- a/floatcsep/experiment.py +++ b/floatcsep/experiment.py @@ -20,6 +20,7 @@ from floatcsep.logger import add_fhandler from floatcsep.model import Model, TimeDependentModel from floatcsep.registry import ExperimentRegistry +from floatcsep.repository import ResultsRepository, CatalogRepository from floatcsep.utils import ( NoAliasLoader, parse_csep_func, @@ -144,11 +145,14 @@ def __init__( os.makedirs(os.path.join(workdir, rundir), exist_ok=True) self.name = name if name else "floatingExp" - self.path = ExperimentRegistry(workdir, rundir) + self.registry = ExperimentRegistry(workdir, rundir) + self.results_repo = ResultsRepository(self.registry) + self.catalog_repo = CatalogRepository(self.registry) + self.config_file = kwargs.get("config_file", None) self.original_config = kwargs.get("original_config", None) - self.original_rundir = kwargs.get("original_rundir", None) - self.rundir = rundir + self.original_run_dir = kwargs.get("original_rundir", None) + self.run_dir = rundir self.seed = kwargs.get("seed", None) self.time_config = read_time_cfg(time_config, **kwargs) self.region_config = read_region_cfg(region_config, **kwargs) @@ -158,9 +162,9 @@ def __init__( logger = kwargs.get("logging", True) if logger: filename = "experiment.log" if logger is True else logger - self.path.logger = os.path.join(workdir, rundir, filename) - log.info(f"Logging at {self.path.logger}") - add_fhandler(self.path.logger) + self.registry.logger = os.path.join(workdir, rundir, filename) + log.info(f"Logging at {self.registry.logger}") + add_fhandler(self.registry.logger) log.debug(f"-------- BEGIN OF RUN --------") log.info(f"Setting up experiment {self.name}:") @@ -180,7 +184,8 @@ def __init__( self.postproc_config = postproc_config if postproc_config else {} self.default_test_kwargs = default_test_kwargs - self.catalog = catalog + self.catalog_repo.set_catalog(catalog, self.time_config, self.region_config) + self.models = self.set_models( models or kwargs.get("model_config"), kwargs.get("order", None) ) @@ -233,13 +238,13 @@ def set_models(self, model_config: Union[Dict, str, List], order: List = None) - models = [] if isinstance(model_config, str): - modelcfg_path = self.path.abs(model_config) - _dir = self.path.abs_dir(model_config) + modelcfg_path = self.registry.abs(model_config) + _dir = self.registry.abs_dir(model_config) with open(modelcfg_path, "r") as file_: config_dict = yaml.load(file_, NoAliasLoader) elif isinstance(model_config, (dict, list)): config_dict = model_config - _dir = self.path.workdir + _dir = self.registry.workdir elif model_config is None: return models else: @@ -251,9 +256,13 @@ def set_models(self, model_config: Union[Dict, str, List], order: List = None) - if not any("flavours" in i for i in element.values()): name_ = next(iter(element)) - path_ = self.path.rel(_dir, element[name_]["path"]) + path_ = self.registry.rel(_dir, element[name_]["path"]) model_i = { - name_: {**element[name_], "model_path": path_, "workdir": self.path.workdir} + name_: { + **element[name_], + "model_path": path_, + "workdir": self.registry.workdir, + } } model_i[name_].pop("path") models.append(Model.factory(model_i)) @@ -263,14 +272,14 @@ def set_models(self, model_config: Union[Dict, str, List], order: List = None) - for flav, flav_path in model_flavours: name_super = next(iter(element)) path_super = element[name_super].get("path", "") - path_sub = self.path.rel(_dir, path_super, flav_path) + path_sub = self.registry.rel(_dir, path_super, flav_path) # updates name of submodel name_flav = f"{name_super}@{flav}" model_ = { name_flav: { **element[name_super], "model_path": path_sub, - "workdir": self.path.workdir, + "workdir": self.registry.workdir, } } model_[name_flav].pop("path") @@ -320,110 +329,28 @@ def set_tests(self, test_config: Union[str, Dict, List]) -> list: tests = [] if isinstance(test_config, str): - with open(self.path.abs(test_config), "r") as config: + + with open(self.registry.abs(test_config), "r") as config: config_dict = yaml.load(config, NoAliasLoader) + for eval_dict in config_dict: - tests.append(Evaluation.from_dict(eval_dict)) + eval_i = Evaluation.from_dict(eval_dict) + eval_i.results_repo = self.results_repo + eval_i.catalog_repo = self.catalog_repo + tests.append(eval_i) + elif isinstance(test_config, (dict, list)): + for eval_dict in test_config: - tests.append(Evaluation.from_dict(eval_dict)) + eval_i = Evaluation.from_dict(eval_dict) + eval_i.results_repo = self.results_repo + eval_i.catalog_repo = self.catalog_repo + tests.append(eval_i) log.info(f"\tEvaluations: {[i.name for i in tests]}") return tests - @property - def catalog(self) -> CSEPCatalog: - """ - Returns a CSEP catalog loaded from the given query function or a stored file if it - exists. - """ - cat_path = self.path.abs(self._catpath) - - if callable(self._catalog): - if isfile(self._catpath): - return CSEPCatalog.load_json(self._catpath) - bounds = { - "start_time": min([item for sublist in self.timewindows for item in sublist]), - "end_time": max([item for sublist in self.timewindows for item in sublist]), - "min_magnitude": self.magnitudes.min(), - "max_depth": self.depths.max(), - } - if self.region: - bounds.update( - { - i: j - for i, j in zip( - ["min_longitude", "max_longitude", "min_latitude", "max_latitude"], - self.region.get_bbox(), - ) - } - ) - - catalog = self._catalog(catalog_id="catalog", **bounds) - - if self.region: - catalog.filter_spatial(region=self.region, in_place=True) - catalog.region = None - catalog.write_json(self._catpath) - - return catalog - - elif isfile(cat_path): - try: - return CSEPCatalog.load_json(cat_path) - except json.JSONDecodeError: - return csep.load_catalog(cat_path) - - @catalog.setter - def catalog(self, cat: Union[Callable, CSEPCatalog, str]) -> None: - - if cat is None: - self._catalog = None - self._catpath = None - - elif isfile(self.path.abs(cat)): - log.info(f"\tCatalog: '{cat}'") - self._catalog = self.path.rel(cat) - self._catpath = self.path.rel(cat) - - else: - # catalog can be a function - self._catalog = parse_csep_func(cat) - self._catpath = self.path.abs("catalog.json") - if isfile(self._catpath): - log.info(f"\tCatalog: stored " f"'{self._catpath}' " f"from '{cat}'") - else: - log.info(f"\tCatalog: '{cat}'") - - def get_test_cat(self, tstring: str = None) -> CSEPCatalog: - """ - Filters the complete experiment catalog to a test sub-catalog bounded by the test - time-window. Writes it to filepath defined in :attr:`Experiment.registry` - - Args: - tstring (str): Time window string - """ - - if tstring: - start, end = str2timewindow(tstring) - else: - start = self.start_date - end = self.end_date - sub_cat = self.catalog.filter( - [ - f"origin_time < {end.timestamp() * 1000}", - f"origin_time >= {start.timestamp() * 1000}", - f"magnitude >= {self.mag_min}", - f"magnitude < {self.mag_max}", - ], - in_place=False, - ) - if self.region: - sub_cat.filter_spatial(region=self.region, in_place=True) - - return sub_cat - def set_test_cat(self, tstring: str) -> None: """ Filters the complete experiment catalog to a test sub-catalog bounded by the test @@ -433,32 +360,13 @@ def set_test_cat(self, tstring: str) -> None: tstring (str): Time window string """ - testcat_name = self.path(tstring, "catalog") - if not exists(testcat_name): - log.debug( - f"Filtering catalog to testing sub-catalog and saving to " f"{testcat_name}" - ) - start, end = str2timewindow(tstring) - sub_cat = self.catalog.filter( - [ - f"origin_time < {end.timestamp() * 1000}", - f"origin_time >= {start.timestamp() * 1000}", - f"magnitude >= {self.mag_min}", - f"magnitude < {self.mag_max}", - ], - in_place=False, - ) - if self.region: - sub_cat.filter_spatial(region=self.region, in_place=True) - sub_cat.write_json(filename=testcat_name) - else: - log.debug(f"Using stored test sub-catalog from {testcat_name}") + self.catalog_repo.set_test_cat(tstring) def set_input_cat(self, tstring: str, model: Model) -> None: """ Filters the complete experiment catalog to a input sub-catalog filtered. - to the beginning of thetest time-window. Writes it to filepath defined + to the beginning of the test time-window. Writes it to filepath defined in :attr:`Model.tree.catalog` Args: @@ -466,9 +374,8 @@ def set_input_cat(self, tstring: str, model: Model) -> None: model (:class:`~floatcsep.model.Model`): Model to give the input catalog """ - start, end = str2timewindow(tstring) - sub_cat = self.catalog.filter([f"origin_time < {start.timestamp() * 1000}"]) - sub_cat.write_ascii(filename=model.registry.get_path("input_cat")) + + self.catalog_repo.set_input_cat(tstring, model) def set_tasks(self): """ @@ -489,10 +396,10 @@ def set_tasks(self): """ # Set the file path structure - self.path.build_tree(self.timewindows, self.models, self.tests) + self.registry.build_tree(self.timewindows, self.models, self.tests) log.info("Setting up experiment's tasks") - log.debug("Pre-run: results' paths\n" + yaml.dump(self.path.as_dict())) + log.debug("Pre-run: results' paths\n" + yaml.dump(self.registry.as_dict())) # Get the time windows strings tw_strings = timewindow2str(self.timewindows) @@ -540,10 +447,9 @@ def set_tasks(self): instance=test_k, method="compute", timewindow=time_i, - catalog=self.path(time_i, "catalog"), + catalog=self.registry.get(time_i, "catalog"), model=model_j, region=self.region, - path=self.path(time_i, "evaluations", test_k, model_j), ) task_graph.add(task_ijk) # the forecast needs to have been created @@ -558,11 +464,10 @@ def set_tasks(self): instance=test_k, method="compute", timewindow=time_i, - catalog=self.path(time_i, "catalog"), + catalog=self.registry.get(time_i, "catalog"), model=model_j, ref_model=self.get_model(test_k.ref_model), region=self.region, - path=self.path(time_i, "evaluations", test_k, model_j), ) task_graph.add(task_ik) task_graph.add_dependency( @@ -581,10 +486,9 @@ def set_tasks(self): instance=test_k, method="compute", timewindow=tw_strings, - catalog=[self.path(i, "catalog") for i in tw_strings], + catalog=[self.registry.get(i, "catalog") for i in tw_strings], model=model_j, region=self.region, - path=self.path(tw_strings[-1], "evaluations", test_k, model_j), ) task_graph.add(task_k) for tw_i in tw_strings: @@ -599,11 +503,10 @@ def set_tasks(self): instance=test_k, method="compute", timewindow=tw_strs, - catalog=[self.path(i, "catalog") for i in tw_strs], + catalog=[self.registry.get(i, "catalog") for i in tw_strs], model=model_j, ref_model=self.get_model(test_k.ref_model), region=self.region, - path=self.path(tw_strs[-1], "evaluations", test_k, model_j), ) task_graph.add(task_k) for tw_i in tw_strings: @@ -624,11 +527,10 @@ def set_tasks(self): instance=test_k, method="compute", timewindow=time_str, - catalog=self.path(time_str, "catalog"), + catalog=self.registry.get(time_str, "catalog"), ref_model=self.models, model=model_j, region=self.region, - path=self.path(time_str, "evaluations", test_k, model_j), ) task_graph.add(task_k) for m_j in self.models: @@ -662,7 +564,7 @@ def read_results(self, test: Evaluation, window: str) -> List: for all tested models. """ - return test.read_results(window, self.models, self.path) + return test.read_results(window, self.models) def plot_results(self) -> None: """Plots all evaluation results.""" @@ -670,7 +572,7 @@ def plot_results(self) -> None: timewindows = timewindow2str(self.timewindows) for test in self.tests: - test.plot_results(timewindows, self.models, self.path) + test.plot_results(timewindows, self.models, self.registry) def plot_catalog(self, dpi: int = 300, show: bool = False) -> None: """ @@ -690,31 +592,33 @@ def plot_catalog(self, dpi: int = 300, show: bool = False) -> None: "legend": True, } plot_args.update(self.postproc_config.get("plot_catalog", {})) - catalog = self.get_test_cat() + catalog = self.catalog_repo.get_test_cat() if catalog.get_number_of_events() != 0: ax = catalog.plot(plot_args=plot_args, show=show) ax.get_figure().tight_layout() - ax.get_figure().savefig(self.path("catalog_figure"), dpi=dpi) + ax.get_figure().savefig(self.registry.get("catalog_figure"), dpi=dpi) ax2 = magnitude_vs_time(catalog) ax2.get_figure().tight_layout() - ax2.get_figure().savefig(self.path("magnitude_time"), dpi=dpi) + ax2.get_figure().savefig(self.registry.get("magnitude_time"), dpi=dpi) if self.postproc_config.get("all_time_windows"): timewindow = self.timewindows for tw in timewindow: - catpath = self.path(tw, "catalog") + catpath = self.registry.get(tw, "catalog") catalog = CSEPCatalog.load_json(catpath) if catalog.get_number_of_events() != 0: ax = catalog.plot(plot_args=plot_args, show=show) ax.get_figure().tight_layout() - ax.get_figure().savefig(self.path(tw, "figures", "catalog"), dpi=dpi) + ax.get_figure().savefig( + self.registry.get(tw, "figures", "catalog"), dpi=dpi + ) ax2 = magnitude_vs_time(catalog) ax2.get_figure().tight_layout() ax2.get_figure().savefig( - self.path(tw, "figures", "magnitude_time"), dpi=dpi + self.registry.get(tw, "figures", "magnitude_time"), dpi=dpi ) def plot_forecasts(self) -> None: @@ -758,7 +662,7 @@ def plot_forecasts(self) -> None: winstr = timewindow2str(window) for model in self.models: - fig_path = self.path(winstr, "forecasts", model.name) + fig_path = self.registry.get(winstr, "forecasts", model.name) start = decimal_year(window[0]) end = decimal_year(window[1]) time = f"{round(end - start, 3)} years" @@ -800,36 +704,38 @@ def plot_forecasts(self) -> None: def generate_report(self) -> None: """Creates a report summarizing the Experiment's results.""" - log.info(f"Saving report into {self.path.rundir}") - self.path.build_tree(self.timewindows, self.models, self.tests) - log.debug("Post-run: results' paths\n" + yaml.dump(self.path.as_dict())) + log.info(f"Saving report into {self.registry.rundir}") + self.registry.build_tree(self.timewindows, self.models, self.tests) + log.debug("Post-run: results' paths\n" + yaml.dump(self.registry.as_dict())) report.generate_report(self) def make_repr(self): log.info("Creating reproducibility config file") - repr_config = self.path("config") + repr_config = self.registry.get("config") # Dropping region to results folder if it is a file region_path = self.region_config.get("path", False) if region_path: if isfile(region_path) and region_path: - new_path = join(self.path.rundir, self.region_config["path"]) + new_path = join(self.registry.rundir, self.region_config["path"]) shutil.copy2(region_path, new_path) self.region_config.pop("path") - self.region_config["region"] = self.path.rel(new_path) + self.region_config["region"] = self.registry.rel(new_path) # Dropping catalog to results folder - target_cat = join(self.path.workdir, self.path.rundir, split(self._catpath)[-1]) + target_cat = join( + self.registry.workdir, self.registry.rundir, split(self.catalog_repo._catpath)[-1] + ) if not exists(target_cat): - shutil.copy2(self.path.abs(self._catpath), target_cat) - self._catpath = self.path.rel(target_cat) + shutil.copy2(self.registry.abs(self.catalog_repo._catpath), target_cat) + self._catpath = self.registry.rel(target_cat) relative_path = os.path.relpath( - self.path.workdir, os.path.join(self.path.workdir, self.path.rundir) + self.registry.workdir, os.path.join(self.registry.workdir, self.registry.rundir) ) - self.path.workdir = relative_path + self.registry.workdir = relative_path self.to_yml(repr_config, extended=True) def as_dict( @@ -843,6 +749,8 @@ def as_dict( "tasks", "models", "tests", + "results_repo", + "catalog_repo", ), extended: bool = False, ) -> dict: @@ -859,9 +767,10 @@ def as_dict( """ listwalk = [(i, j) for i, j in self.__dict__.items() if not i.startswith("_") and j] - listwalk.insert(6, ("catalog", self._catpath)) + listwalk.insert(6, ("catalog", self.catalog_repo._catpath)) dictwalk = {i: j for i, j in listwalk} + dictwalk["path"] = dictwalk.pop("registry").workdir return parse_nested_dicts(dictwalk, excluded=exclude, extended=extended) diff --git a/floatcsep/model.py b/floatcsep/model.py index f0e6e00..67c3eee 100644 --- a/floatcsep/model.py +++ b/floatcsep/model.py @@ -95,7 +95,7 @@ def get_source(self, zenodo_id: int = None, giturl: str = None, **kwargs) -> Non try: from_zenodo( zenodo_id, - self.registry.dir if self.registry.fmt else self.registry.get_path("path"), + self.registry.dir if self.registry.fmt else self.registry.get("path"), force=True, ) except (KeyError, TypeError) as msg: @@ -106,7 +106,7 @@ def get_source(self, zenodo_id: int = None, giturl: str = None, **kwargs) -> Non try: from_git( giturl, - self.registry.dir if self.registry.fmt else self.registry.get_path("path"), + self.registry.dir if self.registry.fmt else self.registry.get("path"), **kwargs, ) except (git.NoSuchPathError, git.CommandError) as msg: @@ -115,7 +115,7 @@ def get_source(self, zenodo_id: int = None, giturl: str = None, **kwargs) -> Non raise FileNotFoundError("Model has no path or identified") if not os.path.exists(self.registry.dir) or not os.path.exists( - self.registry.get_path("path") + self.registry.get("path") ): raise FileNotFoundError( f"Directory '{self.registry.dir}' or file {self.registry}' do not exist. " @@ -241,7 +241,7 @@ def init_db(self, dbpath: str = "", force: bool = False) -> None: """ parser = getattr(ForecastParsers, self.registry.fmt) - rates, region, mag = parser(self.registry.get_path("path")) + rates, region, mag = parser(self.registry.get("path")) db_func = HDF5Serializer.grid2hdf5 if not dbpath: @@ -373,11 +373,11 @@ def create_forecast(self, tstring: str, **kwargs) -> None: f"Running {self.name} using {self.environment.__class__.__name__}:" f" {timewindow2str([start_date, end_date])}" ) - self.environment.run_command(f'{self.func} {self.registry.get_path("args_file")}') + self.environment.run_command(f'{self.func} {self.registry.get("args_file")}') def prepare_args(self, start, end, **kwargs): - filepath = self.registry.get_path("args_file") + filepath = self.registry.get("args_file") fmt = os.path.splitext(filepath)[1] if fmt == ".txt": diff --git a/floatcsep/registry.py b/floatcsep/registry.py index 8437645..f4b42b1 100644 --- a/floatcsep/registry.py +++ b/floatcsep/registry.py @@ -37,7 +37,7 @@ def build_tree(self, *args, **kwargs) -> None: pass @abstractmethod - def get_path(self, *args): + def get(self, *args): pass def abs(self, *paths: Sequence[str]) -> str: @@ -71,7 +71,7 @@ def rel_dir(self, *paths: Sequence[str]) -> str: return relpath(_dir, self.workdir) def file_exists(self, *args): - file_abspath = self.get_path(*args) + file_abspath = self.get(*args) return exists(file_abspath) @@ -93,7 +93,7 @@ def __init__( self.forecasts = {} self.inventory = {} - def get_path(self, *args): + def get(self, *args): val = self.__dict__ for i in args: parsed_arg = self._parse_arg(i) @@ -107,10 +107,10 @@ def dir(self) -> str: The directory containing the model source. """ - if os.path.isdir(self.get_path("path")): - return self.get_path("path") + if os.path.isdir(self.get("path")): + return self.get("path") else: - return os.path.dirname(self.get_path("path")) + return os.path.dirname(self.get("path")) @property def fmt(self) -> str: @@ -202,8 +202,12 @@ def __init__(self, workdir: str, rundir: str = "results"): self.paths = {} self.result_exists = {} - def get_path(self, *args): - pass + def get(self, *args): + val = self.paths + for i in args: + parsed_arg = self._parse_arg(i) + val = val[parsed_arg] + return self.abs(self.rundir, val) def __call__(self, *args): val = self.paths diff --git a/floatcsep/report.py b/floatcsep/report.py index 1860229..44aedfe 100644 --- a/floatcsep/report.py +++ b/floatcsep/report.py @@ -39,11 +39,14 @@ def generate_report(experiment, timewindow=-1): report.add_heading("Authoritative Data", level=2) # Generate catalog plot - if experiment.catalog is not None: + if experiment.catalog_repo.catalog is not None: experiment.plot_catalog() report.add_figure( f"Input catalog", - [experiment.path("catalog_figure"), experiment.path("magnitude_time")], + [ + experiment.registry.get("catalog_figure"), + experiment.registry.get("magnitude_time"), + ], level=3, ncols=1, caption="Evaluation catalog from " @@ -67,14 +70,16 @@ def generate_report(experiment, timewindow=-1): # Include results from Experiment for test in experiment.tests: - fig_path = experiment.path(timestr, "figures", test) + fig_path = experiment.registry.get(timestr, "figures", test) width = test.plot_args[0].get("figsize", [4])[0] * 96 report.add_figure( f"{test.name}", fig_path, level=3, caption=test.markdown, add_ext=True, width=width ) for model in experiment.models: try: - fig_path = experiment.path(timestr, "figures", f"{test.name}_{model.name}") + fig_path = experiment.registry.get( + timestr, "figures", f"{test.name}_{model.name}" + ) width = test.plot_args[0].get("figsize", [4])[0] * 96 report.add_figure( f"{test.name}: {model.name}", @@ -87,4 +92,4 @@ def generate_report(experiment, timewindow=-1): except KeyError: pass report.table_of_contents() - report.save(experiment.path.abs(experiment.path.rundir)) + report.save(experiment.registry.abs(experiment.registry.rundir)) diff --git a/floatcsep/repository.py b/floatcsep/repository.py index be981bb..e92e63b 100644 --- a/floatcsep/repository.py +++ b/floatcsep/repository.py @@ -1,18 +1,28 @@ +import datetime +import json import logging from abc import ABC, abstractmethod -from typing import Sequence, Union +from os.path import isfile, exists +from typing import Sequence, Union, List, TYPE_CHECKING, Callable import csep +import numpy +from csep.core.catalogs import CSEPCatalog from csep.core.forecasts import GriddedForecast +from csep.models import EvaluationResult from csep.utils.time_utils import decimal_year from floatcsep.readers import ForecastParsers -from floatcsep.registry import ForecastRegistry -from floatcsep.utils import str2timewindow +from floatcsep.registry import ForecastRegistry, ExperimentRegistry +from floatcsep.utils import str2timewindow, parse_csep_func from floatcsep.utils import timewindow2str log = logging.getLogger("floatLogger") +if TYPE_CHECKING: + from floatcsep.evaluation import Evaluation + from floatcsep.model import Model + class ForecastRepository(ABC): @@ -86,7 +96,7 @@ def load_forecast(self, tstring: Union[str, list], region=None): return [self._load_single_forecast(t, region) for t in tstring] def _load_single_forecast(self, t: str, region=None): - fc_path = self.registry.get_path("forecasts", t) + fc_path = self.registry.get("forecasts", t) return csep.load_catalog_forecast( fc_path, region=region, apply_filters=True, filter_spatial=True ) @@ -132,7 +142,7 @@ def _load_single_forecast(self, tstring: str, fc_unit=1, name_=""): time_horizon = decimal_year(end_date) - decimal_year(start_date) tstring_ = timewindow2str([start_date, end_date]) - f_path = self.registry.get_path("forecasts", tstring_) + f_path = self.registry.get("forecasts", tstring_) f_parser = getattr(ForecastParsers, self.registry.fmt) rates, region, mags = f_parser(f_path) @@ -160,3 +170,256 @@ def _load_single_forecast(self, tstring: str, fc_unit=1, name_=""): def remove(self, tstring: Union[str, Sequence[str]]): pass + + +class ResultsRepository: + + def __init__(self, registry: ExperimentRegistry): + self.registry = registry + self.a = 1 + + def _load_result( + self, + test: "Evaluation", + window: Union[str, Sequence[datetime.datetime]], + model: "Model", + ) -> EvaluationResult: + + if not isinstance(window, str): + wstr_ = timewindow2str(window) + else: + wstr_ = window + + eval_path = self.registry.get(wstr_, "evaluations", test, model) + + with open(eval_path, "r") as file_: + model_eval = EvaluationResult.from_dict(json.load(file_)) + + return model_eval + + def load_results( + self, + test, + window: Union[str, Sequence[datetime.datetime]], + models: List, + ) -> List: + """ + Reads an Evaluation result for a given time window and returns a list of the results for + all tested models. + """ + test_results = [] + + for model in models: + model_eval = self._load_result(test, window, model) + test_results.append(model_eval) + + return test_results + + def write_result(self, result: EvaluationResult, test, model, window) -> None: + + path = self.registry.get(window, "evaluations", test, model) + + class NumpyEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, numpy.integer): + return int(obj) + if isinstance(obj, numpy.floating): + return float(obj) + if isinstance(obj, numpy.ndarray): + return obj.tolist() + return json.JSONEncoder.default(self, obj) + + with open(path, "w") as _file: + json.dump(result.to_dict(), _file, indent=4, cls=NumpyEncoder) + + +class CatalogRepository: + + def __init__(self, registry: ExperimentRegistry): + self.registry = registry + self.time_config = {} + self.region_config = {} + + def __dir__(self): + """Adds time and region configs keys to instance scope.""" + + _dir = ( + list(super().__dir__()) + list(self.time_config.keys()) + list(self.region_config) + ) + return sorted(_dir) + + def __getattr__(self, item: str) -> object: + """ + Override built-in method to return attributes found within. + :attr:`region_config` or :attr:`time_config` + """ + + try: + return self.__dict__[item] + except KeyError: + try: + return self.time_config[item] + except KeyError: + try: + return self.region_config[item] + except KeyError: + raise AttributeError( + f"Experiment '{self.name}'" f" has no attribute '{item}'" + ) from None + + def as_dict(self): + return + + def set_catalog( + self, catalog: Union[str, Callable, CSEPCatalog], time_config: dict, region_config: dict + ): + """ + Sets the catalog to be used for the experiment. + + Args: + catalog: Experiment's main catalog. + region_config: Experiment instantiation + time_config: + """ + self.catalog = catalog + self.time_config = time_config + self.region_config = region_config + + @property + def catalog(self) -> CSEPCatalog: + """ + Returns a CSEP catalog loaded from the given query function or a stored file if it + exists. + """ + cat_path = self.registry.abs(self._catpath) + + if callable(self._catalog): + if isfile(self._catpath): + return CSEPCatalog.load_json(self._catpath) + bounds = { + "start_time": min([item for sublist in self.timewindows for item in sublist]), + "end_time": max([item for sublist in self.timewindows for item in sublist]), + "min_magnitude": self.magnitudes.min(), + "max_depth": self.depths.max(), + } + if self.region: + bounds.update( + { + i: j + for i, j in zip( + ["min_longitude", "max_longitude", "min_latitude", "max_latitude"], + self.region.get_bbox(), + ) + } + ) + + catalog = self._catalog(catalog_id="catalog", **bounds) + + if self.region: + catalog.filter_spatial(region=self.region, in_place=True) + catalog.region = None + catalog.write_json(self._catpath) + + return catalog + + elif isfile(cat_path): + try: + return CSEPCatalog.load_json(cat_path) + except json.JSONDecodeError: + return csep.load_catalog(cat_path) + + @catalog.setter + def catalog(self, cat: Union[Callable, CSEPCatalog, str]) -> None: + + if cat is None: + self._catalog = None + self._catpath = None + + elif isfile(self.registry.abs(cat)): + log.info(f"\tCatalog: '{cat}'") + self._catalog = self.registry.rel(cat) + self._catpath = self.registry.rel(cat) + + else: + # catalog can be a function + self._catalog = parse_csep_func(cat) + self._catpath = self.registry.abs("catalog.json") + if isfile(self._catpath): + log.info(f"\tCatalog: stored " f"'{self._catpath}' " f"from '{cat}'") + else: + log.info(f"\tCatalog: '{cat}'") + + def get_test_cat(self, tstring: str = None) -> CSEPCatalog: + """ + Filters the complete experiment catalog to a test sub-catalog bounded by the test + time-window. Writes it to filepath defined in :attr:`Experiment.registry` + + Args: + tstring (str): Time window string + """ + + if tstring: + start, end = str2timewindow(tstring) + else: + start = self.start_date + end = self.end_date + print(self.catalog) + sub_cat = self.catalog.filter( + [ + f"origin_time < {end.timestamp() * 1000}", + f"origin_time >= {start.timestamp() * 1000}", + f"magnitude >= {self.mag_min}", + f"magnitude < {self.mag_max}", + ], + in_place=False, + ) + if self.region: + sub_cat.filter_spatial(region=self.region, in_place=True) + + return sub_cat + + def set_test_cat(self, tstring: str) -> None: + """ + Filters the complete experiment catalog to a test sub-catalog bounded by the test + time-window. Writes it to filepath defined in :attr:`Experiment.registry` + + Args: + tstring (str): Time window string + """ + + testcat_name = self.registry.get(tstring, "catalog") + if not exists(testcat_name): + log.debug( + f"Filtering catalog to testing sub-catalog and saving to " f"{testcat_name}" + ) + start, end = str2timewindow(tstring) + sub_cat = self.catalog.filter( + [ + f"origin_time < {end.timestamp() * 1000}", + f"origin_time >= {start.timestamp() * 1000}", + f"magnitude >= {self.mag_min}", + f"magnitude < {self.mag_max}", + ], + in_place=False, + ) + if self.region: + sub_cat.filter_spatial(region=self.region, in_place=True) + sub_cat.write_json(filename=testcat_name) + else: + log.debug(f"Using stored test sub-catalog from {testcat_name}") + + def set_input_cat(self, tstring: str, model: "Model") -> None: + """ + Filters the complete experiment catalog to input sub-catalog filtered. + + to the beginning of thetest time-window. Writes it to filepath defined + in :attr:`Model.tree.catalog` + + Args: + tstring (str): Time window string + model (:class:`~floatcsep.model.Model`): Model to give the input + catalog + """ + start, end = str2timewindow(tstring) + sub_cat = self.catalog.filter([f"origin_time < {start.timestamp() * 1000}"]) + sub_cat.write_ascii(filename=model.registry.get("input_cat")) diff --git a/floatcsep/utils.py b/floatcsep/utils.py index 7213acd..1409e8b 100644 --- a/floatcsep/utils.py +++ b/floatcsep/utils.py @@ -931,8 +931,8 @@ def get_filecomp(self): for tw in win_orig: results[test.name][tw] = dict.fromkeys(models_orig) for model in models_orig: - orig_path = self.original.path(tw, "evaluations", test, model) - repr_path = self.reproduced.path(tw, "evaluations", test, model) + orig_path = self.original.registry.get(tw, "evaluations", test, model) + repr_path = self.reproduced.registry.get(tw, "evaluations", test, model) results[test.name][tw][model] = { "hash": (self.get_hash(orig_path) == self.get_hash(repr_path)), @@ -941,8 +941,12 @@ def get_filecomp(self): else: results[test.name] = dict.fromkeys(models_orig) for model in models_orig: - orig_path = self.original.path(win_orig[-1], "evaluations", test, model) - repr_path = self.reproduced.path(win_orig[-1], "evaluations", test, model) + orig_path = self.original.registry.get( + win_orig[-1], "evaluations", test, model + ) + repr_path = self.reproduced.registry.get( + win_orig[-1], "evaluations", test, model + ) results[test.name][model] = { "hash": (self.get_hash(orig_path) == self.get_hash(repr_path)), "byte2byte": filecmp.cmp(orig_path, repr_path), @@ -961,7 +965,7 @@ def write_report(self): data = self.file_comp outname = os.path.join("reproducibility_report.md") save_path = os.path.dirname( - os.path.join(self.reproduced.path.workdir, self.reproduced.path.rundir) + os.path.join(self.reproduced.registry.workdir, self.reproduced.registry.rundir) ) report = MarkdownReport(outname=outname) report.add_title(f"Reproducibility Report - {self.original.name}", "") diff --git a/tests/integration/test_model_interface.py b/tests/integration/test_model_interface.py index b8531e6..2d9aa19 100644 --- a/tests/integration/test_model_interface.py +++ b/tests/integration/test_model_interface.py @@ -313,13 +313,11 @@ def test_zenodo(self, mock_buildtree): model_b.stage() self.assertEqual( - os.path.basename(model_a.registry.get_path("path")), - os.path.basename(model_b.registry.get_path("path")), + os.path.basename(model_a.registry.get("path")), + os.path.basename(model_b.registry.get("path")), ) self.assertEqual(model_a.name, model_b.name) - self.assertTrue( - filecmp.cmp(model_a.registry.get_path("path"), model_b.registry.get_path("path")) - ) + self.assertTrue(filecmp.cmp(model_a.registry.get("path"), model_b.registry.get("path"))) def test_zenodo_fail(self): name = "mock_zenodo" diff --git a/tests/unit/test_evaluation.py b/tests/unit/test_evaluation.py index 4b26298..d9b4abd 100644 --- a/tests/unit/test_evaluation.py +++ b/tests/unit/test_evaluation.py @@ -1,8 +1,5 @@ import unittest -from typing import Sequence, List from floatcsep.evaluation import Evaluation -from csep.core.forecasts import GriddedForecast -from csep.core.catalogs import CSEPCatalog class TestEvaluation(unittest.TestCase): @@ -16,14 +13,6 @@ def mock_eval(): @staticmethod def init_noreg(name, func, **kwargs): - """Instantiates a model without using the @register deco, - but mocks Model.Registry() attrs""" - # deprecated - # evaluation = Evaluation.__new__(Evaluation) - # Evaluation.__init__.__wrapped__(self=evaluation, - # name=name, - # func=func, - # **kwargs) evaluation = Evaluation(name=name, func=func, **kwargs) return evaluation @@ -41,6 +30,8 @@ def test_init(self): "plot_kwargs": None, "markdown": "", "_type": None, + "results_repo": None, + "catalog_repo": None, } self.assertEqual(dict_, eval_.__dict__) @@ -50,37 +41,6 @@ def test_discrete_args(self): def test_sequential_args(self): pass - def test_prepare_catalog(self): - from unittest.mock import MagicMock, Mock, patch - - def read_cat(_): - cat = Mock() - cat.name = "csep" - return cat - - with patch("csep.core.catalogs.CSEPCatalog.load_json", read_cat): - region = "CSEPRegion" - forecast = MagicMock(name="forecast", region=region) - - catt = Evaluation.get_catalog("path_to_cat", forecast) - self.assertEqual("csep", catt.name) - self.assertEqual(region, catt.region) - - region2 = "definitelyNotCSEPregion" - forecast2 = Mock(name="forecast", region=region2) - cats = Evaluation.get_catalog(["path1", "path2"], [forecast, forecast2]) - - self.assertIsInstance(cats, list) - self.assertEqual(cats[0].name, "csep") - self.assertEqual(cats[0].region, "CSEPRegion") - self.assertEqual(cats[1].region, "definitelyNotCSEPregion") - - with self.assertRaises(AttributeError): - Evaluation.get_catalog("path1", [forecast, forecast2]) - with self.assertRaises(IndexError): - Evaluation.get_catalog(["path1", "path2"], forecast) - assert True - def test_write_result(self): pass diff --git a/tests/unit/test_experiment.py b/tests/unit/test_experiment.py index f311dc9..0abdc79 100644 --- a/tests/unit/test_experiment.py +++ b/tests/unit/test_experiment.py @@ -29,8 +29,8 @@ def setUpClass(cls) -> None: def assertEqualExperiment(self, exp_a, exp_b): self.assertEqual(exp_a.name, exp_b.name) - self.assertEqual(exp_a.path.workdir, os.getcwd()) - self.assertEqual(exp_a.path.workdir, exp_b.path.workdir) + self.assertEqual(exp_a.registry.workdir, os.getcwd()) + self.assertEqual(exp_a.registry.workdir, exp_b.registry.workdir) self.assertEqual(exp_a.start_date, exp_b.start_date) self.assertEqual(exp_a.timewindows, exp_b.timewindows) self.assertEqual(exp_a.exp_class, exp_b.exp_class) @@ -65,7 +65,7 @@ def test_to_dict(self): dict_ = { "name": "test", "path": os.getcwd(), - "rundir": "results", + "run_dir": "results", "time_config": { "exp_class": "ti", "start_date": datetime(2020, 1, 1), diff --git a/tests/unit/test_model.py b/tests/unit/test_model.py index 5f8f41f..5991629 100644 --- a/tests/unit/test_model.py +++ b/tests/unit/test_model.py @@ -185,7 +185,7 @@ def setUp(self): # Set attributes on the mock objects self.mock_registry_instance.workdir = "/path/to/workdir" self.mock_registry_instance.path = "/path/to/model" - self.mock_registry_instance.get_path.return_value = ( + self.mock_registry_instance.get.return_value = ( "/path/to/args_file.txt" # Mocking the return of the registry call ) @@ -253,7 +253,7 @@ def test_create_forecast(self, prep_args_mock): self.model.create_forecast(tstring, force=True) self.mock_environment_instance.run_command.assert_called_once_with( - f'{self.func} {self.model.registry.get_path("args_file")}' + f'{self.func} {self.model.registry.get("args_file")}' ) @patch("builtins.open", new_callable=mock_open) @@ -278,9 +278,8 @@ def test_prepare_args(self, mock_json_dump, mock_json_load, mock_open_file): ] # Call the method - args_file_path = self.model.registry.get_path("args_file") + args_file_path = self.model.registry.get("args_file") self.model.prepare_args(start_date, end_date, custom_arg="value") - mock_open_file.assert_any_call(args_file_path, "r") mock_open_file.assert_any_call(args_file_path, "w") handle = mock_open_file() @@ -293,7 +292,7 @@ def test_prepare_args(self, mock_json_dump, mock_json_load, mock_open_file): ) json_file_path = "/path/to/args_file.json" - self.model.registry.get_path.return_value = json_file_path + self.model.registry.get.return_value = json_file_path self.model.prepare_args(start_date, end_date, custom_arg="value") mock_open_file.assert_any_call(json_file_path, "r") diff --git a/tests/unit/test_registry.py b/tests/unit/test_registry.py index 0bb8e14..5914976 100644 --- a/tests/unit/test_registry.py +++ b/tests/unit/test_registry.py @@ -16,7 +16,7 @@ def setUp(self): def test_call(self): self.registry_file._parse_arg = MagicMock(return_value="path") - result = self.registry_file.get_path("path") + result = self.registry_file.get("path") self.assertEqual(result, "/test/workdir/model.txt") @patch("os.path.isdir") @@ -62,7 +62,7 @@ def test_absdir(self): @patch("floatcsep.registry.exists") def test_fileexists(self, mock_exists): mock_exists.return_value = True - self.registry_file.get_path = MagicMock(return_value="/test/path/file.txt") + self.registry_file.get = MagicMock(return_value="/test/path/file.txt") self.assertTrue(self.registry_file.file_exists("file.txt")) @patch("os.makedirs") diff --git a/tests/unit/test_repositories.py b/tests/unit/test_repositories.py index 3b5029d..1f3a28b 100644 --- a/tests/unit/test_repositories.py +++ b/tests/unit/test_repositories.py @@ -1,11 +1,18 @@ import datetime import unittest -from unittest.mock import MagicMock, patch, PropertyMock -from floatcsep.registry import ForecastRegistry -from floatcsep.repository import CatalogForecastRepository, GriddedForecastRepository -from floatcsep.readers import ForecastParsers +from unittest.mock import MagicMock, patch, PropertyMock, mock_open + from csep.core.forecasts import GriddedForecast +from floatcsep.readers import ForecastParsers +from floatcsep.registry import ForecastRegistry +from floatcsep.repository import ( + CatalogForecastRepository, + GriddedForecastRepository, + ResultsRepository, + CatalogRepository, +) + class TestCatalogForecastRepository(unittest.TestCase): @@ -151,5 +158,60 @@ def test_equal(self, MockForecastRegistry): self.assertNotEqual(self.repo1, self.repo3) +class TestResultsRepository(unittest.TestCase): + + @patch("floatcsep.repository.ExperimentRegistry") + def setUp(self, MockRegistry): + self.mock_registry = MockRegistry() + self.results_repo = ResultsRepository(self.mock_registry) + + def test_initialization(self): + self.assertEqual(self.results_repo.registry, self.mock_registry) + + @patch("floatcsep.repository.EvaluationResult.from_dict") + @patch("builtins.open", new_callable=unittest.mock.mock_open, read_data='{"key": "value"}') + def test_load_result(self, mock_open, mock_from_dict): + mock_from_dict.return_value = "mocked_result" + result = self.results_repo._load_result("test", "window", "model") + self.assertEqual(result, "mocked_result") + + @patch.object(ResultsRepository, "_load_result", return_value="mocked_result") + def test_load_results(self, mock_load_result): + results = self.results_repo.load_results("test", "window", ["model1", "model2"]) + self.assertEqual(results, ["mocked_result", "mocked_result"]) + + @patch("json.dump") + @patch("builtins.open", new_callable=unittest.mock.mock_open) + def test_write_result(self, mock_open, mock_json_dump): + mock_result = MagicMock() + self.results_repo.write_result(mock_result, "test", "model", "window") + mock_open.assert_called_once() + mock_json_dump.assert_called_once() + + +class TestCatalogRepository(unittest.TestCase): + + @patch("floatcsep.repository.ExperimentRegistry") + def setUp(self, MockRegistry): + self.mock_registry = MockRegistry() + self.catalog_repo = CatalogRepository(self.mock_registry) + + def test_initialization(self): + self.assertEqual(self.catalog_repo.registry, self.mock_registry) + + @patch("floatcsep.repository.isfile", return_value=True) + def test_set_catalog(self, mock_isfile): + # Mock the registry's rel method to return the same path for simplicity + self.mock_registry.rel.return_value = "catalog_path" + + self.catalog_repo.set_catalog("catalog_path", {}, {}) + + # Check if _catpath is set correctly + self.assertEqual(self.catalog_repo._catpath, "catalog_path") + + # Check if _catalog is set correctly + self.assertEqual(self.catalog_repo._catalog, "catalog_path") + + if __name__ == "__main__": unittest.main()