Skip to content

Commit

Permalink
Improved header-data alignment checks
Browse files Browse the repository at this point in the history
  • Loading branch information
Erik-Geo committed Mar 20, 2024
1 parent 5002958 commit 522ef74
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 9 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ jobs:
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Setup Pixi
uses: prefix-dev/setup-pixi@v0.5.1
- name: Run linter
Expand Down
46 changes: 41 additions & 5 deletions geost/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __new__(cls, *args, **kwargs):
def __repr__(self):
return f"{self.__class__.__name__}:\n# header = {self.n_points}"

def reset_header(self):
def reset_header(self, *args):
"""
Create a new header based on the 'data' dataframe
(:py:attr:`~geost.base.PointDataCollection.data`). Can be used to reset the
Expand Down Expand Up @@ -200,9 +200,8 @@ def header(self, header):
header through validation and warn the user of any potential problems.
"""
headerschema.validate(header)
if any(~header["nr"].isin(self.data["nr"].unique())):
warn("Header does not cover all unique objects in data")
self._header = header
self.__check_header_to_data_alignment()

@data.setter
def data(self, data):
Expand All @@ -221,6 +220,34 @@ def data(self, data):
inclined_dataschema.validate(data)

self._data = data
self.__check_header_to_data_alignment()

def __check_header_to_data_alignment(self):
"""
Two-way check to warn of any misalignment between the header and data
attributes. Two way, i.e. if header includes more objects than in the data and
if the data includes more unique objects that listed in the header.
This check is performed everytime the object is instantiated AND if any change
is made to either the header or data attributes (see their respective setters).
"""
if hasattr(self, "_header") and hasattr(self, "_data"):
if any(~self.header["nr"].isin(self.data["nr"].unique())):
warn(
"Header covers more objects than present in the data table, "
"consider running the method 'reset_header' to update the header."
)
if any(
[
True
for nr in self.data["nr"].unique()
if not self.header["nr"].isin([nr]).any()
]
):
warn(
"Header does not cover all unique objects in data, consider running "
+ "the method 'reset_header' to update the header."
)

def change_vertical_reference(self, to: str):
"""
Expand Down Expand Up @@ -730,7 +757,9 @@ def slice_depth_interval(

return result

def slice_by_values(self, column: str, selection_values: Union[str, Iterable]):
def slice_by_values(
self, column: str, selection_values: Union[str, Iterable], invert: bool = False
):
"""
Slice rows from data based on matching condition. E.g. only return rows with
a certain lithology in the collection object.
Expand All @@ -742,6 +771,9 @@ def slice_by_values(self, column: str, selection_values: Union[str, Iterable]):
values.
selection_values : Union[str, Iterable]
Values to look for in the column.
invert : bool
Invert the slicing action, so remove layers with selected values instead of
keeping them.
Returns
-------
Expand All @@ -754,7 +786,11 @@ def slice_by_values(self, column: str, selection_values: Union[str, Iterable]):
selection_values = [selection_values]

data_sliced = self.data.copy()
data_sliced = data_sliced[data_sliced[column].isin(selection_values)]
if invert:
data_sliced = data_sliced[~data_sliced[column].isin(selection_values)]
elif not invert:
data_sliced = data_sliced[data_sliced[column].isin(selection_values)]

header_sliced = self.header.loc[
self.header["nr"].isin(data_sliced["nr"].unique())
]
Expand Down
4 changes: 2 additions & 2 deletions tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_vtk_prepare_borehole(self, borehole_collection):
)
assert_array_equal(prepared_borehole, target)

@pytest.mark.skip(reason="CI known for not working with pyvisyta")
@pytest.mark.unittest # .skip(reason="Gitlab CI not working with pyvisyta")
def test_vtk_borehole_to_multiblock(self, borehole_collection):
multiblock = vtk.borehole_to_multiblock(
borehole_collection.data,
Expand All @@ -126,7 +126,7 @@ def test_vtk_borehole_to_multiblock(self, borehole_collection):
assert multiblock[0].n_cells == 22
assert multiblock[0].n_points == 260

@pytest.mark.skip(reason="CI known for not working with pyvisyta")
@pytest.mark.unittest # .skip(reason="Gitlab CI not working with pyvisyta")
def test_to_vtm(self, borehole_collection):
out_file = self.export_folder.joinpath("test_output_file.vtm")
out_folder = self.export_folder.joinpath("test_output_file")
Expand Down
26 changes: 24 additions & 2 deletions tests/test_pointcollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,20 @@ def header_missing_object(self):

return dataframe

@pytest.fixture
def header_surplus_objects(self):
dataframe = pd.DataFrame(
{
"nr": ["B-01", "B-02", "B-03"],
"x": [139370, 100000, 110000],
"y": [455540, 400000, 410000],
"mv": [1, 0, -1],
"end": [-4, -8, -9],
}
)

return dataframe

@pytest.mark.unittest
def test_change_vertical_reference(self, borehole_df_ok):
borehole_collection_ok = BoreholeCollection(borehole_df_ok)
Expand Down Expand Up @@ -241,7 +255,15 @@ def test_validation_fail(self, capfd, borehole_df_bad_validation):
assert 'data in column "end" failed check "< mv" for 1 rows: [1]' in out

@pytest.mark.integrationtest
def test_header_mismatch(self, capfd, borehole_df_ok, header_missing_object):
BoreholeCollection(borehole_df_ok, header=header_missing_object)
def test_header_mismatch(
self, capfd, borehole_df_ok, header_missing_object, header_surplus_objects
):
# Situation #1: More unique objects in data table than listed in header
collection = BoreholeCollection(borehole_df_ok, header=header_missing_object)
out, err = capfd.readouterr()
assert "Header does not cover all unique objects in data" in out

# Situation #2: More objects in header table than in data table
collection = BoreholeCollection(borehole_df_ok, header=header_surplus_objects)
out, err = capfd.readouterr()
assert "Header covers more objects than present in the data table" in out

0 comments on commit 522ef74

Please sign in to comment.