diff --git a/.github/workflows/commit.yaml b/.github/workflows/commit.yaml index 81f3e6a27..03741ece4 100644 --- a/.github/workflows/commit.yaml +++ b/.github/workflows/commit.yaml @@ -1,4 +1,4 @@ ---- +--- name: build on: [push, pull_request] jobs: @@ -33,10 +33,14 @@ jobs: run: | pylama . + - name: Run type checker + run: | + mypy -p napalm --config-file mypy.ini + - name: Run Tests run: | py.test --cov=napalm --cov-report term-missing -vs --pylama - + build_docs: needs: std_tests runs-on: ubuntu-latest diff --git a/mypy.ini b/mypy.ini index 52d4dcba2..8a0948964 100644 --- a/mypy.ini +++ b/mypy.ini @@ -5,6 +5,21 @@ allow_redefinition = True [mypy-napalm.base.*] disallow_untyped_defs = True +[mypy-napalm.nxos.*] +disallow_untyped_defs = True + +[mypy-napalm.eos.*] +ignore_errors = True + +[mypy-napalm.ios.*] +ignore_errors = True + +[mypy-napalm.nxapi_plumbing.*] +disallow_untyped_defs = True + +[mypy-napalm.base.clitools.*] +ignore_errors = True + [mypy-napalm.base.test.*] ignore_errors = True diff --git a/napalm/base/__init__.py b/napalm/base/__init__.py index fe677a7d7..934cefe77 100644 --- a/napalm/base/__init__.py +++ b/napalm/base/__init__.py @@ -19,6 +19,8 @@ import importlib # NAPALM base +from typing import Type + from napalm.base.base import NetworkDriver from napalm.base.exceptions import ModuleImportError from napalm.base.mock import MockDriver @@ -29,7 +31,7 @@ ] -def get_network_driver(name, prepend=True): +def get_network_driver(name: str, prepend: bool = True) -> Type[NetworkDriver]: """ Searches for a class derived form the base NAPALM class NetworkDriver in a specific library. The library name must repect the following pattern: napalm_[DEVICE_OS]. diff --git a/napalm/base/base.py b/napalm/base/base.py index 6cf0e2c77..c2666c7e6 100644 --- a/napalm/base/base.py +++ b/napalm/base/base.py @@ -13,6 +13,10 @@ # the License. import sys +from types import TracebackType +from typing import Optional, Dict, Type, Any, List, Union + +from typing_extensions import Literal from netmiko import ConnectHandler, NetMikoTimeoutException @@ -22,26 +26,42 @@ from napalm.base import constants as c from napalm.base import validate from napalm.base.exceptions import ConnectionException +from napalm.base.test import models +from napalm.base.test.models import NTPStats class NetworkDriver(object): - def __init__(self, hostname, username, password, timeout=60, optional_args=None): + hostname: str + username: str + password: str + timeout: int + force_no_enable: bool + use_canonical_interface: bool + + def __init__( + self, + hostname: str, + username: str, + password: str, + timeout: int = 60, + optional_args: Dict = None, + ) -> None: """ This is the base class you have to inherit from when writing your own Network Driver to manage any device. You will, in addition, have to override all the methods specified on this class. Make sure you follow the guidelines for every method and that you return the correct data. - :param hostname: (str) IP or FQDN of the device you want to connect to. - :param username: (str) Username you want to use - :param password: (str) Password - :param timeout: (int) Time in seconds to wait for the device to respond. - :param optional_args: (dict) Pass additional arguments to underlying driver + :param hostname: IP or FQDN of the device you want to connect to. + :param username: Username you want to use + :param password: Password + :param timeout: Time in seconds to wait for the device to respond. + :param optional_args: Pass additional arguments to underlying driver :return: """ raise NotImplementedError - def __enter__(self): + def __enter__(self) -> "NetworkDriver": # type: ignore try: self.open() return self @@ -52,11 +72,16 @@ def __enter__(self): else: raise - def __exit__(self, exc_type, exc_value, exc_traceback): + def __exit__( # type: ignore + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + exc_traceback: Optional[TracebackType], + ) -> Optional[Literal[False]]: self.close() if exc_type is not None and ( exc_type.__name__ not in dir(napalm.base.exceptions) - and exc_type.__name__ not in __builtins__.keys() + and exc_type.__name__ not in __builtins__.keys() # type: ignore ): epilog = ( "NAPALM didn't catch this exception. Please, fill a bugfix on " @@ -66,7 +91,7 @@ def __exit__(self, exc_type, exc_value, exc_traceback): print(epilog) return False - def __del__(self): + def __del__(self) -> None: """ This method is used to cleanup when the program is terminated suddenly. We need to make sure the connection is closed properly and the configuration DB @@ -78,7 +103,9 @@ def __del__(self): except Exception: pass - def _netmiko_open(self, device_type, netmiko_optional_args=None): + def _netmiko_open( + self, device_type: str, netmiko_optional_args: Optional[Dict] = None + ) -> ConnectHandler: """Standardized method of creating a Netmiko connection using napalm attributes.""" if netmiko_optional_args is None: netmiko_optional_args = {} @@ -104,26 +131,26 @@ def _netmiko_open(self, device_type, netmiko_optional_args=None): return self._netmiko_device - def _netmiko_close(self): + def _netmiko_close(self) -> None: """Standardized method of closing a Netmiko connection.""" if getattr(self, "_netmiko_device", None): self._netmiko_device.disconnect() self._netmiko_device = None - self.device = None + self.device: Optional[ConnectHandler] = None - def open(self): + def open(self) -> None: """ Opens a connection to the device. """ raise NotImplementedError - def close(self): + def close(self) -> None: """ Closes the connection to the device. """ raise NotImplementedError - def is_alive(self): + def is_alive(self) -> models.AliveDict: """ Returns a flag with the connection state. Depends on the nature of API used by each driver. @@ -133,7 +160,7 @@ def is_alive(self): """ raise NotImplementedError - def pre_connection_tests(self): + def pre_connection_tests(self) -> None: """ This is a helper function used by the cli tool cl_napalm_show_tech. Drivers can override this method to do some tests, show information, enable debugging, etc. @@ -141,7 +168,7 @@ def pre_connection_tests(self): """ raise NotImplementedError - def connection_tests(self): + def connection_tests(self) -> None: """ This is a helper function used by the cli tool cl_napalm_show_tech. Drivers can override this method to do some tests, show information, enable debugging, etc. @@ -149,7 +176,7 @@ def connection_tests(self): """ raise NotImplementedError - def post_connection_tests(self): + def post_connection_tests(self) -> None: """ This is a helper function used by the cli tool cl_napalm_show_tech. Drivers can override this method to do some tests, show information, enable debugging, etc. @@ -158,8 +185,12 @@ def post_connection_tests(self): raise NotImplementedError def load_template( - self, template_name, template_source=None, template_path=None, **template_vars - ): + self, + template_name: str, + template_source: Optional[str] = None, + template_path: Optional[str] = None, + **template_vars: Any + ) -> None: """ Will load a templated configuration on the device. @@ -185,7 +216,9 @@ def load_template( **template_vars ) - def load_replace_candidate(self, filename=None, config=None): + def load_replace_candidate( + self, filename: Optional[str] = None, config: Optional[str] = None + ) -> None: """ Populates the candidate configuration. You can populate it from a file or from a string. If you send both a filename and a string containing the configuration, the file takes @@ -201,7 +234,9 @@ def load_replace_candidate(self, filename=None, config=None): """ raise NotImplementedError - def load_merge_candidate(self, filename=None, config=None): + def load_merge_candidate( + self, filename: Optional[str] = None, config: Optional[str] = None + ) -> None: """ Populates the candidate configuration. You can populate it from a file or from a string. If you send both a filename and a string containing the configuration, the file takes @@ -217,7 +252,7 @@ def load_merge_candidate(self, filename=None, config=None): """ raise NotImplementedError - def compare_config(self): + def compare_config(self) -> str: """ :return: A string showing the difference between the running configuration and the \ candidate configuration. The running_config is loaded automatically just before doing the \ @@ -225,7 +260,7 @@ def compare_config(self): """ raise NotImplementedError - def commit_config(self, message="", revert_in=None): + def commit_config(self, message: str = "", revert_in: Optional[int] = None) -> None: """ Commits the changes requested by the method load_replace_candidate or load_merge_candidate. @@ -242,7 +277,7 @@ def commit_config(self, message="", revert_in=None): """ raise NotImplementedError - def confirm_commit(self): + def confirm_commit(self) -> None: """ Confirm the changes requested via commit_config when commit_confirm=True. @@ -250,19 +285,19 @@ def confirm_commit(self): """ raise NotImplementedError - def has_pending_commit(self): + def has_pending_commit(self) -> bool: """ :return Boolean indicating if a commit_config that needs confirmed is in process. """ raise NotImplementedError - def discard_config(self): + def discard_config(self) -> None: """ Discards the configuration loaded into the candidate. """ raise NotImplementedError - def rollback(self): + def rollback(self) -> None: """ If changes were made, revert changes to the original state. @@ -270,7 +305,7 @@ def rollback(self): """ raise NotImplementedError - def get_facts(self): + def get_facts(self) -> models.FactsDict: """ Returns a dictionary containing the following information: * uptime - Uptime of the device in seconds. @@ -298,7 +333,7 @@ def get_facts(self): """ raise NotImplementedError - def get_interfaces(self): + def get_interfaces(self) -> Dict[str, models.InterfaceDict]: """ Returns a dictionary of dictionaries. The keys for the first dictionary will be the \ interfaces in the devices. The inner dictionary will containing the following data for \ @@ -359,7 +394,7 @@ def get_interfaces(self): """ raise NotImplementedError - def get_lldp_neighbors(self): + def get_lldp_neighbors(self) -> Dict[str, List[models.LLDPNeighborDict]]: """ Returns a dictionary where the keys are local ports and the value is a list of \ dictionaries with the following information: @@ -405,7 +440,7 @@ def get_lldp_neighbors(self): """ raise NotImplementedError - def get_bgp_neighbors(self): + def get_bgp_neighbors(self) -> Dict[str, models.BGPStateNeighborsPerVRFDict]: """ Returns a dictionary of dictionaries. The keys for the first dictionary will be the vrf (global if no vrf). The inner dictionary will contain the following data for each vrf: @@ -462,7 +497,7 @@ def get_bgp_neighbors(self): """ raise NotImplementedError - def get_environment(self): + def get_environment(self) -> models.EnvironmentDict: """ Returns a dictionary where: @@ -484,7 +519,7 @@ def get_environment(self): """ raise NotImplementedError - def get_interfaces_counters(self): + def get_interfaces_counters(self) -> Dict[str, models.InterfaceCounterDict]: """ Returns a dictionary of dictionaries where the first key is an interface name and the inner dictionary contains the following keys: @@ -551,7 +586,9 @@ def get_interfaces_counters(self): """ raise NotImplementedError - def get_lldp_neighbors_detail(self, interface=""): + def get_lldp_neighbors_detail( + self, interface: str = "" + ) -> models.LLDPNeighborsDetailDict: """ Returns a detailed view of the LLDP neighbors as a dictionary containing lists of dictionaries for each interface. @@ -599,7 +636,9 @@ def get_lldp_neighbors_detail(self, interface=""): """ raise NotImplementedError - def get_bgp_config(self, group="", neighbor=""): + def get_bgp_config( + self, group: str = "", neighbor: str = "" + ) -> models.BPGConfigGroupDict: """ Returns a dictionary containing the BGP configuration. Can return either the whole config, either the config only for a group or neighbor. @@ -698,7 +737,7 @@ def get_bgp_config(self, group="", neighbor=""): """ raise NotImplementedError - def cli(self, commands): + def cli(self, commands: List[str]) -> Dict[str, Union[str, Dict[str, Any]]]: """ Will execute a list of commands and return the output in a dictionary format. @@ -726,7 +765,9 @@ def cli(self, commands): """ raise NotImplementedError - def get_bgp_neighbors_detail(self, neighbor_address=""): + def get_bgp_neighbors_detail( + self, neighbor_address: str = "" + ) -> Dict[str, models.PeerDetailsDict]: """ Returns a detailed view of the BGP neighbors as a dictionary of lists. @@ -821,7 +862,7 @@ def get_bgp_neighbors_detail(self, neighbor_address=""): """ raise NotImplementedError - def get_arp_table(self, vrf=""): + def get_arp_table(self, vrf: str = "") -> List[models.ARPTableDict]: """ Returns a list of dictionaries having the following set of keys: @@ -856,7 +897,7 @@ def get_arp_table(self, vrf=""): """ raise NotImplementedError - def get_ntp_peers(self): + def get_ntp_peers(self) -> Dict[str, models.NTPPeerDict]: """ Returns the NTP peers configuration as dictionary. @@ -876,7 +917,7 @@ def get_ntp_peers(self): raise NotImplementedError - def get_ntp_servers(self): + def get_ntp_servers(self) -> Dict[str, models.NTPServerDict]: """ Returns the NTP servers configuration as dictionary. @@ -896,7 +937,7 @@ def get_ntp_servers(self): raise NotImplementedError - def get_ntp_stats(self): + def get_ntp_stats(self) -> List[NTPStats]: """ Returns a list of NTP synchronization statistics. @@ -933,7 +974,7 @@ def get_ntp_stats(self): """ raise NotImplementedError - def get_interfaces_ip(self): + def get_interfaces_ip(self) -> Dict[str, models.InterfacesIPDict]: """ Returns all configured IP addresses on all interfaces as a dictionary of dictionaries. @@ -987,7 +1028,7 @@ def get_interfaces_ip(self): """ raise NotImplementedError - def get_mac_address_table(self): + def get_mac_address_table(self) -> List[models.MACAdressTable]: """ Returns a lists of dictionaries. Each dictionary represents an entry in the MAC Address @@ -1038,7 +1079,9 @@ def get_mac_address_table(self): """ raise NotImplementedError - def get_route_to(self, destination="", protocol="", longer=False): + def get_route_to( + self, destination: str = "", protocol: str = "", longer: bool = False + ) -> Dict[str, models.RouteDict]: """ Returns a dictionary of dictionaries containing details of all available routes to a @@ -1113,7 +1156,7 @@ def get_route_to(self, destination="", protocol="", longer=False): """ raise NotImplementedError - def get_snmp_information(self): + def get_snmp_information(self) -> models.SNMPDict: """ Returns a dict of dicts containing SNMP configuration. @@ -1157,7 +1200,7 @@ def get_snmp_information(self): """ raise NotImplementedError - def get_probes_config(self): + def get_probes_config(self) -> Dict[str, models.ProbeTestDict]: """ Returns a dictionary with the probes configured on the device. Probes can be either RPM on JunOS devices, either SLA on IOS-XR. Other vendors do not @@ -1195,7 +1238,7 @@ def get_probes_config(self): """ raise NotImplementedError - def get_probes_results(self): + def get_probes_results(self) -> Dict[str, models.ProbeTestResultDict]: """ Returns a dictionary with the results of the probes. The keys of the main dictionary represent the name of the probes. @@ -1266,14 +1309,14 @@ def get_probes_results(self): def ping( self, - destination, - source=c.PING_SOURCE, - ttl=c.PING_TTL, - timeout=c.PING_TIMEOUT, - size=c.PING_SIZE, - count=c.PING_COUNT, - vrf=c.PING_VRF, - ): + destination: str, + source: str = c.PING_SOURCE, + ttl: int = c.PING_TTL, + timeout: int = c.PING_TIMEOUT, + size: int = c.PING_SIZE, + count: int = c.PING_COUNT, + vrf: str = c.PING_VRF, + ) -> models.PingResultDict: """ Executes ping on the device and returns a dictionary with the result @@ -1345,12 +1388,12 @@ def ping( def traceroute( self, - destination, - source=c.TRACEROUTE_SOURCE, - ttl=c.TRACEROUTE_TTL, - timeout=c.TRACEROUTE_TIMEOUT, - vrf=c.TRACEROUTE_VRF, - ): + destination: str, + source: str = c.TRACEROUTE_SOURCE, + ttl: int = c.TRACEROUTE_TTL, + timeout: int = c.TRACEROUTE_TIMEOUT, + vrf: str = c.TRACEROUTE_VRF, + ) -> models.TracerouteResultDict: """ Executes traceroute on the device and returns a dictionary with the result. @@ -1462,7 +1505,7 @@ def traceroute( """ raise NotImplementedError - def get_users(self): + def get_users(self) -> Dict[str, models.UsersDict]: """ Returns a dictionary with the configured users. The keys of the main dictionary represents the username. The values represent the details @@ -1494,7 +1537,7 @@ def get_users(self): """ raise NotImplementedError - def get_optics(self): + def get_optics(self) -> Dict[str, models.OpticsDict]: """Fetches the power usage on the various transceivers installed on the switch (in dbm), and returns a view that conforms with the openconfig model openconfig-platform-transceiver.yang @@ -1558,7 +1601,9 @@ def get_optics(self): """ raise NotImplementedError - def get_config(self, retrieve="all", full=False, sanitized=False): + def get_config( + self, retrieve: str = "all", full: bool = False, sanitized: bool = False + ) -> models.ConfigDict: """ Return the configuration of a device. @@ -1581,7 +1626,9 @@ def get_config(self, retrieve="all", full=False, sanitized=False): """ raise NotImplementedError - def get_network_instances(self, name=""): + def get_network_instances( + self, name: str = "" + ) -> Dict[str, models.NetworkInstanceDict]: """ Return a dictionary of network instances (VRFs) configured, including default/global @@ -1633,7 +1680,7 @@ def get_network_instances(self, name=""): """ raise NotImplementedError - def get_firewall_policies(self): + def get_firewall_policies(self) -> Dict[str, List[models.FirewallPolicyDict]]: """ Returns a dictionary of lists of dictionaries where the first key is an unique policy name and the inner dictionary contains the following keys: @@ -1674,7 +1721,7 @@ def get_firewall_policies(self): """ raise NotImplementedError - def get_ipv6_neighbors_table(self): + def get_ipv6_neighbors_table(self) -> List[models.IPV6NeighborDict]: """ Get IPv6 neighbors table information. @@ -1707,7 +1754,7 @@ def get_ipv6_neighbors_table(self): """ raise NotImplementedError - def get_vlans(self): + def get_vlans(self) -> Dict[str, models.VlanDict]: """ Return structure being spit balled is as follows. * vlan_id (int) @@ -1729,7 +1776,11 @@ def get_vlans(self): """ raise NotImplementedError - def compliance_report(self, validation_file=None, validation_source=None): + def compliance_report( + self, + validation_file: Optional[str] = None, + validation_source: Optional[str] = None, + ) -> models.ReportResult: """ Return a compliance report. @@ -1745,7 +1796,7 @@ def compliance_report(self, validation_file=None, validation_source=None): self, validation_file=validation_file, validation_source=validation_source ) - def _canonical_int(self, interface): + def _canonical_int(self, interface: str) -> str: """Expose the helper function within this class.""" if self.use_canonical_interface is True: return napalm.base.helpers.canonical_interface_name( diff --git a/napalm/base/helpers.py b/napalm/base/helpers.py index f7d9954f1..311ff1081 100644 --- a/napalm/base/helpers.py +++ b/napalm/base/helpers.py @@ -1,25 +1,33 @@ """Helper functions for the NAPALM base.""" +import itertools +import logging + # std libs import os import re import sys -import itertools -import logging from collections.abc import Iterable # third party libs +from typing import Optional, Dict, Any, List, Union, Tuple, TypeVar, Callable + import jinja2 import textfsm +from ciscoconfparse import CiscoConfParse +from lxml import etree from netaddr import EUI -from netaddr import mac_unix from netaddr import IPAddress -from ciscoconfparse import CiscoConfParse +from netaddr import mac_unix # local modules import napalm.base.exceptions from napalm.base import constants -from napalm.base.utils.jinja_filters import CustomJinjaFilters from napalm.base.canonical_map import base_interfaces, reverse_mapping +from napalm.base.test.models import ConfigDict +from napalm.base.utils.jinja_filters import CustomJinjaFilters + +T = TypeVar("T") +R = TypeVar("R") # ------------------------------------------------------------------- # Functional Global @@ -41,14 +49,14 @@ class _MACFormat(mac_unix): # callable helpers # ------------------------------------------------------------------- def load_template( - cls, - template_name, - template_source=None, - template_path=None, - openconfig=False, - jinja_filters={}, - **template_vars, -): + cls: "napalm.base.NetworkDriver", + template_name: str, + template_source: Optional[str] = None, + template_path: Optional[str] = None, + openconfig: bool = False, + jinja_filters: Dict = {}, + **template_vars: Any, +) -> None: try: search_path = [] if isinstance(template_source, str): @@ -111,7 +119,9 @@ def load_template( return cls.load_merge_candidate(config=configuration) -def cisco_conf_parse_parents(parent, child, config): +def cisco_conf_parse_parents( + parent: str, child: str, config: Union[str, List[str]] +) -> List[str]: """ Use CiscoConfParse to find parent lines that contain a specific child line. @@ -120,13 +130,15 @@ def cisco_conf_parse_parents(parent, child, config): :param config: The device running/startup config """ if type(config) == str: - config = config.splitlines() + config = config.splitlines() # type: ignore parse = CiscoConfParse(config) cfg_obj = parse.find_parents_w_child(parent, child) return cfg_obj -def cisco_conf_parse_objects(cfg_section, config): +def cisco_conf_parse_objects( + cfg_section: str, config: Union[str, List[str]] +) -> List[str]: """ Use CiscoConfParse to find and return a section of Cisco IOS config. Similar to "show run | section " @@ -136,7 +148,7 @@ def cisco_conf_parse_objects(cfg_section, config): """ return_config = [] if type(config) is str: - config = config.splitlines() + config = config.splitlines() # type: ignore parse = CiscoConfParse(config) cfg_obj = parse.find_objects(cfg_section) for parent in cfg_obj: @@ -146,7 +158,7 @@ def cisco_conf_parse_objects(cfg_section, config): return return_config -def regex_find_txt(pattern, text, default=""): +def regex_find_txt(pattern: str, text: str, default: str = "") -> Any: """ "" RegEx search for pattern in text. Will try to match the data type of the "default" value or return the default value if no match is found. @@ -167,7 +179,7 @@ def regex_find_txt(pattern, text, default=""): if not isinstance(value, type(default)): if isinstance(value, list) and len(value) == 1: value = value[0] - value = type(default)(value) + value = type(default)(value) # type: ignore except Exception as regexFindTxtErr01: # in case of any exception, returns default logger.error( 'errorCode="regexFindTxtErr01" in napalm.base.helpers with systemMessage="%s"\ @@ -175,11 +187,13 @@ def regex_find_txt(pattern, text, default=""): default to empty string"' % (regexFindTxtErr01) ) - value = default + value = default # type: ignore return value -def textfsm_extractor(cls, template_name, raw_text): +def textfsm_extractor( + cls: "napalm.base.NetworkDriver", template_name: str, raw_text: str +) -> List[Dict]: """ Applies a TextFSM template over a raw text and return the matching table. @@ -245,7 +259,12 @@ def textfsm_extractor(cls, template_name, raw_text): ) -def find_txt(xml_tree, path, default="", namespaces=None): +def find_txt( + xml_tree: etree._Element, + path: str, + default: str = "", + namespaces: Optional[Dict] = None, +) -> str: """ Extracts the text value from an XML tree, using XPath. In case of error or text element unavailability, will return a default value. @@ -284,7 +303,7 @@ def find_txt(xml_tree, path, default="", namespaces=None): return str(value) -def convert(to, who, default=""): +def convert(to: Callable[[T], R], who: Optional[T], default: R = "") -> R: # type: ignore """ Converts data to a specific datatype. In case of error, will return a default value. @@ -302,7 +321,7 @@ def convert(to, who, default=""): return default -def mac(raw): +def mac(raw: str) -> str: """ Converts a raw string to a standardised MAC Address EUI Format. @@ -339,7 +358,7 @@ def mac(raw): return str(EUI(raw, dialect=_MACFormat)) -def ip(addr, version=None): +def ip(addr: str, version: Optional[int] = None) -> str: """ Converts a raw string to a valid IP address. Optional version argument will detect that \ object matches specified version. @@ -368,7 +387,7 @@ def ip(addr, version=None): return str(addr_obj) -def as_number(as_number_val): +def as_number(as_number_val: str) -> int: """Convert AS Number to standardized asplain notation as an integer.""" as_number_str = str(as_number_val) if "." in as_number_str: @@ -378,14 +397,16 @@ def as_number(as_number_val): return int(as_number_str) -def split_interface(intf_name): +def split_interface(intf_name: str) -> Tuple[str, str]: """Split an interface name based on first digit, slash, or space match.""" head = intf_name.rstrip(r"/\0123456789. ") tail = intf_name[len(head) :].lstrip() return (head, tail) -def canonical_interface_name(interface, addl_name_map=None): +def canonical_interface_name( + interface: str, addl_name_map: Optional[Dict[str, str]] = None +) -> str: """Function to return an interface's canonical name (fully expanded name). Use of explicit matches used to indicate a clear understanding on any potential @@ -410,13 +431,18 @@ def canonical_interface_name(interface, addl_name_map=None): # check in dict for mapping if name_map.get(interface_type): long_int = name_map.get(interface_type) + assert isinstance(long_int, str) return long_int + str(interface_number) # if nothing matched, return the original name else: return interface -def abbreviated_interface_name(interface, addl_name_map=None, addl_reverse_map=None): +def abbreviated_interface_name( + interface: str, + addl_name_map: Optional[Dict[str, str]] = None, + addl_reverse_map: Optional[Dict[str, str]] = None, +) -> str: """Function to return an abbreviated representation of the interface name. :param interface: The interface you are attempting to abbreviate. @@ -449,6 +475,8 @@ def abbreviated_interface_name(interface, addl_name_map=None, addl_reverse_map=N else: canonical_type = interface_type + assert isinstance(canonical_type, str) + try: abbreviated_name = rev_name_map[canonical_type] + str(interface_number) return abbreviated_name @@ -459,7 +487,7 @@ def abbreviated_interface_name(interface, addl_name_map=None, addl_reverse_map=N return interface -def transform_lldp_capab(capabilities): +def transform_lldp_capab(capabilities: Union[str, Any]) -> List[str]: if capabilities and isinstance(capabilities, str): capabilities = capabilities.strip().lower().split(",") return sorted( @@ -469,7 +497,7 @@ def transform_lldp_capab(capabilities): return [] -def generate_regex_or(filters): +def generate_regex_or(filters: Iterable) -> str: """ Build a regular expression logical-or from a list/tuple of regex patterns. @@ -490,7 +518,7 @@ def generate_regex_or(filters): return return_pattern -def sanitize_config(config, filters): +def sanitize_config(config: str, filters: Dict) -> str: """ Given a dictionary of filters, remove sensitive data from the provided config. """ @@ -499,12 +527,13 @@ def sanitize_config(config, filters): return config -def sanitize_configs(configs, filters): +def sanitize_configs(configs: ConfigDict, filters: Dict) -> ConfigDict: """ Apply sanitize_config on the dictionary of configs typically returned by the get_config method. """ for cfg_name, config in configs.items(): + assert isinstance(config, str) if config.strip(): - configs[cfg_name] = sanitize_config(config, filters) + configs[cfg_name] = sanitize_config(config, filters) # type: ignore return configs diff --git a/napalm/base/mock.py b/napalm/base/mock.py index b56682f49..6e8fd0451 100644 --- a/napalm/base/mock.py +++ b/napalm/base/mock.py @@ -11,6 +11,8 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations under # the License. +from collections import Callable +from typing import Optional, List, Dict, Union, Any from napalm.base.base import NetworkDriver import napalm.base.exceptions @@ -23,8 +25,10 @@ from pydoc import locate +from napalm.base.test import models -def raise_exception(result): + +def raise_exception(result): # type: ignore exc = locate(result["exception"]) if exc: raise exc(*result.get("args", []), **result.get("kwargs", {})) @@ -32,19 +36,19 @@ def raise_exception(result): raise TypeError("Couldn't resolve exception {}", result["exception"]) -def is_mocked_method(method): +def is_mocked_method(method: str) -> bool: mocked_methods = ["traceroute", "ping"] if method.startswith("get_") or method in mocked_methods: return True return False -def mocked_method(path, name, count): +def mocked_method(path: str, name: str, count: int) -> Callable: parent_method = getattr(NetworkDriver, name) parent_method_args = inspect.getfullargspec(parent_method) modifier = 0 if "self" not in parent_method_args.args else 1 - def _mocked_method(*args, **kwargs): + def _mocked_method(*args, **kwargs): # type: ignore # Check len(args) if len(args) + len(kwargs) + modifier > len(parent_method_args.args): raise TypeError( @@ -64,7 +68,7 @@ def _mocked_method(*args, **kwargs): return _mocked_method -def mocked_data(path, name, count): +def mocked_data(path: str, name: str, count: int) -> Union[Dict, List]: filename = "{}.{}".format(os.path.join(path, name), count) try: with open(filename) as f: @@ -74,26 +78,36 @@ def mocked_data(path, name, count): if "exception" in result: raise_exception(result) + assert False else: return result class MockDevice(object): - def __init__(self, parent, profile): + def __init__(self, parent: NetworkDriver, profile: str) -> None: self.parent = parent self.profile = profile - def run_commands(self, commands): + def run_commands(self, commands: List[str]) -> str: """Mock for EOS""" - return list(self.parent.cli(commands).values())[0] + return_value = list(self.parent.cli(commands).values())[0] + assert isinstance(return_value, str) + return return_value - def show(self, command): + def show(self, command: str) -> str: """Mock for nxos""" return self.run_commands([command]) class MockDriver(NetworkDriver): - def __init__(self, hostname, username, password, timeout=60, optional_args=None): + def __init__( + self, + hostname: str, + username: str, + password: str, + timeout: int = 60, + optional_args: Optional[Dict] = None, + ) -> None: """ Supported optional_args: * path(str) - path to where the mocked files are located @@ -102,40 +116,42 @@ def __init__(self, hostname, username, password, timeout=60, optional_args=None) self.hostname = hostname self.username = username self.password = password + if not optional_args: + optional_args = {} self.path = optional_args.get("path", "") self.profile = optional_args.get("profile", []) self.fail_on_open = optional_args.get("fail_on_open", False) self.opened = False - self.calls = {} + self.calls: Dict[str, int] = {} self.device = MockDevice(self, self.profile) # None no action, True load_merge, False load_replace - self.merge = None - self.filename = None - self.config = None + self.merge: Optional[bool] = None + self.filename: Optional[str] = None + self.config: Optional[str] = None - def _count_calls(self, name): + def _count_calls(self, name: str) -> int: current_count = self.calls.get(name, 0) self.calls[name] = current_count + 1 return self.calls[name] - def _raise_if_closed(self): + def _raise_if_closed(self) -> None: if not self.opened: raise napalm.base.exceptions.ConnectionClosedException("connection closed") - def open(self): + def open(self) -> None: if self.fail_on_open: raise napalm.base.exceptions.ConnectionException("You told me to do this") self.opened = True - def close(self): + def close(self) -> None: self.opened = False - def is_alive(self): + def is_alive(self) -> models.AliveDict: return {"is_alive": self.opened} - def cli(self, commands): + def cli(self, commands: List[str]) -> Dict[str, Union[str, Dict[str, Any]]]: count = self._count_calls("cli") result = {} regexp = re.compile("[^a-zA-Z0-9]+") @@ -145,9 +161,11 @@ def cli(self, commands): filename = "{}.{}".format(os.path.join(self.path, name), i) with open(filename, "r") as f: result[c] = f.read() - return result + return result # type: ignore - def load_merge_candidate(self, filename=None, config=None): + def load_merge_candidate( + self, filename: Optional[str] = None, config: Optional[str] = None + ) -> None: count = self._count_calls("load_merge_candidate") self._raise_if_closed() self.merge = True @@ -155,7 +173,9 @@ def load_merge_candidate(self, filename=None, config=None): self.config = config mocked_data(self.path, "load_merge_candidate", count) - def load_replace_candidate(self, filename=None, config=None): + def load_replace_candidate( + self, filename: Optional[str] = None, config: Optional[str] = None + ) -> None: count = self._count_calls("load_replace_candidate") self._raise_if_closed() self.merge = False @@ -163,12 +183,16 @@ def load_replace_candidate(self, filename=None, config=None): self.config = config mocked_data(self.path, "load_replace_candidate", count) - def compare_config(self, filename=None, config=None): + def compare_config( + self, filename: Optional[str] = None, config: Optional[str] = None + ) -> str: count = self._count_calls("compare_config") self._raise_if_closed() - return mocked_data(self.path, "compare_config", count)["diff"] + mocked = mocked_data(self.path, "compare_config", count) + assert isinstance(mocked, dict) + return mocked["diff"] - def commit_config(self): + def commit_config(self, message: str = "", revert_in: Optional[int] = None) -> None: count = self._count_calls("commit_config") self._raise_if_closed() self.merge = None @@ -176,7 +200,7 @@ def commit_config(self): self.config = None mocked_data(self.path, "commit_config", count) - def discard_config(self): + def discard_config(self) -> None: count = self._count_calls("discard_config") self._raise_if_closed() self.merge = None @@ -184,11 +208,13 @@ def discard_config(self): self.config = None mocked_data(self.path, "discard_config", count) - def _rpc(self, get): + def _rpc(self, get: str) -> str: """This one is only useful for junos.""" - return list(self.cli([get]).values())[0] + return_value = list(self.cli([get]).values())[0] + assert isinstance(return_value, str) + return return_value - def __getattribute__(self, name): + def __getattribute__(self, name: str) -> Callable: if is_mocked_method(name): self._raise_if_closed() count = self._count_calls(name) diff --git a/napalm/base/netmiko_helpers.py b/napalm/base/netmiko_helpers.py index 382041f29..6455ee955 100644 --- a/napalm/base/netmiko_helpers.py +++ b/napalm/base/netmiko_helpers.py @@ -10,10 +10,12 @@ # License for the specific language governing permissions and limitations under # the License. import inspect +from typing import Dict + from netmiko import BaseConnection -def netmiko_args(optional_args): +def netmiko_args(optional_args: Dict) -> Dict: """Check for Netmiko arguments that were passed in as NAPALM optional arguments. Return a dictionary of these optional args that will be passed into the Netmiko @@ -22,6 +24,7 @@ def netmiko_args(optional_args): fields = inspect.getfullargspec(BaseConnection.__init__) args = fields[0] defaults = fields[3] + assert isinstance(defaults, tuple) check_self = args.pop(0) if check_self != "self": diff --git a/napalm/base/test/base.py b/napalm/base/test/base.py index 9a90b81d4..f005101ca 100644 --- a/napalm/base/test/base.py +++ b/napalm/base/test/base.py @@ -192,7 +192,7 @@ def test_get_firewall_policies(self): for policy_name, policy_details in policies.items(): for policy_term in policy_details: result = result and self._test_model( - models.firewall_policies, policy_term + models.FirewallPolicyDict, policy_term ) self.assertTrue(result) @@ -202,7 +202,7 @@ def test_is_alive(self): alive = self.device.is_alive() except NotImplementedError: raise SkipTest() - result = self._test_model(models.alive, alive) + result = self._test_model(models.AliveDict, alive) self.assertTrue(result) def test_get_facts(self): @@ -210,7 +210,7 @@ def test_get_facts(self): facts = self.device.get_facts() except NotImplementedError: raise SkipTest() - result = self._test_model(models.facts, facts) + result = self._test_model(models.FactsDict, facts) self.assertTrue(result) def test_get_interfaces(self): @@ -221,7 +221,7 @@ def test_get_interfaces(self): result = len(get_interfaces) > 0 for interface, interface_data in get_interfaces.items(): - result = result and self._test_model(models.interface, interface_data) + result = result and self._test_model(models.InterfaceDict, interface_data) self.assertTrue(result) @@ -234,7 +234,7 @@ def test_get_lldp_neighbors(self): for interface, neighbor_list in get_lldp_neighbors.items(): for neighbor in neighbor_list: - result = result and self._test_model(models.lldp_neighbors, neighbor) + result = result and self._test_model(models.LLDPNeighborDict, neighbor) self.assertTrue(result) @@ -247,7 +247,7 @@ def test_get_interfaces_counters(self): for interface, interface_data in get_interfaces_counters.items(): result = result and self._test_model( - models.interface_counters, interface_data + models.InterfaceCounterDict, interface_data ) self.assertTrue(result) @@ -260,18 +260,20 @@ def test_get_environment(self): result = len(environment) > 0 for fan, fan_data in environment["fans"].items(): - result = result and self._test_model(models.fan, fan_data) + result = result and self._test_model(models.FanDict, fan_data) for power, power_data in environment["power"].items(): - result = result and self._test_model(models.power, power_data) + result = result and self._test_model(models.PowerDict, power_data) for temperature, temperature_data in environment["temperature"].items(): - result = result and self._test_model(models.temperature, temperature_data) + result = result and self._test_model( + models.TemperatureDict, temperature_data + ) for cpu, cpu_data in environment["cpu"].items(): - result = result and self._test_model(models.cpu, cpu_data) + result = result and self._test_model(models.CPUDict, cpu_data) - result = result and self._test_model(models.memory, environment["memory"]) + result = result and self._test_model(models.MemoryDict, environment["memory"]) self.assertTrue(result) @@ -291,10 +293,10 @@ def test_get_bgp_neighbors(self): print("router_id is not {}".format(str)) for peer, peer_data in vrf_data["peers"].items(): - result = result and self._test_model(models.peer, peer_data) + result = result and self._test_model(models.PeerDict, peer_data) for af, af_data in peer_data["address_family"].items(): - result = result and self._test_model(models.af, af_data) + result = result and self._test_model(models.AFDict, af_data) self.assertTrue(result) @@ -308,7 +310,7 @@ def test_get_lldp_neighbors_detail(self): for interface, neighbor_list in get_lldp_neighbors_detail.items(): for neighbor in neighbor_list: result = result and self._test_model( - models.lldp_neighbors_detail, neighbor + models.LLDPNeighborsDetailDict, neighbor ) self.assertTrue(result) @@ -321,10 +323,10 @@ def test_get_bgp_config(self): result = len(get_bgp_config) > 0 for bgp_group in get_bgp_config.values(): - result = result and self._test_model(models.bgp_config_group, bgp_group) + result = result and self._test_model(models.BPGConfigGroupDict, bgp_group) for bgp_neighbor in bgp_group.get("neighbors", {}).values(): result = result and self._test_model( - models.bgp_config_neighbor, bgp_neighbor + models.BGPConfigNeighborDict, bgp_neighbor ) self.assertTrue(result) @@ -342,7 +344,9 @@ def test_get_bgp_neighbors_detail(self): for remote_as, neighbor_list in vrf_ases.items(): result = result and isinstance(remote_as, int) for neighbor in neighbor_list: - result = result and self._test_model(models.peer_details, neighbor) + result = result and self._test_model( + models.PeerDetailsDict, neighbor + ) self.assertTrue(result) @@ -354,7 +358,7 @@ def test_get_arp_table(self): result = len(get_arp_table) > 0 for arp_entry in get_arp_table: - result = result and self._test_model(models.arp_table, arp_entry) + result = result and self._test_model(models.ARPTableDict, arp_entry) self.assertTrue(result) @@ -366,7 +370,7 @@ def test_get_arp_table_with_vrf(self): result = len(get_arp_table) > 0 for arp_entry in get_arp_table: - result = result and self._test_model(models.arp_table, arp_entry) + result = result and self._test_model(models.ARPTableDict, arp_entry) self.assertTrue(result) @@ -379,7 +383,7 @@ def test_get_ntp_peers(self): for peer, peer_details in get_ntp_peers.items(): result = result and isinstance(peer, str) - result = result and self._test_model(models.ntp_peer, peer_details) + result = result and self._test_model(models.NTPPeerDict, peer_details) self.assertTrue(result) @@ -393,7 +397,7 @@ def test_get_ntp_servers(self): for server, server_details in get_ntp_servers.items(): result = result and isinstance(server, str) - result = result and self._test_model(models.ntp_server, server_details) + result = result and self._test_model(models.NTPServerDict, server_details) self.assertTrue(result) @@ -405,7 +409,7 @@ def test_get_ntp_stats(self): result = len(get_ntp_stats) > 0 for ntp_peer_details in get_ntp_stats: - result = result and self._test_model(models.ntp_stats, ntp_peer_details) + result = result and self._test_model(models.NTPStats, ntp_peer_details) self.assertTrue(result) @@ -420,9 +424,13 @@ def test_get_interfaces_ip(self): ipv4 = interface_details.get("ipv4", {}) ipv6 = interface_details.get("ipv6", {}) for ip, ip_details in ipv4.items(): - result = result and self._test_model(models.interfaces_ip, ip_details) + result = result and self._test_model( + models.InterfacesIPDictEntry, ip_details + ) for ip, ip_details in ipv6.items(): - result = result and self._test_model(models.interfaces_ip, ip_details) + result = result and self._test_model( + models.InterfacesIPDictEntry, ip_details + ) self.assertTrue(result) @@ -434,9 +442,7 @@ def test_get_mac_address_table(self): result = len(get_mac_address_table) > 0 for mac_table_entry in get_mac_address_table: - result = result and self._test_model( - models.mac_address_table, mac_table_entry - ) + result = result and self._test_model(models.MACAdressTable, mac_table_entry) self.assertTrue(result) @@ -454,7 +460,7 @@ def test_get_route_to(self): for prefix, routes in get_route_to.items(): for route in routes: - result = result and self._test_model(models.route, route) + result = result and self._test_model(models.RouteDict, route) self.assertTrue(result) def test_get_snmp_information(self): @@ -466,10 +472,12 @@ def test_get_snmp_information(self): result = len(get_snmp_information) > 0 for snmp_entry in get_snmp_information: - result = result and self._test_model(models.snmp, get_snmp_information) + result = result and self._test_model(models.SNMPDict, get_snmp_information) for community, community_data in get_snmp_information["community"].items(): - result = result and self._test_model(models.snmp_community, community_data) + result = result and self._test_model( + models.SNMPCommunityDict, community_data + ) self.assertTrue(result) @@ -483,7 +491,7 @@ def test_get_probes_config(self): for probe_name, probe_tests in get_probes_config.items(): for test_name, test_config in probe_tests.items(): - result = result and self._test_model(models.probe_test, test_config) + result = result and self._test_model(models.ProbeTestDict, test_config) self.assertTrue(result) @@ -497,7 +505,7 @@ def test_get_probes_results(self): for probe_name, probe_tests in get_probes_results.items(): for test_name, test_results in probe_tests.items(): result = result and self._test_model( - models.probe_test_results, test_results + models.ProbeTestResultDict, test_results ) self.assertTrue(result) @@ -511,10 +519,12 @@ def test_ping(self): result = isinstance(get_ping.get("success"), dict) ping_results = get_ping.get("success", {}) - result = result and self._test_model(models.ping, ping_results) + result = result and self._test_model(models.PingDict, ping_results) for ping_result in ping_results.get("results", []): - result = result and self._test_model(models.ping_result, ping_result) + result = result and self._test_model( + models.PingResultDictEntry, ping_result + ) self.assertTrue(result) @@ -530,7 +540,9 @@ def test_traceroute(self): for hope_id, hop_result in traceroute_results.items(): for probe_id, probe_result in hop_result.get("probes", {}).items(): - result = result and self._test_model(models.traceroute, probe_result) + result = result and self._test_model( + models.TracerouteDict, probe_result + ) self.assertTrue(result) @@ -543,7 +555,7 @@ def test_get_users(self): result = len(get_users) for user, user_details in get_users.items(): - result = result and self._test_model(models.users, user_details) + result = result and self._test_model(models.UsersDict, user_details) result = result and (0 <= user_details.get("level") <= 15) self.assertTrue(result) @@ -578,7 +590,7 @@ def test_get_config(self): raise SkipTest() assert isinstance(get_config, dict) - assert self._test_model(models.config, get_config) + assert self._test_model(models.ConfigDict, get_config) def test_get_config_filtered(self): """Test get_config method.""" @@ -602,13 +614,13 @@ def test_get_network_instances(self): result = isinstance(get_network_instances, dict) for network_instance_name, network_instance in get_network_instances.items(): result = result and self._test_model( - models.network_instance, network_instance + models.NetworkInstanceDict, network_instance ) result = result and self._test_model( - models.network_instance_state, network_instance["state"] + models.NetworkInstanceStateDict, network_instance["state"] ) result = result and self._test_model( - models.network_instance_interfaces, network_instance["interfaces"] + models.NetworkInstanceInterfacesDict, network_instance["interfaces"] ) self.assertTrue(result) @@ -621,7 +633,7 @@ def test_get_ipv6_neighbors_table(self): result = len(get_ipv6_neighbors_table) > 0 for entry in get_ipv6_neighbors_table: - result = result and self._test_model(models.ipv6_neighbor, entry) + result = result and self._test_model(models.IPV6NeighborDict, entry) self.assertTrue(result) @@ -633,6 +645,6 @@ def test_get_vlans(self): result = len(get_vlans) > 0 for vlan, vlan_data in get_vlans.items(): - result = result and self._test_model(models.vlan, vlan_data) + result = result and self._test_model(models.VlanDict, vlan_data) self.assertTrue(result) diff --git a/napalm/base/test/getters.py b/napalm/base/test/getters.py index dd80b9293..af52ebbd8 100644 --- a/napalm/base/test/getters.py +++ b/napalm/base/test/getters.py @@ -109,7 +109,10 @@ class BaseTestGetters(object): """Base class for testing drivers.""" def test_method_signatures(self): - """Test that all methods have the same signature.""" + """ + Test that all methods have the same signature. + + The type hint annotations are ignored here because the import paths might differ.""" errors = {} cls = self.driver # Create fictional driver instance (py3 needs bound methods) @@ -121,17 +124,17 @@ def test_method_signatures(self): continue try: orig = getattr(NetworkDriver, attr) - orig_spec = inspect.getfullargspec(orig) + orig_spec = inspect.getfullargspec(orig)[:4] except AttributeError: orig_spec = "Method does not exist in napalm.base" - func_spec = inspect.getfullargspec(func) + func_spec = inspect.getfullargspec(func)[:4] if orig_spec != func_spec: errors[attr] = (orig_spec, func_spec) EXTRA_METHODS = ["__init__"] for method in EXTRA_METHODS: - orig_spec = inspect.getfullargspec(getattr(NetworkDriver, method)) - func_spec = inspect.getfullargspec(getattr(cls, method)) + orig_spec = inspect.getfullargspec(getattr(NetworkDriver, method))[:4] + func_spec = inspect.getfullargspec(getattr(cls, method))[:4] if orig_spec != func_spec: errors[attr] = (orig_spec, func_spec) @@ -141,14 +144,14 @@ def test_method_signatures(self): def test_is_alive(self, test_case): """Test is_alive method.""" alive = self.device.is_alive() - assert helpers.test_model(models.alive, alive) + assert helpers.test_model(models.AliveDict, alive) return alive @wrap_test_cases def test_get_facts(self, test_case): """Test get_facts method.""" facts = self.device.get_facts() - assert helpers.test_model(models.facts, facts) + assert helpers.test_model(models.FactsDict, facts) return facts @wrap_test_cases @@ -158,7 +161,7 @@ def test_get_interfaces(self, test_case): assert len(get_interfaces) > 0 for interface, interface_data in get_interfaces.items(): - assert helpers.test_model(models.interface, interface_data) + assert helpers.test_model(models.InterfaceDict, interface_data) return get_interfaces @@ -170,7 +173,7 @@ def test_get_lldp_neighbors(self, test_case): for interface, neighbor_list in get_lldp_neighbors.items(): for neighbor in neighbor_list: - assert helpers.test_model(models.lldp_neighbors, neighbor) + assert helpers.test_model(models.LLDPNeighborDict, neighbor) return get_lldp_neighbors @@ -181,7 +184,7 @@ def test_get_interfaces_counters(self, test_case): assert len(self.device.get_interfaces_counters()) > 0 for interface, interface_data in get_interfaces_counters.items(): - assert helpers.test_model(models.interface_counters, interface_data) + assert helpers.test_model(models.InterfaceCounterDict, interface_data) return get_interfaces_counters @@ -192,18 +195,18 @@ def test_get_environment(self, test_case): assert len(environment) > 0 for fan, fan_data in environment["fans"].items(): - assert helpers.test_model(models.fan, fan_data) + assert helpers.test_model(models.FanDict, fan_data) for power, power_data in environment["power"].items(): - assert helpers.test_model(models.power, power_data) + assert helpers.test_model(models.PowerDict, power_data) for temperature, temperature_data in environment["temperature"].items(): - assert helpers.test_model(models.temperature, temperature_data) + assert helpers.test_model(models.TemperatureDict, temperature_data) for cpu, cpu_data in environment["cpu"].items(): - assert helpers.test_model(models.cpu, cpu_data) + assert helpers.test_model(models.CPUDict, cpu_data) - assert helpers.test_model(models.memory, environment["memory"]) + assert helpers.test_model(models.MemoryDict, environment["memory"]) return environment @@ -218,10 +221,10 @@ def test_get_bgp_neighbors(self, test_case): assert isinstance(vrf_data["router_id"], str) for peer, peer_data in vrf_data["peers"].items(): - assert helpers.test_model(models.peer, peer_data) + assert helpers.test_model(models.PeerDict, peer_data) for af, af_data in peer_data["address_family"].items(): - assert helpers.test_model(models.af, af_data) + assert helpers.test_model(models.AFDict, af_data) return get_bgp_neighbors @@ -233,7 +236,7 @@ def test_get_lldp_neighbors_detail(self, test_case): for interface, neighbor_list in get_lldp_neighbors_detail.items(): for neighbor in neighbor_list: - assert helpers.test_model(models.lldp_neighbors_detail, neighbor) + assert helpers.test_model(models.LLDPNeighborDetailDict, neighbor) return get_lldp_neighbors_detail @@ -244,9 +247,9 @@ def test_get_bgp_config(self, test_case): assert len(get_bgp_config) > 0 for bgp_group in get_bgp_config.values(): - assert helpers.test_model(models.bgp_config_group, bgp_group) + assert helpers.test_model(models.BPGConfigGroupDict, bgp_group) for bgp_neighbor in bgp_group.get("neighbors", {}).values(): - assert helpers.test_model(models.bgp_config_neighbor, bgp_neighbor) + assert helpers.test_model(models.BGPConfigNeighborDict, bgp_neighbor) return get_bgp_config @@ -262,7 +265,7 @@ def test_get_bgp_neighbors_detail(self, test_case): for remote_as, neighbor_list in vrf_ases.items(): assert isinstance(remote_as, int) for neighbor in neighbor_list: - assert helpers.test_model(models.peer_details, neighbor) + assert helpers.test_model(models.PeerDetailsDict, neighbor) return get_bgp_neighbors_detail @@ -273,7 +276,7 @@ def test_get_arp_table(self, test_case): assert len(get_arp_table) > 0 for arp_entry in get_arp_table: - assert helpers.test_model(models.arp_table, arp_entry) + assert helpers.test_model(models.ARPTableDict, arp_entry) return get_arp_table @@ -284,7 +287,7 @@ def test_get_arp_table_with_vrf(self, test_case): assert len(get_arp_table) > 0 for arp_entry in get_arp_table: - assert helpers.test_model(models.arp_table, arp_entry) + assert helpers.test_model(models.ARPTableDict, arp_entry) return get_arp_table @@ -294,7 +297,7 @@ def test_get_ipv6_neighbors_table(self, test_case): get_ipv6_neighbors_table = self.device.get_ipv6_neighbors_table() for entry in get_ipv6_neighbors_table: - assert helpers.test_model(models.ipv6_neighbor, entry) + assert helpers.test_model(models.IPV6NeighborDict, entry) return get_ipv6_neighbors_table @@ -306,7 +309,7 @@ def test_get_ntp_peers(self, test_case): for peer, peer_details in get_ntp_peers.items(): assert isinstance(peer, str) - assert helpers.test_model(models.ntp_peer, peer_details) + assert helpers.test_model(models.NTPPeerDict, peer_details) return get_ntp_peers @@ -318,7 +321,7 @@ def test_get_ntp_servers(self, test_case): for server, server_details in get_ntp_servers.items(): assert isinstance(server, str) - assert helpers.test_model(models.ntp_server, server_details) + assert helpers.test_model(models.NTPServerDict, server_details) return get_ntp_servers @@ -329,7 +332,7 @@ def test_get_ntp_stats(self, test_case): assert len(get_ntp_stats) > 0 for ntp_peer_details in get_ntp_stats: - assert helpers.test_model(models.ntp_stats, ntp_peer_details) + assert helpers.test_model(models.NTPStats, ntp_peer_details) return get_ntp_stats @@ -343,9 +346,9 @@ def test_get_interfaces_ip(self, test_case): ipv4 = interface_details.get("ipv4", {}) ipv6 = interface_details.get("ipv6", {}) for ip, ip_details in ipv4.items(): - assert helpers.test_model(models.interfaces_ip, ip_details) + assert helpers.test_model(models.InterfacesIPDictEntry, ip_details) for ip, ip_details in ipv6.items(): - assert helpers.test_model(models.interfaces_ip, ip_details) + assert helpers.test_model(models.InterfacesIPDictEntry, ip_details) return get_interfaces_ip @@ -356,7 +359,7 @@ def test_get_mac_address_table(self, test_case): assert len(get_mac_address_table) > 0 for mac_table_entry in get_mac_address_table: - assert helpers.test_model(models.mac_address_table, mac_table_entry) + assert helpers.test_model(models.MACAdressTable, mac_table_entry) return get_mac_address_table @@ -373,7 +376,7 @@ def test_get_route_to(self, test_case): for prefix, routes in get_route_to.items(): for route in routes: - assert helpers.test_model(models.route, route) + assert helpers.test_model(models.RouteDict, route) return get_route_to @@ -391,7 +394,7 @@ def test_get_route_to_longer(self, test_case): for prefix, routes in get_route_to.items(): for route in routes: - assert helpers.test_model(models.route, route) + assert helpers.test_model(models.RouteDict, route) return get_route_to @@ -403,10 +406,10 @@ def test_get_snmp_information(self, test_case): assert len(get_snmp_information) > 0 for snmp_entry in get_snmp_information: - assert helpers.test_model(models.snmp, get_snmp_information) + assert helpers.test_model(models.SNMPDict, get_snmp_information) for community, community_data in get_snmp_information["community"].items(): - assert helpers.test_model(models.snmp_community, community_data) + assert helpers.test_model(models.SNMPCommunityDict, community_data) return get_snmp_information @@ -419,7 +422,7 @@ def test_get_probes_config(self, test_case): for probe_name, probe_tests in get_probes_config.items(): for test_name, test_config in probe_tests.items(): - assert helpers.test_model(models.probe_test, test_config) + assert helpers.test_model(models.ProbeTestDict, test_config) return get_probes_config @@ -431,7 +434,7 @@ def test_get_probes_results(self, test_case): for probe_name, probe_tests in get_probes_results.items(): for test_name, test_results in probe_tests.items(): - assert helpers.test_model(models.probe_test_results, test_results) + assert helpers.test_model(models.ProbeTestResultDict, test_results) return get_probes_results @@ -443,10 +446,10 @@ def test_ping(self, test_case): assert isinstance(get_ping.get("success"), dict) ping_results = get_ping.get("success", {}) - assert helpers.test_model(models.ping, ping_results) + assert helpers.test_model(models.PingDict, ping_results) for ping_result in ping_results.get("results", []): - assert helpers.test_model(models.ping_result, ping_result) + assert helpers.test_model(models.PingResultDictEntry, ping_result) return get_ping @@ -460,7 +463,7 @@ def test_traceroute(self, test_case): for hope_id, hop_result in traceroute_results.items(): for probe_id, probe_result in hop_result.get("probes", {}).items(): - assert helpers.test_model(models.traceroute, probe_result) + assert helpers.test_model(models.TracerouteDict, probe_result) return get_traceroute @@ -471,7 +474,7 @@ def test_get_users(self, test_case): assert len(get_users) for user, user_details in get_users.items(): - assert helpers.test_model(models.users, user_details) + assert helpers.test_model(models.UsersDict, user_details) assert (0 <= user_details.get("level") <= 15) or ( user_details.get("level") == 20 ) @@ -505,7 +508,7 @@ def test_get_config(self, test_case): get_config = self.device.get_config() assert isinstance(get_config, dict) - assert helpers.test_model(models.config, get_config) + assert helpers.test_model(models.ConfigDict, get_config) return get_config @@ -527,7 +530,7 @@ def test_get_config_sanitized(self, test_case): get_config = self.device.get_config(sanitized=True) assert isinstance(get_config, dict) - assert helpers.test_model(models.config, get_config) + assert helpers.test_model(models.ConfigDict, get_config) return get_config @@ -538,12 +541,12 @@ def test_get_network_instances(self, test_case): assert isinstance(get_network_instances, dict) for network_instance_name, network_instance in get_network_instances.items(): - assert helpers.test_model(models.network_instance, network_instance) + assert helpers.test_model(models.NetworkInstanceDict, network_instance) assert helpers.test_model( - models.network_instance_state, network_instance["state"] + models.NetworkInstanceStateDict, network_instance["state"] ) assert helpers.test_model( - models.network_instance_interfaces, network_instance["interfaces"] + models.NetworkInstanceInterfacesDict, network_instance["interfaces"] ) return get_network_instances @@ -555,7 +558,7 @@ def test_get_firewall_policies(self, test_case): assert len(get_firewall_policies) > 0 for policy_name, policy_details in get_firewall_policies.items(): for policy_term in policy_details: - assert helpers.test_model(models.firewall_policies, policy_term) + assert helpers.test_model(models.FirewallPolicyDict, policy_term) return get_firewall_policies @wrap_test_cases @@ -566,6 +569,6 @@ def test_get_vlans(self, test_case): assert len(get_vlans) > 0 for vlan, vlan_data in get_vlans.items(): - assert helpers.test_model(models.vlan, vlan_data) + assert helpers.test_model(models.VlanDict, vlan_data) return get_vlans diff --git a/napalm/base/test/models.py b/napalm/base/test/models.py index e65cbf0d2..345266419 100644 --- a/napalm/base/test/models.py +++ b/napalm/base/test/models.py @@ -1,20 +1,19 @@ -try: - from typing import TypedDict -except ImportError: - from typing_extensions import TypedDict +from typing import Dict, List -configuration = TypedDict( - "configuration", {"running": str, "candidate": str, "startup": str} +from typing_extensions import TypedDict + +ConfigurationDict = TypedDict( + "ConfigurationDict", {"running": str, "candidate": str, "startup": str} ) -alive = TypedDict("alive", {"is_alive": bool}) +AliveDict = TypedDict("AliveDict", {"is_alive": bool}) -facts = TypedDict( - "facts", +FactsDict = TypedDict( + "FactsDict", { "os_version": str, "uptime": int, - "interface_list": list, + "interface_list": List, "vendor": str, "serial_number": str, "model": str, @@ -23,8 +22,8 @@ }, ) -interface = TypedDict( - "interface", +InterfaceDict = TypedDict( + "InterfaceDict", { "is_up": bool, "is_enabled": bool, @@ -36,10 +35,26 @@ }, ) -lldp_neighbors = TypedDict("lldp_neighbors", {"hostname": str, "port": str}) +LLDPNeighborDict = TypedDict("LLDPNeighborDict", {"hostname": str, "port": str}) -interface_counters = TypedDict( - "interface_counters", +LLDPNeighborDetailDict = TypedDict( + "LLDPNeighborDetailDict", + { + "parent_interface": str, + "remote_port": str, + "remote_chassis_id": str, + "remote_port_description": str, + "remote_system_name": str, + "remote_system_description": str, + "remote_system_capab": List, + "remote_system_enable_capab": List, + }, +) + +LLDPNeighborsDetailDict = Dict[str, List[LLDPNeighborDetailDict]] + +InterfaceCounterDict = TypedDict( + "InterfaceCounterDict", { "tx_errors": int, "rx_errors": int, @@ -56,20 +71,31 @@ }, ) -temperature = TypedDict( - "temperature", {"is_alert": bool, "is_critical": bool, "temperature": float} +TemperatureDict = TypedDict( + "TemperatureDict", {"is_alert": bool, "is_critical": bool, "temperature": float} ) -power = TypedDict("power", {"status": bool, "output": float, "capacity": float}) +PowerDict = TypedDict("PowerDict", {"status": bool, "output": float, "capacity": float}) + +MemoryDict = TypedDict("MemoryDict", {"used_ram": int, "available_ram": int}) -memory = TypedDict("memory", {"used_ram": int, "available_ram": int}) +FanDict = TypedDict("FanDict", {"status": bool}) -fan = TypedDict("fan", {"status": bool}) +CPUDict = TypedDict("CPUDict", {"%usage": float}) -cpu = TypedDict("cpu", {"%usage": float}) +EnvironmentDict = TypedDict( + "EnvironmentDict", + { + "fans": Dict[str, FanDict], + "temperature": Dict[str, TemperatureDict], + "power": Dict[str, PowerDict], + "cpu": Dict[int, CPUDict], + "memory": MemoryDict, + }, +) -peer = TypedDict( - "peer", +PeerDict = TypedDict( + "PeerDict", { "is_enabled": bool, "uptime": int, @@ -82,30 +108,57 @@ }, ) -af = TypedDict( - "af", {"sent_prefixes": int, "accepted_prefixes": int, "received_prefixes": int} -) - -lldp_neighbors_detail = TypedDict( - "lldp_neighbors_detail", +PeerDetailsDict = TypedDict( + "PeerDetailsDict", { - "parent_interface": str, - "remote_port": str, - "remote_chassis_id": str, - "remote_port_description": str, - "remote_system_name": str, - "remote_system_description": str, - "remote_system_capab": list, - "remote_system_enable_capab": list, + "up": bool, + "local_as": int, + "remote_as": int, + "router_id": str, + "local_address": str, + "routing_table": str, + "local_address_configured": bool, + "local_port": int, + "remote_address": str, + "remote_port": int, + "multihop": bool, + "multipath": bool, + "remove_private_as": bool, + "import_policy": str, + "export_policy": str, + "input_messages": int, + "output_messages": int, + "input_updates": int, + "output_updates": int, + "messages_queued_out": int, + "connection_state": str, + "previous_connection_state": str, + "last_event": str, + "suppress_4byte_as": bool, + "local_as_prepend": bool, + "holdtime": int, + "configured_holdtime": int, + "keepalive": int, + "configured_keepalive": int, + "active_prefix_count": int, + "received_prefix_count": int, + "accepted_prefix_count": int, + "suppressed_prefix_count": int, + "advertised_prefix_count": int, + "flap_count": int, }, ) -bgp_config_group = TypedDict( - "bgp_config_group", +AFDict = TypedDict( + "AFDict", {"sent_prefixes": int, "accepted_prefixes": int, "received_prefixes": int} +) + +BPGConfigGroupDict = TypedDict( + "BPGConfigGroupDict", { "type": str, "description": str, - "apply_groups": list, + "apply_groups": List, "multihop_ttl": int, "multipath": bool, "local_address": str, @@ -119,8 +172,8 @@ }, ) -bgp_config_neighbor = TypedDict( - "bgp_config_neighbor", +BGPConfigNeighborDict = TypedDict( + "BGPConfigNeighborDict", { "description": str, "import_policy": str, @@ -135,72 +188,57 @@ }, ) -peer_details = TypedDict( - "peer_details", +BGPStateAdressFamilyDict = TypedDict( + "BGPStateAdressFamilyDict", + {"received_prefixes": int, "accepted_prefixes": int, "sent_prefixes": int}, +) + +BGPStateNeighborDict = TypedDict( + "BGPStateNeighborDict", { - "up": bool, "local_as": int, "remote_as": int, - "router_id": str, - "local_address": str, - "routing_table": str, - "local_address_configured": bool, - "local_port": int, - "remote_address": str, - "remote_port": int, - "multihop": bool, - "multipath": bool, - "remove_private_as": bool, - "import_policy": str, - "export_policy": str, - "input_messages": int, - "output_messages": int, - "input_updates": int, - "output_updates": int, - "messages_queued_out": int, - "connection_state": str, - "previous_connection_state": str, - "last_event": str, - "suppress_4byte_as": bool, - "local_as_prepend": bool, - "holdtime": int, - "configured_holdtime": int, - "keepalive": int, - "configured_keepalive": int, - "active_prefix_count": int, - "received_prefix_count": int, - "accepted_prefix_count": int, - "suppressed_prefix_count": int, - "advertised_prefix_count": int, - "flap_count": int, + "remote_id": str, + "is_up": bool, + "is_enabled": bool, + "description": str, + "uptime": int, + "address_family": Dict[str, BGPStateAdressFamilyDict], }, ) -arp_table = TypedDict( - "arp_table", {"interface": str, "mac": str, "ip": str, "age": float} +BGPStateNeighborsPerVRFDict = TypedDict( + "BGPStateNeighborsPerVRFDict", + {"router_id": str, "peers": Dict[str, BGPStateNeighborDict]}, ) -ipv6_neighbor = TypedDict( - "ipv6_neighbor", +ARPTableDict = TypedDict( + "ARPTableDict", {"interface": str, "mac": str, "ip": str, "age": float} +) + +IPV6NeighborDict = TypedDict( + "IPV6NeighborDict", {"interface": str, "mac": str, "ip": str, "age": float, "state": str}, ) -ntp_peer = TypedDict( - "ntp_peer", +NTPPeerDict = TypedDict( + "NTPPeerDict", { # will populate it in the future wit potential keys }, + total=False, ) -ntp_server = TypedDict( - "ntp_server", +NTPServerDict = TypedDict( + "NTPServerDict", { # will populate it in the future wit potential keys }, + total=False, ) -ntp_stats = TypedDict( - "ntp_stats", +NTPStats = TypedDict( + "NTPStats", { "remote": str, "referenceid": str, @@ -216,10 +254,21 @@ }, ) -interfaces_ip = TypedDict("interfaces_ip", {"prefix_length": int}) +InterfacesIPDictEntry = TypedDict( + "InterfacesIPDictEntry", {"prefix_length": int}, total=False +) + +InterfacesIPDict = TypedDict( + "InterfacesIPDict", + { + "ipv4": Dict[str, InterfacesIPDictEntry], + "ipv6": Dict[str, InterfacesIPDictEntry], + }, + total=False, +) -mac_address_table = TypedDict( - "mac_address_table", +MACAdressTable = TypedDict( + "MACAdressTable", { "mac": str, "interface": str, @@ -231,8 +280,8 @@ }, ) -route = TypedDict( - "route", +RouteDict = TypedDict( + "RouteDict", { "protocol": str, "current_active": bool, @@ -248,14 +297,14 @@ }, ) -snmp = TypedDict( - "snmp", {"chassis_id": str, "community": dict, "contact": str, "location": str} +SNMPDict = TypedDict( + "SNMPDict", {"chassis_id": str, "community": dict, "contact": str, "location": str} ) -snmp_community = TypedDict("snmp_community", {"acl": str, "mode": str}) +SNMPCommunityDict = TypedDict("SNMPCommunityDict", {"acl": str, "mode": str}) -probe_test = TypedDict( - "probe_test", +ProbeTestDict = TypedDict( + "ProbeTestDict", { "probe_type": str, "target": str, @@ -265,8 +314,8 @@ }, ) -probe_test_results = TypedDict( - "probe_test_results", +ProbeTestResultDict = TypedDict( + "ProbeTestResultDict", { "target": str, "source": str, @@ -287,8 +336,12 @@ }, ) -ping = TypedDict( - "ping", +PingResultDictEntry = TypedDict( + "PingResultDictEntry", {"ip_address": str, "rtt": float} +) + +PingDict = TypedDict( + "PingDict", { "probes_sent": int, "packet_loss": int, @@ -300,34 +353,67 @@ }, ) -ping_result = TypedDict("ping_result", {"ip_address": str, "rtt": float}) +PingResultDict = TypedDict( + "PingResultDict", + {"success": PingDict, "error": str}, + total=False, +) + +TracerouteDict = TypedDict( + "TracerouteDict", {"rtt": float, "ip_address": str, "host_name": str} +) + +TracerouteResultDictEntry = TypedDict( + "TracerouteResultDictEntry", {"probes": Dict[int, TracerouteDict]}, total=False +) + +TracerouteResultDict = TypedDict( + "TracerouteResultDict", + {"success": Dict[int, TracerouteResultDictEntry], "error": str}, + total=False, +) + +UsersDict = TypedDict("UsersDict", {"level": int, "password": str, "sshkeys": List}) + +OpticsStateDict = TypedDict( + "OpticsStateDict", {"instant": float, "avg": float, "min": float, "max": float} +) -traceroute = TypedDict( - "traceroute", {"rtt": float, "ip_address": str, "host_name": str} +OpticsStatePerChannelDict = TypedDict( + "OpticsStatePerChannelDict", + { + "input_power": OpticsStateDict, + "output_power": OpticsStateDict, + "laser_bias_current": OpticsStateDict, + }, ) -users = TypedDict("users", {"level": int, "password": str, "sshkeys": list}) +OpticsPerChannelDict = TypedDict( + "OpticsPerChannelDict", {"index": int, "state": OpticsStatePerChannelDict} +) -optics_state = TypedDict( - "optics_state", {"instant": float, "avg": float, "min": float, "max": float} +OpticsPhysicalChannelsDict = TypedDict( + "OpticsPhysicalChannelsDict", {"channels": OpticsPerChannelDict} ) -config = TypedDict("config", {"running": str, "startup": str, "candidate": str}) +OpticsDict = TypedDict("OpticsDict", {"physical_channels": OpticsPhysicalChannelsDict}) -network_instance = TypedDict( - "network_instance", {"name": str, "type": str, "state": dict, "interfaces": dict} +ConfigDict = TypedDict("ConfigDict", {"running": str, "startup": str, "candidate": str}) + +NetworkInstanceDict = TypedDict( + "NetworkInstanceDict", {"name": str, "type": str, "state": dict, "interfaces": dict} ) -network_instance_state = TypedDict( - "network_instance_state", {"route_distinguisher": str} +NetworkInstanceStateDict = TypedDict( + "NetworkInstanceStateDict", {"route_distinguisher": str} ) -network_instance_interfaces = TypedDict( - "network_instance_interfaces", {"interface": dict} +NetworkInstanceInterfacesDict = TypedDict( + "NetworkInstanceInterfacesDict", {"interface": dict} ) -firewall_policies = TypedDict( - "firewall_policies", +FirewallPolicyDict = TypedDict( + "FirewallPolicyDict", { "position": int, "packet_hits": int, @@ -345,4 +431,16 @@ }, ) -vlan = TypedDict("vlan", {"name": str, "interfaces": list}) +VlanDict = TypedDict("VlanDict", {"name": str, "interfaces": List}) + +DictValidationResult = TypedDict( + "DictValidationResult", + {"complies": bool, "present": Dict, "missing": List, "extra": List}, +) + +ListValidationResult = TypedDict( + "ListValidationResult", + {"complies": bool, "present": List, "missing": List, "extra": List}, +) + +ReportResult = TypedDict("ReportResult", {"complies": bool, "skipped": List}) diff --git a/napalm/base/utils/jinja_filters.py b/napalm/base/utils/jinja_filters.py index 73f25fee4..fc1a1b95a 100644 --- a/napalm/base/utils/jinja_filters.py +++ b/napalm/base/utils/jinja_filters.py @@ -1,11 +1,12 @@ """Some common jinja filters.""" +from typing import Dict, Any class CustomJinjaFilters(object): """Utility filters for jinja2.""" @classmethod - def filters(cls): + def filters(cls) -> Dict: """Return jinja2 filters that this module provide.""" return { "oc_attr_isdefault": oc_attr_isdefault, @@ -14,7 +15,7 @@ def filters(cls): } -def oc_attr_isdefault(o): +def oc_attr_isdefault(o: Any) -> bool: """Return wether an OC attribute has been defined or not.""" if not o._changed() and not o.default(): return True @@ -23,7 +24,7 @@ def oc_attr_isdefault(o): return False -def openconfig_to_cisco_af(value): +def openconfig_to_cisco_af(value: str) -> str: """Translate openconfig AF name to Cisco AFI name.""" if ":" in value: value = value.split(":")[1] @@ -39,7 +40,7 @@ def openconfig_to_cisco_af(value): return mapd[value] -def openconfig_to_eos_af(value): +def openconfig_to_eos_af(value: str) -> str: """Translate openconfig AF name to EOS AFI name.""" if ":" in value: value = value.split(":")[1] diff --git a/napalm/base/utils/string_parsers.py b/napalm/base/utils/string_parsers.py index eb9142da3..2d1ca14a9 100644 --- a/napalm/base/utils/string_parsers.py +++ b/napalm/base/utils/string_parsers.py @@ -1,25 +1,28 @@ """ Common methods to normalize a string """ import re +from typing import Union, List, Iterable, Dict, Optional -def convert(text): +def convert(text: str) -> Union[str, int]: """Convert text to integer, if it is a digit.""" if text.isdigit(): return int(text) return text -def alphanum_key(key): +def alphanum_key(key: str) -> List[Union[str, int]]: """ split on end numbers.""" return [convert(c) for c in re.split("([0-9]+)", key)] -def sorted_nicely(sort_me): +def sorted_nicely(sort_me: Iterable) -> Iterable: """ Sort the given iterable in the way that humans expect.""" return sorted(sort_me, key=alphanum_key) -def colon_separated_string_to_dict(string, separator=":"): +def colon_separated_string_to_dict( + string: str, separator: str = ":" +) -> Dict[str, Optional[str]]: """ Converts a string in the format: @@ -38,7 +41,7 @@ def colon_separated_string_to_dict(string, separator=":"): into a dictionary """ - dictionary = dict() + dictionary: Dict[str, Optional[str]] = dict() for line in string.splitlines(): line_data = line.split(separator) if len(line_data) > 1: @@ -52,7 +55,7 @@ def colon_separated_string_to_dict(string, separator=":"): return dictionary -def hyphen_range(string): +def hyphen_range(string: str) -> List[int]: """ Expands a string of numbers separated by commas and hyphens into a list of integers. For example: 2-3,5-7,20-21,23,100-200 @@ -76,7 +79,7 @@ def hyphen_range(string): return list_numbers -def convert_uptime_string_seconds(uptime): +def convert_uptime_string_seconds(uptime: str) -> int: """Convert uptime strings to seconds. The string can be formatted various ways.""" regex_list = [ # n years, n weeks, n days, n hours, n minutes where each of the fields except minutes diff --git a/napalm/base/validate.py b/napalm/base/validate.py index a9616adac..1e2b529a7 100644 --- a/napalm/base/validate.py +++ b/napalm/base/validate.py @@ -3,18 +3,23 @@ See: https://napalm.readthedocs.io/en/latest/validate.html """ -import yaml import copy import re +from typing import Dict, List, Union, TypeVar, Optional, TYPE_CHECKING + +import yaml +if TYPE_CHECKING: + from napalm.base import NetworkDriver from napalm.base.exceptions import ValidationException +from napalm.base.test import models # We put it here to compile it only once numeric_compare_regex = re.compile(r"^(<|>|<=|>=|==|!=)(\d+(\.\d+){0,1})$") -def _get_validation_file(validation_file): +def _get_validation_file(validation_file: str) -> Dict[str, Dict]: try: with open(validation_file, "r") as stream: try: @@ -26,7 +31,7 @@ def _get_validation_file(validation_file): return validation_source -def _mode(mode_string): +def _mode(mode_string: str) -> Dict[str, bool]: mode = {"strict": False} for m in mode_string.split(): @@ -36,8 +41,15 @@ def _mode(mode_string): return mode -def _compare_getter_list(src, dst, mode): - result = {"complies": True, "present": [], "missing": [], "extra": []} +def _compare_getter_list( + src: List, dst: List, mode: Dict[str, bool] +) -> models.ListValidationResult: + result: models.ListValidationResult = { + "complies": True, + "present": [], + "missing": [], + "extra": [], + } for src_element in src: found = False @@ -71,8 +83,15 @@ def _compare_getter_list(src, dst, mode): return result -def _compare_getter_dict(src, dst, mode): - result = {"complies": True, "present": {}, "missing": [], "extra": []} +def _compare_getter_dict( + src: Dict[str, List], dst: Dict[str, List], mode: Dict[str, bool] +) -> models.DictValidationResult: + result: models.DictValidationResult = { + "complies": True, + "present": {}, + "missing": [], + "extra": [], + } dst = copy.deepcopy(dst) # Otherwise we are going to modify a "live" object for key, src_element in src.items(): @@ -111,7 +130,12 @@ def _compare_getter_dict(src, dst, mode): return result -def compare(src, dst): +CompareInput = TypeVar("CompareInput", str, Dict, List) + + +def compare( + src: CompareInput, dst: CompareInput +) -> Union[bool, models.DictValidationResult, models.ListValidationResult]: if isinstance(src, str): src = str(src) @@ -152,7 +176,7 @@ def compare(src, dst): return src == dst -def _compare_numeric(src_num, dst_num): +def _compare_numeric(src_num: str, dst_num: str) -> bool: """Compare numerical values. You can use '<%d','>%d'.""" dst_num = float(dst_num) @@ -174,7 +198,7 @@ def _compare_numeric(src_num, dst_num): return getattr(dst_num, operand[match.group(1)])(float(match.group(2))) -def _compare_range(src_num, dst_num): +def _compare_range(src_num: str, dst_num: str) -> bool: """Compare value against a range of values. You can use '%d<->%d'.""" dst_num = float(dst_num) @@ -191,7 +215,7 @@ def _compare_range(src_num, dst_num): return False -def empty_tree(input_list): +def empty_tree(input_list: List) -> bool: """Recursively iterate through values in nested lists.""" for item in input_list: if not isinstance(item, list) or not empty_tree(item): @@ -199,14 +223,20 @@ def empty_tree(input_list): return True -def compliance_report(cls, validation_file=None, validation_source=None): - report = {} +def compliance_report( + cls: "NetworkDriver", + validation_file: Optional[str] = None, + validation_source: Optional[str] = None, +) -> models.ReportResult: + report: models.ReportResult = {} # type: ignore if validation_file: - validation_source = _get_validation_file(validation_file) + validation_source = _get_validation_file(validation_file) # type: ignore # Otherwise we are going to modify a "live" object validation_source = copy.deepcopy(validation_source) + assert isinstance(validation_source, list), validation_source + for validation_check in validation_source: for getter, expected_results in validation_check.items(): if getter == "get_config": diff --git a/napalm/nxapi_plumbing/api_client.py b/napalm/nxapi_plumbing/api_client.py index d403a216d..a51723000 100644 --- a/napalm/nxapi_plumbing/api_client.py +++ b/napalm/nxapi_plumbing/api_client.py @@ -6,6 +6,8 @@ from __future__ import print_function, unicode_literals from builtins import super +from typing import Optional, List, Dict, Any + import requests from requests.auth import HTTPBasicAuth from requests.exceptions import ConnectionError @@ -30,13 +32,13 @@ class RPCBase(object): def __init__( self, - host, - username, - password, - transport="https", - port=None, - timeout=30, - verify=True, + host: str, + username: str, + password: str, + transport: str = "https", + port: Optional[int] = None, + timeout: int = 30, + verify: bool = True, ): if transport not in ["http", "https"]: raise NXAPIError("'{}' is an invalid transport.".format(transport)) @@ -52,11 +54,36 @@ def __init__( self.password = password self.timeout = timeout self.verify = verify + self.cmd_method: str + self.cmd_method_conf: str + self.cmd_method_raw: str + self.headers: Dict + + def _process_api_response( + self, response: str, commands: List[str], raw_text: bool = False + ) -> List[Any]: + raise NotImplementedError("Method must be implemented in child class") - def _process_api_response(self, response, commands, raw_text=False): + def _nxapi_command_conf( + self, commands: List[str], method: Optional[str] = None + ) -> List[Any]: + raise NotImplementedError("Method must be implemented in child class") + + def _build_payload( + self, + commands: List[str], + method: str, + rpc_version: str = "2.0", + api_version: str = "1.0", + ) -> str: raise NotImplementedError("Method must be implemented in child class") - def _send_request(self, commands, method): + def _nxapi_command( + self, commands: List[str], method: Optional[str] = None + ) -> List[Any]: + raise NotImplementedError("Method must be implemented in child class") + + def _send_request(self, commands: List[str], method: str) -> str: payload = self._build_payload(commands, method) try: @@ -90,7 +117,7 @@ def _send_request(self, commands, method): class RPCClient(RPCBase): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) self.headers = {"content-type": "application/json-rpc"} self.api = "jsonrpc" @@ -98,7 +125,9 @@ def __init__(self, *args, **kwargs): self.cmd_method_conf = "cli" self.cmd_method_raw = "cli_ascii" - def _nxapi_command(self, commands, method=None): + def _nxapi_command( + self, commands: List[str], method: Optional[str] = None + ) -> List[Any]: """Send a command down the NX-API channel.""" if method is None: method = self.cmd_method @@ -111,12 +140,20 @@ def _nxapi_command(self, commands, method=None): api_response = self._process_api_response(response, commands, raw_text=raw_text) return api_response - def _nxapi_command_conf(self, commands, method=None): + def _nxapi_command_conf( + self, commands: List[str], method: Optional[str] = None + ) -> List[Any]: if method is None: method = self.cmd_method_conf return self._nxapi_command(commands=commands, method=method) - def _build_payload(self, commands, method, rpc_version="2.0", api_version=1.0): + def _build_payload( + self, + commands: List[str], + method: str, + rpc_version: str = "2.0", + api_version: str = "1.0", + ) -> str: """Construct the JSON-RPC payload for NX-API.""" payload_list = [] id_num = 1 @@ -124,7 +161,7 @@ def _build_payload(self, commands, method, rpc_version="2.0", api_version=1.0): payload = { "jsonrpc": rpc_version, "method": method, - "params": {"cmd": command, "version": api_version}, + "params": {"cmd": command, "version": float(api_version)}, "id": id_num, } payload_list.append(payload) @@ -132,7 +169,9 @@ def _build_payload(self, commands, method, rpc_version="2.0", api_version=1.0): return json.dumps(payload_list) - def _process_api_response(self, response, commands, raw_text=False): + def _process_api_response( + self, response: str, commands: List[str], raw_text: bool = False + ) -> List[Any]: """ Normalize the API response including handling errors; adding the sent command into the returned data strucutre; make response structure consistent for raw_text and @@ -170,10 +209,11 @@ def _process_api_response(self, response, commands, raw_text=False): return new_response - def _error_check(self, command_response): + def _error_check(self, command_response: Dict) -> None: error = command_response.get("error") if error: command = command_response.get("command") + assert isinstance(command, str) if "data" in error: raise NXAPICommandError(command, error["data"]["msg"]) else: @@ -181,7 +221,7 @@ def _error_check(self, command_response): class XMLClient(RPCBase): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.headers = {"content-type": "application/xml"} self.api = "xml" @@ -189,7 +229,9 @@ def __init__(self, *args, **kwargs): self.cmd_method_conf = "cli_conf" self.cmd_method_raw = "cli_show_ascii" - def _nxapi_command(self, commands, method=None): + def _nxapi_command( + self, commands: List[str], method: Optional[str] = None + ) -> List[Any]: """Send a command down the NX-API channel.""" if method is None: method = self.cmd_method @@ -203,12 +245,20 @@ def _nxapi_command(self, commands, method=None): self._error_check(command_response) return api_response - def _nxapi_command_conf(self, commands, method=None): + def _nxapi_command_conf( + self, commands: List[str], method: Optional[str] = None + ) -> List[Any]: if method is None: method = self.cmd_method_conf return self._nxapi_command(commands=commands, method=method) - def _build_payload(self, commands, method, xml_version="1.0", version="1.0"): + def _build_payload( + self, + commands: List[str], + method: str, + xml_version: str = "1.0", + version: str = "1.0", + ) -> str: xml_commands = "" for command in commands: if not xml_commands: @@ -234,7 +284,9 @@ def _build_payload(self, commands, method, xml_version="1.0", version="1.0"): ) return payload - def _process_api_response(self, response, commands, raw_text=False): + def _process_api_response( + self, response: str, commands: List[str], raw_text: bool = False + ) -> List[Any]: xml_root = etree.fromstring(response) response_list = xml_root.xpath("outputs/output") if len(commands) != len(response_list): @@ -244,7 +296,7 @@ def _process_api_response(self, response, commands, raw_text=False): return response_list - def _error_check(self, command_response): + def _error_check(self, command_response: etree) -> None: """commmand_response will be an XML Etree object.""" error_list = command_response.find("./clierror") command_obj = command_response.find("./input") diff --git a/napalm/nxapi_plumbing/device.py b/napalm/nxapi_plumbing/device.py index 188af45f1..3f2ac5f5b 100644 --- a/napalm/nxapi_plumbing/device.py +++ b/napalm/nxapi_plumbing/device.py @@ -6,21 +6,23 @@ from __future__ import print_function, unicode_literals +from typing import List, Optional, Any, Union + from napalm.nxapi_plumbing.errors import NXAPIError, NXAPICommandError -from napalm.nxapi_plumbing.api_client import RPCClient, XMLClient +from napalm.nxapi_plumbing.api_client import RPCClient, XMLClient, RPCBase class Device(object): def __init__( self, - host, - username, - password, - transport="http", - api_format="jsonrpc", - port=None, - timeout=30, - verify=True, + host: str, + username: str, + password: str, + transport: str = "http", + api_format: str = "jsonrpc", + port: Optional[int] = None, + timeout: int = 30, + verify: bool = True, ): self.host = host self.username = username @@ -30,6 +32,7 @@ def __init__( self.verify = verify self.port = port + self.api: RPCBase if api_format == "xml": self.api = XMLClient( host, @@ -51,7 +54,7 @@ def __init__( verify=verify, ) - def show(self, command, raw_text=False): + def show(self, command: str, raw_text: bool = False) -> Any: """Send a non-configuration command. Args: @@ -79,7 +82,7 @@ def show(self, command, raw_text=False): return result - def show_list(self, commands, raw_text=False): + def show_list(self, commands: List[str], raw_text: bool = False) -> List[Any]: """Send a list of non-configuration commands. Args: @@ -94,7 +97,7 @@ def show_list(self, commands, raw_text=False): cmd_method = self.api.cmd_method_raw if raw_text else self.api.cmd_method return self.api._nxapi_command(commands, method=cmd_method) - def config(self, command): + def config(self, command: str) -> Union[str, List]: """Send a configuration command. Args: @@ -120,7 +123,7 @@ def config(self, command): return result - def config_list(self, commands): + def config_list(self, commands: List[str]) -> List[Any]: """Send a list of configuration commands. Args: @@ -131,7 +134,7 @@ def config_list(self, commands): """ return self.api._nxapi_command_conf(commands) - def save(self, filename="startup-config"): + def save(self, filename: str = "startup-config") -> bool: """Save a device's running configuration. Args: @@ -148,7 +151,7 @@ def save(self, filename="startup-config"): raise return True - def rollback(self, filename): + def rollback(self, filename: str) -> None: """Rollback to a checkpoint file. Args: @@ -158,7 +161,7 @@ def rollback(self, filename): cmd = "rollback running-config file {}".format(filename) self.show(cmd, raw_text=True) - def checkpoint(self, filename): + def checkpoint(self, filename: str) -> None: """Save a checkpoint of the running configuration to the device. Args: diff --git a/napalm/nxapi_plumbing/errors.py b/napalm/nxapi_plumbing/errors.py index 144e8ae1f..f524959e5 100644 --- a/napalm/nxapi_plumbing/errors.py +++ b/napalm/nxapi_plumbing/errors.py @@ -14,11 +14,11 @@ class NXAPIError(Exception): class NXAPICommandError(NXAPIError): - def __init__(self, command, message): + def __init__(self, command: str, message: str) -> None: self.command = command self.message = message - def __repr__(self): + def __repr__(self) -> str: return 'The command "{}" gave the error "{}".'.format( self.command, self.message ) diff --git a/napalm/nxapi_plumbing/utilities.py b/napalm/nxapi_plumbing/utilities.py index 6300de068..04a2ea9df 100644 --- a/napalm/nxapi_plumbing/utilities.py +++ b/napalm/nxapi_plumbing/utilities.py @@ -1,5 +1,5 @@ from lxml import etree -def xml_to_string(xml_object): +def xml_to_string(xml_object: etree.Element) -> str: return etree.tostring(xml_object).decode() diff --git a/napalm/nxos/nxos.py b/napalm/nxos/nxos.py index ad14ed537..deea73c82 100644 --- a/napalm/nxos/nxos.py +++ b/napalm/nxos/nxos.py @@ -13,45 +13,74 @@ # License for the specific language governing permissions and limitations under # the License. -# import stdlib -from builtins import super +import json import os import re -import time import tempfile +import time import uuid + +# import stdlib +from abc import abstractmethod +from builtins import super from collections import defaultdict # import third party lib -from requests.exceptions import ConnectionError +from typing import ( + Optional, + Dict, + List, + Union, + Any, + cast, + Callable, + TypeVar, + DefaultDict, +) + +from typing_extensions import TypedDict + from netaddr import IPAddress from netaddr.core import AddrFormatError from netmiko import file_transfer -from napalm.nxapi_plumbing import Device as NXOSDevice -from napalm.nxapi_plumbing import ( - NXAPIAuthError, - NXAPIConnectionError, - NXAPICommandError, -) -import json +from requests.exceptions import ConnectionError + +import napalm.base.constants as c # import NAPALM Base import napalm.base.helpers from napalm.base import NetworkDriver +from napalm.base.exceptions import CommandErrorException from napalm.base.exceptions import ConnectionException from napalm.base.exceptions import MergeConfigException -from napalm.base.exceptions import CommandErrorException from napalm.base.exceptions import ReplaceConfigException -from napalm.base.helpers import generate_regex_or from napalm.base.helpers import as_number +from napalm.base.helpers import generate_regex_or from napalm.base.netmiko_helpers import netmiko_args -import napalm.base.constants as c +from napalm.base.test import models +from napalm.nxapi_plumbing import Device as NXOSDevice +from napalm.nxapi_plumbing import ( + NXAPIAuthError, + NXAPIConnectionError, + NXAPICommandError, +) +F = TypeVar("F", bound=Callable[..., Any]) +ShowIPInterfaceReturn = TypedDict( + "ShowIPInterfaceReturn", + { + "intf-name": str, + "prefix": str, + "unnum-intf": str, + "masklen": str, + }, +) -def ensure_netmiko_conn(func): + +def ensure_netmiko_conn(func: F) -> F: """Decorator that ensures Netmiko connection exists.""" - def wrap_function(self, filename=None, config=None): + def wrap_function(self, filename=None, config=None): # type: ignore try: netmiko_object = self._netmiko_device if netmiko_object is None: @@ -66,13 +95,20 @@ def wrap_function(self, filename=None, config=None): ) func(self, filename=filename, config=config) - return wrap_function + return cast(F, wrap_function) class NXOSDriverBase(NetworkDriver): """Common code shared between nx-api and nxos_ssh.""" - def __init__(self, hostname, username, password, timeout=60, optional_args=None): + def __init__( + self, + hostname: str, + username: str, + password: str, + timeout: int = 60, + optional_args: Optional[Dict] = None, + ) -> None: if optional_args is None: optional_args = {} self.hostname = hostname @@ -88,10 +124,12 @@ def __init__(self, hostname, username, password, timeout=60, optional_args=None) self._dest_file_system = optional_args.pop("dest_file_system", "bootflash:") self.force_no_enable = optional_args.get("force_no_enable", False) self.netmiko_optional_args = netmiko_args(optional_args) - self.device = None + self.device: Optional[NXOSDevice] = None @ensure_netmiko_conn - def load_replace_candidate(self, filename=None, config=None): + def load_replace_candidate( + self, filename: Optional[str] = None, config: Optional[str] = None + ) -> None: if not filename and not config: raise ReplaceConfigException( @@ -99,6 +137,7 @@ def load_replace_candidate(self, filename=None, config=None): ) if not filename: + assert isinstance(config, str) tmp_file = self._create_tmp_file(config) filename = tmp_file else: @@ -129,7 +168,9 @@ def load_replace_candidate(self, filename=None, config=None): if config and os.path.isfile(tmp_file): os.remove(tmp_file) - def load_merge_candidate(self, filename=None, config=None): + def load_merge_candidate( + self, filename: Optional[str] = None, config: Optional[str] = None + ) -> None: if not filename and not config: raise MergeConfigException("filename or config param must be provided.") @@ -138,14 +179,17 @@ def load_merge_candidate(self, filename=None, config=None): with open(filename, "r") as f: self.merge_candidate += f.read() else: + assert isinstance(config, str) self.merge_candidate += config self.replace = False self.loaded = True - def _send_command(self, command, raw_text=False): + def _send_command( + self, command: str, raw_text: bool = False + ) -> Dict[str, Union[str, Dict[str, Any]]]: raise NotImplementedError - def _commit_merge(self): + def _commit_merge(self) -> None: try: output = self._send_config(self.merge_candidate) if output and "Invalid command" in output: @@ -161,7 +205,7 @@ def _commit_merge(self): # clear the merge buffer self.merge_candidate = "" - def _get_merge_diff(self): + def _get_merge_diff(self) -> str: """ The merge diff is not necessarily what needs to be loaded for example under NTP, even though the 'ntp commit' command might be @@ -180,9 +224,9 @@ def _get_merge_diff(self): diff.append(line) return "\n".join(diff) - def _get_diff(self): + def _get_diff(self) -> str: """Get a diff between running config and a proposed file.""" - diff = [] + diff: List[str] = [] self._create_sot_file() diff_out = self._send_command( "show diff rollback-patch file {} file {}".format( @@ -190,6 +234,7 @@ def _get_diff(self): ), raw_text=True, ) + assert isinstance(diff_out, str) try: diff_out = ( diff_out.split("Generating Rollback Patch")[1] @@ -206,7 +251,7 @@ def _get_diff(self): ) return "\n".join(diff) - def compare_config(self): + def compare_config(self) -> str: if self.loaded: if not self.replace: return self._get_merge_diff() @@ -214,7 +259,7 @@ def compare_config(self): return diff return "" - def commit_config(self, message="", revert_in=None): + def commit_config(self, message: str = "", revert_in: Optional[int] = None) -> None: if revert_in is not None: raise NotImplementedError( "Commit confirm has not been implemented on this platform." @@ -243,7 +288,7 @@ def commit_config(self, message="", revert_in=None): else: raise ReplaceConfigException("No config loaded.") - def discard_config(self): + def discard_config(self) -> None: if self.loaded: # clear the buffer self.merge_candidate = "" @@ -251,7 +296,7 @@ def discard_config(self): self._delete_file(self.candidate_cfg) self.loaded = False - def _create_sot_file(self): + def _create_sot_file(self) -> None: """Create Source of Truth file to compare.""" # Bug on on NX-OS 6.2.16 where overwriting sot_file would take exceptionally long time @@ -269,14 +314,14 @@ def _create_sot_file(self): def ping( self, - destination, - source=c.PING_SOURCE, - ttl=c.PING_TTL, - timeout=c.PING_TIMEOUT, - size=c.PING_SIZE, - count=c.PING_COUNT, - vrf=c.PING_VRF, - ): + destination: str, + source: str = c.PING_SOURCE, + ttl: int = c.PING_TTL, + timeout: int = c.PING_TIMEOUT, + size: int = c.PING_SIZE, + count: int = c.PING_COUNT, + vrf: str = c.PING_VRF, + ) -> models.PingResultDict: """ Execute ping on the device and returns a dictionary with the result. Output dictionary has one of following keys: @@ -294,7 +339,7 @@ def ping( * ip_address (str) * rtt (float) """ - ping_dict = {} + ping_dict: models.PingResultDict = {} version = "" try: @@ -315,6 +360,7 @@ def ping( if vrf != "": command += " vrf {}".format(vrf) output = self._send_command(command, raw_text=True) + assert isinstance(output, str) if "connect:" in output: ping_dict["error"] = output @@ -364,12 +410,12 @@ def ping( fields[3] ) elif "min/avg/max" in line: - m = fields[3].split("/") + split_fields = fields[3].split("/") ping_dict["success"].update( { - "rtt_min": float(m[0]), - "rtt_avg": float(m[1]), - "rtt_max": float(m[2]), + "rtt_min": float(split_fields[0]), + "rtt_avg": float(split_fields[1]), + "rtt_max": float(split_fields[2]), } ) ping_dict["success"].update({"results": results_array}) @@ -377,12 +423,12 @@ def ping( def traceroute( self, - destination, - source=c.TRACEROUTE_SOURCE, - ttl=c.TRACEROUTE_TTL, - timeout=c.TRACEROUTE_TIMEOUT, - vrf=c.TRACEROUTE_VRF, - ): + destination: str, + source: str = c.TRACEROUTE_SOURCE, + ttl: int = c.TRACEROUTE_TTL, + timeout: int = c.TRACEROUTE_TIMEOUT, + vrf: str = c.TRACEROUTE_VRF, + ) -> models.TracerouteResultDict: _HOP_ENTRY_PROBE = [ r"\s+", @@ -403,7 +449,7 @@ def traceroute( _HOP_ENTRY = [r"\s?", r"(\d+)"] # space before hop index? # hop index - traceroute_result = {} + traceroute_result: models.TracerouteResultDict = {} timeout = 5 # seconds probes = 3 # 3 probes/jop and this cannot be changed on NXOS! @@ -434,6 +480,8 @@ def traceroute( "error": "Cannot execute traceroute on the device: {}".format(command) } + assert isinstance(traceroute_raw_output, str) + hop_regex = "".join(_HOP_ENTRY + _HOP_ENTRY_PROBE * probes) traceroute_result["success"] = {} if traceroute_raw_output: @@ -452,9 +500,9 @@ def traceroute( ip_address = napalm.base.helpers.convert( napalm.base.helpers.ip, ip_address_raw, ip_address_raw ) - rtt = hop_details[5 + probe_index * 5] - if rtt: - rtt = float(rtt) + rtt_as_string = hop_details[5 + probe_index * 5] + if rtt_as_string: + rtt = float(rtt_as_string) else: rtt = timeout * 1000.0 if not host_name: @@ -475,15 +523,16 @@ def traceroute( previous_probe_ip_address = ip_address return traceroute_result - def _get_checkpoint_file(self): + def _get_checkpoint_file(self) -> str: filename = "temp_cp_file_from_napalm" self._set_checkpoint(filename) command = "show file {}".format(filename) output = self._send_command(command, raw_text=True) + assert isinstance(output, str) self._delete_file(filename) return output - def _set_checkpoint(self, filename): + def _set_checkpoint(self, filename: str) -> None: commands = [ "terminal dont-ask", "checkpoint file {}".format(filename), @@ -491,7 +540,7 @@ def _set_checkpoint(self, filename): ] self._send_command_list(commands) - def _save_to_checkpoint(self, filename): + def _save_to_checkpoint(self, filename: str) -> None: """Save the current running config to the given file.""" commands = [ "terminal dont-ask", @@ -500,7 +549,7 @@ def _save_to_checkpoint(self, filename): ] self._send_command_list(commands) - def _delete_file(self, filename): + def _delete_file(self, filename: str) -> None: commands = [ "terminal dont-ask", "delete {}".format(filename), @@ -509,7 +558,7 @@ def _delete_file(self, filename): self._send_command_list(commands) @staticmethod - def _create_tmp_file(config): + def _create_tmp_file(config: str) -> str: tmp_dir = tempfile.gettempdir() rand_fname = str(uuid.uuid4()) filename = os.path.join(tmp_dir, rand_fname) @@ -517,10 +566,12 @@ def _create_tmp_file(config): fobj.write(config) return filename - def _disable_confirmation(self): + def _disable_confirmation(self) -> None: self._send_command_list(["terminal dont-ask"]) - def get_config(self, retrieve="all", full=False, sanitized=False): + def get_config( + self, retrieve: str = "all", full: bool = False, sanitized: bool = False + ) -> models.ConfigDict: # NX-OS adds some extra, unneeded lines that should be filtered. filter_strings = [ @@ -530,18 +581,24 @@ def get_config(self, retrieve="all", full=False, sanitized=False): ] filter_pattern = generate_regex_or(filter_strings) - config = {"startup": "", "running": "", "candidate": ""} # default values + config: models.ConfigDict = { + "startup": "", + "running": "", + "candidate": "", + } # default values # NX-OS only supports "all" on "show run" run_full = " all" if full else "" if retrieve.lower() in ("running", "all"): command = f"show running-config{run_full}" output = self._send_command(command, raw_text=True) + assert isinstance(output, str) output = re.sub(filter_pattern, "", output, flags=re.M) config["running"] = output.strip() if retrieve.lower() in ("startup", "all"): command = "show startup-config" output = self._send_command(command, raw_text=True) + assert isinstance(output, str) output = re.sub(filter_pattern, "", output, flags=re.M) config["startup"] = output.strip() @@ -552,9 +609,9 @@ def get_config(self, retrieve="all", full=False, sanitized=False): return config - def get_lldp_neighbors(self): + def get_lldp_neighbors(self) -> Dict[str, List[models.LLDPNeighborDict]]: """IOS implementation of get_lldp_neighbors.""" - lldp = {} + lldp: Dict[str, List[models.LLDPNeighborDict]] = {} neighbors_detail = self.get_lldp_neighbors_detail() for intf_name, entries in neighbors_detail.items(): lldp[intf_name] = [] @@ -564,14 +621,19 @@ def get_lldp_neighbors(self): # When lacking a system name (in show lldp neighbors) if hostname == "N/A": hostname = lldp_entry["remote_chassis_id"] - lldp_dict = {"port": lldp_entry["remote_port"], "hostname": hostname} + lldp_dict: models.LLDPNeighborDict = { + "port": lldp_entry["remote_port"], + "hostname": hostname, + } lldp[intf_name].append(lldp_dict) return lldp - def get_lldp_neighbors_detail(self, interface=""): - lldp = {} - lldp_interfaces = [] + def get_lldp_neighbors_detail( + self, interface: str = "" + ) -> models.LLDPNeighborsDetailDict: + lldp: models.LLDPNeighborsDetailDict = {} + lldp_interfaces: List[str] = [] if interface: command = "show lldp neighbors interface {} detail".format(interface) @@ -608,12 +670,14 @@ def get_lldp_neighbors_detail(self, interface=""): # Turn the interfaces into their long version local_intf = napalm.base.helpers.canonical_interface_name(local_intf) lldp.setdefault(local_intf, []) - lldp[local_intf].append(lldp_entry) + lldp[local_intf].append(lldp_entry) # type: ignore return lldp @staticmethod - def _get_table_rows(parent_table, table_name, row_name): + def _get_table_rows( + parent_table: Optional[Dict], table_name: str, row_name: str + ) -> List: """ Inconsistent behavior: {'TABLE_intf': [{'ROW_intf': { @@ -629,21 +693,24 @@ def _get_table_rows(parent_table, table_name, row_name): if isinstance(_table, list): _table_rows = [_table_row.get(row_name) for _table_row in _table] elif isinstance(_table, dict): - _table_rows = _table.get(row_name) + _table_rows = _table.get(row_name) # type: ignore if not isinstance(_table_rows, list): _table_rows = [_table_rows] return _table_rows - def _get_reply_table(self, result, table_name, row_name): + def _get_reply_table( + self, result: Optional[Dict], table_name: str, row_name: str + ) -> List: return self._get_table_rows(result, table_name, row_name) - def _get_command_table(self, command, table_name, row_name): + def _get_command_table(self, command: str, table_name: str, row_name: str) -> List: json_output = self._send_command(command) if type(json_output) is not dict and json_output: + assert isinstance(json_output, str) json_output = json.loads(json_output) return self._get_reply_table(json_output, table_name, row_name) - def _parse_vlan_ports(self, vlan_s): + def _parse_vlan_ports(self, vlan_s: Union[str, List]) -> List: vlans = [] find_regexp = r"^([A-Za-z\/-]+|.*\/)(\d+)-(\d+)$" vlan_str = "" @@ -666,9 +733,32 @@ def _parse_vlan_ports(self, vlan_s): vlans.append(napalm.base.helpers.canonical_interface_name(vls.strip())) return vlans + @abstractmethod + def _send_config(self, commands: Union[str, List]) -> List[str]: + raise NotImplementedError + + @abstractmethod + def _load_cfg_from_checkpoint(self) -> None: + raise NotImplementedError + + @abstractmethod + def _copy_run_start(self) -> None: + raise NotImplementedError + + @abstractmethod + def _send_command_list(self, commands: List[str]) -> List[str]: + raise NotImplementedError + class NXOSDriver(NXOSDriverBase): - def __init__(self, hostname, username, password, timeout=60, optional_args=None): + def __init__( + self, + hostname: str, + username: str, + password: str, + timeout: int = 60, + optional_args: Optional[Dict] = None, + ) -> None: super().__init__( hostname, username, password, timeout=timeout, optional_args=optional_args ) @@ -687,7 +777,7 @@ def __init__(self, hostname, username, password, timeout=60, optional_args=None) self.ssl_verify = optional_args.get("ssl_verify", False) self.platform = "nxos" - def open(self): + def open(self) -> None: try: self.device = NXOSDevice( host=self.hostname, @@ -704,28 +794,37 @@ def open(self): # unable to open connection raise ConnectionException("Cannot connect to {}".format(self.hostname)) - def close(self): + def close(self) -> None: self.device = None - def _send_command(self, command, raw_text=False): + def _send_command(self, command: str, raw_text: bool = False) -> Any: """ Wrapper for NX-API show method. Allows more code sharing between NX-API and SSH. """ + assert ( + self.device is not None + ), "Call open() or use as a context manager before calling _send_command" return self.device.show(command, raw_text=raw_text) - def _send_command_list(self, commands): + def _send_command_list(self, commands: List[str]) -> List[Any]: + assert ( + self.device is not None + ), "Call open() or use as a context manager before running _send_command_list" return self.device.config_list(commands) - def _send_config(self, commands): + def _send_config(self, commands: Union[str, List]) -> List[str]: + assert ( + self.device is not None + ), "Call open() or use as a context manager before running _send_config" if isinstance(commands, str): # Has to be a list generator and not generator expression (not JSON serializable) commands = [command for command in commands.splitlines() if command] return self.device.config_list(commands) @staticmethod - def _compute_timestamp(stupid_cisco_output): + def _compute_timestamp(stupid_cisco_output: str) -> float: """ Some fields such `uptime` are returned as: 23week(s) 3day(s) This method will determine the epoch of the event. @@ -744,7 +843,7 @@ def _compute_timestamp(stupid_cisco_output): stupid_cisco_output = stupid_cisco_output.replace("d", "day(s) ") stupid_cisco_output = stupid_cisco_output.replace("h", "hour(s)") - things = { + things: Dict[str, Dict[str, Union[int, float]]] = { "second(s)": {"weight": 1}, "minute(s)": {"weight": 60}, "hour(s)": {"weight": 3600}, @@ -761,24 +860,25 @@ def _compute_timestamp(stupid_cisco_output): int, part.replace(key, ""), 0 ) - delta = sum( - [det.get("count", 0) * det.get("weight") for det in things.values()] - ) + delta = sum([det.get("count", 0) * det["weight"] for det in things.values()]) return time.time() - delta - def is_alive(self): + def is_alive(self) -> models.AliveDict: if self.device: return {"is_alive": True} else: return {"is_alive": False} - def _copy_run_start(self): + def _copy_run_start(self) -> None: + assert ( + self.device is not None + ), "Call open() or use as a context manager before calling _copy_run_start" results = self.device.save(filename="startup-config") if not results: msg = "Unable to save running-config to startup-config!" raise CommandErrorException(msg) - def _load_cfg_from_checkpoint(self): + def _load_cfg_from_checkpoint(self) -> None: commands = [ "terminal dont-ask", "rollback running-config file {}".format(self.candidate_cfg), @@ -795,6 +895,7 @@ def _load_cfg_from_checkpoint(self): # For nx-api a list is returned so extract the result associated with the # 'rollback' command. rollback_result = rollback_result[1] + assert isinstance(rollback_result, dict) msg = ( rollback_result.get("msg") if rollback_result.get("msg") @@ -807,14 +908,15 @@ def _load_cfg_from_checkpoint(self): elif rollback_result == []: raise ReplaceConfigException - def rollback(self): + def rollback(self) -> None: + assert isinstance(self.device, NXOSDevice) if self.changed: self.device.rollback(self.rollback_cfg) self._copy_run_start() self.changed = False - def get_facts(self): - facts = {} + def get_facts(self) -> models.FactsDict: + facts: models.FactsDict = {} # type: ignore facts["vendor"] = "Cisco" show_inventory_table = self._get_command_table( @@ -823,7 +925,7 @@ def get_facts(self): if isinstance(show_inventory_table, dict): show_inventory_table = [show_inventory_table] - facts["serial_number"] = None + facts["serial_number"] = None # type: ignore for row in show_inventory_table: if row["name"] == '"Chassis"' or row["name"] == "Chassis": @@ -831,6 +933,7 @@ def get_facts(self): break show_version = self._send_command("show version") + show_version = cast(Dict[str, str], show_version) facts["model"] = show_version.get("chassis_id", "") facts["hostname"] = show_version.get("host_name", "") facts["os_version"] = show_version.get( @@ -863,14 +966,16 @@ def get_facts(self): return facts - def get_interfaces(self): - interfaces = {} + def get_interfaces(self) -> Dict[str, models.InterfaceDict]: + interfaces: Dict[str, models.InterfaceDict] = {} iface_cmd = "show interface" interfaces_out = self._send_command(iface_cmd) interfaces_body = interfaces_out["TABLE_interface"]["ROW_interface"] for interface_details in interfaces_body: + assert isinstance(interface_details, dict) interface_name = interface_details.get("interface") + assert isinstance(interface_name, str) if interface_details.get("eth_mtu"): interface_mtu = int(interface_details["eth_mtu"]) @@ -915,12 +1020,12 @@ def get_interfaces(self): "speed": interface_speed, "mtu": interface_mtu, "mac_address": napalm.base.helpers.convert( - napalm.base.helpers.mac, mac_address + napalm.base.helpers.mac, mac_address, "" ), } return interfaces - def get_bgp_neighbors(self): + def get_bgp_neighbors(self) -> Dict[str, models.BGPStateNeighborsPerVRFDict]: results = {} bgp_state_dict = { "Idle": {"is_up": False, "is_enabled": True}, @@ -950,7 +1055,10 @@ def get_bgp_neighbors(self): vrf_list = [] for vrf_dict in vrf_list: - result_vrf_dict = {"router_id": str(vrf_dict["vrf-router-id"]), "peers": {}} + result_vrf_dict: models.BGPStateNeighborsPerVRFDict = { + "router_id": str(vrf_dict["vrf-router-id"]), + "peers": {}, + } af_list = vrf_dict.get("TABLE_af", {}).get("ROW_af", []) if isinstance(af_list, dict): @@ -974,7 +1082,7 @@ def get_bgp_neighbors(self): afid_dict = af_name_dict[int(af_dict["af-id"])] safi_name = afid_dict[int(saf_dict["safi"])] - result_peer_dict = { + result_peer_dict: models.BGPStateNeighborDict = { "local_as": as_number(vrf_dict["vrf-local-as"]), "remote_as": remoteas, "remote_id": neighborid, @@ -1000,8 +1108,8 @@ def get_bgp_neighbors(self): results[vrf_name] = result_vrf_dict return results - def cli(self, commands): - cli_output = {} + def cli(self, commands: List[str]) -> Dict[str, Union[str, Dict[str, Any]]]: + cli_output: Dict[str, Union[str, Dict[str, Any]]] = {} if type(commands) is not list: raise TypeError("Please enter a valid list of commands!") @@ -1010,12 +1118,12 @@ def cli(self, commands): cli_output[str(command)] = command_output return cli_output - def get_arp_table(self, vrf=""): + def get_arp_table(self, vrf: str = "") -> List[models.ARPTableDict]: if vrf: msg = "VRF support has not been added for this getter on this platform." raise NotImplementedError(msg) - arp_table = [] + arp_table: List[models.ARPTableDict] = [] command = "show ip arp" arp_table_vrf = self._get_command_table(command, "TABLE_vrf", "ROW_vrf") arp_table_raw = self._get_table_rows(arp_table_vrf[0], "TABLE_adj", "ROW_adj") @@ -1056,8 +1164,10 @@ def get_arp_table(self, vrf=""): ) return arp_table - def _get_ntp_entity(self, peer_type): - ntp_entities = {} + def _get_ntp_entity( + self, peer_type: str + ) -> Dict[str, Union[models.NTPPeerDict, models.NTPServerDict]]: + ntp_entities: Dict[str, Union[models.NTPPeerDict, models.NTPServerDict]] = {} command = "show ntp peers" ntp_peers_table = self._get_command_table(command, "TABLE_peers", "ROW_peers") @@ -1065,18 +1175,19 @@ def _get_ntp_entity(self, peer_type): if ntp_peer.get("serv_peer", "").strip() != peer_type: continue peer_addr = napalm.base.helpers.ip(ntp_peer.get("PeerIPAddress").strip()) - ntp_entities[peer_addr] = {} + # Ignore the type of the following line until NTP data is modelled + ntp_entities[peer_addr] = {} # type: ignore return ntp_entities - def get_ntp_peers(self): + def get_ntp_peers(self) -> Dict[str, models.NTPPeerDict]: return self._get_ntp_entity("Peer") - def get_ntp_servers(self): + def get_ntp_servers(self) -> Dict[str, models.NTPServerDict]: return self._get_ntp_entity("Server") - def get_ntp_stats(self): - ntp_stats = [] + def get_ntp_stats(self) -> List[models.NTPStats]: + ntp_stats: List[models.NTPStats] = [] command = "show ntp peer-status" ntp_stats_table = self._get_command_table( command, "TABLE_peersstatus", "ROW_peersstatus" @@ -1106,15 +1217,19 @@ def get_ntp_stats(self): ) return ntp_stats - def get_interfaces_ip(self): - interfaces_ip = {} + def get_interfaces_ip(self) -> Dict[str, models.InterfacesIPDict]: + interfaces_ip: Dict[str, models.InterfacesIPDict] = {} ipv4_command = "show ip interface" ipv4_interf_table_vrf = self._get_command_table( ipv4_command, "TABLE_intf", "ROW_intf" ) + # Cast the table to the proper type + cast(ShowIPInterfaceReturn, ipv4_interf_table_vrf) + for interface in ipv4_interf_table_vrf: interface_name = str(interface.get("intf-name", "")) + assert isinstance(interface_name, str) addr_str = interface.get("prefix") unnumbered = str(interface.get("unnum-intf", "")) if addr_str: @@ -1124,7 +1239,7 @@ def get_interfaces_ip(self): interfaces_ip[interface_name] = {} if "ipv4" not in interfaces_ip[interface_name].keys(): interfaces_ip[interface_name]["ipv4"] = {} - if address not in interfaces_ip[interface_name].get("ipv4"): + if address not in interfaces_ip[interface_name]["ipv4"]: interfaces_ip[interface_name]["ipv4"][address] = {} interfaces_ip[interface_name]["ipv4"][address].update( {"prefix_length": prefix} @@ -1139,7 +1254,7 @@ def get_interfaces_ip(self): interfaces_ip[interface_name] = {} if "ipv4" not in interfaces_ip[interface_name].keys(): interfaces_ip[interface_name]["ipv4"] = {} - if address not in interfaces_ip[interface_name].get("ipv4"): + if address not in interfaces_ip[interface_name]["ipv4"]: interfaces_ip[interface_name]["ipv4"][address] = {} interfaces_ip[interface_name]["ipv4"][address].update( {"prefix_length": prefix} @@ -1157,9 +1272,7 @@ def get_interfaces_ip(self): secondary_address_prefix = int(secondary_address.get("masklen1", "")) if "ipv4" not in interfaces_ip[interface_name].keys(): interfaces_ip[interface_name]["ipv4"] = {} - if secondary_address_ip not in interfaces_ip[interface_name].get( - "ipv4" - ): + if secondary_address_ip not in interfaces_ip[interface_name]["ipv4"]: interfaces_ip[interface_name]["ipv4"][secondary_address_ip] = {} interfaces_ip[interface_name]["ipv4"][secondary_address_ip].update( {"prefix_length": secondary_address_prefix} @@ -1196,7 +1309,7 @@ def get_interfaces_ip(self): for ipv6_address in interface.get("addr", ""): address = napalm.base.helpers.ip(ipv6_address.split("/")[0]) prefix = int(ipv6_address.split("/")[-1]) - if address not in interfaces_ip[interface_name].get("ipv6"): + if address not in interfaces_ip[interface_name]["ipv6"]: interfaces_ip[interface_name]["ipv6"][address] = {} interfaces_ip[interface_name]["ipv6"][address].update( {"prefix_length": prefix} @@ -1211,15 +1324,15 @@ def get_interfaces_ip(self): else: prefix = 128 - if address not in interfaces_ip[interface_name].get("ipv6"): + if address not in interfaces_ip[interface_name]["ipv6"]: interfaces_ip[interface_name]["ipv6"][address] = {} interfaces_ip[interface_name]["ipv6"][address].update( {"prefix_length": prefix} ) return interfaces_ip - def get_mac_address_table(self): - mac_table = [] + def get_mac_address_table(self) -> List[models.MACAdressTable]: + mac_table: List[models.MACAdressTable] = [] command = "show mac address-table" mac_table_raw = self._get_command_table( command, "TABLE_mac_address", "ROW_mac_address" @@ -1249,10 +1362,11 @@ def get_mac_address_table(self): ) return mac_table - def get_snmp_information(self): - snmp_information = {} + def get_snmp_information(self) -> models.SNMPDict: + snmp_information: models.SNMPDict = {} # type: ignore snmp_command = "show running-config" snmp_raw_output = self.cli([snmp_command]).get(snmp_command, "") + assert isinstance(snmp_raw_output, str) snmp_config = napalm.base.helpers.textfsm_extractor( self, "snmp_config", snmp_raw_output ) @@ -1260,7 +1374,7 @@ def get_snmp_information(self): if not snmp_config: return snmp_information - snmp_information = { + snmp_information: models.SNMPDict = { "contact": str(""), "location": str(""), "community": {}, @@ -1293,14 +1407,19 @@ def get_snmp_information(self): snmp_information["community"][community_name]["mode"] = mode return snmp_information - def get_users(self): + def get_users(self) -> Dict[str, models.UsersDict]: _CISCO_TO_CISCO_MAP = {"network-admin": 15, "network-operator": 5} - _DEFAULT_USER_DICT = {"password": "", "level": 0, "sshkeys": []} + _DEFAULT_USER_DICT: models.UsersDict = { + "password": "", + "level": 0, + "sshkeys": [], + } - users = {} + users: Dict[str, models.UsersDict] = {} command = "show running-config" section_username_raw_output = self.cli([command]).get(command, "") + assert isinstance(section_username_raw_output, str) section_username_tabled_output = napalm.base.helpers.textfsm_extractor( self, "users", section_username_raw_output ) @@ -1322,7 +1441,7 @@ def get_users(self): level = int(role.split("-")[-1]) else: level = _CISCO_TO_CISCO_MAP.get(role, 0) - if level > users.get(username).get("level"): + if level > users[username]["level"]: # unfortunately on Cisco you can set different priv levels for the same user # Good news though: the device will consider the highest level users[username]["level"] = level @@ -1335,7 +1454,9 @@ def get_users(self): users[username]["sshkeys"].append(str(sshkeyvalue)) return users - def get_network_instances(self, name=""): + def get_network_instances( + self, name: str = "" + ) -> Dict[str, models.NetworkInstanceDict]: """ get_network_instances implementation for NX-OS """ # command 'show vrf detail' returns all VRFs with detailed information @@ -1353,7 +1474,7 @@ def get_network_instances(self, name=""): for intf in intf_table_raw: vrf_intfs[intf["vrf_name"]].append(str(intf["if_name"])) - vrfs = {} + vrfs: Dict[str, Any] = {} for vrf in vrf_table_raw: vrf_name = str(vrf.get("vrf_name")) vrfs[vrf_name] = {} @@ -1385,9 +1506,11 @@ def get_network_instances(self, name=""): else: return vrfs - def get_environment(self): - def _process_pdus(power_data): - normalized = defaultdict(dict) + def get_environment(self) -> models.EnvironmentDict: + def _process_pdus(power_data: Dict) -> Dict[str, models.PowerDict]: + # defaultdict(dict) doesn't yet seem to work with Mypy + # https://github.com/python/mypy/issues/7217 + normalized: DefaultDict[str, models.PowerDict] = defaultdict(dict) # type: ignore # some nexus devices have keys postfixed with the shorthand device series name (ie n3k) # ex. on a 9k, the key is TABLE_psinfo, but on a 3k it is TABLE_psinfo_n3k ps_info_key = [ @@ -1440,8 +1563,8 @@ def _process_pdus(power_data): ) return json.loads(json.dumps(normalized)) - def _process_fans(fan_data): - normalized = {} + def _process_fans(fan_data: Dict) -> Dict[str, models.FanDict]: + normalized: Dict[str, models.FanDict] = {} for entry in fan_data["TABLE_faninfo"]["ROW_faninfo"]: if "PS" in entry["fanname"]: # Skip fans in power supplies @@ -1454,8 +1577,10 @@ def _process_fans(fan_data): } return normalized - def _process_temperature(temperature_data): - normalized = {} + def _process_temperature( + temperature_data: Dict, + ) -> Dict[str, models.TemperatureDict]: + normalized: Dict[str, models.TemperatureDict] = {} # The modname and sensor type are not unique enough keys, so adding a count count = 1 past_tempmod = "1" @@ -1474,15 +1599,16 @@ def _process_temperature(temperature_data): count += 1 return normalized - def _process_cpu(cpu_data): + def _process_cpu(cpu_data: Dict) -> Dict[int, models.CPUDict]: idle = ( cpu_data.get("idle_percent") if cpu_data.get("idle_percent") else cpu_data["TABLE_cpu_util"]["ROW_cpu_util"]["idle_percent"] ) + assert isinstance(idle, str) return {0: {"%usage": round(100 - float(idle), 2)}} - def _process_memory(memory_data): + def _process_memory(memory_data: Dict) -> models.MemoryDict: avail = memory_data["TABLE_process_tag"]["ROW_process_tag"][ "process-memory-share-total-shm-avail" ] @@ -1503,8 +1629,8 @@ def _process_memory(memory_data): "memory": _process_memory(memory_raw), } - def get_vlans(self): - vlans = {} + def get_vlans(self) -> Dict[str, models.VlanDict]: + vlans: Dict[str, models.VlanDict] = {} command = "show vlan brief" vlan_table_raw = self._get_command_table( command, "TABLE_vlanbriefxbrief", "ROW_vlanbriefxbrief" diff --git a/test/base/test_helpers.py b/test/base/test_helpers.py index 5c9ffe796..5b6d0b150 100644 --- a/test/base/test_helpers.py +++ b/test/base/test_helpers.py @@ -240,11 +240,11 @@ def test_convert(self): napalm.base.helpers.convert(int, "non-int-value", default=-100) == -100 ) # default value returned - self.assertIsInstance(napalm.base.helpers.convert(float, "1e-17"), float) + self.assertIsInstance(napalm.base.helpers.convert(float, "1e-17", 1.0), float) # converts indeed to float - self.assertFalse(napalm.base.helpers.convert(str, None) == "None") + self.assertFalse(napalm.base.helpers.convert(str, None, "") == "None") # should not convert None-type to 'None' string - self.assertTrue(napalm.base.helpers.convert(str, None) == "") + self.assertTrue(napalm.base.helpers.convert(str, None, "") == "") # should return empty unicode def test_find_txt(self): diff --git a/test/nxapi_plumbing/conftest.py b/test/nxapi_plumbing/conftest.py index 245319eaa..2bc29249e 100644 --- a/test/nxapi_plumbing/conftest.py +++ b/test/nxapi_plumbing/conftest.py @@ -73,7 +73,7 @@ def mock_pynxos_device_xml(request): @pytest.fixture(scope="module") def pynxos_device(request): """Create a real pynxos test device.""" - device_under_test = request.config.getoption("test_device") + device_under_test = request.ConfigDict.getoption("test_device") test_devices = parse_yaml(PWD + "/etc/test_devices.yml") device = test_devices[device_under_test] conn = Device(**device) diff --git a/test/nxos/test_getters.py b/test/nxos/test_getters.py index 48a993d94..365dc44c2 100644 --- a/test/nxos/test_getters.py +++ b/test/nxos/test_getters.py @@ -24,6 +24,6 @@ def test_get_interfaces(self, test_case): assert len(get_interfaces) > 0 for interface, interface_data in get_interfaces.items(): - assert helpers.test_model(models.interface, interface_data) + assert helpers.test_model(models.InterfaceDict, interface_data) return get_interfaces diff --git a/test/nxos_ssh/test_getters.py b/test/nxos_ssh/test_getters.py index 48a993d94..365dc44c2 100644 --- a/test/nxos_ssh/test_getters.py +++ b/test/nxos_ssh/test_getters.py @@ -24,6 +24,6 @@ def test_get_interfaces(self, test_case): assert len(get_interfaces) > 0 for interface, interface_data in get_interfaces.items(): - assert helpers.test_model(models.interface, interface_data) + assert helpers.test_model(models.InterfaceDict, interface_data) return get_interfaces