Skip to content

Commit

Permalink
inflow schema 300 (#377)
Browse files Browse the repository at this point in the history
* Update modelchecker to work with schema 223
* Add extra checks for parameters that cannot be Null
* Test tag validity
* Add test for ListOfIntsCheck and fix mistake in ListOfIntsCheck
* Test and fix TagsValidCheck
  • Loading branch information
margrietpalm authored Aug 1, 2024
1 parent 13b00b1 commit 36ee17d
Show file tree
Hide file tree
Showing 8 changed files with 282 additions and 280 deletions.
2 changes: 1 addition & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Changelog of threedi-modelchecker
2.8.2 (unreleased)
------------------

- Nothing changed yet.
- Adapt modelchecker to work with schema upgrades for inflow (0.223)


2.8.1 (2024-07-24)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ dependencies = [
"Click",
"GeoAlchemy2>=0.9,!=0.11.*",
"SQLAlchemy>=1.4",
"threedi-schema==0.222.*"
"threedi-schema==0.223.*"
]

[project.optional-dependencies]
Expand Down
19 changes: 19 additions & 0 deletions threedi_modelchecker/checks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,3 +392,22 @@ def description(self):
if self.max_value is not None:
parts.append(f"{'>' if self.right_inclusive else '>='}{self.max_value}")
return f"{self.column_name} is {' and/or '.join(parts)}"


class ListOfIntsCheck(BaseCheck):
def get_invalid(self, session):
invalids = []
for record in self.to_check(session).filter(
(self.column != None) & (self.column != "")
):
# check if casting to int works
try:
[int(x) for x in getattr(record, self.column.name).split(",")]
except ValueError:
invalids.append(record)
return invalids

def description(self) -> str:
return (
f"{self.table.name}.{self.column} is not a comma seperated list of integers"
)
119 changes: 53 additions & 66 deletions threedi_modelchecker/checks/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,37 +233,20 @@ def __init__(self, *args, **kwargs):
column=models.SimulationTemplateSettings.use_0d_inflow, *args, **kwargs
)

def to_check(self, session):
"""Return a Query object on which this check is applied"""
return session.query(models.SimulationTemplateSettings).filter(
models.SimulationTemplateSettings.use_0d_inflow != 0
)

def get_invalid(self, session):
surface_count = session.query(func.count(models.Surface.id)).scalar()
impervious_surface_count = session.query(
func.count(models.ImperviousSurface.id)
).scalar()

invalid_rows = []
for row in self.to_check(session):
if (
row.use_0d_inflow == constants.InflowType.IMPERVIOUS_SURFACE
and impervious_surface_count == 0
):
invalid_rows.append(row)
elif (
row.use_0d_inflow == constants.InflowType.SURFACE and surface_count == 0
):
invalid_rows.append(row)
else:
continue
return invalid_rows
settings = session.query(models.SimulationTemplateSettings).one_or_none()
if settings is None:
return []
use_0d_flow = settings.use_0d_inflow
if use_0d_flow != constants.InflowType.NO_INFLOW:
surface_count = session.query(func.count(models.Surface.id)).scalar()
if surface_count == 0:
return [settings]
return []

def description(self):
return (
f"When {self.column_name} is used, there should exist at least one "
"(impervious) surface."
f"When {self.column_name} is used, there should exist at least one surface."
)


Expand Down Expand Up @@ -742,30 +725,27 @@ def description(self) -> str:
return f"{self.column_name} will empty its storage faster than one timestep, which can cause simulation instabilities"


class ImperviousNodeInflowAreaCheck(BaseCheck):
class SurfaceNodeInflowAreaCheck(BaseCheck):
"""Check that total inflow area per connection node is no larger than 10000 square metres"""

def __init__(self, *args, **kwargs):
super().__init__(column=models.ConnectionNode.id, *args, **kwargs)

def get_invalid(self, session: Session) -> List[NamedTuple]:
impervious_surfaces = (
select(models.ImperviousSurfaceMap.connection_node_id)
.select_from(models.ImperviousSurfaceMap)
surfaces = (
select(models.SurfaceMap.connection_node_id)
.select_from(models.SurfaceMap)
.join(
models.ImperviousSurface,
models.ImperviousSurfaceMap.impervious_surface_id
== models.ImperviousSurface.id,
models.Surface,
models.SurfaceMap.surface_id == models.Surface.id,
)
.group_by(models.ImperviousSurfaceMap.connection_node_id)
.having(func.sum(models.ImperviousSurface.area) > 10000)
.group_by(models.SurfaceMap.connection_node_id)
.having(func.sum(models.Surface.area) > 10000)
).subquery()

return (
session.query(models.ConnectionNode)
.filter(
models.ConnectionNode.id == impervious_surfaces.c.connection_node_id
)
.filter(models.ConnectionNode.id == surfaces.c.connection_node_id)
.all()
)

Expand Down Expand Up @@ -804,14 +784,14 @@ def description(self) -> str:
class InflowNoFeaturesCheck(BaseCheck):
"""Check that the surface table in the global use_0d_inflow setting contains at least 1 feature."""

def __init__(self, *args, surface_table, condition=True, **kwargs):
def __init__(self, *args, feature_table, condition=True, **kwargs):
super().__init__(*args, column=models.ModelSettings.id, **kwargs)
self.surface_table = surface_table
self.feature_table = feature_table
self.condition = condition

def get_invalid(self, session: Session):
surface_table_length = session.execute(
select(func.count(self.surface_table.id))
select(func.count(self.feature_table.id))
).scalar()
return (
session.query(models.ModelSettings)
Expand All @@ -820,41 +800,28 @@ def get_invalid(self, session: Session):
)

def description(self) -> str:
return f"model_settings.use_0d_inflow is set to use {self.surface_table.__tablename__}, but {self.surface_table.__tablename__} does not contain any features."
return f"model_settings.use_0d_inflow is set to use {self.feature_table.__tablename__}, but {self.feature_table.__tablename__} does not contain any features."


class NodeSurfaceConnectionsCheck(BaseCheck):
"""Check that no more than 50 surfaces are mapped to a connection node"""

def __init__(
self,
check_type: Literal["impervious", "pervious"] = "impervious",
*args,
**kwargs,
):
def __init__(self, *args, **kwargs):
super().__init__(column=models.ConnectionNode.id, *args, **kwargs)

self.surface_column = None
if check_type == "impervious":
self.surface_column = models.ImperviousSurfaceMap
elif check_type == "pervious":
self.surface_column = models.SurfaceMap
self.surface_column = models.SurfaceMap

def get_invalid(self, session: Session) -> List[NamedTuple]:
if self.surface_column is None:
return []

overloaded_connections = (
select(self.surface_column.connection_node_id)
.group_by(self.surface_column.connection_node_id)
.having(func.count(self.surface_column.connection_node_id) > 50)
).subquery()
select(models.SurfaceMap.connection_node_id)
.group_by(models.SurfaceMap.connection_node_id)
.having(func.count(models.SurfaceMap.connection_node_id) > 50)
)

return (
session.query(models.ConnectionNode)
.filter(
models.ConnectionNode.id == overloaded_connections.c.connection_node_id
)
self.to_check(session)
.filter(models.ConnectionNode.id.in_(overloaded_connections))
.all()
)

Expand Down Expand Up @@ -923,8 +890,8 @@ def get_invalid(self, session: Session) -> List[NamedTuple]:
all_results = select(
self.table.c.id,
self.table.c.area,
self.table.c.the_geom,
func.ST_Area(transform(self.table.c.the_geom)).label("calculated_area"),
self.table.c.geom,
func.ST_Area(transform(self.table.c.geom)).label("calculated_area"),
).subquery()
return (
session.query(all_results)
Expand Down Expand Up @@ -1113,3 +1080,23 @@ def description(self) -> str:
f"{self.table.name} has {self.observed_length} rows, "
f"but should have at most 1 row."
)


class TagsValidCheck(BaseCheck):
def get_invalid(self, session):
invalids = []
for record in self.to_check(session).filter(
(self.column != None) & (self.column != "")
):
query = (
f"SELECT id FROM tags WHERE id IN ({getattr(record, self.column.name)})"
)
match_rows = session.connection().execute(text(query)).fetchall()
found_idx = {row[0] for row in match_rows}
req_idx = {int(x) for x in getattr(record, self.column.name).split(",")}
if found_idx != req_idx:
invalids.append(record)
return invalids

def description(self) -> str:
return f"{self.table.name}.{self.column} refers to tag ids that are not present in Tags, "
Loading

0 comments on commit 36ee17d

Please sign in to comment.