diff --git a/src/nethsec/mwan/__init__.py b/src/nethsec/mwan/__init__.py index bae19309..fefe4873 100644 --- a/src/nethsec/mwan/__init__.py +++ b/src/nethsec/mwan/__init__.py @@ -123,7 +123,7 @@ def __store_member(e_uci: EUci, interface_name: str, metric: int, weight: int) - def store_rule(e_uci: EUci, name: str, policy: str, protocol: str = None, source_address: str = None, source_port: str = None, - destination_address: str = None, destination_port: str = None) -> str: + destination_address: str = None, destination_port: str = None, sticky: bool = False) -> str: """ Stores a rule for mwan3 @@ -136,6 +136,7 @@ def store_rule(e_uci: EUci, name: str, policy: str, protocol: str = None, source_port: source ports to match or range destination_address: destination addresses to match destination_port: destination ports to match or range + sticky: whether to use sticky connections Returns: name of the rule created @@ -156,7 +157,10 @@ def store_rule(e_uci: EUci, name: str, policy: str, protocol: str = None, e_uci.set('mwan3', rule_config_name, 'rule') e_uci.set('mwan3', rule_config_name, 'label', name) e_uci.set('mwan3', rule_config_name, 'use_policy', policy) - e_uci.set('mwan3', rule_config_name, 'sticky', '0') + # test if sticky is True or False, if not raise an error + if type(sticky) is not bool: + raise ValidationError('sticky', 'sticky_not_valid', sticky) + e_uci.set('mwan3', rule_config_name, 'sticky', sticky) if protocol is not None: e_uci.set('mwan3', rule_config_name, 'proto', protocol) if source_address is not None: @@ -420,6 +424,8 @@ def index_rules(e_uci: EUci) -> list[dict]: rule_data['destination_address'] = rule_value['dest_ip'] if 'dest_port' in rule_value: rule_data['destination_port'] = rule_value['dest_port'] + if 'sticky' in rule_value: + rule_data['sticky'] = rule_value['sticky'] == '1' data.append(rule_data) return data @@ -487,7 +493,7 @@ def delete_rule(e_uci: EUci, name: str): def edit_rule(e_uci: EUci, name: str, policy: str, label: str, protocol: str = None, source_address: str = None, source_port: str = None, - destination_address: str = None, destination_port: str = None): + destination_address: str = None, destination_port: str = None, sticky: bool = False): """ Edits a mwan3 rule. @@ -501,6 +507,7 @@ def edit_rule(e_uci: EUci, name: str, policy: str, label: str, protocol: str = N source_port: port or port range destination_address: CIDR notation of destination address destination_port: port or port range + sticky: whether to use sticky connections Raises: ValidationError: if name is not valid or policy is not valid @@ -512,6 +519,10 @@ def edit_rule(e_uci: EUci, name: str, policy: str, label: str, protocol: str = N raise ValidationError('policy', 'invalid', policy) e_uci.set('mwan3', name, 'use_policy', policy) e_uci.set('mwan3', name, 'label', label) + # test if sticky is True of False, if not raise an error + if type(sticky) is not bool: + raise ValidationError('sticky', 'sticky_not_valid', sticky) + e_uci.set('mwan3', name, 'sticky', sticky) if protocol is not None: e_uci.set('mwan3', name, 'proto', protocol) if protocol != 'tcp' and protocol != 'udp': diff --git a/tests/test_mwan.py b/tests/test_mwan.py index 8265c694..c1569d9b 100644 --- a/tests/test_mwan.py +++ b/tests/test_mwan.py @@ -295,7 +295,7 @@ def test_store_rule(e_uci, mocker): } ]) assert mwan.store_rule(e_uci, 'rule 1', 'ns_default', 'udp', '192.168.1.1/24', '1:1024', '10.0.0.2/12', - '22,443') == 'mwan3.ns_rule_1' + '22,443', True) == 'mwan3.ns_rule_1' assert e_uci.get('mwan3', 'ns_rule_1') == 'rule' assert e_uci.get('mwan3', 'ns_rule_1', 'label') == 'rule 1' assert e_uci.get('mwan3', 'ns_rule_1', 'use_policy') == 'ns_default' @@ -304,7 +304,7 @@ def test_store_rule(e_uci, mocker): assert e_uci.get('mwan3', 'ns_rule_1', 'src_port') == '1:1024' assert e_uci.get('mwan3', 'ns_rule_1', 'dest_ip') == '10.0.0.2/12' assert e_uci.get('mwan3', 'ns_rule_1', 'dest_port') == '22,443' - assert e_uci.get('mwan3', 'ns_rule_1', 'sticky') == '0' + assert e_uci.get('mwan3', 'ns_rule_1', 'sticky') == '1' def test_unique_rule(e_uci, mocker): @@ -431,7 +431,8 @@ def test_index_rules(e_uci, mocker): 'policy': { 'name': 'ns_default', 'label': 'default', - } + }, + "sticky": False, } assert index[1] == { 'name': 'ns_rule_1', @@ -439,7 +440,8 @@ def test_index_rules(e_uci, mocker): 'policy': { 'name': 'ns_default', 'label': 'default', - } + }, + "sticky": False, }