diff --git a/hydrolib/core/io/crosssection/models.py b/hydrolib/core/io/crosssection/models.py index e16b8f588..5047d613f 100644 --- a/hydrolib/core/io/crosssection/models.py +++ b/hydrolib/core/io/crosssection/models.py @@ -7,6 +7,8 @@ from hydrolib.core.io.friction.models import FrictionType from hydrolib.core.io.ini.models import INIBasedModel, INIGeneral, INIModel from hydrolib.core.io.ini.util import ( + LocationValidationConfiguration, + LocationValidationFieldNames, get_enum_validator, get_from_subclass_defaults, get_location_specification_rootvalidator, @@ -683,7 +685,10 @@ class Comments(INIBasedModel.Comments): definitionid: str = Field(alias="definitionId") _location_validator = get_location_specification_rootvalidator( - allow_nodeid=False, numfield_name=None, xfield_name="x", yfield_name="y" + config=LocationValidationConfiguration( + validate_node=False, validate_num_coordinates=False + ), + fields=LocationValidationFieldNames(x_coordinates="x", y_coordinates="y"), ) diff --git a/hydrolib/core/io/ext/models.py b/hydrolib/core/io/ext/models.py index 87b28dadd..4b40f158f 100644 --- a/hydrolib/core/io/ext/models.py +++ b/hydrolib/core/io/ext/models.py @@ -11,8 +11,8 @@ from hydrolib.core.io.ini.models import INIBasedModel, INIGeneral, INIModel from hydrolib.core.io.ini.serializer import SerializerConfig, write_ini from hydrolib.core.io.ini.util import ( + LocationValidationConfiguration, get_location_specification_rootvalidator, - get_number_of_coordinates_validator, get_split_string_on_delimiter_validator, make_list_validator, ) @@ -141,9 +141,8 @@ def is_intermediate_link(self) -> bool: "xcoordinates", "ycoordinates" ) - _location_validator = get_location_specification_rootvalidator(allow_nodeid=True) - _number_of_coordinates_validator = get_number_of_coordinates_validator( - minimum_required_number_of_coordinates=1 + _location_validator = get_location_specification_rootvalidator( + config=LocationValidationConfiguration(minimum_num_coordinates=1) ) def _get_identifier(self, data: dict) -> Optional[str]: diff --git a/hydrolib/core/io/ini/util.py b/hydrolib/core/io/ini/util.py index 3a4cf5c10..0fca0b861 100644 --- a/hydrolib/core/io/ini/util.py +++ b/hydrolib/core/io/ini/util.py @@ -8,7 +8,8 @@ from pydantic.fields import ModelField from pydantic.main import BaseModel -from hydrolib.core.utils import operator_str, str_is_empty_or_none +from hydrolib.core.io.common.models import LocationType +from hydrolib.core.utils import operator_str, str_is_empty_or_none, to_list def get_split_string_on_delimiter_validator(*field_name: str): @@ -251,33 +252,73 @@ def get_from_subclass_defaults(cls: Type[BaseModel], fieldname: str, value: str) return value +class LocationValidationConfiguration(BaseModel): + """Class that holds the various configuration settings needed for location validation.""" + + validate_node: bool = True + """bool, optional: Whether or not node location specification should be validated. Defaults to True.""" + + validate_coordinates: bool = True + """bool, optional: Whether or not coordinate location specification should be validated. Defaults to True.""" + + validate_branch: bool = True + """bool, optional: Whether or not branch location specification should be validated. Defaults to True.""" + + validate_num_coordinates: bool = True + """bool, optional: Whether or not the number of coordinates should be validated or not. This option is only relevant when `validate_coordinates` is True. Defaults to True.""" + + minimum_num_coordinates: int = 0 + """int, optional: The minimum required number of coordinates. This option is only relevant when `validate_coordinates` is True. Defaults to 0.""" + + +class LocationValidationFieldNames(BaseModel): + """Class that holds the various field names needed for location validation.""" + + node_id: str = "nodeId" + """str, optional: The node id field name. Defaults to `nodeId`.""" + + branch_id: str = "branchId" + """str, optional: The branch id field name. Defaults to `branchId`.""" + + chainage: str = "chainage" + """str, optional: The chainage field name. Defaults to `chainage`.""" + + x_coordinates: str = "xCoordinates" + """str, optional: The x-coordinates field name. Defaults to `xCoordinates`.""" + + y_coordinates: str = "yCoordinates" + """str, optional: The y-coordinates field name. Defaults to `yCoordinates`.""" + + num_coordinates: str = "numCoordinates" + """str, optional: The number of coordinates field name. Defaults to `numCoordinates`.""" + + location_type: str = "locationType" + """str, optional: The location type field name. Defaults to `locationType`.""" + + def get_location_specification_rootvalidator( - allow_nodeid: bool = True, - numfield_name: str = "numCoordinates", - xfield_name: str = "xCoordinates", - yfield_name: str = "yCoordinates", + config: Optional[LocationValidationConfiguration] = None, + fields: Optional[LocationValidationFieldNames] = None, ): """ Get a root validator that checks for correct location specification in typical 1D2D input in an IniBasedModel class. - It checks for presence of at least one of: nodeId (if allowed), - branchId+chainage or num/x/yCoordinates. - Also, completeness of the given location is checked (e.g., no chainage - missing when branchId given), as well as the locationType. + Validates for presence of at least one of: nodeId, branchId with chainage, + xCoordinates with yCoordinates, or xCoordinates with yCoordinates and numCoordinates. + Validates for the locationType for nodeId and branchId. Args: - allow_nodeid (bool): Allow nodeId in input. Defaults to True. - numfield_name (str): Field name (in input file) for the coordinates - count. Will be lowercased in values dict. Use None when this - class has no count field at all. Defaults to "numCoordinates". - xfield_name (str): Field name (in input file) for the x coordinates. - Will be lowercased in values dict. Defaults to "xCoordinates". - yfield_name (str): Field name (in input file) for the y coordinates. - Will be lowercased in values dict. Defaults to "yCoordinates". - + config (LocationValidationConfiguration, optional): Configuration for the location validation. Default is None. + field (LocationValidationFieldNames, optional): Fields names that should be used for the location validation. Default is None. """ + if config is None: + config = LocationValidationConfiguration() + + if fields is None: + fields = LocationValidationFieldNames() + def validate_location_specification(cls, values: Dict) -> Dict: """ Verify whether the location given for this object matches the expectations. @@ -286,156 +327,138 @@ def validate_location_specification(cls, values: Dict) -> Dict: values (Dict): Dictionary of object's validated fields. Raises: - ValueError: When neither nodeid, branchid or coordinates have been given. - ValueError: When either x or y coordinates were expected but not given. - ValueError: When locationtype should be 1d but other was specified. + ValueError: When exactly one of the following combinations were not given: + - nodeId + - branchId with chainage + - xCoordinates with yCoordinates + - xCoordinates with yCoordinates and numCoordinates. + ValueError: When numCoordinates does not meet the requirement minimum amount or does not match the amount of xCoordinates or yCoordinates. + ValueError: When locationType should be 1d but other was specified. Returns: Dict: Validated dictionary of input class fields. """ - def validate_coordinates(coord_name: str) -> None: - if values.get(coord_name.lower(), None) is None: - raise ValueError("{} should be given.".format(coord_name)) - - # If nodeid or branchid and Chainage are present - node_id: str = values.get("nodeid", None) - branch_id: str = values.get("branchid", None) - n_coords: int = ( - values.get(numfield_name.lower(), 0) - if not str_is_empty_or_none(numfield_name) - else None - ) + has_node_id = not str_is_empty_or_none(values.get(fields.node_id.lower())) + has_branch_id = not str_is_empty_or_none(values.get(fields.branch_id.lower())) + has_chainage = values.get(fields.chainage.lower()) is not None + has_x_coordinates = values.get(fields.x_coordinates.lower()) is not None + has_y_coordinates = values.get(fields.y_coordinates.lower()) is not None + has_num_coordinates = values.get(fields.num_coordinates.lower()) is not None - chainage: float = values.get("chainage", None) + # ----- Local validation functions + def get_length(field: str): + value = values[field.lower()] + return len(to_list(value)) - # First validation - at least one of the following should be specified. - if str_is_empty_or_none(node_id) and (str_is_empty_or_none(branch_id)): - if n_coords == 0: - raise ValueError( - f"Either {'nodeId, ' if allow_nodeid else ''}branchId (with chainage) or {numfield_name + ' with ' if numfield_name else ''}{xfield_name} and {yfield_name} are required." - ) - else: - # Validation: when ids are absent, coordinates should be valid. - validate_coordinates(xfield_name) - validate_coordinates(yfield_name) - return values - else: - # Validation: nodeId only when it is allowed - if not str_is_empty_or_none(node_id) and not allow_nodeid: - raise ValueError(f"nodeId is not allowed for {cls.__name__} objects") - # Validation: chainage should be given with branchid - if not str_is_empty_or_none(branch_id) and chainage is None: - raise ValueError( - "Chainage should be provided when branchId is specified." - ) - # Validation: when nodeid, or branchid specified, expected 1d. - location_type = values.get("locationtype", None) + def validate_location_type(expected_location_type: LocationType) -> None: + location_type = values.get(fields.location_type.lower(), None) if str_is_empty_or_none(location_type): - values["locationtype"] = "1d" - elif location_type.lower() != "1d": + values[fields.location_type.lower()] = expected_location_type + elif location_type != expected_location_type: raise ValueError( - "locationType should be 1d when nodeId (or branchId and chainage) is specified." + f"{fields.location_type} should be {expected_location_type} but was {location_type}" ) - return values - - return root_validator(allow_reuse=True)(validate_location_specification) - + def validate_coordinates_with_num_coordinates() -> None: + length_x_coordinates = get_length(fields.x_coordinates) + length_y_coordinates = get_length(fields.y_coordinates) + num_coordinates = values[fields.num_coordinates.lower()] -def get_number_of_coordinates_validator( - numfield_name: str = "numCoordinates", - xfield_name: str = "xCoordinates", - yfield_name: str = "yCoordinates", - minimum_required_number_of_coordinates: int = 0, -): - """ - Get a validator that validates whether the given coordinates match in number - to the expected value given by numCoordinates and that numCoordinates is - greater than or equal to the minimum required number of coordinates. + if not num_coordinates == length_x_coordinates == length_y_coordinates: + raise ValueError( + f"{fields.num_coordinates} should be equal to the amount of {fields.x_coordinates} and {fields.y_coordinates}" + ) - Args: - numfield_name (str, optional): Field name (in input file) for the coordinates - count. Will be lowercased in values dict. Defaults to "numCoordinates". - xfield_name (str, optional): Field name (in input file) for the x coordinates. - Will be lowercased in values dict. Defaults to "xCoordinates". - yfield_name (str, optional): Field name (in input file) for the y coordinates. - Will be lowercased in values dict. Defaults to "yCoordinates". - minimum_required_number_of_coordinates (int, optional): Minimum number of - coordinates required in order to validate. Defaults to 0. - """ + validate_minimum_num_coordinates(num_coordinates) - def validate_number_of_coordinates(cls, values: Dict) -> Dict: - """ - Validates whether the given coordinates match in number to the - expected value given for numCoordinates and is greater than or - equal to the minimum required number of coordinates. + def validate_coordinates() -> None: + len_x_coordinates = get_length(fields.x_coordinates) + len_y_coordinates = get_length(fields.y_coordinates) - Args: - values (Dict): Dictionary of object's validated fields. + if len_x_coordinates != len_y_coordinates: + raise ValueError( + f"{fields.x_coordinates} and {fields.y_coordinates} should have an equal amount of coordinates" + ) - Raises: - ValueError: When the number of coordinates is not specified but the coordinates are. - ValueError: When the number of coordinates is provided but the x-coordinates or - y-coordinates are not. - ValueError: When the number of x-coordinates or the number of y-coordinates - does not match the number of coordinates. - ValueError: When the number of x-coordinates or the number of y-coordinates - is less than the number of required coordinates. + validate_minimum_num_coordinates(len_x_coordinates) - Returns: - Dict: Validated dictionary of input class fields. - """ + def validate_minimum_num_coordinates(actual_num: int) -> None: + if actual_num < config.minimum_num_coordinates: + raise ValueError( + f"{fields.x_coordinates} and {fields.y_coordinates} should have at least {config.minimum_num_coordinates} coordinate(s)" + ) - def get_value(field_name: str) -> Any: - return ( - values.get(field_name.lower(), None) - if not str_is_empty_or_none(field_name) - else None + def is_valid_node_specification() -> bool: + has_other = ( + has_branch_id + or has_chainage + or has_x_coordinates + or has_y_coordinates + or has_num_coordinates ) + return has_node_id and not has_other + + def is_valid_branch_specification() -> bool: + has_other = ( + has_node_id + or has_x_coordinates + or has_y_coordinates + or has_num_coordinates + ) + return has_branch_id and has_chainage and not has_other - def all_values_are_none() -> bool: - return ( - number_of_coordinates is None - and xcoordinates is None - and ycoordinates is None + def is_valid_coordinates_specification() -> bool: + has_other = ( + has_node_id or has_branch_id or has_chainage or has_num_coordinates ) + return has_x_coordinates and has_y_coordinates and not has_other - def some_values_are_none() -> bool: + def is_valid_coordinates_with_num_coordinates_specification() -> bool: + has_other = has_node_id or has_branch_id or has_chainage return ( - number_of_coordinates is None - or xcoordinates is None - or ycoordinates is None + has_x_coordinates + and has_y_coordinates + and has_num_coordinates + and not has_other ) - def validate_x_and_ycoordinate_number() -> None: - number_of_xcoordinates = len(xcoordinates) - number_of_ycoordinates = len(ycoordinates) + # ----- - if ( - number_of_xcoordinates != number_of_coordinates - or number_of_ycoordinates != number_of_coordinates - or number_of_xcoordinates < minimum_required_number_of_coordinates - ): - raise ValueError( - f"Number of x-coordinates and y-coordinates should match number of" - "coordinates and should be atleast {minimum_required_number_of_coordinates}." - ) + error_parts: List[str] = [] - number_of_coordinates = get_value(numfield_name) - xcoordinates = get_value(xfield_name) - ycoordinates = get_value(yfield_name) + if config.validate_node: + if is_valid_node_specification(): + validate_location_type(LocationType.oned) + return values - if all_values_are_none(): - return values + error_parts.append(fields.node_id) - if some_values_are_none(): - raise ValueError( - f"When using coordinates, the fields {numfield_name}, {xfield_name} and {yfield_name} should be given." - ) + if config.validate_branch: + if is_valid_branch_specification(): + validate_location_type(LocationType.oned) + return values - validate_x_and_ycoordinate_number() + error_parts.append(f"{fields.branch_id} and {fields.chainage}") - return values + if config.validate_coordinates: + if config.validate_num_coordinates: + if is_valid_coordinates_with_num_coordinates_specification(): + validate_coordinates_with_num_coordinates() + return values - return root_validator(allow_reuse=True)(validate_number_of_coordinates) + error_parts.append( + f"{fields.x_coordinates}, {fields.y_coordinates} and {fields.num_coordinates}" + ) + + else: + if is_valid_coordinates_specification(): + validate_coordinates() + return values + + error_parts.append(f"{fields.x_coordinates} and {fields.y_coordinates}") + + error = " or ".join(error_parts) + " should be provided" + raise ValueError(error) + + return root_validator(allow_reuse=True)(validate_location_specification) diff --git a/hydrolib/core/io/obs/models.py b/hydrolib/core/io/obs/models.py index 9d24c6c29..a0a57cdcd 100644 --- a/hydrolib/core/io/obs/models.py +++ b/hydrolib/core/io/obs/models.py @@ -5,6 +5,8 @@ from hydrolib.core.io.common.models import LocationType from hydrolib.core.io.ini.models import INIBasedModel, INIGeneral, INIModel from hydrolib.core.io.ini.util import ( + LocationValidationConfiguration, + LocationValidationFieldNames, get_enum_validator, get_location_specification_rootvalidator, make_list_validator, @@ -70,8 +72,12 @@ class Comments(INIBasedModel.Comments): y: Optional[float] = Field(None, alias="y") _type_validator = get_enum_validator("locationtype", enum=LocationType) + _location_validator = get_location_specification_rootvalidator( - allow_nodeid=False, numfield_name=None, xfield_name="x", yfield_name="y" + config=LocationValidationConfiguration( + validate_node=False, validate_num_coordinates=False + ), + fields=LocationValidationFieldNames(x_coordinates="x", y_coordinates="y"), ) def _get_identifier(self, data: dict) -> Optional[str]: diff --git a/hydrolib/core/io/obscrosssection/models.py b/hydrolib/core/io/obscrosssection/models.py index 0805865a3..5eb1e165e 100644 --- a/hydrolib/core/io/obscrosssection/models.py +++ b/hydrolib/core/io/obscrosssection/models.py @@ -4,8 +4,8 @@ from hydrolib.core.io.ini.models import INIBasedModel, INIGeneral, INIModel from hydrolib.core.io.ini.util import ( + LocationValidationConfiguration, get_location_specification_rootvalidator, - get_number_of_coordinates_validator, get_split_string_on_delimiter_validator, ) @@ -71,10 +71,10 @@ class Comments(INIBasedModel.Comments): "xcoordinates", "ycoordinates" ) - _location_validator = get_location_specification_rootvalidator(allow_nodeid=False) - - _number_of_coordinates_validator = get_number_of_coordinates_validator( - minimum_required_number_of_coordinates=2 + _location_validator = get_location_specification_rootvalidator( + config=LocationValidationConfiguration( + validate_node=False, minimum_num_coordinates=2 + ) ) def _get_identifier(self, data: dict) -> Optional[str]: diff --git a/tests/io/ini/test_util.py b/tests/io/ini/test_util.py index b7657f9fe..67ae1872f 100644 --- a/tests/io/ini/test_util.py +++ b/tests/io/ini/test_util.py @@ -3,92 +3,198 @@ import pytest from pydantic.error_wrappers import ValidationError -from hydrolib.core.io.ini.models import INIBasedModel -from hydrolib.core.io.ini.util import get_number_of_coordinates_validator - - -class TestCoordinatesValidator: - class DummyModel(INIBasedModel): - """Dummy model to test the validation of the number of coordinates.""" - - numcoordinates: Optional[int] +from hydrolib.core.basemodel import BaseModel +from hydrolib.core.io.ini.util import ( + LocationValidationConfiguration, + LocationValidationFieldNames, + get_location_specification_rootvalidator, +) + + +class TestLocationValidationConfiguration: + def test_default(self): + config = LocationValidationConfiguration() + assert config.validate_node == True + assert config.validate_coordinates == True + assert config.validate_branch == True + assert config.validate_num_coordinates == True + assert config.minimum_num_coordinates == 0 + + +class TestLocationValidationFieldNames: + def test_default(self): + fields = LocationValidationFieldNames() + assert fields.node_id == "nodeId" + assert fields.branch_id == "branchId" + assert fields.chainage == "chainage" + assert fields.x_coordinates == "xCoordinates" + assert fields.y_coordinates == "yCoordinates" + assert fields.num_coordinates == "numCoordinates" + assert fields.location_type == "locationType" + + +class TestLocationSpecificationValidator: + class DummyModel(BaseModel): + """Dummy model to test the validation of the location specification.""" + + nodeid: Optional[str] + branchid: Optional[str] + chainage: Optional[str] xcoordinates: Optional[List[float]] ycoordinates: Optional[List[float]] + numcoordinates: Optional[int] + locationtype: Optional[str] - _number_of_coordinates_validator = get_number_of_coordinates_validator( - minimum_required_number_of_coordinates=2 + validator = get_location_specification_rootvalidator( + config=LocationValidationConfiguration(minimum_num_coordinates=3) ) - def test_all_values_none_does_not_throw(self): - model = TestCoordinatesValidator.DummyModel() - - assert model.numcoordinates is None - assert model.xcoordinates is None - assert model.ycoordinates is None - - def test_coordinates_given_but_none_expected_throws_value_error(self): - values = self._create_valid_dummy_model_values() - values["numcoordinates"] = None - - with pytest.raises(ValidationError): - TestCoordinatesValidator.DummyModel(**values) - - def test_no_xcoordinates_given_while_expected_throws_value_error(self): - values = self._create_valid_dummy_model_values() - values["xcoordinates"] = None - - with pytest.raises(ValidationError): - TestCoordinatesValidator.DummyModel(**values) - - def test_no_ycoordinates_given_while_expected_throws_value_error(self): - values = self._create_valid_dummy_model_values() - values["ycoordinates"] = None - - with pytest.raises(ValidationError): - TestCoordinatesValidator.DummyModel(**values) - - def test_fewer_xcoordinates_than_expected_throws_value_error(self): - values = self._create_valid_dummy_model_values() - values["xcoordinates"] = [1, 2] - - with pytest.raises(ValidationError): - TestCoordinatesValidator.DummyModel(**values) - - def test_more_xcoordinates_than_expected_throws_value_error(self): - values = self._create_valid_dummy_model_values() - values["xcoordinates"] = [1, 2, 3, 4] - - with pytest.raises(ValidationError): - TestCoordinatesValidator.DummyModel(**values) - - def test_fewer_ycoordinates_than_expected_throws_value_error(self): - values = self._create_valid_dummy_model_values() - values["ycoordinates"] = [1, 2] - - with pytest.raises(ValidationError): - TestCoordinatesValidator.DummyModel(**values) - - def test_more_ycoordinates_than_expected_throws_value_error(self): - values = self._create_valid_dummy_model_values() - values["ycoordinates"] = [1, 2, 3, 4] - - with pytest.raises(ValidationError): - TestCoordinatesValidator.DummyModel(**values) - - def test_fewer_than_minimum_required_number_of_coordinates_throws_value_error(self): - values = self._create_valid_dummy_model_values() - values["numcoordinates"] = 1 - values["xcoordinates"] = [1.23] - values["ycoordinates"] = [9.87] - - with pytest.raises(ValidationError): - TestCoordinatesValidator.DummyModel(**values) - - def _create_valid_dummy_model_values(self) -> Dict: - values = dict( - numcoordinates=3, - xcoordinates=[1.23, 4.56, 7.89], - ycoordinates=[9.87, 6.54, 3.21], + @pytest.mark.parametrize( + "values", + [ + {}, + { + "nodeid": "some_nodeid", + "branchid": "some_branchid", + "chainage": 1.23, + "xcoordinates": [4.56, 5.67, 6.78], + "ycoordinates": [7.89, 8.91, 9.12], + "numcoordinates": 3, + }, + { + "xcoordinates": [4.56, 5.67, 6.78], + "ycoordinates": [7.89, 8.91, 9.12], + }, + { + "nodeid": "some_nodeid", + "branchid": "some_branchid", + "chainage": 1.23, + }, + { + "branchid": "some_branchid", + "chainage": 1.23, + "xcoordinates": [4.56, 5.67, 6.78], + }, + { + "branchid": "some_branchid", + }, + ], + ) + def test_incorrect_fields_provided_raises_error(self, values: dict): + with pytest.raises(ValidationError) as error: + TestLocationSpecificationValidator.DummyModel(**values) + + expected_message = "nodeId or branchId and chainage or xCoordinates, yCoordinates and numCoordinates should be provided" + assert expected_message in str(error.value) + + def test_too_few_coordinates_raises_error(self): + values = { + "xcoordinates": [1.23, 2.34], + "ycoordinates": [3.45, 4.56], + "numcoordinates": 2, + } + + with pytest.raises(ValidationError) as error: + TestLocationSpecificationValidator.DummyModel(**values) + + expected_message = ( + "xCoordinates and yCoordinates should have at least 3 coordinate(s)" ) - - return values + assert expected_message in str(error.value) + + def test_coordinate_amount_does_not_match_numcoordinates_raises_error(self): + values = { + "xcoordinates": [1.23, 2.34, 3.45], + "ycoordinates": [4.56, 5.67, 6.78], + "numcoordinates": 4, + } + + with pytest.raises(ValidationError) as error: + TestLocationSpecificationValidator.DummyModel(**values) + + expected_message = "numCoordinates should be equal to the amount of xCoordinates and yCoordinates" + assert expected_message in str(error.value) + + @pytest.mark.parametrize( + "values", + [ + pytest.param( + { + "nodeid": "some_nodeid", + "locationtype": "2d", + }, + id="nodeid", + ), + pytest.param( + { + "branchid": "some_branchid", + "chainage": 1.23, + "locationtype": "2d", + }, + id="branchid", + ), + ], + ) + def test_incorrect_location_type_raises_error(self, values: dict): + with pytest.raises(ValidationError) as error: + TestLocationSpecificationValidator.DummyModel(**values) + + expected_message = "locationType should be 1d but was 2d" + assert expected_message in str(error.value) + + @pytest.mark.parametrize( + "values", + [ + pytest.param( + { + "nodeid": "some_nodeid", + }, + id="nodeid", + ), + pytest.param( + { + "branchid": "some_branchid", + "chainage": 1.23, + }, + id="branchid", + ), + pytest.param( + { + "xcoordinates": [4.56, 5.67, 6.78], + "ycoordinates": [7.89, 8.91, 9.12], + "numcoordinates": 3, + }, + id="coordinates", + ), + ], + ) + def test_correct_fields_initializes(self, values: dict): + validated_values = TestLocationSpecificationValidator.DummyModel.validator( + values + ) + assert validated_values == values + + @pytest.mark.parametrize( + "values, expected_values", + [ + pytest.param( + { + "nodeid": "some_nodeid", + }, + {"nodeid": "some_nodeid", "locationtype": "1d"}, + id="nodeid", + ), + pytest.param( + {"branchid": "some_branchid", "chainage": 1.23}, + {"branchid": "some_branchid", "chainage": 1.23, "locationtype": "1d"}, + id="branchid", + ), + ], + ) + def test_correct_1d_fields_locationtype_is_added( + self, values: dict, expected_values: dict + ): + validated_values = TestLocationSpecificationValidator.DummyModel.validator( + values + ) + assert validated_values == expected_values diff --git a/tests/io/test_crosssection.py b/tests/io/test_crosssection.py index 3b3ad472f..c23ae3301 100644 --- a/tests/io/test_crosssection.py +++ b/tests/io/test_crosssection.py @@ -311,12 +311,38 @@ def test_crossdef_model_from_file(self): dict(branchid="", chainage=None, x=None, y=None), id="All Empty", ), + pytest.param( + dict(branchid="some_branchid", chainage=None, x=None, y=None), + id="Only branchid given", + ), + pytest.param( + dict(branchid=None, chainage=1.0, x=None, y=None), + id="Only chainage given", + ), + pytest.param( + dict(branchid=None, chainage=None, x=[1, 2, 3], y=None), + id="Only x given", + ), + pytest.param( + dict(branchid=None, chainage=None, x=None, y=[1, 2, 3]), + id="Only y given", + ), + pytest.param( + dict(branchid="some_branchid", chainage=1.0, x=[1, 2, 3], y=None), + id="branchid and chainage given, but with something else", + ), + pytest.param( + dict(branchid="some_branchid", chainage=None, x=[1, 2, 3], y=[1, 2, 3]), + id="x and y given, but with something else", + ), ], ) - def test_given_no_values_raises_valueerror(self, dict_values: dict): + def test_wrong_values_raises_valueerror(self, dict_values: dict): with pytest.raises(ValueError) as exc_err: CrossSection._location_validator(values=dict_values) - assert str(exc_err.value) == "x should be given." + assert ( + str(exc_err.value) == "branchId and chainage or x and y should be provided" + ) def test_given_valid_coordinates(self): test_dict = dict( @@ -327,16 +353,3 @@ def test_given_valid_coordinates(self): ) return_value = CrossSection._location_validator(test_dict) assert return_value == test_dict - - def test_given_branchid_and_no_chainage_raises_valueerror(self): - with pytest.raises(ValueError) as exc_err: - CrossSection._location_validator( - dict( - branchid="aBranchId", - chainage=None, - ) - ) - assert ( - str(exc_err.value) - == "Chainage should be provided when branchId is specified." - ) diff --git a/tests/io/test_ext.py b/tests/io/test_ext.py index fe69ea2bb..50f4fa963 100644 --- a/tests/io/test_ext.py +++ b/tests/io/test_ext.py @@ -19,6 +19,8 @@ class TestLateral: """Class to test all methods contained in the hydrolib.core.io.ext.models.Lateral class""" + location_error: str = "nodeId or branchId and chainage or xCoordinates, yCoordinates and numCoordinates should be provided" + class TestValidateCoordinates: """ Class to test the paradigms for validate_coordinates. @@ -28,8 +30,6 @@ def _create_valid_lateral_values(self) -> Dict: values = dict( id="randomId", name="randomName", - branchid="randomBranchName", - chainage=1.234, numcoordinates=2, xcoordinates=[1.1, 2.2], ycoordinates=[1.1, 2.2], @@ -123,11 +123,11 @@ class TestValidateLocationTypeDependencies: "dict_values", [ pytest.param( - dict(nodeid=None, branch_id=None, n_coords=None, chainage=None), + dict(nodeid=None, branchid=None, chainage=None), id="All None", ), pytest.param( - dict(nodeid="", branch_id="", n_coords=0, chainage=None), + dict(nodeid="", branchid="", chainage=None), id="All Empty", ), ], @@ -135,10 +135,7 @@ class TestValidateLocationTypeDependencies: def test_given_no_values_raises_valueerror(self, dict_values: dict): with pytest.raises(ValueError) as exc_err: Lateral._location_validator(values=dict_values) - assert ( - str(exc_err.value) - == "Either nodeId, branchId (with chainage) or numCoordinates with xCoordinates and yCoordinates are required." - ) + assert str(exc_err.value) == TestModels.TestLateral.location_error @pytest.mark.parametrize( "missing_coordinates", [("xCoordinates"), ("yCoordinates")] @@ -157,7 +154,7 @@ def test_given_numcoords_but_missing_coordinates( test_dict[missing_coordinates.lower()] = None with pytest.raises(ValueError) as exc_error: Lateral._location_validator(test_dict) - assert str(exc_error.value) == f"{missing_coordinates} should be given." + assert str(exc_error.value) == TestModels.TestLateral.location_error def test_given_numcoordinates_and_valid_coordinates(self): test_dict = dict( @@ -180,10 +177,7 @@ def test_given_branchid_and_no_chainage_raises_valueerror(self): chainage=None, ) ) - assert ( - str(exc_err.value) - == "Chainage should be provided when branchId is specified." - ) + assert str(exc_err.value) == TestModels.TestLateral.location_error @pytest.mark.parametrize( "dict_values", @@ -199,17 +193,13 @@ def test_given_1d_args_and_location_type_other_then_raises_valueerror( self, dict_values: dict ): test_values = dict( - numcoordinates=2, - xcoordinates=[42, 24], - ycoordinates=[24, 42], locationtype="wrongType", ) test_dict = {**dict_values, **test_values} with pytest.raises(ValueError) as exc_err: Lateral._location_validator(test_dict) assert ( - str(exc_err.value) - == "locationType should be 1d when nodeId (or branchId and chainage) is specified." + str(exc_err.value) == "locationType should be 1d but was wrongType" ) @pytest.mark.parametrize( @@ -224,9 +214,6 @@ def test_given_1d_args_and_location_type_other_then_raises_valueerror( ) def test_given_1d_args_and_1d_location_type(self, dict_values: dict): test_values = dict( - numcoordinates=2, - xcoordinates=[42, 24], - ycoordinates=[24, 42], locationtype="1d", ) test_dict = {**dict_values, **test_values} @@ -277,7 +264,7 @@ def test_given_coordinates_but_no_numcoordinates_raises( ycoordinates=y_coord, ) - expected_error_mssg = "When using coordinates, the fields numCoordinates, xCoordinates and yCoordinates should be given." + expected_error_mssg = TestModels.TestLateral.location_error assert expected_error_mssg in str(exc_mssg.value) @pytest.mark.parametrize( @@ -314,8 +301,8 @@ def test_given_partial_coordinates_raises(self, missing_coord: str): lateral_dict[missing_coord.lower()] = None with pytest.raises(ValidationError) as exc_mssg: Lateral(**lateral_dict) - - assert f"{missing_coord} should be given." in str(exc_mssg.value) + expected_error_mssg = TestModels.TestLateral.location_error + assert expected_error_mssg in str(exc_mssg.value) def test_given_unknown_locationtype_raises(self): with pytest.raises(ValidationError) as exc_mssg: @@ -339,10 +326,6 @@ def test_given_unknown_locationtype_raises(self): dict(branchid="aBranchId", chainage=42), id="branchid + chainage given.", ), - pytest.param( - dict(nodeid="aNodeId", branchid="aBranchId", chainage=42), - id="all given.", - ), pytest.param( dict(nodeid="", branchid="aBranchId", chainage=42), id="Empty nodeid.", @@ -356,9 +339,6 @@ def test_given_valid_location_args_constructs_lateral( default_values = dict( id="42", discharge=1.23, - numcoordinates=2, - xcoordinates=[42, 24], - ycoordinates=[24, 42], locationtype="1d", ) test_dict = {**default_values, **location_values} diff --git a/tests/io/test_obscrosssection.py b/tests/io/test_obscrosssection.py index 80998d125..d21c7ab01 100644 --- a/tests/io/test_obscrosssection.py +++ b/tests/io/test_obscrosssection.py @@ -33,14 +33,6 @@ class TestObservationCrossSection: True, id="Using branchId without specifying numCoordinates should validate.", ), - pytest.param( - True, - 2, - [1.1, 2.2], - [1.1, 2.2], - True, - id="Using branchId while also specifying numCoordinates should validate.", - ), pytest.param( True, 2,