Skip to content

Commit

Permalink
Restructure static API functions and add results caching (#449)
Browse files Browse the repository at this point in the history
* Add caching

* Fix latest backend functions

* Fix black

* Fix comments

* Fix bug

* Fix naming

* fix black
  • Loading branch information
dladrichem authored Jun 3, 2024
1 parent 303b883 commit 6a18711
Show file tree
Hide file tree
Showing 13 changed files with 317 additions and 228 deletions.
2 changes: 1 addition & 1 deletion flood_adapt/api/measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,4 @@ def calculate_volume(


def get_green_infra_table(measure_type: str) -> pd.DataFrame:
return Database().get_green_infra_table(measure_type)
return Database().static.get_green_infra_table(measure_type)
16 changes: 8 additions & 8 deletions flood_adapt/api/startup.py → flood_adapt/api/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def get_aggregation_areas() -> list[GeoDataFrame]:
list[GeoDataFrame]
list of GeoDataFrames with the aggregation areas
"""
return Database().get_aggregation_areas()
return Database().static.get_aggregation_areas()


def get_obs_points() -> GeoDataFrame:
Expand All @@ -57,11 +57,11 @@ def get_obs_points() -> GeoDataFrame:
GeoDataFrame
GeoDataFrame with observation points from the site.toml.
"""
return Database().get_obs_points()
return Database().static.get_obs_points()


def get_model_boundary() -> GeoDataFrame:
return Database().get_model_boundary()
return Database().static.get_model_boundary()


def get_model_grid() -> QuadtreeGrid:
Expand All @@ -76,7 +76,7 @@ def get_model_grid() -> QuadtreeGrid:
QuadtreeGrid
QuadtreeGrid with the model grid
"""
return Database().get_model_grid()
return Database().static.get_model_grid()


@staticmethod
Expand All @@ -93,7 +93,7 @@ def get_svi_map() -> Union[GeoDataFrame, None]:
GeoDataFrames with the SVI map, None if not available
"""
try:
return Database().get_static_map(Database().site.attrs.fiat.svi.geom)
return Database().static.get_static_map(Database().site.attrs.fiat.svi.geom)
except Exception:
return None

Expand All @@ -115,7 +115,7 @@ def get_static_map(path: Union[str, Path]) -> Union[GeoDataFrame, None]:
GeoDataFrame with the static map
"""
try:
return Database().get_static_map(path)
return Database().static.get_static_map(path)
except Exception:
return None

Expand All @@ -132,11 +132,11 @@ def get_buildings() -> GeoDataFrame:
GeoDataFrame
GeoDataFrames with the buildings from FIAT exposure
"""
return Database().get_buildings()
return Database().static.get_buildings()


def get_property_types() -> list:
return Database().get_property_types()
return Database().static.get_property_types()


def get_hazard_measure_types():
Expand Down
9 changes: 7 additions & 2 deletions flood_adapt/dbs_classes/dbs_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,16 @@ def list_objects(self) -> dict[str, Any]:
geometries.append(gpd.read_file(file_path))
# If aggregation area is used read the polygon from the aggregation area name
elif obj.attrs.aggregation_area_name:
if obj.attrs.aggregation_area_type not in self._database.aggr_areas:
if (
obj.attrs.aggregation_area_type
not in self._database.static.get_aggregation_areas()
):
raise ValueError(
f"Aggregation area type {obj.attrs.aggregation_area_type} for measure {obj.attrs.name} does not exist."
)
gdf = self._database.aggr_areas[obj.attrs.aggregation_area_type]
gdf = self._database.static.get_aggregation_areas()[
obj.attrs.aggregation_area_type
]
if obj.attrs.aggregation_area_name not in gdf["name"].to_numpy():
raise ValueError(
f"Aggregation area name {obj.attrs.aggregation_area_name} for measure {obj.attrs.name} does not exist."
Expand Down
235 changes: 235 additions & 0 deletions flood_adapt/dbs_classes/dbs_static.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
from pathlib import Path
from typing import Any, Callable, Tuple, Union

import geopandas as gpd
import pandas as pd
from geopandas import GeoDataFrame
from hydromt_fiat.fiat import FiatModel
from hydromt_sfincs.quadtree import QuadtreeGrid

from flood_adapt.object_model.interface.database import IDatabase


def cache_method_wrapper(func: Callable) -> Callable:
def wrapper(self, *args: Tuple[Any], **kwargs: dict[str, Any]) -> Any:
if func.__name__ not in self._cached_data:
self._cached_data[func.__name__] = {}

args_key = (
str(args) + str(sorted(kwargs.items())) if args or kwargs else "no_args"
)
if args_key in self._cached_data[func.__name__]:
return self._cached_data[func.__name__][args_key]

result = func(self, *args, **kwargs)
self._cached_data[func.__name__][args_key] = result

return result

return wrapper


class DbsStatic:

_cached_data: dict[str, Any] = {}
_database: IDatabase = None

def __init__(self, database: IDatabase):
"""
Initialize any necessary attributes.
"""
self._database = database

@cache_method_wrapper
def get_aggregation_areas(self) -> dict:
"""Get a list of the aggregation areas that are provided in the site configuration.
These are expected to much the ones in the FIAT model
Returns
-------
list[GeoDataFrame]
list of geodataframes with the polygons defining the aggregation areas
"""
aggregation_areas = {}
for aggr_dict in self._database.site.attrs.fiat.aggregation:
aggregation_areas[aggr_dict.name] = gpd.read_file(
self._database.static_path / "site" / aggr_dict.file,
engine="pyogrio",
).to_crs(4326)
# Use always the same column name for name labels
aggregation_areas[aggr_dict.name] = aggregation_areas[
aggr_dict.name
].rename(columns={aggr_dict.field_name: "name"})
# Make sure they are ordered alphabetically
aggregation_areas[aggr_dict.name].sort_values(by="name").reset_index(
drop=True
)
return aggregation_areas

@cache_method_wrapper
def get_model_boundary(self) -> GeoDataFrame:
"""Get the model boundary from the SFINCS model"""
bnd = self._database.static_sfincs_model.get_model_boundary()
return bnd

@cache_method_wrapper
def get_model_grid(self) -> QuadtreeGrid:
"""Get the model grid from the SFINCS model
Returns
-------
QuadtreeGrid
The model grid
"""
grid = self._database.static_sfincs_model.get_model_grid()
return grid

@cache_method_wrapper
def get_obs_points(self) -> GeoDataFrame:
"""Get the observation points from the flood hazard model"""
if self._database.site.attrs.obs_point is not None:
obs_points = self._database.site.attrs.obs_point
names = []
descriptions = []
lat = []
lon = []
for pt in obs_points:
names.append(pt.name)
descriptions.append(pt.description)
lat.append(pt.lat)
lon.append(pt.lon)

# create GeoDataFrame from obs_points in site file
df = pd.DataFrame({"name": names, "description": descriptions})
# TODO: make crs flexible and add this as a parameter to site.toml?
gdf = gpd.GeoDataFrame(
df, geometry=gpd.points_from_xy(lon, lat), crs="EPSG:4326"
)
return gdf

@cache_method_wrapper
def get_static_map(self, path: Union[str, Path]) -> gpd.GeoDataFrame:
"""Get a map from the static folder
Parameters
----------
path : Union[str, Path]
Path to the map relative to the static folder
Returns
-------
gpd.GeoDataFrame
GeoDataFrame with the map in crs 4326
Raises
------
FileNotFoundError
If the file is not found
"""
# Read the map
full_path = self._database.static_path / path
if full_path.is_file():
return gpd.read_file(full_path, engine="pyogrio").to_crs(4326)

# If the file is not found, throw an error
raise FileNotFoundError(f"File {full_path} not found")

@cache_method_wrapper
def get_slr_scn_names(self) -> list:
"""Get the names of the sea level rise scenarios from the slr.csv file
Returns
-------
list
List of scenario names
"""
input_file = self._database.static_path.joinpath("slr", "slr.csv")
df = pd.read_csv(input_file)
return df.columns[2:].to_list()

@cache_method_wrapper
def get_green_infra_table(self, measure_type: str) -> pd.DataFrame:
"""Return a table with different types of green infrastructure measures and their infiltration depths.
This is read by a csv file in the database.
Returns
-------
pd.DataFrame
Table with values
"""
# Read file from database
df = pd.read_csv(
self._database.static_path.joinpath(
"green_infra_table", "green_infra_lookup_table.csv"
)
)

# Get column with values
val_name = "Infiltration depth"
col_name = [name for name in df.columns if val_name in name][0]
if not col_name:
raise KeyError(f"A column with a name containing {val_name} was not found!")

# Get list of types per measure
df["types"] = [
[x.strip() for x in row["types"].split(",")] for i, row in df.iterrows()
]

# Show specific values based on measure type
inds = [i for i, row in df.iterrows() if measure_type in row["types"]]
df = df.drop(columns="types").iloc[inds, :]

return df

@cache_method_wrapper
def get_buildings(self) -> GeoDataFrame:
"""Get the building footprints from the FIAT model.
This should only be the buildings excluding any other types (e.g., roads)
The parameters non_building_names in the site config is used for that
Returns
-------
GeoDataFrame
building footprints with all the FIAT columns
"""
# use hydromt-fiat to load the fiat model
fm = FiatModel(
root=self._database.static_path / "templates" / "fiat",
mode="r",
)
fm.read()
buildings = fm.exposure.select_objects(
primary_object_type="ALL",
non_building_names=self._database.site.attrs.fiat.non_building_names,
return_gdf=True,
)

del fm

return buildings

@cache_method_wrapper
def get_property_types(self) -> list:
"""_summary_
Returns
-------
list
_description_
"""
# use hydromt-fiat to load the fiat model
fm = FiatModel(
root=self._database.static_path / "templates" / "fiat",
mode="r",
)
fm.read()
types = fm.exposure.get_primary_object_type()
for name in self._database.site.attrs.fiat.non_building_names:
if name in types:
types.remove(name)
# Add "all" type for using as identifier
types.append("all")

del fm

return types
2 changes: 2 additions & 0 deletions flood_adapt/dbs_classes/dbs_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class DbsTemplate(AbstractDatabaseElement):
_type = ""
_folder_name = ""
_object_model_class = None
_path = None
_database = None

def __init__(self, database: IDatabase):
"""
Expand Down
Loading

0 comments on commit 6a18711

Please sign in to comment.