From 5dca5479f14becc0248afd08fd996c549a70640a Mon Sep 17 00:00:00 2001 From: dladrichem <136334482+dladrichem@users.noreply.github.com> Date: Tue, 9 Apr 2024 14:29:03 +0200 Subject: [PATCH] Refactor of the save/edit/delete enz. functions for all tabs (#367) * Fix dbs classes * Fix rebase issues * Add lock_count and start unit testing * Add tests * Temp fixes * Add test template * Add unit tests * Fixed deleting function * Fix scenarios in dbs controller * Remove tests for other branch * Fix scenario logging * Fix ruff and black * Fix spellcheck * Fix comments and remove locking * Make check_higher_level_usage function public * Add rerunning validators --- docs/index.qmd | 2 +- flood_adapt/api/benefits.py | 10 +- flood_adapt/api/events.py | 64 +- flood_adapt/api/measures.py | 12 +- flood_adapt/api/output.py | 4 +- flood_adapt/api/projections.py | 47 +- flood_adapt/api/scenarios.py | 10 +- flood_adapt/api/strategies.py | 8 +- flood_adapt/dbs_classes/dbs_benefit.py | 84 ++ flood_adapt/dbs_classes/dbs_event.py | 103 ++ flood_adapt/dbs_classes/dbs_interface.py | 133 +++ flood_adapt/dbs_classes/dbs_measure.py | 106 ++ flood_adapt/dbs_classes/dbs_projection.py | 57 ++ flood_adapt/dbs_classes/dbs_scenario.py | 125 +++ flood_adapt/dbs_classes/dbs_strategy.py | 57 ++ flood_adapt/dbs_classes/dbs_template.py | 274 +++++ flood_adapt/dbs_controller.py | 960 +----------------- .../object_model/interface/benefits.py | 10 +- .../object_model/interface/database.py | 112 +- .../object_model/interface/measures.py | 8 + .../object_model/interface/projections.py | 10 +- .../object_model/interface/scenarios.py | 10 +- .../object_model/interface/strategies.py | 10 +- flood_adapt/object_model/scenario.py | 8 +- tests/test_api/test_events.py | 4 +- tests/test_api/test_projections.py | 4 +- tests/test_api/test_scenarios.py | 4 +- tests/test_api/test_strategy.py | 4 +- 28 files changed, 1082 insertions(+), 1158 deletions(-) create mode 100644 flood_adapt/dbs_classes/dbs_benefit.py create mode 100644 flood_adapt/dbs_classes/dbs_event.py create mode 100644 flood_adapt/dbs_classes/dbs_interface.py create mode 100644 flood_adapt/dbs_classes/dbs_measure.py create mode 100644 flood_adapt/dbs_classes/dbs_projection.py create mode 100644 flood_adapt/dbs_classes/dbs_scenario.py create mode 100644 flood_adapt/dbs_classes/dbs_strategy.py create mode 100644 flood_adapt/dbs_classes/dbs_template.py diff --git a/docs/index.qmd b/docs/index.qmd index ab4b0257d..8cbc70b7d 100644 --- a/docs/index.qmd +++ b/docs/index.qmd @@ -122,7 +122,7 @@ Create a warning . Cross reference to another page or figure: (remove quotation marks) -"[Write here the text wou want to see]"("add the reference") +"[Write here the text you want to see]"("add the reference") ### Cross reference page "[Setup Guide]"(/3_setup_guide/index.qmd)" diff --git a/flood_adapt/api/benefits.py b/flood_adapt/api/benefits.py index 4ed8464a0..538b59b14 100644 --- a/flood_adapt/api/benefits.py +++ b/flood_adapt/api/benefits.py @@ -10,11 +10,11 @@ def get_benefits(database: IDatabase) -> dict[str, Any]: # sorting and filtering either with PyQt table or in the API - return database.get_benefits() + return database.benefits.list_objects() def get_benefit(name: str, database: IDatabase) -> IBenefit: - return database.get_benefit(name) + return database.benefits.get(name) def create_benefit(attrs: dict[str, Any], database: IDatabase) -> IBenefit: @@ -22,15 +22,15 @@ def create_benefit(attrs: dict[str, Any], database: IDatabase) -> IBenefit: def save_benefit(benefit: IBenefit, database: IDatabase) -> None: - database.save_benefit(benefit) + database.benefits.save(benefit) def edit_benefit(benefit: IBenefit, database: IDatabase) -> None: - database.edit_benefit(benefit) + database.benefits.edit(benefit) def delete_benefit(name: str, database: IDatabase) -> None: - database.delete_benefit(name) + database.benefits.delete(name) def check_benefit_scenarios(benefit: IBenefit, database: IDatabase) -> pd.DataFrame: diff --git a/flood_adapt/api/events.py b/flood_adapt/api/events.py index ac8e3c3ef..c6170fb48 100644 --- a/flood_adapt/api/events.py +++ b/flood_adapt/api/events.py @@ -23,11 +23,11 @@ def get_events(database: IDatabase) -> dict[str, Any]: # use PyQt table / sorting and filtering either with PyQt table or in the API - return database.get_events() + return database.events.list_objects() def get_event(name: str, database: IDatabase) -> IEvent: - return database.get_event(name) + return database.events.get(name) def get_event_mode(name: str, database: IDatabase) -> str: @@ -100,7 +100,7 @@ def create_historical_hurricane_event(attrs: dict[str, Any]) -> IHistoricalHurri def save_event_toml(event: IEvent, database: IDatabase) -> None: - database.save_event(event) + database.events.save(event) def save_timeseries_csv( @@ -110,17 +110,17 @@ def save_timeseries_csv( def edit_event(event: IEvent, database: IDatabase) -> None: - database.edit_event(event) + database.events.edit(event) def delete_event(name: str, database: IDatabase) -> None: - database.delete_event(name) + database.events.delete(name) def copy_event( old_name: str, database: IDatabase, new_name: str, new_description: str ) -> None: - database.copy_event(old_name, new_name, new_description) + database.events.copy(old_name, new_name, new_description) def download_wl_data( @@ -163,55 +163,3 @@ def plot_wind( def save_cyclone_track(event: IEvent, track: TropicalCyclone, database: IDatabase): database.write_cyc(event, track) - - -# def get_event(name: str) -> dict(): # get attributes -# pass - - -# # on click add event -# def create_new_event(template: str) -> dict(): # get attributes -# pass - - -# def set_event(event: dict): # set attributes -# pass - - -# # in event pop-up window on click OK -# def save_event(name: str): -# pass - - -# # on click hurricane: -# def get_hurricane_tracks(): -# pass - - -# # on click historical from nearshore: -# def create_historical_nearshore_event() -> ( -# dict() -# ): # gives back empty object to populate pop-up window, different options for discharge are in the class #TODO: ask Julian -# pass - - -# # on click plot water level boundary -# def get_waterlevel_timeseries(event: dict) -> dict(): -# pass - - -# # on click plot rainfall -# def get_rainfall_timeseries(event: dict): -# pass - - -# # on click delete event -# def check_delete_event() -> ( -# bool -# ): # , str: # str contains full error message, empty if False -# pass - - -# # on click copy event -# def copy_event(name_orig: str, name_copy: str): -# pass diff --git a/flood_adapt/api/measures.py b/flood_adapt/api/measures.py index ace867e64..888e5d4db 100644 --- a/flood_adapt/api/measures.py +++ b/flood_adapt/api/measures.py @@ -19,11 +19,11 @@ def get_measures(database: IDatabase) -> dict[str, Any]: - return database.get_measures() + return database.measures.list_objects() def get_measure(name: str, database: IDatabase) -> IMeasure: - return database.get_measure(name) + return database.measures.get(name) def create_measure( @@ -64,21 +64,21 @@ def create_measure( def save_measure(measure: IMeasure, database: IDatabase) -> None: - database.save_measure(measure) + database.measures.save(measure) def edit_measure(measure: IMeasure, database: IDatabase) -> None: - database.edit_measure(measure) + database.measures.edit(measure) def delete_measure(name: str, database: IDatabase) -> None: - database.delete_measure(name) + database.measures.delete(name) def copy_measure( old_name: str, database: IDatabase, new_name: str, new_description: str ) -> None: - database.copy_measure(old_name, new_name, new_description) + database.measures.copy(old_name, new_name, new_description) # Green infrastructure diff --git a/flood_adapt/api/output.py b/flood_adapt/api/output.py index ae1aa758a..debc43b34 100644 --- a/flood_adapt/api/output.py +++ b/flood_adapt/api/output.py @@ -59,7 +59,7 @@ def get_obs_point_timeseries(name: str, database: IDatabase) -> gpd.GeoDataFrame The HTML strings of the water level timeseries """ # Get the direct_impacts objects from the scenario - hazard = database.get_scenario(name).direct_impacts.hazard + hazard = database.scenarios.get(name).direct_impacts.hazard # Check if the scenario has run if not hazard.has_run_check(): @@ -93,7 +93,7 @@ def get_infographic(name: str, database: IDatabase) -> str: The HTML string of the infographic. """ # Get the direct_impacts objects from the scenario - impact = database.get_scenario(name).direct_impacts + impact = database.scenarios.get(name).direct_impacts # Check if the scenario has run if not impact.has_run_check(): diff --git a/flood_adapt/api/projections.py b/flood_adapt/api/projections.py index 9e2df6f36..9860deec0 100644 --- a/flood_adapt/api/projections.py +++ b/flood_adapt/api/projections.py @@ -7,11 +7,11 @@ def get_projections(database: IDatabase) -> dict[str, Any]: # sorting and filtering either with PyQt table or in the API - return database.get_projections() + return database.projections.list_objects() def get_projection(name: str, database: IDatabase) -> IProjection: - return database.get_projection(name) + return database.projections.get(name) def create_projection(attrs: dict[str, Any]) -> IProjection: @@ -19,21 +19,21 @@ def create_projection(attrs: dict[str, Any]) -> IProjection: def save_projection(projection: IProjection, database: IDatabase) -> None: - database.save_projection(projection) + database.projections.save(projection) def edit_projection(projection: IProjection, database: IDatabase) -> None: - database.edit_projection(projection) + database.projections.edit(projection) def delete_projection(name: str, database: IDatabase) -> None: - database.delete_projection(name) + database.projections.delete(name) def copy_projection( old_name: str, database: IDatabase, new_name: str, new_description: str ) -> None: - database.copy_projection(old_name, new_name, new_description) + database.projections.copy(old_name, new_name, new_description) def get_slr_scn_names(database: IDatabase) -> list: @@ -46,38 +46,3 @@ def interp_slr(database: IDatabase, slr_scenario: str, year: float) -> float: def plot_slr_scenarios(database: IDatabase) -> str: return database.plot_slr_scenarios() - - -# # on click add projection -# def create_new_projection(template: str) -> dict(): # get attributes -# pass - - -# def set_projection(event: dict): # set attributes -# pass - - -# # on click edit projection -# def get_projection(name: str) -> dict(): # get attributes -# # incl physical and spcio-economic -# pass - - -# def set_projection(event: dict): # set attributes -# pass - - -# # on click copy projection -# # get_projection -# # set_projection - - -# # on click delete projection -# def remove_projection(name: str) -> dict(): # get attributes -# # remove object from database object and toml file, both socio-economic and physical -# pass - - -# # in projection pop-up window on click OK -# def save_projection(name: str): -# pass diff --git a/flood_adapt/api/scenarios.py b/flood_adapt/api/scenarios.py index a25d7ef84..96dade258 100644 --- a/flood_adapt/api/scenarios.py +++ b/flood_adapt/api/scenarios.py @@ -7,11 +7,11 @@ def get_scenarios(database: IDatabase) -> dict[str, Any]: # sorting and filtering either with PyQt table or in the API - return database.get_scenarios() + return database.scenarios.list_objects() def get_scenario(name: str, database: IDatabase) -> IScenario: - return database.get_scenario(name) + return database.scenarios.get(name) def create_scenario(attrs: dict[str, Any], database: IDatabase) -> IScenario: @@ -34,18 +34,18 @@ def save_scenario(scenario: IScenario, database: IDatabase) -> (bool, str): The error message if the scenario was not saved successfully. """ try: - database.save_scenario(scenario) + database.scenarios.save(scenario) return True, "" except Exception as e: return False, str(e) def edit_scenario(scenario: IScenario, database: IDatabase) -> None: - database.edit_scenario(scenario) + database.scenarios.edit(scenario) def delete_scenario(name: str, database: IDatabase) -> None: - database.delete_scenario(name) + database.scenarios.delete(name) def run_scenario(name: Union[str, list[str]], database: IDatabase) -> None: diff --git a/flood_adapt/api/strategies.py b/flood_adapt/api/strategies.py index 0df94ca0b..da50fe7c7 100644 --- a/flood_adapt/api/strategies.py +++ b/flood_adapt/api/strategies.py @@ -7,11 +7,11 @@ def get_strategies(database: IDatabase) -> dict[str, Any]: # sorting and filtering either with PyQt table or in the API - return database.get_strategies() + return database.strategies.list_objects() def get_strategy(name: str, database: IDatabase) -> IStrategy: - return database.get_strategy(name) + return database.strategies.get(name) def create_strategy(attrs: dict[str, Any], database: IDatabase) -> IStrategy: @@ -19,8 +19,8 @@ def create_strategy(attrs: dict[str, Any], database: IDatabase) -> IStrategy: def save_strategy(strategy: IStrategy, database: IDatabase) -> None: - database.save_strategy(strategy) + database.strategies.save(strategy) def delete_strategy(name: str, database: IDatabase) -> None: - database.delete_strategy(name) + database.strategies.delete(name) diff --git a/flood_adapt/dbs_classes/dbs_benefit.py b/flood_adapt/dbs_classes/dbs_benefit.py new file mode 100644 index 000000000..b3a3330cc --- /dev/null +++ b/flood_adapt/dbs_classes/dbs_benefit.py @@ -0,0 +1,84 @@ +import shutil + +from flood_adapt.dbs_classes.dbs_template import DbsTemplate +from flood_adapt.object_model.benefit import Benefit +from flood_adapt.object_model.interface.benefits import IBenefit + + +class DbsBenefit(DbsTemplate): + _type = "benefit" + _folder_name = "benefits" + _object_model_class = Benefit + + def save(self, benefit: IBenefit, overwrite: bool = False): + """Saves a benefit object in the database. + + Parameters + ---------- + measure : IBenefit + object of scenario type + overwrite : bool, optional + whether to overwrite existing benefit with same name, by default False + + Raises + ------ + ValueError + Raise error if name is already in use. Names of benefits assessments should be unique. + """ + + # Check if all scenarios are created + if not all(benefit.scenarios["scenario created"] != "No"): + raise ValueError( + f"'{benefit.attrs.name}' name cannot be created before all necessary scenarios are created." + ) + + # Save the benefit + super().save(benefit, overwrite=overwrite) + + def delete(self, name: str, toml_only: bool = False): + """Deletes an already existing benefit in the database. + + Parameters + ---------- + name : str + name of the benefit + toml_only : bool, optional + whether to only delete the toml file or the entire folder. If the folder is empty after deleting the toml, + it will always be deleted. By default False + + Raises + ------ + ValueError + Raise error if benefit has already model output + """ + + # First delete the benefit + super().delete(name, toml_only=toml_only) + + # Delete output if edited + output_path = self._database.output_path / "Benefits" / name + + if output_path.exists(): + shutil.rmtree(output_path, ignore_errors=True) + + def edit(self, benefit: IBenefit): + """Edits an already existing benefit in the database. + + Parameters + ---------- + benefit : IBenefit + benefit to be edited in the database + + Raises + ------ + ValueError + Raise error if name is already in use. + """ + # Check if it is possible to edit the benefit. + super().edit(benefit) + + # Delete output if edited + output_path = self._database.output_path / "Benefits" / benefit.attrs.name + + if output_path.exists(): + shutil.rmtree(output_path, ignore_errors=True) diff --git a/flood_adapt/dbs_classes/dbs_event.py b/flood_adapt/dbs_classes/dbs_event.py new file mode 100644 index 000000000..8bd760c11 --- /dev/null +++ b/flood_adapt/dbs_classes/dbs_event.py @@ -0,0 +1,103 @@ +from pathlib import Path +from typing import Any + +from flood_adapt.dbs_classes.dbs_template import DbsTemplate +from flood_adapt.object_model.hazard.event.event import Event +from flood_adapt.object_model.hazard.event.event_factory import EventFactory +from flood_adapt.object_model.hazard.hazard import Hazard +from flood_adapt.object_model.interface.events import IEvent +from flood_adapt.object_model.scenario import Scenario + + +class DbsEvent(DbsTemplate): + _type = "event" + _folder_name = "events" + _object_model_class = Event + + def get(self, name: str) -> IEvent: + """Returns an event object. + + Parameters + ---------- + name : str + name of the event to be returned + + Returns + ------- + IEvent + event object + """ + # Get event path + event_path = self._path / f"{name}" / f"{name}.toml" + + # Check if the object exists + if not Path(event_path).is_file(): + raise ValueError(f"{self._type.capitalize()} '{name}' does not exist.") + + # Load event + event_template = Event.get_template(event_path) + event = EventFactory.get_event(event_template).load_file(event_path) + return event + + def list_objects(self) -> dict[str, Any]: + """Returns a dictionary with info on the events that currently + exist in the database. + + Returns + ------- + dict[str, Any] + Includes 'name', 'description', 'path' and 'last_modification_date' info + """ + events = self._get_object_list() + objects = [Hazard.get_event_object(path) for path in events["path"]] + events["name"] = [obj.attrs.name for obj in objects] + events["description"] = [obj.attrs.description for obj in objects] + events["objects"] = objects + return events + + def _check_standard_objects(self, name: str) -> bool: + """Checks if an event is a standard event. + + Parameters + ---------- + name : str + name of the event to be checked + + Returns + ------- + bool + True if the event is a standard event, False otherwise + """ + # Check if event is a standard event + if self._database.site.attrs.standard_objects.events: + if name in self._database.site.attrs.standard_objects.events: + return True + + return False + + def check_higher_level_usage(self, name: str) -> list[str]: + """Checks if an event is used in a scenario. + + Parameters + ---------- + name : str + name of the event to be checked + + Returns + ------- + list[str] + list of scenarios that use the event + """ + # Get all the scenarios + scenarios = [ + Scenario.load_file(path) + for path in self._database.scenarios.list_objects()["path"] + ] + + # Check if event is used in a scenario + used_in_scenario = [ + scenario.attrs.name + for scenario in scenarios + if name == scenario.attrs.event + ] + return used_in_scenario diff --git a/flood_adapt/dbs_classes/dbs_interface.py b/flood_adapt/dbs_classes/dbs_interface.py new file mode 100644 index 000000000..97106c20b --- /dev/null +++ b/flood_adapt/dbs_classes/dbs_interface.py @@ -0,0 +1,133 @@ +from abc import ABC, abstractmethod +from typing import Any, Union + +from flood_adapt.object_model.interface.benefits import IBenefit +from flood_adapt.object_model.interface.events import IEvent +from flood_adapt.object_model.interface.measures import IMeasure +from flood_adapt.object_model.interface.projections import IProjection +from flood_adapt.object_model.interface.scenarios import IScenario +from flood_adapt.object_model.interface.strategies import IStrategy + +ObjectModel = Union[IScenario, IEvent, IProjection, IStrategy, IMeasure, IBenefit] + + +class AbstractDatabaseElement(ABC): + def __init__(self): + """ + Initialize any necessary attributes. + """ + pass + + @abstractmethod + def get(self, name: str) -> ObjectModel: + """Returns the object of the type of the database with the given name. + + Parameters + ---------- + name : str + name of the object to be returned + + Returns + ------- + ObjectModel + object of the type of the specified object model + """ + pass + + @abstractmethod + def list_objects(self) -> dict[str, Any]: + """Returns a dictionary with info on the objects that currently + exist in the database. + + Returns + ------- + dict[str, Any] + Includes 'name', 'description', 'path' and 'last_modification_date' info, as well as the objects themselves + """ + pass + + @abstractmethod + def copy(self, old_name: str, new_name: str, new_description: str): + """Copies (duplicates) an existing object, and gives it a new name. + + Parameters + ---------- + old_name : str + name of the existing measure + new_name : str + name of the new measure + new_description : str + description of the new measure + """ + pass + + @abstractmethod + def save(self, object_model: ObjectModel, overwrite: bool = False): + """Saves an object in the database. This only saves the toml file. If the object also contains a geojson file, + this should be saved separately. + + Parameters + ---------- + object_model : ObjectModel + object to be saved in the database + overwrite : OverwriteMode, optional + whether to overwrite the object if it already exists in the + database, by default False + + Raises + ------ + ValueError + Raise error if name is already in use. + """ + pass + + @abstractmethod + def edit(self, object_model: ObjectModel): + """Edits an already existing object in the database. + + Parameters + ---------- + object : ObjectModel + object to be edited in the database + + Raises + ------ + ValueError + Raise error if name is already in use. + """ + pass + + @abstractmethod + def delete(self, name: str, toml_only: bool = False): + """Deletes an already existing object in the database. + + Parameters + ---------- + name : str + name of the object to be deleted + toml_only : bool, optional + whether to only delete the toml file or the entire folder. If the folder is empty after deleting the toml, + it will always be deleted. By default False + + Raises + ------ + ValueError + Raise error if object to be deleted is already in use. + """ + pass + + @abstractmethod + def check_higher_level_usage(self, name: str) -> list[str]: + """Checks if an object is used in a higher level object. + + Parameters + ---------- + name : str + name of the object to be checked + + Returns + ------- + list[str] + list of higher level objects that use the object + """ + pass diff --git a/flood_adapt/dbs_classes/dbs_measure.py b/flood_adapt/dbs_classes/dbs_measure.py new file mode 100644 index 000000000..4e3bad07f --- /dev/null +++ b/flood_adapt/dbs_classes/dbs_measure.py @@ -0,0 +1,106 @@ +from typing import Any + +import geopandas as gpd + +from flood_adapt.dbs_classes.dbs_template import DbsTemplate +from flood_adapt.object_model.interface.measures import IMeasure +from flood_adapt.object_model.measure import Measure +from flood_adapt.object_model.measure_factory import MeasureFactory +from flood_adapt.object_model.strategy import Strategy + + +class DbsMeasure(DbsTemplate): + _type = "measure" + _folder_name = "measures" + _object_model_class = Measure + + def get(self, name: str) -> IMeasure: + """Returns a measure object. + + Parameters + ---------- + name : str + name of the measure to be returned + + Returns + ------- + IMeasure + measure object + """ + measure_path = self._path / name / f"{name}.toml" + measure = MeasureFactory.get_measure_object(measure_path) + return measure + + def list_objects(self) -> dict[str, Any]: + """Returns a dictionary with info on the measures that currently + exist in the database. + + Returns + ------- + dict[str, Any] + Includes 'name', 'description', 'path' and 'last_modification_date' info + """ + measures = self._get_object_list() + objects = [MeasureFactory.get_measure_object(path) for path in measures["path"]] + measures["name"] = [obj.attrs.name for obj in objects] + measures["description"] = [obj.attrs.description for obj in objects] + measures["objects"] = objects + + geometries = [] + for path, obj in zip(measures["path"], objects): + # If polygon is used read the polygon file + if obj.attrs.polygon_file: + file_path = path.parent.joinpath(obj.attrs.polygon_file) + if not file_path.exists(): + raise FileNotFoundError( + f"Polygon file {obj.attrs.polygon_file} for measure {obj.attrs.name} does not exist." + ) + geometries.append(gpd.read_file(file_path)) + # If aggregation area is used read the polygon from the aggregation area name + elif obj.attrs.aggregation_area_name: + if obj.attrs.aggregation_area_type not in self._database.aggr_areas: + raise ValueError( + f"Aggregation area type {obj.attrs.aggregation_area_type} for measure {obj.attrs.name} does not exist." + ) + gdf = self._database.aggr_areas[obj.attrs.aggregation_area_type] + if obj.attrs.aggregation_area_name not in gdf["name"].to_numpy(): + raise ValueError( + f"Aggregation area name {obj.attrs.aggregation_area_name} for measure {obj.attrs.name} does not exist." + ) + geometries.append( + gdf.loc[gdf["name"] == obj.attrs.aggregation_area_name, :] + ) + # Else assign a None value + else: + geometries.append(None) + + measures["geometry"] = geometries + return measures + + def check_higher_level_usage(self, name: str) -> list[str]: + """Checks if a measure is used in a strategy. + + Parameters + ---------- + name : str + name of the measure to be checked + + Returns + ------- + list[str] + list of strategies that use the measure + """ + # Get all the strategies + strategies = [ + Strategy.load_file(path) + for path in self._database.strategies.list_objects()["path"] + ] + + # Check if measure is used in a strategy + used_in_strategy = [ + strategy.attrs.name + for strategy in strategies + for measure in strategy.attrs.measures + if name == measure + ] + return used_in_strategy diff --git a/flood_adapt/dbs_classes/dbs_projection.py b/flood_adapt/dbs_classes/dbs_projection.py new file mode 100644 index 000000000..6766906ff --- /dev/null +++ b/flood_adapt/dbs_classes/dbs_projection.py @@ -0,0 +1,57 @@ +from flood_adapt.dbs_classes.dbs_template import DbsTemplate +from flood_adapt.object_model.projection import Projection +from flood_adapt.object_model.scenario import Scenario + + +class DbsProjection(DbsTemplate): + _type = "projection" + _folder_name = "projections" + _object_model_class = Projection + + def _check_standard_objects(self, name: str) -> bool: + """Checks if a projection is a standard projection. + + Parameters + ---------- + name : str + name of the projection to be checked + + Raises + ------ + ValueError + Raise error if projection is a standard projection. + """ + # Check if projection is a standard projection + if self._database.site.attrs.standard_objects.projections: + if name in self._database.site.attrs.standard_objects.projections: + return True + + return False + + def check_higher_level_usage(self, name: str) -> list[str]: + """Checks if a projection is used in a scenario. + + Parameters + ---------- + name : str + name of the projection to be checked + + Returns + ------- + list[str] + list of scenarios that use the projection + """ + # Get all the scenarios + scenarios = [ + Scenario.load_file(path) + for path in self._database.scenarios.list_objects()["path"] + ] + + # Check if projection is used in a scenario + used_in_scenario = [ + scenario.attrs.name + for scenario in scenarios + if name == scenario.attrs.projection + ] + + return used_in_scenario diff --git a/flood_adapt/dbs_classes/dbs_scenario.py b/flood_adapt/dbs_classes/dbs_scenario.py new file mode 100644 index 000000000..f15e07baa --- /dev/null +++ b/flood_adapt/dbs_classes/dbs_scenario.py @@ -0,0 +1,125 @@ +import shutil +from typing import Any + +from flood_adapt.dbs_classes.dbs_template import DbsTemplate +from flood_adapt.object_model.benefit import Benefit +from flood_adapt.object_model.interface.scenarios import IScenario +from flood_adapt.object_model.scenario import Scenario + + +class DbsScenario(DbsTemplate): + _type = "scenario" + _folder_name = "scenarios" + _object_model_class = Scenario + + def get(self, name: str) -> IScenario: + """Returns a scenario object. + + Parameters + ---------- + name : str + name of the scenario to be returned + + Returns + ------- + IScenario + scenario object + """ + return super().get(name).init_object_model() + + def list_objects(self) -> dict[str, Any]: + """Returns a dictionary with info on the events that currently + exist in the database. + + Returns + ------- + dict[str, Any] + Includes 'name', 'description', 'path' and 'last_modification_date' info + """ + scenarios = super().list_objects() + objects = scenarios["objects"] + scenarios["Projection"] = [obj.attrs.projection for obj in objects] + scenarios["Event"] = [obj.attrs.event for obj in objects] + scenarios["Strategy"] = [obj.attrs.strategy for obj in objects] + scenarios["finished"] = [ + obj.init_object_model().direct_impacts.has_run for obj in objects + ] + + return scenarios + + def delete(self, name: str, toml_only: bool = False): + """Deletes an already existing scenario in the database. + + Parameters + ---------- + name : str + name of the scenario to be deleted + toml_only : bool, optional + whether to only delete the toml file or the entire folder. If the folder is empty after deleting the toml, + it will always be deleted. By default False + + Raises + ------ + ValueError + Raise error if scenario to be deleted is already in use. + """ + + # First delete the scenario + super().delete(name, toml_only) + + # Then delete the results + results_path = self._database.output_path / "Scenarios" / name + if results_path.exists(): + shutil.rmtree(results_path, ignore_errors=False) + + def edit(self, scenario: IScenario): + """Edits an already existing scenario in the database. + + Parameters + ---------- + scenario : IScenario + scenario to be edited in the database + + Raises + ------ + ValueError + Raise error if name is already in use. + """ + # Check if it is possible to edit the scenario. This then also covers checking whether the + # scenario is already used in a higher level object. If this is the case, it cannot be edited. + super().edit(scenario) + + # Delete output if edited + output_path = self._database.output_path / "Scenarios" / scenario.attrs.name + + if output_path.exists(): + shutil.rmtree(output_path, ignore_errors=True) + + def check_higher_level_usage(self, name: str) -> list[str]: + """Checks if a scenario is used in a benefit. + + Parameters + ---------- + name : str + name of the scenario to be checked + + Returns + ------- + list[str] + list of benefits that use the scenario + """ + # Get all the benefits + benefits = [ + Benefit.load_file(path) + for path in self._database.benefits.list_objects()["path"] + ] + + # Check in which benefits this scenario is used + used_in_benefit = [ + benefit.attrs.name + for benefit in benefits + for scenario in benefit.check_scenarios()["scenario created"].to_list() + if name == scenario + ] + + return used_in_benefit diff --git a/flood_adapt/dbs_classes/dbs_strategy.py b/flood_adapt/dbs_classes/dbs_strategy.py new file mode 100644 index 000000000..bc66683ec --- /dev/null +++ b/flood_adapt/dbs_classes/dbs_strategy.py @@ -0,0 +1,57 @@ +from flood_adapt.dbs_classes.dbs_template import DbsTemplate +from flood_adapt.object_model.scenario import Scenario +from flood_adapt.object_model.strategy import Strategy + + +class DbsStrategy(DbsTemplate): + _type = "strategy" + _folder_name = "strategies" + _object_model_class = Strategy + + def _check_standard_objects(self, name: str) -> bool: + """Checks if a strategy is a standard strategy. + + Parameters + ---------- + name : str + name of the strategy to be checked + + Raises + ------ + ValueError + Raise error if strategy is a standard strategy. + """ + # Check if strategy is a standard strategy + if self._database.site.attrs.standard_objects.strategies: + if name in self._database.site.attrs.standard_objects.strategies: + return True + + return False + + def check_higher_level_usage(self, name: str) -> list[str]: + """Checks if a strategy is used in a scenario. + + Parameters + ---------- + name : str + name of the strategy to be checked + + Returns + ------- + list[str] + list of scenarios that use the strategy + """ + # Get all the scenarios + scenarios = [ + Scenario.load_file(path) + for path in self._database.scenarios.list_objects()["path"] + ] + + # Check if strategy is used in a scenario + used_in_scenario = [ + scenario.attrs.name + for scenario in scenarios + if name == scenario.attrs.strategy + ] + + return used_in_scenario diff --git a/flood_adapt/dbs_classes/dbs_template.py b/flood_adapt/dbs_classes/dbs_template.py new file mode 100644 index 000000000..a896ac99d --- /dev/null +++ b/flood_adapt/dbs_classes/dbs_template.py @@ -0,0 +1,274 @@ +import shutil +from datetime import datetime +from pathlib import Path +from typing import Union + +from flood_adapt.dbs_classes.dbs_interface import AbstractDatabaseElement +from flood_adapt.object_model.interface.benefits import IBenefit +from flood_adapt.object_model.interface.database import IDatabase +from flood_adapt.object_model.interface.events import IEvent +from flood_adapt.object_model.interface.measures import IMeasure +from flood_adapt.object_model.interface.projections import IProjection +from flood_adapt.object_model.interface.scenarios import IScenario +from flood_adapt.object_model.interface.strategies import IStrategy + +ObjectModel = Union[IScenario, IEvent, IProjection, IStrategy, IMeasure, IBenefit] + + +class DbsTemplate(AbstractDatabaseElement): + _type = "" + _folder_name = "" + _object_model_class = None + + def __init__(self, database: IDatabase): + """ + Initialize any necessary attributes. + """ + self.input_path = database.input_path + self._path = self.input_path / self._folder_name + self._database = database + + def get(self, name: str) -> ObjectModel: + """Returns an object of the type of the database with the given name. + + Parameters + ---------- + name : str + name of the object to be returned + + Returns + ------- + ObjectModel + object of the type of the specified object model + """ + # Make the full path to the object + full_path = self._path / name / f"{name}.toml" + + # Check if the object exists + if not Path(full_path).is_file(): + raise ValueError(f"{self._type.capitalize()} '{name}' does not exist.") + + # Load and return the object + object_model = self._object_model_class.load_file(full_path) + return object_model + + def list_objects(self): + """Returns a dictionary with info on the objects that currently + exist in the database. + + Returns + ------- + dict[str, Any] + Includes 'name', 'description', 'path' and 'last_modification_date' info, as well as the objects themselves + """ + # Check if all objects exist + object_list = self._get_object_list() + if not all(Path(path).is_file() for path in object_list["path"]): + raise ValueError( + f"Error in {self._type} database. Some {self._type} are missing from the database." + ) + + # Load all objects + objects = [ + self._object_model_class.load_file(path) for path in object_list["path"] + ] + + # From the loaded objects, get the name and description and add them to the object_list + object_list["name"] = [obj.attrs.name for obj in objects] + object_list["description"] = [obj.attrs.description for obj in objects] + object_list["objects"] = objects + return object_list + + def copy(self, old_name: str, new_name: str, new_description: str): + """Copies (duplicates) an existing object, and gives it a new name. + + Parameters + ---------- + old_name : str + name of the existing measure + new_name : str + name of the new measure + new_description : str + description of the new measure + """ + # Check if the provided old_name is valid + if old_name not in self.list_objects()["name"]: + raise ValueError(f"'{old_name}' {self._type} does not exist.") + + # First do a get and change the name and description + copy_object = self.get(old_name) + copy_object.attrs.name = new_name + copy_object.attrs.description = new_description + + # After changing the name and description, receate the model to re-trigger the validators + copy_object.attrs = type(copy_object.attrs)(**copy_object.attrs.dict()) + + # Then a save. Checking whether the name is already in use is done in the save function + self.save(copy_object) + + # Then save all the accompanied files + src = self._path / old_name + dest = self._path / new_name + for file in src.glob("*"): + if "toml" not in file.name: + shutil.copy(file, dest / file.name) + + def save(self, object_model: ObjectModel, overwrite: bool = False): + """Saves an object in the database. This only saves the toml file. If the object also contains a geojson file, + this should be saved separately. + + Parameters + ---------- + object_model : ObjectModel + object to be saved in the database + overwrite : bool, optional + whether to overwrite the object if it already exists in the + database, by default False + + Raises + ------ + ValueError + Raise error if name is already in use. + """ + + object_exists = object_model.attrs.name in self.list_objects()["name"] + + # If you want to overwrite the object, and the object already exists, first delete it. If it exists and you + # don't want to overwrite, raise an error. + if overwrite and object_exists: + self.delete(object_model.attrs.name, toml_only=True) + elif not overwrite and object_exists: + raise ValueError( + f"'{object_model.attrs.name}' name is already used by another {self._type}. Choose a different name" + ) + + # If the folder doesnt exist yet, make the folder and save the object + if not (self._path / object_model.attrs.name).exists(): + (self._path / object_model.attrs.name).mkdir() + + object_model.save( + self._path / object_model.attrs.name / f"{object_model.attrs.name}.toml" + ) + + def edit(self, object_model: ObjectModel): + """Edits an already existing object in the database. + + Parameters + ---------- + object : ObjectModel + object to be edited in the database + + Raises + ------ + ValueError + Raise error if name is already in use. + """ + # Check if the object exists + if object_model.attrs.name not in self.list_objects()["name"]: + raise ValueError( + f"'{object_model.attrs.name}' {self._type} does not exist. You cannot edit an {self._type} that does not exist." + ) + + # Check if it is possible to delete the object by saving with overwrite. This then + # also covers checking whether the object is a standard object, is already used in + # a higher level object. If any of these are the case, it cannot be deleted. + self.save(object_model, overwrite=True) + + def delete(self, name: str, toml_only: bool = False): + """Deletes an already existing object in the database. + + Parameters + ---------- + name : str + name of the object to be deleted + toml_only : bool, optional + whether to only delete the toml file or the entire folder. If the folder is empty after deleting the toml, + it will always be deleted. By default False + + Raises + ------ + ValueError + Raise error if object to be deleted is already in use. + """ + # Check if the object is a standard object. If it is, raise an error + if self._check_standard_objects(name): + raise ValueError( + f"'{name}' cannot be deleted/modified since it is a standard {self._type}." + ) + + # Check if object is used in a higher level object. If it is, raise an error + if used_in := self.check_higher_level_usage(name): + raise ValueError( + f"'{name}' {self._type} cannot be deleted/modified since it is already used in: {', '.join(used_in)}" + ) + + # Once all checks are passed, delete the object + path = self._path / name + if toml_only: + # Only delete the toml file + toml_path = path / f"{name}.toml" + if toml_path.exists(): + toml_path.unlink() + # If the folder is empty, delete the folder + if not list(path.iterdir()): + path.rmdir() + else: + # Delete the entire folder + shutil.rmtree(path, ignore_errors=True) + + def _check_standard_objects(self, name: str) -> bool: + """Checks if an object is a standard object. + + Parameters + ---------- + name : str + name of the object to be checked + + Returns + ------- + bool + True if the object is a standard object, False otherwise + """ + # If this function is not implemented for the object type, it cannot be a standard object. + # By default, it is not a standard object. + return False + + def check_higher_level_usage(self, name: str) -> list[str]: + """Checks if an object is used in a higher level object. + + Parameters + ---------- + name : str + name of the object to be checked + + Returns + ------- + list[str] + list of higher level objects that use the object + """ + # If this function is not implemented for the object type, it cannot be used in a higher + # level object. By default, return an empty list + return [] + + def _get_object_list(self) -> dict[Path, datetime]: + """Given an object type (e.g., measures) get a dictionary with all the toml paths + and last modification dates that exist in the database. + + Returns + ------- + dict[str, Any] + Includes 'path' and 'last_modification_date' info + """ + base_path = self.input_path / self._folder_name + directories = list(base_path.iterdir()) + paths = [Path(dir / f"{dir.name}.toml") for dir in directories] + last_modification_date = [ + datetime.fromtimestamp(file.stat().st_mtime) for file in paths + ] + + objects = { + "path": paths, + "last_modification_date": last_modification_date, + } + + return objects diff --git a/flood_adapt/dbs_controller.py b/flood_adapt/dbs_controller.py index 8cab9273a..1fdb5913e 100644 --- a/flood_adapt/dbs_controller.py +++ b/flood_adapt/dbs_controller.py @@ -15,26 +15,22 @@ from hydromt_fiat.fiat import FiatModel from hydromt_sfincs.quadtree import QuadtreeGrid +from flood_adapt.dbs_classes.dbs_benefit import DbsBenefit +from flood_adapt.dbs_classes.dbs_event import DbsEvent +from flood_adapt.dbs_classes.dbs_measure import DbsMeasure +from flood_adapt.dbs_classes.dbs_projection import DbsProjection +from flood_adapt.dbs_classes.dbs_scenario import DbsScenario +from flood_adapt.dbs_classes.dbs_strategy import DbsStrategy from flood_adapt.integrator.sfincs_adapter import SfincsAdapter -from flood_adapt.object_model.benefit import Benefit -from flood_adapt.object_model.hazard.event.event import Event from flood_adapt.object_model.hazard.event.event_factory import EventFactory from flood_adapt.object_model.hazard.event.synthetic import Synthetic -from flood_adapt.object_model.hazard.hazard import Hazard from flood_adapt.object_model.interface.benefits import IBenefit from flood_adapt.object_model.interface.database import IDatabase from flood_adapt.object_model.interface.events import IEvent -from flood_adapt.object_model.interface.measures import IMeasure -from flood_adapt.object_model.interface.projections import IProjection -from flood_adapt.object_model.interface.scenarios import IScenario from flood_adapt.object_model.interface.site import ISite -from flood_adapt.object_model.interface.strategies import IStrategy from flood_adapt.object_model.io.unitfulvalue import UnitfulLength, UnitTypesLength -from flood_adapt.object_model.measure_factory import MeasureFactory -from flood_adapt.object_model.projection import Projection from flood_adapt.object_model.scenario import Scenario from flood_adapt.object_model.site import Site -from flood_adapt.object_model.strategy import Strategy class Database(IDatabase): @@ -75,6 +71,39 @@ def __init__( ) self.static_sfincs_model = SfincsAdapter(model_root=sfincs_path, site=self.site) + # Initialize the different database objects + self._events = DbsEvent(self) + self._scenarios = DbsScenario(self) + self._strategies = DbsStrategy(self) + self._measures = DbsMeasure(self) + self._projections = DbsProjection(self) + self._benefits = DbsBenefit(self) + + # Property methods + @property + def events(self) -> DbsEvent: + return self._events + + @property + def scenarios(self) -> DbsScenario: + return self._scenarios + + @property + def strategies(self) -> DbsStrategy: + return self._strategies + + @property + def measures(self) -> DbsMeasure: + return self._measures + + @property + def projections(self) -> DbsProjection: + return self._projections + + @property + def benefits(self) -> DbsBenefit: + return self._benefits + # General methods def get_aggregation_areas(self) -> dict: """Get a list of the aggregation areas that are provided in the site configuration. @@ -671,196 +700,6 @@ def get_property_types(self) -> list: types.append("all") return types - # Measure methods - def get_measure(self, name: str) -> IMeasure: - """Get the respective measure object using the name of the measure. - - Parameters - ---------- - name : str - name of the measure - - Returns - ------- - IMeasure - object of one of the measure types (e.g., IElevate) - """ - measure_path = self.input_path / "measures" / name / f"{name}.toml" - measure = MeasureFactory.get_measure_object(measure_path) - return measure - - def save_measure(self, measure: IMeasure) -> None: - """Saves a measure object in the database. - - Parameters - ---------- - measure : IMeasure - object of one of the measure types (e.g., IElevate) - - Raises - ------ - ValueError - Raise error if name is already in use. Names of measures should be unique. - """ - names = self.get_measures()["name"] - if measure.attrs.name in names: - raise ValueError( - f"'{measure.attrs.name}' name is already used by another measure. Choose a different name" - ) - else: - # TODO: how to save the extra files? e.g., polygons - (self.input_path / "measures" / measure.attrs.name).mkdir() - measure.save( - self.input_path - / "measures" - / measure.attrs.name - / f"{measure.attrs.name}.toml" - ) - - def edit_measure(self, measure: IMeasure): - """Edits an already existing measure in the database. - - Parameters - ---------- - measure : IMeasure - object of one of the measure types (e.g., IElevate) - """ - name = measure.attrs.name - # Get all the strategies - strategies = [ - Strategy.load_file(path) for path in self.get_strategies()["path"] - ] - - # Check if measure is used in a strategy - used_in_strategy = [ - strategy.attrs.name - for strategy in strategies - for measure in strategy.attrs.measures - if name == measure - ] - - # If measure is used in a strategy, raise error - if used_in_strategy: - text = "strategy" if len(strategies) == 1 else "strategies" - raise ValueError( - f"'{name}' measure cannot be edited since it is already used in {text}: {', '.join(used_in_strategy)}" - ) - else: - measure.save( - self.input_path - / "measures" - / measure.attrs.name - / f"{measure.attrs.name}.toml" - ) - - def delete_measure(self, name: str): - """Deletes an already existing measure in the database. - - Parameters - ---------- - name : str - name of the measure - - Raises - ------ - ValueError - Raise error if measure to be deleted is already used in a strategy. - """ - - # Get all the strategies - strategies = [ - Strategy.load_file(path) for path in self.get_strategies()["path"] - ] - - # Check if measure is used in a strategy - used_in_strategy = [ - strategy.attrs.name - for strategy in strategies - for measure in strategy.attrs.measures - if name == measure - ] - - # If measure is used in a strategy, raise error - if used_in_strategy: - text = "strategy" if len(strategies) == 1 else "strategies" - raise ValueError( - f"'{name}' measure cannot be deleted since it is already used in {text}: {', '.join(used_in_strategy)}" - ) - else: - measure_path = self.input_path / "measures" / name - shutil.rmtree(measure_path, ignore_errors=True) - - def copy_measure(self, old_name: str, new_name: str, new_description: str): - """Copies (duplicates) an existing measures, and gives it a new name. - - Parameters - ---------- - old_name : str - name of the existing measure - new_name : str - name of the new measure - new_description : str - description of the new measure - """ - # First do a get - measure = self.get_measure(old_name) - measure.attrs.name = new_name - measure.attrs.description = new_description - # Then a save - self.save_measure(measure) - # Then save all the accompanied files - src = self.input_path / "measures" / old_name - dest = self.input_path / "measures" / new_name - for file in src.glob("*"): - if "toml" not in file.name: - shutil.copy(file, dest / file.name) - - # Event methods - def get_event(self, name: str) -> IEvent: - """Get the respective event object using the name of the event. - - Parameters - ---------- - name : str - name of the event - - Returns - ------- - IMeasure - object of one of the events - """ - event_path = self.input_path / "events" / f"{name}" / f"{name}.toml" - event_template = Event.get_template(event_path) - event = EventFactory.get_event(event_template).load_file(event_path) - return event - - def save_event(self, event: IEvent) -> None: - """Saves a synthetic event object in the database. - - Parameters - ---------- - event : IEvent - object of one of the synthetic event types - - Raises - ------ - ValueError - Raise error if name is already in use. Names of measures should be unique. - """ - names = self.get_events()["name"] - if event.attrs.name in names: - raise ValueError( - f"'{event.attrs.name}' name is already used by another event. Choose a different name" - ) - else: - (self.input_path / "events" / event.attrs.name).mkdir() - event.save( - self.input_path - / "events" - / event.attrs.name - / f"{event.attrs.name}.toml" - ) - def write_to_csv(self, name: str, event: IEvent, df: pd.DataFrame): df.to_csv( Path(self.input_path, "events", event.attrs.name, f"{name}.csv"), @@ -877,587 +716,6 @@ def write_cyc(self, event: IEvent, track: TropicalCyclone): # cht_cyclone function to write TropicalCyclone as .cyc file track.write_track(filename=cyc_file, fmt="ddb_cyc") - def edit_event(self, event: IEvent): - """Edits an already existing event in the database. - - Parameters - ---------- - event : IEvent - object of the event - """ - name = event.attrs.name - - # Check if event is a standard event - if self.site.attrs.standard_objects.events: - if name in self.site.attrs.standard_objects.events: - raise ValueError( - f"'{name}' event cannot be deleted since it is a standard event." - ) - - # Get all the scenarios - scenarios = [Scenario.load_file(path) for path in self.get_scenarios()["path"]] - - # Check if event is used in a scenario - used_in_scenario = [ - scenario.attrs.name - for scenario in scenarios - if name == scenario.attrs.event - ] - - # If event is used in a scenario, raise error - if used_in_scenario: - text = "scenario" if len(used_in_scenario) == 1 else "scenarios" - raise ValueError( - f"'{name}' event cannot be edited since it is already used in {text}: {', '.join(used_in_scenario)}" - ) - else: - event.save( - self.input_path - / "events" - / event.attrs.name - / f"{event.attrs.name}.toml" - ) - - def delete_event(self, name: str): - """Deletes an already existing event in the database. - - Parameters - ---------- - name : str - name of the event - - Raises - ------ - ValueError - Raise error if event to be deleted is already used in a scenario. - """ - - # Check if event is a standard event - if self.site.attrs.standard_objects.events: - if name in self.site.attrs.standard_objects.events: - raise ValueError( - f"'{name}' event cannot be deleted since it is a standard event." - ) - - # Get all the scenarios - scenarios = [Scenario.load_file(path) for path in self.get_scenarios()["path"]] - - # Check if event is used in a scenario - used_in_scenario = [ - scenario.attrs.name - for scenario in scenarios - if name == scenario.attrs.event - ] - - # If event is used in a scenario, raise error - if used_in_scenario: - text = "scenario" if len(used_in_scenario) == 1 else "scenarios" - raise ValueError( - f"'{name}' event cannot be deleted since it is already used in {text}: {', '.join(used_in_scenario)}" - ) - else: - event_path = self.input_path / "events" / name - shutil.rmtree(event_path, ignore_errors=True) - - def copy_event(self, old_name: str, new_name: str, new_description: str): - """Copies (duplicates) an existing event, and gives it a new name. - - Parameters - ---------- - old_name : str - name of the existing event - new_name : str - name of the new event - new_description : str - description of the new event - """ - # First do a get - event = self.get_event(old_name) - event.attrs.name = new_name - event.attrs.description = new_description - # Then a save - self.save_event(event) - # Then save all the accompanied files - src = self.input_path / "events" / old_name - dest = self.input_path / "events" / new_name - for file in src.glob("*"): - if "toml" not in file.name: - shutil.copy(file, dest / file.name) - - # Projection methods - def get_projection(self, name: str) -> IProjection: - """Get the respective projection object using the name of the projection. - - Parameters - ---------- - name : str - name of the projection - - Returns - ------- - IProjection - object of one of the projection types - """ - projection_path = self.input_path / "projections" / name / f"{name}.toml" - projection = Projection.load_file(projection_path) - return projection - - def save_projection(self, projection: IProjection) -> None: - """Saves a projection object in the database. - - Parameters - ---------- - projection : IProjection - object of one of the projection types - - Raises - ------ - ValueError - Raise error if name is already in use. Names of projections should be unique. - """ - names = self.get_projections()["name"] - if projection.attrs.name in names: - raise ValueError( - f"'{projection.attrs.name}' name is already used by another projection. Choose a different name" - ) - - projection_path = Path(self.input_path / "projections" / projection.attrs.name) - os.mkdir(projection_path) - - # Handle user uploaded shapefile - if projection.attrs.socio_economic_change.new_development_shapefile is not None: - # Original path to the shapefile - src_file = Path( - projection.attrs.socio_economic_change.new_development_shapefile - ) - - # New destination path to the shapefile - dst_path = projection_path / f"{projection.attrs.name}.geojson" - projection.attrs.socio_economic_change.new_development_shapefile = ( - f"{projection.attrs.name}.geojson" - ) - - # Read the shapefile and save it as a geojson - gdf = gpd.read_file(src_file, engine="pyogrio") - with open(dst_path, "w") as f: - f.write(gdf.to_crs(4326).to_json(drop_id=True)) - - # Save the projection toml file - projection.save(projection_path / f"{projection.attrs.name}.toml") - - def edit_projection(self, projection: IProjection): - """Edits an already existing projection in the database. - - Parameters - ---------- - projection : IProjection - object of one of the projection types (e.g., IElevate) - """ - name = projection.attrs.name - - # Check if projection is a standard projection - if self.site.attrs.standard_objects.projections: - if name in self.site.attrs.standard_objects.projections: - raise ValueError( - f"'{name}' projection cannot be deleted since it is a standard projection." - ) - - # Get all the scenarios - scenarios = [Scenario.load_file(path) for path in self.get_scenarios()["path"]] - - # Check if projection is used in a scenario - used_in_scenario = [ - scenario.attrs.name - for scenario in scenarios - if name == scenario.attrs.projection - ] - - # If projection is used in a scenario, raise error - if used_in_scenario: - text = "scenario" if len(used_in_scenario) == 1 else "scenarios" - raise ValueError( - f"'{name}' projection cannot be edited since it is already used in {text}: {', '.join(used_in_scenario)}" - ) - else: - projection.save( - self.input_path - / "projections" - / projection.attrs.name - / f"{projection.attrs.name}.toml" - ) - - def delete_projection(self, name: str): - """Deletes an already existing projection in the database. - - Parameters - ---------- - name : str - name of the projection - - Raises - ------ - ValueError - Raise error if projection to be deleted is already used in a scenario. - """ - - # Check if projection is a standard projection - if self.site.attrs.standard_objects.projections: - if name in self.site.attrs.standard_objects.projections: - raise ValueError( - f"'{name}' projection cannot be deleted since it is a standard projection." - ) - - # Get all the scenarios - scenarios = [Scenario.load_file(path) for path in self.get_scenarios()["path"]] - - # Check if projection is used in a scenario - used_in_scenario = [ - scenario.attrs.name - for scenario in scenarios - if name == scenario.attrs.projection - ] - - # If projection is used in a scenario, raise error - if used_in_scenario: - text = "scenario" if len(used_in_scenario) == 1 else "scenarios" - raise ValueError( - f"'{name}' projection cannot be deleted since it is already used in {text}: {', '.join(used_in_scenario)}" - ) - else: - projection_path = self.input_path / "projections" / name - shutil.rmtree(projection_path, ignore_errors=True) - - def copy_projection(self, old_name: str, new_name: str, new_description: str): - """Copies (duplicates) an existing projection, and gives it a new name. - - Parameters - ---------- - old_name : str - name of the existing projection - new_name : str - name of the new projection - new_description : str - description of the new projection - """ - # First do a get - projection = self.get_projection(old_name) - projection.attrs.name = new_name - projection.attrs.description = new_description - # Then a save - self.save_projection(projection) - # Then save all the accompanied files - src = self.input_path / "projections" / old_name - dest = self.input_path / "projections" / new_name - for file in src.glob("*"): - if "toml" not in file.name: - shutil.copy(file, dest / file.name) - - # Strategy methods - def get_strategy(self, name: str) -> IStrategy: - """Get the respective strategy object using the name of the strategy. - - Parameters - ---------- - name : str - name of the strategy - - Returns - ------- - IStrategy - strategy object - """ - strategy_path = self.input_path / "strategies" / name / f"{name}.toml" - strategy = Strategy.load_file(strategy_path) - return strategy - - def save_strategy(self, strategy: IStrategy) -> None: - """Saves a strategy object in the database. - - Parameters - ---------- - measure : IStrategy - object of strategy type - - Raises - ------ - ValueError - Raise error if name is already in use. Names of strategies should be unique. - """ - names = self.get_strategies()["name"] - if strategy.attrs.name in names: - raise ValueError( - f"'{strategy.attrs.name}' name is already used by another strategy. Choose a different name" - ) - else: - (self.input_path / "strategies" / strategy.attrs.name).mkdir() - strategy.save( - self.input_path - / "strategies" - / strategy.attrs.name - / f"{strategy.attrs.name}.toml" - ) - - def delete_strategy(self, name: str): - """Deletes an already existing strategy in the database. - - Parameters - ---------- - name : str - name of the strategy - - Raises - ------ - ValueError - Raise error if strategy to be deleted is already used in a scenario. - """ - - # Check if strategy is a standard strategy - if self.site.attrs.standard_objects.strategies: - if name in self.site.attrs.standard_objects.strategies: - raise ValueError( - f"'{name}' strategy cannot be deleted since it is a standard strategy." - ) - - # Get all the scenarios - scenarios = [Scenario.load_file(path) for path in self.get_scenarios()["path"]] - - # Check if strategy is used in a scenario - used_in_scenario = [ - scenario.attrs.name - for scenario in scenarios - if name == scenario.attrs.strategy - ] - - # If strategy is used in a scenario, raise error - if used_in_scenario: - text = "scenario" if len(used_in_scenario) == 1 else "scenarios" - raise ValueError( - f"'{name}' strategy cannot be deleted since it is already used in {text}: {', '.join(used_in_scenario)}" - ) - else: - strategy_path = self.input_path / "strategies" / name - shutil.rmtree(strategy_path, ignore_errors=True) - - # scenario methods - def get_scenario(self, name: str) -> IScenario: - """Get the respective scenario object using the name of the scenario. - - Parameters - ---------- - name : str - name of the scenario - - Returns - ------- - IScenario - Scenario object - """ - scenario_path = self.input_path / "scenarios" / name / f"{name}.toml" - scenario = Scenario.load_file(scenario_path) - scenario.init_object_model() - return scenario - - def save_scenario(self, scenario: IScenario) -> None: - """Saves a scenario object in the database. - - Parameters - ---------- - measure : IScenario - object of scenario type - - Raises - ------ - ValueError - Raise error if name is already in use. Names of scenarios should be unique. - """ - names = self.get_scenarios()["name"] - if scenario.attrs.name in names: - raise ValueError( - f"'{scenario.attrs.name}' name is already used by another scenario. Choose a different name" - ) - # TODO add check to see if a scenario with the same attributes but different name already exists - else: - (self.input_path / "scenarios" / scenario.attrs.name).mkdir() - scenario.save( - self.input_path - / "scenarios" - / scenario.attrs.name - / f"{scenario.attrs.name}.toml" - ) - - def edit_scenario(self, scenario: IScenario): - """Edits an already existing scenario in the database. - - Parameters - ---------- - scenario : IScenario - object of one of the scenario types (e.g., IScenario) - """ - name = scenario.attrs.name - - # Get all the benefits - benefits = [Benefit.load_file(path) for path in self.get_benefits()["path"]] - - # Check in which benefits this scenario is used - used_in_benefit = [ - benefit.attrs.name - for benefit in benefits - for scenario in self.check_benefit_scenarios(benefit)[ - "scenario created" - ].to_list() - if name == scenario - ] - - # If scenario is used in a benefit, raise error - if used_in_benefit: - text = "benefit" if len(used_in_benefit) == 1 else "Benefits" - raise ValueError( - f"'{name}' scenario cannot be edited since it is already used in {text}: {', '.join(used_in_benefit)}" - ) - else: - scenario.save( - self.input_path - / "scenarios" - / scenario.attrs.name - / f"{scenario.attrs.name}.toml" - ) - - def delete_scenario(self, name: str): - """Deletes an already existing scenario in the database. - - Parameters - ---------- - name : str - name of the scenario - - Raises - ------ - ValueError - Raise error if scenario has already model output - """ - - # Get all the benefits - benefits = [Benefit.load_file(path) for path in self.get_benefits()["path"]] - - # Check in which benefits this scenario is used - used_in_benefit = [ - benefit.attrs.name - for benefit in benefits - for scenario in self.check_benefit_scenarios(benefit)[ - "scenario created" - ].to_list() - if name == scenario - ] - - # If scenario is used in a benefit, raise error - if used_in_benefit: - text = "benefit" if len(used_in_benefit) == 1 else "Benefits" - raise ValueError( - f"'{name}' scenario cannot be deleted since it is already used in {text}: {', '.join(used_in_benefit)}" - ) - else: - scenario_path = self.input_path / "scenarios" / name - shutil.rmtree(scenario_path, ignore_errors=False) - - results_path = self.input_path.parent / "output" / "Scenarios" / name - if results_path.exists(): - shutil.rmtree(results_path, ignore_errors=False) - - def get_benefit(self, name: str) -> IBenefit: - """Get the respective benefit object using the name of the benefit. - - Parameters - ---------- - name : str - name of the benefit - - Returns - ------- - IBenefit - Benefit object - """ - benefit_path = self.input_path / "Benefits" / name / f"{name}.toml" - benefit = Benefit.load_file(benefit_path) - return benefit - - def save_benefit(self, benefit: IBenefit) -> None: - """Saves a benefit object in the database. - - Parameters - ---------- - measure : IBenefit - object of scenario type - - Raises - ------ - ValueError - Raise error if name is already in use. Names of benefits assessments should be unique. - """ - names = self.get_benefits()["name"] - if benefit.attrs.name in names: - raise ValueError( - f"'{benefit.attrs.name}' name is already used by another benefit. Choose a different name" - ) - elif not all(benefit.scenarios["scenario created"] != "No"): - raise ValueError( - f"'{benefit.attrs.name}' name cannot be created before all necessary scenarios are created." - ) - else: - (self.input_path / "Benefits" / benefit.attrs.name).mkdir() - benefit.save( - self.input_path - / "Benefits" - / benefit.attrs.name - / f"{benefit.attrs.name}.toml" - ) - - def edit_benefit(self, benefit: IBenefit): - """Edits an already existing benefit in the database. - - Parameters - ---------- - benefit : IBenefit - object of one of the benefit types (e.g., IBenefit) - """ - benefit.save( - self.input_path - / "Benefits" - / benefit.attrs.name - / f"{benefit.attrs.name}.toml" - ) - - # Delete output if edited - output_path = ( - self.input_path.parent / "output" / "Benefits" / benefit.attrs.name - ) - - if output_path.exists(): - shutil.rmtree(output_path, ignore_errors=True) - - def delete_benefit(self, name: str) -> None: - """Deletes an already existing benefit in the database. - - Parameters - ---------- - name : str - name of the benefit - - Raises - ------ - ValueError - Raise error if benefit has already model output - """ - benefit_path = self.input_path / "Benefits" / name - benefit = Benefit.load_file(benefit_path / f"{name}.toml") - shutil.rmtree(benefit_path, ignore_errors=True) - # Delete output if edited - output_path = ( - self.input_path.parent / "output" / "Benefits" / benefit.attrs.name - ) - - if output_path.exists(): - shutil.rmtree(output_path, ignore_errors=True) - def check_benefit_scenarios(self, benefit: IBenefit) -> pd.DataFrame: """Returns a dataframe with the scenarios needed for this benefit assessment run @@ -1491,7 +749,7 @@ def create_benefit_scenarios(self, benefit: IBenefit) -> None: scenario_obj = Scenario.load_dict(scenario_dict, self.input_path) - self.save_scenario(scenario_obj) + self._scenarios.save(scenario_obj) # Update the scenarios check benefit.check_scenarios() @@ -1507,132 +765,16 @@ def run_benefit(self, benefit_name: Union[str, list[str]]) -> None: if not isinstance(benefit_name, list): benefit_name = [benefit_name] for name in benefit_name: - benefit = self.get_benefit(name) + benefit = self._benefits.get(name) benefit.run_cost_benefit() def update(self) -> None: - self.projections = self.get_projections() - self.events = self.get_events() - self.measures = self.get_measures() - self.strategies = self.get_strategies() - self.scenarios = self.get_scenarios() - self.benefits = self.get_benefits() - - def get_projections(self) -> dict[str, Any]: - """Returns a dictionary with info on the projections that currently - exist in the database. - - Returns - ------- - dict[str, Any] - Includes 'name', 'description', 'path' and 'last_modification_date' info - """ - projections = self.get_object_list(object_type="projections") - objects = [Projection.load_file(path) for path in projections["path"]] - projections["name"] = [obj.attrs.name for obj in objects] - projections["description"] = [obj.attrs.description for obj in objects] - return projections - - def get_events(self) -> dict[str, Any]: - """Returns a dictionary with info on the events that currently - exist in the database. - - Returns - ------- - dict[str, Any] - Includes 'name', 'description', 'path' and 'last_modification_date' info - """ - events = self.get_object_list(object_type="events") - objects = [Hazard.get_event_object(path) for path in events["path"]] - events["name"] = [obj.attrs.name for obj in objects] - events["description"] = [obj.attrs.description for obj in objects] - return events - - def get_measures(self) -> dict[str, Any]: - """Returns a dictionary with info on the measures that currently - exist in the database. - - Returns - ------- - dict[str, Any] - Includes 'name', 'description', 'path' and 'last_modification_date' info - """ - measures = self.get_object_list(object_type="measures") - objects = [MeasureFactory.get_measure_object(path) for path in measures["path"]] - measures["name"] = [obj.attrs.name for obj in objects] - measures["description"] = [obj.attrs.description for obj in objects] - - geometries = [] - for path, obj in zip(measures["path"], objects): - # If polygon is used read the polygon file - if obj.attrs.polygon_file: - geometries.append( - gpd.read_file(path.parent.joinpath(obj.attrs.polygon_file)) - ) - # If aggregation area is used read the polygon from the aggregation area name - elif obj.attrs.aggregation_area_name: - gdf = self.aggr_areas[obj.attrs.aggregation_area_type] - geometries.append( - gdf.loc[gdf["name"] == obj.attrs.aggregation_area_name, :] - ) - # Else assign a None value - else: - geometries.append(None) - - measures["geometry"] = geometries - return measures - - def get_strategies(self) -> dict[str, Any]: - """Returns a dictionary with info on the strategies that currently - exist in the database. - - Returns - ------- - dict[str, Any] - Includes 'name', 'description', 'path' and 'last_modification_date' info - """ - strategies = self.get_object_list(object_type="strategies") - objects = [Strategy.load_file(path) for path in strategies["path"]] - strategies["name"] = [obj.attrs.name for obj in objects] - strategies["description"] = [obj.attrs.description for obj in objects] - return strategies - - def get_scenarios(self) -> dict[str, Any]: - """Returns a dictionary with info on the events that currently - exist in the database. - - Returns - ------- - dict[str, Any] - Includes 'name', 'description', 'path' and 'last_modification_date' info - """ - scenarios = self.get_object_list(object_type="scenarios") - objects = [Scenario.load_file(path) for path in scenarios["path"]] - scenarios["name"] = [obj.attrs.name for obj in objects] - scenarios["description"] = [obj.attrs.description for obj in objects] - scenarios["Projection"] = [obj.attrs.projection for obj in objects] - scenarios["Event"] = [obj.attrs.event for obj in objects] - scenarios["Strategy"] = [obj.attrs.strategy for obj in objects] - scenarios["finished"] = [ - obj.init_object_model().direct_impacts.has_run for obj in objects - ] - - return scenarios - - def get_benefits(self) -> dict[str, Any]: - """Returns a dictionary with info on the (cost-)benefit assessments that currently - exist in the database. - - Returns - ------- - dict[str, Any] - Includes 'name', 'path' and 'last_modification_date' info - """ - benefits = self.get_object_list(object_type="benefits") - objects = [Benefit.load_file(path) for path in benefits["path"]] - benefits["name"] = [obj.attrs.name for obj in objects] - - return benefits + self.projections = self._projections.list_objects() + self.events = self._events.list_objects() + self.measures = self._measures.list_objects() + self.strategies = self._strategies.list_objects() + self.scenarios = self._scenarios.list_objects() + self.benefits = self._benefits.list_objects() def get_outputs(self) -> dict[str, Any]: """Returns a dictionary with info on the outputs that currently @@ -1643,7 +785,7 @@ def get_outputs(self) -> dict[str, Any]: dict[str, Any] Includes 'name', 'path', 'last_modification_date' and "finished" info """ - all_scenarios = pd.DataFrame(self.get_scenarios()) + all_scenarios = pd.DataFrame(self._scenarios.list_objects()) if len(all_scenarios) > 0: df = all_scenarios[all_scenarios["finished"]] else: @@ -1858,13 +1000,13 @@ def has_run_hazard(self, scenario_name: str) -> None: scenario_name : str name of the scenario to check if needs to be rerun for hazard """ - scenario = self.get_scenario(scenario_name) + scenario = self._scenarios.get(scenario_name) simulations = list( self.input_path.parent.joinpath("output", "Scenarios").glob("*") ) - scns_simulated = [self.get_scenario(sim.name) for sim in simulations] + scns_simulated = [self._scenarios.get(sim.name) for sim in simulations] for scn in scns_simulated: if scn.direct_impacts.hazard == scenario.direct_impacts.hazard: @@ -1907,7 +1049,7 @@ def run_scenario(self, scenario_name: Union[str, list[str]]) -> None: for scn in scenario_name: try: self.has_run_hazard(scn) - scenario = self.get_scenario(scn) + scenario = self.scenarios.get(scn) scenario.run() except RuntimeError as e: if "SFINCS model failed to run." in str(e): diff --git a/flood_adapt/object_model/interface/benefits.py b/flood_adapt/object_model/interface/benefits.py index 411048a47..83288c987 100644 --- a/flood_adapt/object_model/interface/benefits.py +++ b/flood_adapt/object_model/interface/benefits.py @@ -3,7 +3,7 @@ from typing import Any, Optional, Union import pandas as pd -from pydantic import BaseModel +from pydantic import BaseModel, validator class CurrentSituationModel(BaseModel): @@ -16,6 +16,7 @@ class BenefitModel(BaseModel): name: str description: Optional[str] = "" + lock_count: int = 0 strategy: str event_set: str projection: str @@ -26,6 +27,13 @@ class BenefitModel(BaseModel): implementation_cost: Optional[float] = None annual_maint_cost: Optional[float] = None + @validator("lock_count") + def validate_lock_count(cls, lock_count: int) -> int: + """Validate lock_count""" + if lock_count < 0: + raise ValueError("lock_count must be a positive integer") + return lock_count + class IBenefit(ABC): attrs: BenefitModel diff --git a/flood_adapt/object_model/interface/database.py b/flood_adapt/object_model/interface/database.py index 522a4c048..41b60fa50 100644 --- a/flood_adapt/object_model/interface/database.py +++ b/flood_adapt/object_model/interface/database.py @@ -9,11 +9,7 @@ from flood_adapt.object_model.interface.benefits import IBenefit from flood_adapt.object_model.interface.events import IEvent -from flood_adapt.object_model.interface.measures import IMeasure -from flood_adapt.object_model.interface.projections import IProjection -from flood_adapt.object_model.interface.scenarios import IScenario from flood_adapt.object_model.interface.site import ISite -from flood_adapt.object_model.interface.strategies import IStrategy class IDatabase(ABC): @@ -30,6 +26,9 @@ def __init__( @abstractmethod def get_aggregation_areas(self) -> dict: ... + @abstractmethod + def get_model_boundary(self) -> dict[str, Any]: ... + @abstractmethod def get_obs_points(self) -> GeoDataFrame: ... @@ -62,122 +61,17 @@ def plot_wind(self, event: IEvent, input_wind_df: pd.DataFrame = None) -> str: . @abstractmethod def get_buildings(self) -> GeoDataFrame: ... - @abstractmethod - def get_projection(self, name: str) -> IProjection: ... - - @abstractmethod - def save_projection(self, measure: IProjection) -> None: ... - - @abstractmethod - def edit_projection(self, measure: IProjection) -> None: ... - - @abstractmethod - def delete_projection(self, name: str): ... - - @abstractmethod - def copy_projection(self, old_name: str, new_name: str, new_description: str): ... - - @abstractmethod - def get_event(self, name: str) -> IEvent: ... - - @abstractmethod - def save_event(self, measure: IEvent) -> None: ... - @abstractmethod def write_to_csv(self, name: str, event: IEvent, df: pd.DataFrame) -> None: ... @abstractmethod def write_cyc(self, event: IEvent, track: TropicalCyclone): ... - @abstractmethod - def edit_event(self, measure: IEvent) -> None: ... - - @abstractmethod - def delete_event(self, name: str): ... - - @abstractmethod - def copy_event(self, old_name: str, new_name: str, new_description: str): ... - - @abstractmethod - def get_measure(self, name: str) -> IMeasure: ... - - @abstractmethod - def save_measure(self, measure: IMeasure) -> None: ... - - @abstractmethod - def edit_measure(self, measure: IMeasure): ... - - @abstractmethod - def delete_measure(self, name: str): ... - - @abstractmethod - def copy_measure(self, old_name: str, new_name: str, new_description: str): ... - - @abstractmethod - def get_strategy(self, name: str) -> IStrategy: ... - - @abstractmethod - def save_strategy(self, measure: IStrategy) -> None: ... - - @abstractmethod - def delete_strategy(self, name: str): ... - - @abstractmethod - def get_scenario(self, name: str) -> IScenario: ... - - @abstractmethod - def save_scenario(self, measure: IScenario) -> None: ... - - @abstractmethod - def edit_scenario(self, measure: IScenario) -> None: ... - - @abstractmethod - def delete_scenario(self, name: str): ... - - @abstractmethod - def get_benefit(self, name: str) -> IBenefit: ... - - @abstractmethod - def save_benefit(self, benefit: IBenefit) -> None: ... - - @abstractmethod - def edit_benefit(self, measure: IBenefit) -> None: ... - - @abstractmethod - def delete_benefit(self, name: str) -> None: ... - @abstractmethod def check_benefit_scenarios(self, benefit: IBenefit) -> None: ... @abstractmethod def create_benefit_scenarios(self, benefit: IBenefit) -> None: ... - @abstractmethod - def run_benefit(self, benefit_name: Union[str, list[str]]) -> None: ... - - @abstractmethod - def get_model_boundary(self) -> dict[str, Any]: ... - - @abstractmethod - def get_projections(self) -> dict[str, Any]: ... - - @abstractmethod - def get_events(self) -> dict[str, Any]: ... - - @abstractmethod - def get_measures(self) -> dict[str, Any]: ... - - @abstractmethod - def get_strategies(self) -> dict[str, Any]: ... - - @abstractmethod - def get_scenarios(self) -> dict[str, Any]: ... - - @abstractmethod - def get_benefits(self) -> dict[str, Any]: ... - - @abstractmethod - def get_outputs(self) -> dict[str, Any]: ... - @abstractmethod def run_scenario(self, scenario_name: Union[str, list[str]]) -> None: ... diff --git a/flood_adapt/object_model/interface/measures.py b/flood_adapt/object_model/interface/measures.py index 8f430febf..f08d7f60f 100644 --- a/flood_adapt/object_model/interface/measures.py +++ b/flood_adapt/object_model/interface/measures.py @@ -51,8 +51,16 @@ class MeasureModel(BaseModel): name: str = Field(..., min_length=1) description: Optional[str] = "" + lock_count: int = 0 type: Union[HazardType, ImpactType] + @validator("lock_count") + def validate_lock_count(cls, lock_count: int) -> int: + """Validate lock_count""" + if lock_count < 0: + raise ValueError("lock_count must be a positive integer") + return lock_count + class HazardMeasureModel(MeasureModel): """BaseModel describing the expected variables and data types of attributes common to all impact measures""" diff --git a/flood_adapt/object_model/interface/projections.py b/flood_adapt/object_model/interface/projections.py index 902e67b08..36cf39a20 100644 --- a/flood_adapt/object_model/interface/projections.py +++ b/flood_adapt/object_model/interface/projections.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from typing import Any, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, validator from flood_adapt.object_model.io.unitfulvalue import ( UnitfulLength, @@ -34,9 +34,17 @@ class SocioEconomicChangeModel(BaseModel): class ProjectionModel(BaseModel): name: str description: Optional[str] = "" + lock_count: int = 0 physical_projection: PhysicalProjectionModel socio_economic_change: SocioEconomicChangeModel + @validator("lock_count") + def validate_lock_count(cls, lock_count: int) -> int: + """Validate lock_count""" + if lock_count < 0: + raise ValueError("lock_count must be a positive integer") + return lock_count + class IProjection(ABC): attrs: ProjectionModel diff --git a/flood_adapt/object_model/interface/scenarios.py b/flood_adapt/object_model/interface/scenarios.py index 610eaaf27..8fcb51d97 100644 --- a/flood_adapt/object_model/interface/scenarios.py +++ b/flood_adapt/object_model/interface/scenarios.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from typing import Any, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, validator class ScenarioModel(BaseModel): @@ -10,10 +10,18 @@ class ScenarioModel(BaseModel): name: str description: Optional[str] = "" + lock_count: int = 0 event: str projection: str strategy: str + @validator("lock_count") + def validate_lock_count(cls, lock_count: int) -> int: + """Validate lock_count""" + if lock_count < 0: + raise ValueError("lock_count must be a positive integer") + return lock_count + class IScenario(ABC): attrs: ScenarioModel diff --git a/flood_adapt/object_model/interface/strategies.py b/flood_adapt/object_model/interface/strategies.py index a1e091892..4b7e347c0 100644 --- a/flood_adapt/object_model/interface/strategies.py +++ b/flood_adapt/object_model/interface/strategies.py @@ -2,14 +2,22 @@ from abc import ABC, abstractmethod from typing import Any, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, validator class StrategyModel(BaseModel): name: str description: Optional[str] = "" + lock_count: int = 0 measures: Optional[list[str]] = [] + @validator("lock_count") + def validate_lock_count(cls, lock_count: int) -> int: + """Validate lock_count""" + if lock_count < 0: + raise ValueError("lock_count must be a positive integer") + return lock_count + class IStrategy(ABC): attrs: StrategyModel diff --git a/flood_adapt/object_model/scenario.py b/flood_adapt/object_model/scenario.py index 454dcbcd7..0286d52a2 100644 --- a/flood_adapt/object_model/scenario.py +++ b/flood_adapt/object_model/scenario.py @@ -64,12 +64,8 @@ def save(self, filepath: Union[str, os.PathLike]): def run(self): """run direct impact models for the scenario""" self.init_object_model() - # start log file in scenario results folder - for parent in reversed(self.results_path.parents): - if not parent.exists(): - os.mkdir(parent) - if not self.results_path.exists(): - os.mkdir(self.results_path) + os.makedirs(self.results_path, exist_ok=True) + # Initiate the logger for all the integrator scripts. self.initiate_root_logger( self.results_path.joinpath(f"logfile_{self.attrs.name}.log") diff --git a/tests/test_api/test_events.py b/tests/test_api/test_events.py index d231e5f71..eafbbe2a2 100644 --- a/tests/test_api/test_events.py +++ b/tests/test_api/test_events.py @@ -60,7 +60,7 @@ def test_synthetic_event(test_db, test_dict): event = api_events.create_synthetic_event(test_dict) # If the name is not used before the measure is save in the database api_events.save_event_toml(event, test_db) - test_db.get_events() + test_db.events.list_objects() # Try to delete a measure which is already used in a scenario # with pytest.raises(ValueError): @@ -68,4 +68,4 @@ def test_synthetic_event(test_db, test_dict): # If user presses delete event the measure is deleted api_events.delete_event("test1", test_db) - test_db.get_events() + test_db.events.list_objects() diff --git a/tests/test_api/test_projections.py b/tests/test_api/test_projections.py index 48e64de49..9a50423e1 100644 --- a/tests/test_api/test_projections.py +++ b/tests/test_api/test_projections.py @@ -32,7 +32,7 @@ def test_projection(test_db): projection = api_projections.create_projection(test_dict) # If the name is not used before the measure is save in the database api_projections.save_projection(projection, test_db) - test_db.get_projections() + test_db.projections.list_objects() # Try to delete a measure which is already used in a scenario # with pytest.raises(ValueError): @@ -40,4 +40,4 @@ def test_projection(test_db): # If user presses delete projection the measure is deleted api_projections.delete_projection("test_proj_1", test_db) - test_db.get_projections() + test_db.projections.list_objects() diff --git a/tests/test_api/test_scenarios.py b/tests/test_api/test_scenarios.py index 37ed66744..a126ad5b7 100644 --- a/tests/test_api/test_scenarios.py +++ b/tests/test_api/test_scenarios.py @@ -26,11 +26,11 @@ def test_scenario(test_db): test_dict["name"] = "test1" scenario = api_scenarios.create_scenario(test_dict, test_db) api_scenarios.save_scenario(scenario, test_db) - test_db.get_scenarios() + test_db.scenarios.list_objects() # If user presses delete scenario the measure is deleted api_scenarios.delete_scenario("test1", test_db) - test_db.get_scenarios() + test_db.scenarios.list_objects() @pytest.mark.skip(reason="Part of test_has_hazard_run") diff --git a/tests/test_api/test_strategy.py b/tests/test_api/test_strategy.py index 78abde09c..866a565a6 100644 --- a/tests/test_api/test_strategy.py +++ b/tests/test_api/test_strategy.py @@ -44,7 +44,7 @@ def test_strategy(test_db): strategy = api_strategies.create_strategy(test_dict, test_db) # If the name is not used before the measure is save in the database api_strategies.save_strategy(strategy, test_db) - test_db.get_strategies() + test_db.strategies.list_objects() # Try to delete a measure which is already used in a scenario # with pytest.raises(ValueError): @@ -52,4 +52,4 @@ def test_strategy(test_db): # If user presses delete strategy the measure is deleted api_strategies.delete_strategy("test_strat_1", test_db) - test_db.get_strategies() + test_db.strategies.list_objects()