Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion floatcsep/cmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def reproduce(config, **kwargs):
reproduced_exp.run()

original_config = reproduced_exp.original_config
original_exp = Experiment.from_yml(original_config, rundir=reproduced_exp.original_rundir)
original_exp = Experiment.from_yml(original_config, rundir=reproduced_exp.original_run_dir)
original_exp.stage_models()
original_exp.set_tasks()

Expand Down
88 changes: 31 additions & 57 deletions floatcsep/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import datetime
import json
import os
from typing import Dict, Callable, Union, Sequence, List

import numpy
from csep.core.catalogs import CSEPCatalog
from csep.core.forecasts import GriddedForecast
from csep.models import EvaluationResult
from matplotlib import pyplot

from floatcsep.model import Model
from floatcsep.registry import ExperimentRegistry
from floatcsep.utils import parse_csep_func, timewindow2str
from floatcsep.utils import parse_csep_func


class Evaluation:
Expand Down Expand Up @@ -76,6 +73,9 @@ def __init__(
self.markdown = markdown
self.type = Evaluation._TYPES.get(self.func.__name__)

self.results_repo = None
self.catalog_repo = None

@property
def type(self):
"""
Expand Down Expand Up @@ -123,7 +123,6 @@ def parse_plots(self, plot_func, plot_args, plot_kwargs):
def prepare_args(
self,
timewindow: Union[str, list],
catpath: Union[str, list],
model: Union[Model, Sequence[Model]],
ref_model: Union[Model, Sequence] = None,
region=None,
Expand Down Expand Up @@ -153,7 +152,7 @@ def prepare_args(
# Prepare argument tuple

forecast = model.get_forecast(timewindow, region)
catalog = self.get_catalog(catpath, forecast)
catalog = self.get_catalog(timewindow, forecast)

if isinstance(ref_model, Model):
# Args: (Fc, RFc, Cat)
Expand All @@ -169,29 +168,32 @@ def prepare_args(

return test_args

@staticmethod
def get_catalog(
catalog_path: Union[str, Sequence[str]],
self,
timewindow: Union[str, Sequence[str]],
forecast: Union[GriddedForecast, Sequence[GriddedForecast]],
) -> Union[CSEPCatalog, List[CSEPCatalog]]:
"""
Reads the catalog(s) from the given path(s). References the catalog region to the
forecast region.

Args:
catalog_path (str, list(str)): Path to the existing catalog
timewindow (str): Time window of the testing catalog
forecast (:class:`~csep.core.forecasts.GriddedForecast`): Forecast
object, onto which the catalog will be confronted for testing.

Returns:
"""
if isinstance(catalog_path, str):
eval_cat = CSEPCatalog.load_json(catalog_path)

if isinstance(timewindow, str):
# eval_cat = CSEPCatalog.load_json(catalog_path)
eval_cat = self.catalog_repo.get_test_cat(timewindow)
eval_cat.region = getattr(forecast, "region")

else:
eval_cat = [CSEPCatalog.load_json(i) for i in catalog_path]
eval_cat = [self.catalog_repo.get_test_cat(i) for i in timewindow]
if (len(forecast) != len(eval_cat)) or (not isinstance(forecast, Sequence)):
raise IndexError("Amount of passed catalogs and forecats must " "be the same")
raise IndexError("Amount of passed catalogs and forecasts must " "be the same")
for cat, fc in zip(eval_cat, forecast):
cat.region = getattr(fc, "region", None)

Expand All @@ -202,7 +204,6 @@ def compute(
timewindow: Union[str, list],
catalog: str,
model: Model,
path: str,
ref_model: Union[Model, Sequence[Model]] = None,
region=None,
) -> None:
Expand All @@ -216,65 +217,38 @@ def compute(
catalog (str): Path to the filtered catalog
model (Model, list[Model]): Model(s) to be evaluated
ref_model: Model to be used as reference
path: Path to store the Evaluation result
region: region to filter a catalog forecast.

Returns:
"""
test_args = self.prepare_args(
timewindow, catpath=catalog, model=model, ref_model=ref_model, region=region
timewindow, model=model, ref_model=ref_model, region=region
)

evaluation_result = self.func(*test_args, **self.func_kwargs)
self.write_result(evaluation_result, path)

@staticmethod
def write_result(result: EvaluationResult, path: str) -> None:
"""Dumps a test result into a json file."""

class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, numpy.integer):
return int(obj)
if isinstance(obj, numpy.floating):
return float(obj)
if isinstance(obj, numpy.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)

with open(path, "w") as _file:
json.dump(result.to_dict(), _file, indent=4, cls=NumpyEncoder)
if self.type in ["sequential", "sequential_comparative"]:
self.results_repo.write_result(evaluation_result, self, model, timewindow[-1])
else:
self.results_repo.write_result(evaluation_result, self, model, timewindow)

def read_results(
self,
window: Union[str, Sequence[datetime.datetime]],
models: List[Model],
tree: ExperimentRegistry,
self, window: Union[str, Sequence[datetime.datetime]], models: List[Model]
) -> List:
"""
Reads an Evaluation result for a given time window and returns a list of the results for
all tested models.
"""
test_results = []

if not isinstance(window, str):
wstr_ = timewindow2str(window)
else:
wstr_ = window

for i in models:
eval_path = tree(wstr_, "evaluations", self, i.name)
with open(eval_path, "r") as file_:
model_eval = EvaluationResult.from_dict(json.load(file_))
test_results.append(model_eval)
test_results = self.results_repo.load_results(self, window, models)

return test_results

def plot_results(
self,
timewindow: Union[str, List],
models: List[Model],
tree: ExperimentRegistry,
registry: ExperimentRegistry,
dpi: int = 300,
show: bool = False,
) -> None:
Expand All @@ -284,7 +258,7 @@ def plot_results(
Args:
timewindow: string representing the desired timewindow to plot
models: a list of :class:`floatcsep:models.Model`
tree: a :class:`floatcsep:models.PathTree` containing path of the results
registry: a :class:`floatcsep:models.PathTree` containing path of the results
dpi: Figure resolution with which to save
show: show in runtime
"""
Expand All @@ -296,8 +270,8 @@ def plot_results(

try:
for time_str in timewindow:
fig_path = tree(time_str, "figures", self.name)
results = self.read_results(time_str, models, tree)
fig_path = registry.get(time_str, "figures", self.name)
results = self.read_results(time_str, models)
ax = func(results, plot_args=fargs, **fkwargs)
if "code" in fargs:
exec(fargs["code"])
Expand All @@ -308,14 +282,14 @@ def plot_results(
except AttributeError as msg:
if self.type in ["consistency", "comparative"]:
for time_str in timewindow:
results = self.read_results(time_str, models, tree)
results = self.read_results(time_str, models)
for result, model in zip(results, models):
fig_name = f"{self.name}_{model.name}"

tree.paths[time_str]["figures"][fig_name] = os.path.join(
registry.paths[time_str]["figures"][fig_name] = os.path.join(
time_str, "figures", fig_name
)
fig_path = tree(time_str, "figures", fig_name)
fig_path = registry.get(time_str, "figures", fig_name)
ax = func(result, plot_args=fargs, **fkwargs, show=False)
if "code" in fargs:
exec(fargs["code"])
Expand All @@ -324,8 +298,8 @@ def plot_results(
pyplot.show()

elif self.type in ["sequential", "sequential_comparative", "batch"]:
fig_path = tree(timewindow[-1], "figures", self.name)
results = self.read_results(timewindow[-1], models, tree)
fig_path = registry.get(timewindow[-1], "figures", self.name)
results = self.read_results(timewindow[-1], models)
ax = func(results, plot_args=fargs, **fkwargs)

if "code" in fargs:
Expand Down
Loading