Skip to content

Commit

Permalink
[dns] Implement config and show commands for static DNS.
Browse files Browse the repository at this point in the history
Implement unit tests for all added commands.
Coverage for config/dns.py : 94%
Coverage for show/dns.py : 86%
  • Loading branch information
oleksandrivantsiv committed Mar 14, 2023
1 parent 6f84aae commit c2fd9b9
Show file tree
Hide file tree
Showing 5 changed files with 321 additions and 0 deletions.
92 changes: 92 additions & 0 deletions config/dns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@

import click
from swsscommon.swsscommon import ConfigDBConnector
import ipaddress


NAMESERVERS_MAX_NUM = 3


def is_valid_ip_address(address):
"""Check if the given IP address is valid"""
try:
ip = ipaddress.ip_address(address)
if ip.is_reserved or ip.is_multicast or ip.is_loopback:
return

invalid_ips = [
ipaddress.IPv4Address('0.0.0.0'),
ipaddress.IPv4Address('255.255.255.255'),
ipaddress.IPv6Address("0::0"),
ipaddress.IPv6Address("0::1")
]
if ip in invalid_ips:
return

return ip
except Exception:
return


def get_nameservers(db):
nameservers = db.get_table('DNS_NAMESERVER')
return [ipaddress.ip_address(ip) for ip in nameservers]


# 'dns' group ('config dns ...')
@click.group()
@click.pass_context
def dns(ctx):
"""Static DNS configuration"""
config_db = ConfigDBConnector()
config_db.connect()
ctx.obj = {'db': config_db}


# dns nameserver config
@dns.group('nameserver')
@click.pass_context
def nameserver(ctx):
"""Static DNS namesevers configuration"""
pass


# dns nameserver add
@nameserver.command('add')
@click.argument('ip_address_str', metavar='<ip_address>', required=True)
@click.pass_context
def add_dns_nameserver(ctx, ip_address_str):
"""Add static DNS namesever entry"""
ip_address = is_valid_ip_address(ip_address_str)
if not ip_address:
ctx.fail(f"{ip_address_str} invalid nameserver ip address")

db = ctx.obj['db']

nameservers = get_nameservers(db)
if ip_address in nameservers:
ctx.fail(f"{ip_address} nameserver is already configured")

if len(nameservers) >= NAMESERVERS_MAX_NUM:
ctx.fail(f"The maximum number ({NAMESERVERS_MAX_NUM}) of nameservers exceeded.")

db.set_entry('DNS_NAMESERVER', ip_address, {})

# dns nameserver delete
@nameserver.command('del')
@click.argument('ip_address_str', metavar='<ip_address>', required=True)
@click.pass_context
def del_dns_nameserver(ctx, ip_address_str):
"""Delete static DNS nameserver entry"""

ip_address = is_valid_ip_address(ip_address_str)
if not ip_address:
ctx.fail(f"{ip_address_str} invalid nameserver ip address")

db = ctx.obj['db']

nameservers = get_nameservers(db)
if ip_address not in nameservers:
ctx.fail(f"DNS nameserver {ip_address} is not configured")

db.set_entry('DNS_NAMESERVER', ip_address, None)
4 changes: 4 additions & 0 deletions config/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from .config_mgmt import ConfigMgmtDPB, ConfigMgmt
from . import mclag
from . import syslog
from . import dns

# mock masic APIs for unit test
try:
Expand Down Expand Up @@ -1246,6 +1247,9 @@ def config(ctx):
# syslog module
config.add_command(syslog.syslog)

# DNS module
config.add_command(dns.dns)

@config.command()
@click.option('-y', '--yes', is_flag=True, callback=_abort_if_false,
expose_value=False, prompt='Existing files will be overwritten, continue?')
Expand Down
30 changes: 30 additions & 0 deletions show/dns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import click
import utilities_common.cli as clicommon
from natsort import natsorted
from tabulate import tabulate

from swsscommon.swsscommon import ConfigDBConnector
from utilities_common.cli import pass_db


# 'dns' group ("show dns ...")
@click.group(cls=clicommon.AliasedGroup)
@click.pass_context
def dns(ctx):
"""Show details of the static DNS configuration """
config_db = ConfigDBConnector()
config_db.connect()
ctx.obj = {'db': config_db}


# 'nameserver' subcommand ("show dns nameserver")
@dns.command()
@click.pass_context
def nameserver(ctx):
""" Show static DNS configuration """
header = ["Nameserver"]
db = ctx.obj['db']

nameservers = db.get_table('DNS_NAMESERVER')

click.echo(tabulate([(ns,) for ns in nameservers.keys()], header, tablefmt='simple', stralign='right'))
2 changes: 2 additions & 0 deletions show/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from . import warm_restart
from . import plugins
from . import syslog
from . import dns

# Global Variables
PLATFORM_JSON = 'platform.json'
Expand Down Expand Up @@ -289,6 +290,7 @@ def cli(ctx):
cli.add_command(vxlan.vxlan)
cli.add_command(system_health.system_health)
cli.add_command(warm_restart.warm_restart)
cli.add_command(dns.dns)

# syslog module
cli.add_command(syslog.syslog)
Expand Down
193 changes: 193 additions & 0 deletions tests/dns_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import os
import pytest

from click.testing import CliRunner

import config.main as config
import show.main as show
from utilities_common.db import Db

test_path = os.path.dirname(os.path.abspath(__file__))

dns_show_nameservers_header = """\
Nameserver
------------
"""

dns_show_nameservers = """\
Nameserver
--------------------
1.1.1.1
2001:4860:4860::8888
"""

class TestDns(object):

valid_nameservers = (
("1.1.1.1",),
("1.1.1.1", "8.8.8.8", "10.10.10.10",),
("1.1.1.1", "2001:4860:4860::8888"),
("2001:4860:4860::8888", "2001:4860:4860::8844", "2001:4860:4860::8800")
)

invalid_nameservers = (
"0.0.0.0",
"255.255.255.255",
"224.0.0.0",
"0::0",
"0::1",
"1.1.1.x",
"2001:4860:4860.8888",
"ff02::1"
)

config_dns_ns_add = config.config.commands["dns"].commands["nameserver"].commands["add"]
config_dns_ns_del = config.config.commands["dns"].commands["nameserver"].commands["del"]
show_dns_ns = show.cli.commands["dns"].commands["nameserver"]

@classmethod
def setup_class(cls):
print("SETUP")
os.environ["UTILITIES_UNIT_TESTING"] = "1"

@classmethod
def teardown_class(cls):
os.environ['UTILITIES_UNIT_TESTING'] = "0"
print("TEARDOWN")

@pytest.mark.parametrize('nameservers', valid_nameservers)
def test_dns_config_nameserver_add_del_with_valid_ip_addresses(self, nameservers):
db = Db()
runner = CliRunner()
obj = {'db': db.cfgdb}

for ip in nameservers:
# config dns nameserver add <ip>
result = runner.invoke(self.config_dns_ns_add, [ip], obj=obj)
print(result.exit_code, result.output)
assert result.exit_code == 0
assert ip in db.cfgdb.get_table('DNS_NAMESERVER')

for ip in nameservers:
# config dns nameserver del <ip>
result = runner.invoke(self.config_dns_ns_del, [ip], obj=obj)
print(result.exit_code, result.output)
assert result.exit_code == 0
assert ip not in db.cfgdb.get_table('DNS_NAMESERVER')

@pytest.mark.parametrize('nameserver', invalid_nameservers)
def test_dns_config_nameserver_add_del_with_invalid_ip_addresses(self, nameserver):
db = Db()
runner = CliRunner()
obj = {'db': db.cfgdb}

# config dns nameserver add <nameserver>
result = runner.invoke(self.config_dns_ns_add, [nameserver], obj=obj)
print(result.exit_code, result.output)
assert result.exit_code != 0
assert "invalid nameserver ip address" in result.output

# config dns nameserver del <nameserver>
result = runner.invoke(self.config_dns_ns_del, [nameserver], obj=obj)
print(result.exit_code, result.output)
assert result.exit_code != 0
assert "invalid nameserver ip address" in result.output

@pytest.mark.parametrize('nameservers', valid_nameservers)
def test_dns_config_nameserver_add_existing_ip(self, nameservers):
db = Db()
runner = CliRunner()
obj = {'db': db.cfgdb}

for ip in nameservers:
# config dns nameserver add <ip>
result = runner.invoke(self.config_dns_ns_add, [ip], obj=obj)
print(result.exit_code, result.output)
assert result.exit_code == 0
assert ip in db.cfgdb.get_table('DNS_NAMESERVER')

# Execute command once more
result = runner.invoke(self.config_dns_ns_add, [ip], obj=obj)
print(result.exit_code, result.output)
assert result.exit_code != 0
assert "nameserver is already configured" in result.output

# config dns nameserver del <ip>
result = runner.invoke(self.config_dns_ns_del, [ip], obj=obj)
print(result.exit_code, result.output)
assert result.exit_code == 0

@pytest.mark.parametrize('nameservers', valid_nameservers)
def test_dns_config_nameserver_del_unexisting_ip(self, nameservers):
db = Db()
runner = CliRunner()
obj = {'db': db.cfgdb}

for ip in nameservers:
# config dns nameserver del <ip>
result = runner.invoke(self.config_dns_ns_del, [ip], obj=obj)
print(result.exit_code, result.output)
assert result.exit_code != 0
assert "is not configured" in result.output

def test_dns_config_nameserver_add_max_number(self):
db = Db()
runner = CliRunner()
obj = {'db': db.cfgdb}

nameservers = ("1.1.1.1", "2.2.2.2", "3.3.3.3")
for ip in nameservers:
# config dns nameserver add <ip>
result = runner.invoke(self.config_dns_ns_add, [ip], obj=obj)
print(result.exit_code, result.output)
assert result.exit_code == 0

# config dns nameserver add <ip>
result = runner.invoke(self.config_dns_ns_add, ["4.4.4.4"], obj=obj)
print(result.exit_code, result.output)
assert result.exit_code != 0
assert "nameservers exceeded" in result.output

for ip in nameservers:
# config dns nameserver del <ip>
result = runner.invoke(self.config_dns_ns_del, [ip], obj=obj)
print(result.exit_code, result.output)
assert result.exit_code == 0

def test_dns_show_nameserver_empty_table(self):
db = Db()
runner = CliRunner()
obj = {'db': db.cfgdb}

# show dns nameserver
result = runner.invoke(self.show_dns_ns, [], obj=obj)
print(result.exit_code, result.output)
assert result.exit_code == 0
assert result.output == dns_show_nameservers_header

def test_dns_show_nameserver(self):
db = Db()
runner = CliRunner()
obj = {'db': db.cfgdb}

nameservers = ("1.1.1.1", "2001:4860:4860::8888")

for ip in nameservers:
# config dns nameserver add <ip>
result = runner.invoke(self.config_dns_ns_add, [ip], obj=obj)
print(result.exit_code, result.output)
assert result.exit_code == 0
assert ip in db.cfgdb.get_table('DNS_NAMESERVER')

# show dns nameserver
result = runner.invoke(self.show_dns_ns, [], obj=obj)
print(result.exit_code, result.output)
assert result.exit_code == 0
assert result.output == dns_show_nameservers

for ip in nameservers:
# config dns nameserver del <ip>
result = runner.invoke(self.config_dns_ns_del, [ip], obj=obj)
print(result.exit_code, result.output)
assert result.exit_code == 0
assert ip not in db.cfgdb.get_table('DNS_NAMESERVER')

0 comments on commit c2fd9b9

Please sign in to comment.