diff --git a/src/nethsec/firewall/__init__.py b/src/nethsec/firewall/__init__.py index 891f0593..1eb7517f 100644 --- a/src/nethsec/firewall/__init__.py +++ b/src/nethsec/firewall/__init__.py @@ -195,6 +195,7 @@ def add_trusted_zone(uci, name, networks = [], link = ""): - be able to access lan and wan zone - be accessible from lan zone + If a zone with the same name already exists, do not recreate it. Changes are saved to staging area. Arguments: @@ -211,6 +212,11 @@ def add_trusted_zone(uci, name, networks = [], link = ""): if len(name) > 12: return None, None + # avoid duplicated zones + zones = utils.get_all_by_type(uci, 'firewall', 'zone') + for z in zones: + if zones[z].get("name", "") == name: + return None, None forwardings = list() zname = utils.get_random_id() name = utils.sanitize(name) diff --git a/tests/test_firewall.py b/tests/test_firewall.py index b84db979..3209e3e1 100644 --- a/tests/test_firewall.py +++ b/tests/test_firewall.py @@ -302,6 +302,19 @@ def test_add_trusted_zone(tmp_path): assert u.get("firewall", forwardings[1], 'ns_link') == link assert u.get("firewall", forwardings[2], 'ns_link') == link +def test_duplicated_add_trusted_zone(tmp_path): + u = _setup_db(tmp_path) + (zone, forwardings) = firewall.add_trusted_zone(u, 'mytrusted') + assert zone is None + assert forwardings is None + + trusted = 0 + for s in u.get_all('firewall'): + if u.get('firewall', s) == 'forwarding': + if u.get('firewall', s, 'src', default='') == "mytrusted" and u.get('firewall', s, 'dest', default='') == "lan": + trusted = trusted + 1 + assert trusted == 1 + def test_add_trusted_zone_with_networks(tmp_path): u = _setup_db(tmp_path) interface = firewall.add_vpn_interface(u, 'testvpn2', 'tuntest2') @@ -394,7 +407,7 @@ def test_get_all_linked(tmp_path): sections = firewall.add_template_service_group(u, "ns_web_secure", "blue", "yellow", link=link) rule = firewall.add_service(u, "my_service", "443", "tcp", link=link) interface = firewall.add_vpn_interface(u, 'p2p', 'ppp10', link=link) - (zone, forwardings) = firewall.add_trusted_zone(u, 'mylinked', link=link) + (zone, forwardings) = firewall.add_trusted_zone(u, 'mylinked2', link=link) linked = firewall.get_all_linked(u, link) for s in sections: assert s in linked['firewall'] @@ -411,7 +424,7 @@ def test_disable_linked_rules(tmp_path): sections = firewall.add_template_service_group(u, "ns_web_secure", "blue", "yellow", link=link) rule = firewall.add_service(u, "my_service", "443", "tcp", link=link) interface = firewall.add_vpn_interface(u, 'p2p', 'ppp10', link=link) - (zone, forwardings) = firewall.add_trusted_zone(u, 'mylinked', link=link) + (zone, forwardings) = firewall.add_trusted_zone(u, 'mylinked4', link=link) disabled = firewall.disable_linked_rules(u, link) for s in sections: assert u.get("firewall", s, "enabled") == "0" @@ -429,7 +442,7 @@ def test_delete_linked_sections(tmp_path): sections = firewall.add_template_service_group(u, "ns_web_secure", "blue", "yellow", link=link) rule = firewall.add_service(u, "my_service", "443", "tcp", link=link) interface = firewall.add_vpn_interface(u, 'p2p', 'ppp10', link=link) - (zone, forwardings) = firewall.add_trusted_zone(u, 'mylinked', link=link) + (zone, forwardings) = firewall.add_trusted_zone(u, 'mylinked3', link=link) deleted = firewall.delete_linked_sections(u, link) assert len(deleted) > 0 with pytest.raises(UciExceptionNotFound):