diff --git a/config/dns.py b/config/dns.py new file mode 100644 index 0000000000..1b0ba6cb2d --- /dev/null +++ b/config/dns.py @@ -0,0 +1,96 @@ + +import click +from swsscommon.swsscommon import ConfigDBConnector +from .validated_config_db_connector import ValidatedConfigDBConnector +import ipaddress + + +ADHOC_VALIDATION = True +NAMESERVERS_MAX_NUM = 3 + + +def to_ip_address(address): + """Check if the given IP address is valid""" + try: + ip = ipaddress.ip_address(address) + + if ADHOC_VALIDATION: + if ip.is_reserved or ip.is_multicast or ip.is_loopback: + return + + invalid_ips = [ + ipaddress.IPv4Address('0.0.0.0'), + ipaddress.IPv4Address('255.255.255.255'), + ipaddress.IPv6Address("0::0"), + ipaddress.IPv6Address("0::1") + ] + if ip in invalid_ips: + return + + return ip + except Exception: + return + + +def get_nameservers(db): + nameservers = db.get_table('DNS_NAMESERVER') + return [ipaddress.ip_address(ip) for ip in nameservers] + + +# 'dns' group ('config dns ...') +@click.group() +@click.pass_context +def dns(ctx): + """Static DNS configuration""" + config_db = ValidatedConfigDBConnector(ConfigDBConnector()) + config_db.connect() + ctx.obj = {'db': config_db} + + +# dns nameserver config +@dns.group('nameserver') +@click.pass_context +def nameserver(ctx): + """Static DNS nameservers configuration""" + pass + + +# dns nameserver add +@nameserver.command('add') +@click.argument('ip_address_str', metavar='', required=True) +@click.pass_context +def add_dns_nameserver(ctx, ip_address_str): + """Add static DNS nameserver entry""" + ip_address = to_ip_address(ip_address_str) + if not ip_address: + ctx.fail(f"{ip_address_str} invalid nameserver ip address") + + db = ctx.obj['db'] + + nameservers = get_nameservers(db) + if ip_address in nameservers: + ctx.fail(f"{ip_address} nameserver is already configured") + + if len(nameservers) >= NAMESERVERS_MAX_NUM: + ctx.fail(f"The maximum number ({NAMESERVERS_MAX_NUM}) of nameservers exceeded.") + + db.set_entry('DNS_NAMESERVER', ip_address, {}) + +# dns nameserver delete +@nameserver.command('del') +@click.argument('ip_address_str', metavar='', required=True) +@click.pass_context +def del_dns_nameserver(ctx, ip_address_str): + """Delete static DNS nameserver entry""" + + ip_address = to_ip_address(ip_address_str) + if not ip_address: + ctx.fail(f"{ip_address_str} invalid nameserver ip address") + + db = ctx.obj['db'] + + nameservers = get_nameservers(db) + if ip_address not in nameservers: + ctx.fail(f"DNS nameserver {ip_address} is not configured") + + db.set_entry('DNS_NAMESERVER', ip_address, None) diff --git a/config/main.py b/config/main.py index f6bec33f8f..8547592e0a 100644 --- a/config/main.py +++ b/config/main.py @@ -54,6 +54,7 @@ from .config_mgmt import ConfigMgmtDPB, ConfigMgmt from . import mclag from . import syslog +from . import dns # mock masic APIs for unit test try: @@ -1200,6 +1201,9 @@ def config(ctx): # syslog module config.add_command(syslog.syslog) +# DNS module +config.add_command(dns.dns) + @config.command() @click.option('-y', '--yes', is_flag=True, callback=_abort_if_false, expose_value=False, prompt='Existing files will be overwritten, continue?') diff --git a/show/dns.py b/show/dns.py new file mode 100644 index 0000000000..3aea482438 --- /dev/null +++ b/show/dns.py @@ -0,0 +1,30 @@ +import click +import utilities_common.cli as clicommon +from natsort import natsorted +from tabulate import tabulate + +from swsscommon.swsscommon import ConfigDBConnector +from utilities_common.cli import pass_db + + +# 'dns' group ("show dns ...") +@click.group(cls=clicommon.AliasedGroup) +@click.pass_context +def dns(ctx): + """Show details of the static DNS configuration """ + config_db = ConfigDBConnector() + config_db.connect() + ctx.obj = {'db': config_db} + + +# 'nameserver' subcommand ("show dns nameserver") +@dns.command() +@click.pass_context +def nameserver(ctx): + """ Show static DNS configuration """ + header = ["Nameserver"] + db = ctx.obj['db'] + + nameservers = db.get_table('DNS_NAMESERVER') + + click.echo(tabulate([(ns,) for ns in nameservers.keys()], header, tablefmt='simple', stralign='right')) diff --git a/show/main.py b/show/main.py index 7f79cd4779..095a272e51 100755 --- a/show/main.py +++ b/show/main.py @@ -63,6 +63,7 @@ from . import warm_restart from . import plugins from . import syslog +from . import dns # Global Variables PLATFORM_JSON = 'platform.json' @@ -289,6 +290,7 @@ def cli(ctx): cli.add_command(vxlan.vxlan) cli.add_command(system_health.system_health) cli.add_command(warm_restart.warm_restart) +cli.add_command(dns.dns) # syslog module cli.add_command(syslog.syslog) diff --git a/tests/dns_test.py b/tests/dns_test.py new file mode 100644 index 0000000000..00b04ca98b --- /dev/null +++ b/tests/dns_test.py @@ -0,0 +1,193 @@ +import os +import pytest + +from click.testing import CliRunner + +import config.main as config +import show.main as show +from utilities_common.db import Db + +test_path = os.path.dirname(os.path.abspath(__file__)) + +dns_show_nameservers_header = """\ + Nameserver +------------ +""" + +dns_show_nameservers = """\ + Nameserver +-------------------- + 1.1.1.1 +2001:4860:4860::8888 +""" + +class TestDns(object): + + valid_nameservers = ( + ("1.1.1.1",), + ("1.1.1.1", "8.8.8.8", "10.10.10.10",), + ("1.1.1.1", "2001:4860:4860::8888"), + ("2001:4860:4860::8888", "2001:4860:4860::8844", "2001:4860:4860::8800") + ) + + invalid_nameservers = ( + "0.0.0.0", + "255.255.255.255", + "224.0.0.0", + "0::0", + "0::1", + "1.1.1.x", + "2001:4860:4860.8888", + "ff02::1" + ) + + config_dns_ns_add = config.config.commands["dns"].commands["nameserver"].commands["add"] + config_dns_ns_del = config.config.commands["dns"].commands["nameserver"].commands["del"] + show_dns_ns = show.cli.commands["dns"].commands["nameserver"] + + @classmethod + def setup_class(cls): + print("SETUP") + os.environ["UTILITIES_UNIT_TESTING"] = "1" + + @classmethod + def teardown_class(cls): + os.environ['UTILITIES_UNIT_TESTING'] = "0" + print("TEARDOWN") + + @pytest.mark.parametrize('nameservers', valid_nameservers) + def test_dns_config_nameserver_add_del_with_valid_ip_addresses(self, nameservers): + db = Db() + runner = CliRunner() + obj = {'db': db.cfgdb} + + for ip in nameservers: + # config dns nameserver add + result = runner.invoke(self.config_dns_ns_add, [ip], obj=obj) + print(result.exit_code, result.output) + assert result.exit_code == 0 + assert ip in db.cfgdb.get_table('DNS_NAMESERVER') + + for ip in nameservers: + # config dns nameserver del + result = runner.invoke(self.config_dns_ns_del, [ip], obj=obj) + print(result.exit_code, result.output) + assert result.exit_code == 0 + assert ip not in db.cfgdb.get_table('DNS_NAMESERVER') + + @pytest.mark.parametrize('nameserver', invalid_nameservers) + def test_dns_config_nameserver_add_del_with_invalid_ip_addresses(self, nameserver): + db = Db() + runner = CliRunner() + obj = {'db': db.cfgdb} + + # config dns nameserver add + result = runner.invoke(self.config_dns_ns_add, [nameserver], obj=obj) + print(result.exit_code, result.output) + assert result.exit_code != 0 + assert "invalid nameserver ip address" in result.output + + # config dns nameserver del + result = runner.invoke(self.config_dns_ns_del, [nameserver], obj=obj) + print(result.exit_code, result.output) + assert result.exit_code != 0 + assert "invalid nameserver ip address" in result.output + + @pytest.mark.parametrize('nameservers', valid_nameservers) + def test_dns_config_nameserver_add_existing_ip(self, nameservers): + db = Db() + runner = CliRunner() + obj = {'db': db.cfgdb} + + for ip in nameservers: + # config dns nameserver add + result = runner.invoke(self.config_dns_ns_add, [ip], obj=obj) + print(result.exit_code, result.output) + assert result.exit_code == 0 + assert ip in db.cfgdb.get_table('DNS_NAMESERVER') + + # Execute command once more + result = runner.invoke(self.config_dns_ns_add, [ip], obj=obj) + print(result.exit_code, result.output) + assert result.exit_code != 0 + assert "nameserver is already configured" in result.output + + # config dns nameserver del + result = runner.invoke(self.config_dns_ns_del, [ip], obj=obj) + print(result.exit_code, result.output) + assert result.exit_code == 0 + + @pytest.mark.parametrize('nameservers', valid_nameservers) + def test_dns_config_nameserver_del_unexisting_ip(self, nameservers): + db = Db() + runner = CliRunner() + obj = {'db': db.cfgdb} + + for ip in nameservers: + # config dns nameserver del + result = runner.invoke(self.config_dns_ns_del, [ip], obj=obj) + print(result.exit_code, result.output) + assert result.exit_code != 0 + assert "is not configured" in result.output + + def test_dns_config_nameserver_add_max_number(self): + db = Db() + runner = CliRunner() + obj = {'db': db.cfgdb} + + nameservers = ("1.1.1.1", "2.2.2.2", "3.3.3.3") + for ip in nameservers: + # config dns nameserver add + result = runner.invoke(self.config_dns_ns_add, [ip], obj=obj) + print(result.exit_code, result.output) + assert result.exit_code == 0 + + # config dns nameserver add + result = runner.invoke(self.config_dns_ns_add, ["4.4.4.4"], obj=obj) + print(result.exit_code, result.output) + assert result.exit_code != 0 + assert "nameservers exceeded" in result.output + + for ip in nameservers: + # config dns nameserver del + result = runner.invoke(self.config_dns_ns_del, [ip], obj=obj) + print(result.exit_code, result.output) + assert result.exit_code == 0 + + def test_dns_show_nameserver_empty_table(self): + db = Db() + runner = CliRunner() + obj = {'db': db.cfgdb} + + # show dns nameserver + result = runner.invoke(self.show_dns_ns, [], obj=obj) + print(result.exit_code, result.output) + assert result.exit_code == 0 + assert result.output == dns_show_nameservers_header + + def test_dns_show_nameserver(self): + db = Db() + runner = CliRunner() + obj = {'db': db.cfgdb} + + nameservers = ("1.1.1.1", "2001:4860:4860::8888") + + for ip in nameservers: + # config dns nameserver add + result = runner.invoke(self.config_dns_ns_add, [ip], obj=obj) + print(result.exit_code, result.output) + assert result.exit_code == 0 + assert ip in db.cfgdb.get_table('DNS_NAMESERVER') + + # show dns nameserver + result = runner.invoke(self.show_dns_ns, [], obj=obj) + print(result.exit_code, result.output) + assert result.exit_code == 0 + assert result.output == dns_show_nameservers + + for ip in nameservers: + # config dns nameserver del + result = runner.invoke(self.config_dns_ns_del, [ip], obj=obj) + print(result.exit_code, result.output) + assert result.exit_code == 0 + assert ip not in db.cfgdb.get_table('DNS_NAMESERVER')