-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #33 from Deltares/feature/DEI-103-multiple-classif…
…ication-rule-core Feature/dei 103 multiple classification rule core
- Loading branch information
Showing
16 changed files
with
679 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -132,3 +132,4 @@ dmypy.json | |
*.yaml | ||
**/*.nc | ||
!tests/**/*.nc | ||
/examples/data_out/ |
103 changes: 103 additions & 0 deletions
103
decoimpact/business/entities/rules/classification_rule.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
""" | ||
Module for ClassificationRule class | ||
Classes: | ||
ClassificationRule | ||
""" | ||
|
||
from typing import Dict, List | ||
|
||
import xarray as _xr | ||
|
||
from decoimpact.business.entities.rules.i_multi_array_based_rule import ( | ||
IMultiArrayBasedRule, | ||
) | ||
from decoimpact.business.entities.rules.rule_base import RuleBase | ||
from decoimpact.business.entities.rules.string_parser_utils import ( | ||
read_str_comparison, | ||
str_range_to_list, | ||
type_of_classification, | ||
) | ||
from decoimpact.crosscutting.i_logger import ILogger | ||
|
||
|
||
class ClassificationRule(RuleBase, IMultiArrayBasedRule): | ||
"""Implementation for the (multiple) classification rule""" | ||
|
||
def __init__( | ||
self, | ||
name: str, | ||
input_variable_names: List[str], | ||
criteria_table: Dict[str, List], | ||
output_variable_name: str = "output", | ||
description: str = "", | ||
): | ||
super().__init__(name, input_variable_names, output_variable_name, description) | ||
self._criteria_table = criteria_table | ||
|
||
@property | ||
def criteria_table(self) -> Dict: | ||
"""Criteria property""" | ||
return self._criteria_table | ||
|
||
def execute( | ||
self, | ||
value_arrays: Dict[str, _xr.DataArray], | ||
logger: ILogger | ||
) -> _xr.DataArray: | ||
"""Determine the classification based on the table with criteria | ||
Args: | ||
values (Dict[str, float]): Dictionary holding the values | ||
for making the rule | ||
Returns: | ||
integer: classification | ||
""" | ||
|
||
# Get all the headers in the criteria_table representing a value to be checked | ||
column_names = list(self._criteria_table.keys()) | ||
column_names.remove("output") | ||
|
||
# Create an empty result_array to be filled | ||
result_array = _xr.zeros_like(value_arrays[column_names[0]]) | ||
|
||
for (row, out) in reversed(list(enumerate(self._criteria_table["output"]))): | ||
criteria_comparison = _xr.full_like(value_arrays[column_names[0]], True) | ||
for column_name in column_names: | ||
# DataArray on which the criteria needs to be checked | ||
data = value_arrays[column_name] | ||
|
||
# Retrieving criteria and applying it in correct format (number, | ||
# range or comparison) | ||
criteria = self.criteria_table[column_name][row] | ||
criteria_class = type_of_classification(criteria) | ||
|
||
comparison = True | ||
if criteria_class == "number": | ||
comparison = data == float(criteria) | ||
|
||
elif criteria_class == "range": | ||
begin, end = str_range_to_list(criteria) | ||
comparison = (data >= begin) & (data <= end) | ||
|
||
elif criteria_class == "larger": | ||
comparison_val = read_str_comparison(criteria, ">") | ||
comparison = (data > float(comparison_val)) | ||
|
||
elif criteria_class == "smaller": | ||
comparison_val = read_str_comparison(criteria, "<") | ||
comparison = (data < float(comparison_val)) | ||
|
||
# Criteria_comparison == 1 -> to check where the value is True | ||
criteria_comparison = _xr.where( | ||
comparison & (criteria_comparison == 1), | ||
True, | ||
False | ||
) | ||
# For the first row set the default to None, for all the other | ||
# rows use the already created dataarray | ||
default_val = None | ||
if (row != len(self._criteria_table["output"])-1): | ||
default_val = result_array | ||
|
||
result_array = _xr.where(criteria_comparison, out, default_val) | ||
return result_array |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
""" | ||
Module for parser strings | ||
""" | ||
|
||
|
||
def str_range_to_list(range_string: str): | ||
"""Convert a string with a range in the form "x:y" of floats to | ||
two elements (begin and end of range). | ||
Args: | ||
range_string (str): String to be converted to a range (begin and end) | ||
Raises: | ||
ValueError: If the string is not properly defined | ||
Returns: | ||
floats: Return the begin and end value of the range | ||
""" | ||
range_string = range_string.strip() | ||
try: | ||
begin, end = range_string.split(":") | ||
return float(begin), float(end) | ||
except ValueError: | ||
raise ValueError(f'Input "{range_string}" is not a valid range') | ||
|
||
|
||
def read_str_comparison(compare_str: str, operator: str): | ||
"""Read the string of a comparison (with specified operator) and | ||
validate if this is in the correct format (<operator><number>, eg: >100) | ||
Args: | ||
compare_str (str): String to be checked | ||
operator (str): Operator to split on | ||
Raises: | ||
ValueError: If the compared value is not a number | ||
Returns: | ||
float: The number from the comparison string | ||
""" | ||
compare_str = compare_str.strip() | ||
try: | ||
compare_list = compare_str.split(operator) | ||
if (len(compare_list) != 2): | ||
raise IndexError( | ||
f'Input "{compare_str}" is not a valid comparison ' | ||
f'with operator: {operator}' | ||
) | ||
compare_val = compare_list[1] | ||
return float(compare_val) | ||
except ValueError: | ||
raise ValueError( | ||
f'Input "{compare_str}" is not a valid comparison with operator: {operator}' | ||
) | ||
|
||
|
||
def type_of_classification(class_val) -> str: | ||
"""Determine which type of classification is required: number, range, or | ||
NA (not applicable) | ||
Args: | ||
class_val (_type_): String to classify | ||
Raises: | ||
ValueError: Error when the string is not properly defined | ||
Returns: | ||
str: Type of classification | ||
""" | ||
|
||
if type(class_val) == int or type(class_val) == float: | ||
return "number" | ||
if type(class_val) == str: | ||
class_val = class_val.strip() | ||
if class_val in ("-", ""): | ||
return "NA" | ||
if ":" in class_val: | ||
str_range_to_list(class_val) | ||
return "range" | ||
if ">" in class_val: | ||
read_str_comparison(class_val, ">") | ||
return "larger" | ||
if "<" in class_val: | ||
read_str_comparison(class_val, "<") | ||
return "smaller" | ||
try: | ||
float(class_val) | ||
return "number" | ||
except ValueError: | ||
raise ValueError(f"No valid criteria is given: {class_val}") | ||
|
||
raise ValueError(f"No valid criteria is given: {class_val}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
""" | ||
Module for IClassificationRuleData interface | ||
Interfaces: | ||
IClassificationRuleData | ||
""" | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import Dict, List | ||
|
||
from decoimpact.data.api.i_rule_data import IRuleData | ||
|
||
|
||
class IClassificationRuleData(IRuleData, ABC): | ||
"""Data for a combine Results Rule""" | ||
|
||
@property | ||
@abstractmethod | ||
def input_variable_names(self) -> List[str]: | ||
"""Name of the input variable""" | ||
|
||
@property | ||
@abstractmethod | ||
def criteria_table(self) -> Dict[str, List]: | ||
"""Property for the formula""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
""" | ||
Module for (multiple) ClassificationRule class | ||
Classes: | ||
(multiple) ClassificationRuleData | ||
""" | ||
|
||
from typing import Dict, List | ||
|
||
from decoimpact.data.api.i_classification_rule_data import IClassificationRuleData | ||
from decoimpact.data.entities.rule_data import RuleData | ||
|
||
|
||
class ClassificationRuleData(IClassificationRuleData, RuleData): | ||
"""Class for storing data related to formula rule""" | ||
|
||
def __init__( | ||
self, | ||
name: str, | ||
input_variable_names: List[str], | ||
criteria_table: Dict[str, List], | ||
output_variable: str = "output", | ||
description: str = "", | ||
): | ||
super().__init__(name, output_variable, description) | ||
self._input_variable_names = input_variable_names | ||
self._criteria_table = criteria_table | ||
|
||
@property | ||
def criteria_table(self) -> Dict: | ||
"""Criteria property""" | ||
return self._criteria_table | ||
|
||
@property | ||
def input_variable_names(self) -> List[str]: | ||
return self._input_variable_names |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
""" | ||
Module for ParserClassificationRule class | ||
Classes: | ||
ParserClassificationRule | ||
""" | ||
from typing import Any, Dict | ||
|
||
from decoimpact.crosscutting.i_logger import ILogger | ||
from decoimpact.data.api.i_rule_data import IRuleData | ||
from decoimpact.data.dictionary_utils import convert_table_element, get_dict_element | ||
from decoimpact.data.entities.classification_rule_data import ClassificationRuleData | ||
from decoimpact.data.parsers.i_parser_rule_base import IParserRuleBase | ||
from decoimpact.data.parsers.validation_utils import validate_table_with_input | ||
|
||
|
||
class ParserClassificationRule(IParserRuleBase): | ||
|
||
"""Class for creating a ClassificationRuleData""" | ||
|
||
@property | ||
def rule_type_name(self) -> str: | ||
"""Type name for the rule""" | ||
return "classification_rule" | ||
|
||
def parse_dict(self, dictionary: Dict[str, Any], logger: ILogger) -> IRuleData: | ||
"""Parses the provided dictionary to a IRuleData | ||
Args: | ||
dictionary (Dict[str, Any]): Dictionary holding the values | ||
for making the rule | ||
Returns: | ||
RuleBase: Rule based on the provided data | ||
""" | ||
name = get_dict_element("name", dictionary) | ||
input_variable_names = get_dict_element("input_variables", dictionary) | ||
criteria_table_list = get_dict_element("criteria_table", dictionary) | ||
criteria_table = convert_table_element(criteria_table_list) | ||
|
||
validate_table_with_input(criteria_table, input_variable_names) | ||
# validate_table_value_formats() | ||
|
||
output_variable_name = get_dict_element("output_variable", dictionary) | ||
description = get_dict_element("description", dictionary) | ||
|
||
return ClassificationRuleData( | ||
name, | ||
input_variable_names, | ||
criteria_table, | ||
output_variable_name, | ||
description | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.