From aecac896b7859012129bdcbefb1ab8bbe9582204 Mon Sep 17 00:00:00 2001 From: tmikuska <44396040+tmikuska@users.noreply.github.com> Date: Tue, 10 Oct 2023 14:26:27 +0200 Subject: [PATCH] SIMPLE-5377 PCL general improvements (#59) Co-authored-by: Daniel Valent --- docs/source/conf.py | 2 +- examples/demo.ipynb | 2 - examples/licensing.py | 2 +- examples/link_conditioning.py | 49 +- examples/sample.py | 118 ++- tests/test_client_library.py | 62 +- virl2_client/event_handling.py | 272 ++++-- virl2_client/event_listening.py | 55 +- virl2_client/exceptions.py | 12 + virl2_client/models/auth_management.py | 199 +++- virl2_client/models/authentication.py | 83 +- virl2_client/models/cl_pyats.py | 198 ++-- virl2_client/models/configuration.py | 77 +- virl2_client/models/groups.py | 113 ++- virl2_client/models/interface.py | 225 +++-- virl2_client/models/lab.py | 846 ++++++++++-------- virl2_client/models/licensing.py | 193 ++-- virl2_client/models/link.py | 188 ++-- virl2_client/models/node.py | 441 ++++++--- virl2_client/models/node_image_definitions.py | 176 ++-- virl2_client/models/resource_pools.py | 210 +++-- virl2_client/models/system.py | 341 +++++-- virl2_client/models/users.py | 119 +-- virl2_client/utils.py | 30 +- virl2_client/virl2_client.py | 406 +++++---- 25 files changed, 2865 insertions(+), 1554 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 7cc679a..02f6ff0 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -176,7 +176,7 @@ # Bibliographic Dublin Core info. epub_title = project -# The unique identifier of the text. This can be a ISBN number +# The unique identifier of the text. This can be an ISBN # or the project homepage. # # epub_identifier = '' diff --git a/examples/demo.ipynb b/examples/demo.ipynb index ff51b0e..2b5f1d8 100644 --- a/examples/demo.ipynb +++ b/examples/demo.ipynb @@ -6,10 +6,8 @@ "metadata": {}, "outputs": [], "source": [ - "import urllib3\n", "from virl2_client import ClientLibrary\n", "\n", - "urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)\n", "client = ClientLibrary(\"https://192.168.1.1\", \"virl2\", \"virl2\", ssl_verify=False)" ] }, diff --git a/examples/licensing.py b/examples/licensing.py index 6579302..0fe5d0c 100644 --- a/examples/licensing.py +++ b/examples/licensing.py @@ -43,7 +43,7 @@ result = licensing.register_wait(SL_TOKEN) if not result: result = licensing.get_reservation_return_code() - print("ERROR: Failed to register with Smart License server: {}!".format(result)) + print(f"ERROR: Failed to register with Smart License server: {result}!") exit(1) # Get the current registration status. This returns a JSON blob with license diff --git a/examples/link_conditioning.py b/examples/link_conditioning.py index 8f38577..9ad02bc 100644 --- a/examples/link_conditioning.py +++ b/examples/link_conditioning.py @@ -21,7 +21,6 @@ import getpass -import re from httpx import HTTPStatusError @@ -39,67 +38,61 @@ labs = client.find_labs_by_title(LAB_NAME) if not labs or len(labs) != 1: - print("ERROR: Unable to find a unique lab named {}".format(LAB_NAME)) + print(f"ERROR: Unable to find a unique lab named {LAB_NAME}") exit(1) -lobj = client.join_existing_lab(labs[0].id) +lab = client.join_existing_lab(labs[0].id) -if not lobj: - print("ERROR: Failed to join lab {}".format(LAB_NAME)) +if not lab: + print(f"ERROR: Failed to join lab {LAB_NAME}") exit(1) # Print all links in the lab and ask which link to condition. i = 1 -liobjs = [] -for link in lobj.links(): +links = [] +for link in lab.links(): print( - "{}. {}[{}] <-> {}[{}]".format( - i, - link.interface_a.node.label, - link.interface_a.label, - link.interface_b.node.label, - link.interface_b.label, - ) + f"{i}. {link.interface_a.node.label}[{link.interface_a.label}] <-> " + f"{link.interface_b.node.label}[{link.interface_b.label}]" ) - liobjs.append(lobj.get_link_by_interfaces(link.interface_a, link.interface_b)) + links.append(link.interface_a.get_link_to(link.interface_b)) i += 1 print() -lnum = 0 -while lnum < 1 or lnum > i: - lnum = input("Enter link number to condition (1-{}): ".format(i)) +link_number = 0 +while link_number < 1 or link_number > i: try: - lnum = int(lnum) + link_number = int(input(f"Enter link number to condition (1-{i}): ")) except ValueError: - lnum = 0 + link_number = 0 # Print the selected link's current conditioning (if any). -link = liobjs[lnum - 1] -print("Current condition is {}".format(link.get_condition())) +link = links[link_number - 1] +print(f"Current condition is {link.get_condition()}") # Request the new conditoning for bandwidth, latency, jitter, and loss. # Bandwidth is an integer between 0-10000000 kbps # Bandwidth of 0 is "no bandwidth restriction" # Latency is an integer between 0-10000 ms # Jitter is an integer between 0-10000 ms # Loss is a float between 0-100% -new_cond = input( +new_condition = input( "enter new condition in format 'BANDWIDTH, " "LATENCY, JITTER, LOSS' or 'None' to disable: " ) # If "None" is provided disable any conditioning on the link. -if new_cond.lower() == "none": +if new_condition.lower() == "none": link.remove_condition() print("Link conditioning has been disabled.") else: try: # Set the current conditioning based on the provided values. - cond_list = re.split(r"\s*,\s*", new_cond) - bw = int(cond_list[0]) # Bandwidth is an int + cond_list = new_condition.split(",") + bandwidth = int(cond_list[0]) # Bandwidth is an int latency = int(cond_list[1]) # Latency is an int jitter = int(cond_list[2]) # Jitter is an int loss = float(cond_list[3]) # Loss is a float - link.set_condition(bw, latency, jitter, loss) + link.set_condition(bandwidth, latency, jitter, loss) print("Link conditioning set.") except HTTPStatusError as exc: - print("ERROR: Failed to set link conditioning: {}", format(exc)) + print(f"ERROR: Failed to set link conditioning: {exc}") exit(1) diff --git a/examples/sample.py b/examples/sample.py index e910f6e..6b8b72c 100644 --- a/examples/sample.py +++ b/examples/sample.py @@ -19,8 +19,13 @@ # limitations under the License. # +# This script demonstrates various functionalities of the client library. +# Each example is accompanied by comments explaining its purpose and usage. +import pathlib + from virl2_client import ClientLibrary +# Licensing setup configuration CERT = """-----BEGIN CERTIFICATE----- MIIDGzCCAgOgAwIBAgIBATANBgkqhkiG9w0BAQsFADAvMQ4wDAYDVQQKEwVDaXNj WhcNMzMwNDI0MjE1NTQzWjAvMQ4wDAYDVQQKEwVDaXNjbzEdMBsGA1UEAxMUTGlj @@ -45,55 +50,76 @@ ) -# setup the connection and clean everything -cl = ClientLibrary("http://localhost:8001", "cml2", "cml2cml2", allow_http=True) -cl.is_system_ready(wait=True) +# Set up the CML 2 connection +server_url = "http://localhost:8001" +username = "cml2" # Default username if not changed in CML instance +password = pathlib.Path("/etc/machine-id").read_text().strip() +# Default password is equal to the contents of the CML instances /etc/machine-id file +# If you are running this script remotely, replace the password above +client_library = ClientLibrary(server_url, username, password, allow_http=True) -# set transport if needed - also proxy can be set if needed -# cl.licensing.licensing.set_transport( -# ssms=ssms, proxy_server="172.16.1.100", proxy_port=8888 -# ) -cl.licensing.set_transport(ssms=SSMS) -cl.licensing.install_certificate(cert=CERT) -# 'register_wait' method waits max 45s for registration status to become COMPLETED -# and another 45s for authorization status to become IN_COMPLIANCE -cl.licensing.register_wait(token=TOKEN) +# Check if the CML 2 system is ready +client_library.is_system_ready(wait=True) +# Set up licensing configuration +client_library.licensing.set_transport(ssms=SSMS) +client_library.licensing.install_certificate(cert=CERT) +client_library.licensing.register_wait(token=TOKEN) -lab_list = cl.get_lab_list() +# Get a list of existing labs and print their details +lab_list = client_library.get_lab_list() for lab_id in lab_list: - lab = cl.join_existing_lab(lab_id) - lab.stop() - lab.wipe() - cl.remove_lab(lab_id) - -lab = cl.create_lab() -lab = cl.join_existing_lab(lab_id="lab_1") - -s1 = lab.create_node("s1", "server", 50, 100) -s2 = lab.create_node("s2", "server", 50, 200) -print(s1, s2) - -# create a link between s1 and s2, equivalent to -# s1_i1 = s1.create_interface() -# s2_i1 = s2.create_interface() -# lab.create_link(s1_i1, s2_i1) -lab.connect_two_nodes(s1, s2) - -# this must remove the link between s1 and s2 -lab.remove_node(s2) - + lab = client_library.join_existing_lab(lab_id) + print("Lab ID:", lab.id) + print("Lab Title:", lab.title) + print("Lab Description:", lab.description) + print("Lab State:", lab.state) + print("----") + +# A simpler way to join all labs at once +labs = client_library.all_labs() + +# Create a lab +lab = client_library.create_lab() + +# Create two server nodes +server1 = lab.create_node("server1", "server", 50, 100) +server2 = lab.create_node("server2", "server", 50, 200) +print("Created nodes:", server1, server2) + +# Create a link between server1 and server2 +link = lab.connect_two_nodes(server1, server2) +print("Created link between server1 and server2") + +# Remove the link between server1 and server2 +link.remove() +print("Removed link between server1 and server2") + +# Manually synchronize lab states - this happens automatically once per second +# by default, but we can skip the wait by calling this method lab.sync_states() + +# Print the state of each node and its interfaces for node in lab.nodes(): - print(node, node.state) - for iface in node.interfaces(): - print(iface, iface.state) - -assert [link for link in lab.links() if link.state is not None] == [] - -# in case you's like to deregister after you're done -status = cl.licensing.deregister() -cl.licensing.remove_certificate() -# set licensing back to default transport -# default ssms is "https://smartreceiver.cisco.com/licservice/license" -cl.licensing.set_default_transport() + print(f"Node: {node.label} | State: {node.state}") + for interface in node.interfaces(): + print(f" Interface: {interface.label} | State: {interface.state}") + +# Export a lab topology to a file +lab_data = lab.download() +with open("demo_lab_export.yaml", "w") as file: + file.write(lab_data) +print("Lab exported successfully.") + +# Clean up the lab +lab.stop() +lab.wipe() +lab.remove() # or client_library.remove_lab(lab_id) + +# Deregister and remove the certificate (optional) +client_library.licensing.deregister() +client_library.licensing.remove_certificate() + +# Set licensing back to the default transport (optional) +# Default SSMS is "https://smartreceiver.cisco.com/licservice/license" +client_library.licensing.set_default_transport() diff --git a/tests/test_client_library.py b/tests/test_client_library.py index c09c65a..54e8ebe 100644 --- a/tests/test_client_library.py +++ b/tests/test_client_library.py @@ -44,7 +44,7 @@ ) -# TODO: split into multiple test modules, by feature +# TODO: split into multiple test modules, by feature. @pytest.fixture def reset_env(monkeypatch): env_vars = [ @@ -74,13 +74,13 @@ def test_import_lab_from_path_virl( lab = cl.import_lab_from_path(path=(tmp_path / "topology.virl").as_posix()) assert lab.title is not None - assert lab.lab_base_url.startswith("labs/") + assert lab._url_for("lab").startswith("labs/") - cl.session.post.assert_called_once_with( + cl._session.post.assert_called_once_with( "import/virl-1x", content="", ) - cl.session.post.assert_called_once() + cl._session.post.assert_called_once() sync_mock.assert_called_once_with() @@ -99,10 +99,11 @@ def test_import_lab_from_path_virl_title( path=(tmp_path / "topology.virl").as_posix(), title=new_title ) assert lab.title is not None - assert lab.lab_base_url.startswith("labs/") + assert lab._url_for("lab").startswith("labs/") - cl.session.post.assert_called_once_with( - f"import/virl-1x?title={new_title}", + cl._session.post.assert_called_once_with( + "import/virl-1x", + params={"title": new_title}, content="", ) @@ -117,7 +118,7 @@ def test_ssl_certificate(client_library_server_current, mocked_session): cl.is_system_ready(wait=True) assert cl._ssl_verify == "/home/user/cert.pem" - assert cl.session.mock_calls[:4] == [call.get("authok")] + assert cl._session.mock_calls[:4] == [call.get("authok")] def test_ssl_certificate_from_env_variable( @@ -130,7 +131,7 @@ def test_ssl_certificate_from_env_variable( assert cl.is_system_ready() assert cl._ssl_verify == "/home/user/cert.pem" - assert cl.session.mock_calls[:4] == [call.get("authok")] + assert cl._session.mock_calls[:4] == [call.get("authok")] @python37_or_newer @@ -195,10 +196,10 @@ def initial_different_response(initial, subsequent=httpx.Response(200)): def test_client_library_init_allow_http(client_library_server_current): cl = ClientLibrary("http://somehost", "virl2", "virl2", allow_http=True) - assert cl.session.base_url.scheme == "http" - assert cl.session.base_url.host == "somehost" - assert cl.session.base_url.port is None - assert cl.session.base_url.path.endswith("/api/v0/") + assert cl._session.base_url.scheme == "http" + assert cl._session.base_url.host == "somehost" + assert cl._session.base_url.port is None + assert cl._session.base_url.path.endswith("/api/v0/") assert cl.username == "virl2" assert cl.password == "virl2" @@ -254,12 +255,12 @@ def test_client_library_init_url( assert re.match(pattern, str(err.value)) else: cl = ClientLibrary(url, username="virl2", password="virl2", allow_http=True) - url_parts = cl.session.base_url + url_parts = cl._session.base_url assert url_parts.scheme == (expected_parts.scheme or "https") assert url_parts.host == (expected_parts.host or expected_parts.path) assert url_parts.port == expected_parts.port assert url_parts.path == "/api/v0/" - assert cl.session.base_url.path.endswith("/api/v0/") + assert cl._session.base_url.path.endswith("/api/v0/") assert cl.username == "virl2" assert cl.password == "virl2" @@ -294,7 +295,7 @@ def test_client_library_init_user( cl = ClientLibrary(url, username=user, password="virl2") assert cl.username == params[1] assert cl.password == "virl2" - assert cl.session.base_url == "https://validhostname/api/v0/" + assert cl._session.base_url == "https://validhostname/api/v0/" # the test fails if you have variables set in env @@ -327,7 +328,7 @@ def test_client_library_init_password( cl = ClientLibrary(url, username="virl2", password=password) assert cl.username == "virl2" assert cl.password == params[1] - assert cl.session.base_url == "https://validhostname/api/v0/" + assert cl._session.base_url == "https://validhostname/api/v0/" @pytest.mark.parametrize( @@ -343,14 +344,14 @@ def test_client_library_init_password( ) def test_client_library_config(client_library_server_current, mocked_session, config): client_library = config.make_client() - assert client_library.session.base_url.path.startswith(config.url) + assert client_library._session.base_url.path.startswith(config.url) assert client_library.username == config.username assert client_library.password == config.password assert client_library.allow_http == config.allow_http assert client_library._ssl_verify == config.ssl_verify assert client_library.auto_sync == (config.auto_sync >= 0.0) assert client_library.auto_sync_interval == config.auto_sync - assert client_library.session.mock_calls == [ + assert client_library._session.mock_calls == [ call.get("authok"), call.base_url.path.startswith(config.url), call.base_url.path.startswith().__bool__(), @@ -369,9 +370,10 @@ def test_client_library_str_and_repr(client_library_server_current): def test_major_version_mismatch(client_library_server_1_0_0): with pytest.raises(InitializationError) as err: ClientLibrary("somehost", "virl2", password="virl2") - assert str( - err.value - ) == "Major version mismatch. Client {}, controller 1.0.0.".format(CURRENT_VERSION) + assert ( + str(err.value) + == f"Major version mismatch. Client {CURRENT_VERSION}, controller 1.0.0." + ) def test_incompatible_version(client_library_server_2_0_0): @@ -391,8 +393,8 @@ def test_client_minor_version_gt_nowarn(client_library_server_current, caplog): client_library = ClientLibrary("somehost", "virl2", password="virl2") assert client_library is not None assert ( - "Please ensure the client version is compatible with the controller version. " - "Client {}, controller 2.0.0.".format(CURRENT_VERSION) not in caplog.text + f"Please ensure the client version is compatible with the controller version. " + f"Client {CURRENT_VERSION}, controller 2.0.0." not in caplog.text ) @@ -401,8 +403,8 @@ def test_client_minor_version_lt_warn(client_library_server_2_9_0, caplog): client_library = ClientLibrary("somehost", "virl2", password="virl2") assert client_library is not None assert ( - "Please ensure the client version is compatible with the controller version. " - "Client {}, controller 2.9.0.".format(CURRENT_VERSION) in caplog.text + f"Please ensure the client version is compatible with the controller version. " + f"Client {CURRENT_VERSION}, controller 2.9.0." in caplog.text ) @@ -411,8 +413,8 @@ def test_client_minor_version_lt_warn_1(client_library_server_2_19_0, caplog): client_library = ClientLibrary("somehost", "virl2", password="virl2") assert client_library is not None assert ( - "Please ensure the client version is compatible with the controller version. " - "Client {}, controller 2.19.0.".format(CURRENT_VERSION) in caplog.text + f"Please ensure the client version is compatible with the controller version. " + f"Client {CURRENT_VERSION}, controller 2.19.0." in caplog.text ) @@ -421,8 +423,8 @@ def test_exact_version_no_warn(client_library_server_current, caplog): client_library = ClientLibrary("somehost", "virl2", password="virl2") assert client_library is not None assert ( - "Please ensure the client version is compatible with the controller version. " - "Client {}, controller 2.0.0.".format(CURRENT_VERSION) not in caplog.text + f"Please ensure the client version is compatible with the controller version. " + f"Client {CURRENT_VERSION}, controller 2.0.0." not in caplog.text ) diff --git a/virl2_client/event_handling.py b/virl2_client/event_handling.py index d15325a..fd8e0ca 100644 --- a/virl2_client/event_handling.py +++ b/virl2_client/event_handling.py @@ -1,10 +1,30 @@ +# +# This file is part of VIRL 2 +# Copyright (c) 2019-2023, Cisco Systems, Inc. +# All rights reserved. +# +# Python bindings for the Cisco VIRL 2 Network Simulation Platform +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + from __future__ import annotations import asyncio import logging from abc import ABC, abstractmethod from os import name as os_name -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Union from .exceptions import ElementNotFound, LabNotFound @@ -22,111 +42,195 @@ class Event: - """An event object, stores parsed info about the event it represents.""" - def __init__(self, event_dict: dict[str, Any]): + """ + An event object, stores parsed info about the event it represents. + + :param event_dict: A dictionary containing event data. + """ self.type: str = (event_dict.get("event_type") or "").casefold() - self.subtype_original: Optional[str] = event_dict.get("event") + self.subtype_original: str | None = event_dict.get("event") self.subtype: str = (self.subtype_original or "").casefold() self.element_type: str = (event_dict.get("element_type") or "").casefold() self.lab_id: str = event_dict.get("lab_id", "") self.element_id: str = event_dict.get("element_id", "") - self.data: Optional[dict] = event_dict.get("data") - self.lab: Optional[Lab] = None + self.data: dict | None = event_dict.get("data") + self.lab: Lab | None = None self.element: Union[Node, Interface, Link, None] = None + def __str__(self): + return ( + f"Event type: {self.type}, " + f"Subtype: {self.subtype}, " + f"Element type: {self.element_type}" + ) -class EventHandlerBase(ABC): - """ - Abstract base class for event handlers. Subclass this if you - want to implement the entire event handling mechanism from scratch, - otherwise subclass EventHandler instead. - """ +class EventHandlerBase(ABC): def __init__(self, client_library: ClientLibrary = None): - self.client_library = client_library + """ + Abstract base class for event handlers. - def parse_event(self, event: Event) -> None: + Subclass this if you want to implement the entire event handling mechanism + from scratch, otherwise subclass EventHandler instead. + + :param client_library: The client library which should be modified by events. + """ + self._client_library = client_library + + def handle_event(self, event: Event) -> None: + """ + Parse and handle the given event. + + :param event: An Event object representing the event to be parsed. + """ if event.type == "lab_event": - self._parse_lab(event) + self._handle_lab(event) elif event.type == "lab_element_event": - self._parse_element(event) + self._handle_element(event) elif event.type == "state_change": - self._parse_state_change(event) + self._handle_state_change(event) else: - self._parse_other(event) + self._handle_other(event) + + def _handle_lab(self, event: Event) -> None: + """ + Handle lab events. - def _parse_lab(self, event: Event) -> None: + :param event: An Event object representing the lab event. + """ if event.subtype == "created": - self._parse_lab_created(event) + self._handle_lab_created(event) elif event.subtype == "modified": - self._parse_lab_modified(event) + self._handle_lab_modified(event) elif event.subtype == "deleted": - self._parse_lab_deleted(event) + self._handle_lab_deleted(event) elif event.subtype == "state": - self._parse_lab_state(event) + self._handle_lab_state(event) else: - _LOGGER.warning("Received an invalid lab event (%s)", event.subtype) + # There are only four subtypes, anything else is invalid + _LOGGER.warning(f"Received an invalid event. {event}") @abstractmethod - def _parse_lab_created(self, event: Event) -> None: + def _handle_lab_created(self, event: Event) -> None: + """ + Handle lab created events. + + :param event: An Event object representing the lab created event. + """ pass @abstractmethod - def _parse_lab_modified(self, event: Event) -> None: + def _handle_lab_modified(self, event: Event) -> None: + """ + Handle lab modified events. + + :param event: An Event object representing the lab modified event. + """ pass @abstractmethod - def _parse_lab_deleted(self, event: Event) -> None: + def _handle_lab_deleted(self, event: Event) -> None: + """ + Handle lab deleted events. + + :param event: An Event object representing the lab deleted event. + """ pass @abstractmethod - def _parse_lab_state(self, event: Event) -> None: - pass # TODO: implement lab state sync via events + def _handle_lab_state(self, event: Event) -> None: + """ + Handle lab state events. - def _parse_element(self, event: Event) -> None: - if event.subtype == "created": - self._parse_element_created(event) + :param event: An Event object representing the lab state event. + """ + pass + + def _handle_element(self, event: Event) -> None: + """ + Handle lab element events. + + :param event: An Event object representing the lab element event. + """ + if event.element_type in ("annotation", "connectormapping"): + # These are not used in this client library + _LOGGER.debug(f"Received an unused element type: {event.data}") + elif event.subtype == "created": + self._handle_element_created(event) elif event.subtype == "modified": - self._parse_element_modified(event) + self._handle_element_modified(event) elif event.subtype == "deleted": - self._parse_element_deleted(event) + self._handle_element_deleted(event) else: - _LOGGER.warning("Received an invalid element event (%s)", event.subtype) + # There are only three subtypes, anything else is invalid + # ("state" is under type "state_change", not "lab_element_event") + _LOGGER.warning(f"Received an invalid event. {event}") @abstractmethod - def _parse_element_created(self, event: Event) -> None: + def _handle_element_created(self, event: Event) -> None: + """ + Handle element created events. + + :param event: An Event object representing the element created event. + """ pass @abstractmethod - def _parse_element_modified(self, event: Event) -> None: + def _handle_element_modified(self, event: Event) -> None: + """ + Handle element modified events. + + :param event: An Event object representing the element modified event. + """ pass @abstractmethod - def _parse_element_deleted(self, event: Event) -> None: + def _handle_element_deleted(self, event: Event) -> None: + """ + Handle element deleted events. + + :param event: An Event object representing the element deleted event. + """ pass @abstractmethod - def _parse_state_change(self, event: Event) -> None: + def _handle_state_change(self, event: Event) -> None: + """ + Handle state change events. + + :param event: An Event object representing the state change event. + """ pass - def _parse_other(self, event: Event) -> None: + def _handle_other(self, event: Event) -> None: + """ + Handle other events. + + :param event: An Event object representing the other event. + """ # All other events are useless to the client, but in case some handling # needs to be done on them, this method can be overridden pass class EventHandler(EventHandlerBase): - """Handler for JSON events received from controller over websockets. + """ + Handler for JSON events received from controller over websockets. Used by EventListener by default, but can be subclassed and methods overridden - to change/extend the handling mechanism, then passed to EventListener.""" + to change/extend the handling mechanism, then passed to EventListener. + """ - def parse_event(self, event: Event) -> None: - if event.type in ("lab_stats", "system_stats"): + def handle_event(self, event: Event) -> None: + if event.type in ("lab_stats", "system_stats") or ( + event.element_type in ("annotation", "connectormapping") + ): + # Some events are unused in the client library + _LOGGER.debug(f"Received an unused event. {event}") return try: - event.lab = self.client_library.get_local_lab(event.lab_id) + event.lab = self._client_library.get_local_lab(event.lab_id) except LabNotFound: # lab is not locally joined, so we can ignore its events return @@ -135,47 +239,37 @@ def parse_event(self, event: Event) -> None: try: if event.element_type == "node": event.element = event.lab.get_node_by_id(event.element_id) - if event.element_type == "interface": + elif event.element_type == "interface": event.element = event.lab.get_interface_by_id(event.element_id) - if event.element_type == "link": + elif event.element_type == "link": event.element = event.lab.get_link_by_id(event.element_id) except ElementNotFound: - return - # This should really raise the following error, but the order - # in which the events arrive is inconsistent, so sometimes - # when e.g. a node is deleted, the node deletion event - # arrives before the interface deletion events, so the interfaces - # no longer exist by the time their deletion events arrive, - # which would raise an unwarranted DesynchronizedError - # so instead for now, we just ignore those events - # TODO: implement serverside echo prevention and consistent message - # order for all events and raise the following error - # raise DesynchronizedError( - # "{} {} not found in lab {}, " - # "likely due to desynchronization".format( - # event.element_type, - # event.element_id, - # event.lab_id, - # ) - # ) - - super().parse_event(event) - - def _parse_lab_created(self, event: Event) -> None: + if event.subtype == "deleted": + # Element was likely already deleted in a cascading deletion + # (e.g. node being deleted and all its links and interfaces being + # deleted with it) so the event is useless + return + else: + # A modify event arrived for a missing element - something is wrong + raise + + super().handle_event(event) + + def _handle_lab_created(self, event: Event) -> None: # we don't care about labs the user hasn't joined, # so we don't need the lab creation event pass - def _parse_lab_modified(self, event: Event) -> None: + def _handle_lab_modified(self, event: Event) -> None: event.lab.update_lab_properties(event.data) - def _parse_lab_deleted(self, event: Event) -> None: - self.client_library._remove_lab_local(event.lab) + def _handle_lab_deleted(self, event: Event) -> None: + self._client_library._remove_lab_local(event.lab) - def _parse_lab_state(self, event: Event) -> None: + def _handle_lab_state(self, event: Event) -> None: event.lab._state = event.data["state"] - def _parse_element_created(self, event: Event) -> None: + def _handle_element_created(self, event: Event) -> None: new_element: Node | Interface | Link existing_elements: dict = getattr(event.lab, f"_{event.element_type}s", {}) if event.element_id in existing_elements: @@ -203,13 +297,17 @@ def _parse_element_created(self, event: Event) -> None: event.data["interface_b"], ) else: - _LOGGER.warning("Received an invalid element type (%s)", event.element_type) + # "Annotation" and "ConnectorMapping" were weeded out before, + # so we should never get here + _LOGGER.warning(f"Received an invalid event. {event}") return new_element._state = event.data.get("state") - def _parse_element_modified(self, event: Event) -> None: + def _handle_element_modified(self, event: Event) -> None: if event.element_type == "node": - event.element.update(event.data, exclude_configurations=False) + event.element.update( + event.data, exclude_configurations=False, push_to_server=False + ) elif event.element_type == "interface": # it seems only port change info arrives here, @@ -222,13 +320,11 @@ def _parse_element_modified(self, event: Event) -> None: pass else: - _LOGGER.warning( - "Received an invalid lab element event (%s %s)", - event.subtype, - event.element_type, - ) + # "Annotation" and "ConnectorMapping" were weeded out before, + # so we should never get here + _LOGGER.warning(f"Received an invalid event. {event}") - def _parse_element_deleted(self, event: Event) -> None: + def _handle_element_deleted(self, event: Event) -> None: if event.element_type == "node": event.lab._remove_node_local(event.element) @@ -239,11 +335,9 @@ def _parse_element_deleted(self, event: Event) -> None: event.lab._remove_link_local(event.element) else: - _LOGGER.warning( - "Received an invalid lab element event (%s %s)", - event.subtype, - event.element_type, - ) + # "Annotation" and "ConnectorMapping" were weeded out before, + # so we should never get here + _LOGGER.warning(f"Received an invalid event. {event}") - def _parse_state_change(self, event: Event) -> None: + def _handle_state_change(self, event: Event) -> None: event.element._state = event.subtype_original diff --git a/virl2_client/event_listening.py b/virl2_client/event_listening.py index dd7f7c8..2dc66ab 100644 --- a/virl2_client/event_listening.py +++ b/virl2_client/event_listening.py @@ -1,3 +1,23 @@ +# +# This file is part of VIRL 2 +# Copyright (c) 2019-2023, Cisco Systems, Inc. +# All rights reserved. +# +# Python bindings for the Cisco VIRL 2 Network Simulation Platform +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + from __future__ import annotations import asyncio @@ -6,7 +26,7 @@ import ssl import threading from pathlib import Path -from typing import TYPE_CHECKING, Coroutine, Optional +from typing import TYPE_CHECKING, Coroutine from urllib.parse import urlparse import aiohttp @@ -22,25 +42,26 @@ class EventListener: def __init__(self, client_library: ClientLibrary): """ - Initializes a EventListener instance. + Initialize an EventListener instance. EventListener creates and listens to a websocket connection to the server. - Events are then sent to the EventHandler instance for handling + Events are then sent to the EventHandler instance for handling. Use start_listening() to open and stop_listening() to close connection. - :param client_library: parent ClientLibrary instance which provides connection + :param client_library: Parent ClientLibrary instance which provides connection info and is modified when synchronizing. """ - self._thread: Optional[threading.Thread] = None - self._ws_close: Optional[Coroutine] = None - self._ws_close_event: Optional[asyncio.Event] = None - self._ws_connected_event: Optional[threading.Event] = None + self._thread: threading.Thread | None = None + self._ws_close: Coroutine | None = None + self._ws_close_event: asyncio.Event | None = None + self._ws_connected_event: threading.Event | None = None self._synchronizing = False self._listening = False self._connected = False - self._queue: Optional[asyncio.Queue] = None - self._ws_url: Optional[str] = None - self._ssl_context: Optional[ssl.SSLContext] = None + self._auth_data = None + self._queue: asyncio.Queue | None = None + self._ws_url: str | None = None + self._ssl_context: ssl.SSLContext | None = None self._event_handler = EventHandler(client_library) self._init_ws_connection_data(client_library) @@ -49,10 +70,10 @@ def __bool__(self): return self._listening def _init_ws_connection_data(self, client_library: ClientLibrary) -> None: - """Creates an SSL context and Websocket url from client library data. + """ + Create an SSL context and Websocket url from client library data. - :param client_library: the client library instance - :returns: a tuple of url and context + :param client_library: The client library instance. """ ssl_verify = client_library._ssl_verify # Create an SSL context based on the 'verify' str/bool, @@ -75,11 +96,12 @@ def _init_ws_connection_data(self, client_library: ClientLibrary) -> None: self._ws_url = str(ws_url_pieces.geturl()) self._auth_data = { - "token": client_library.session.auth.token, + "token": client_library._session.auth.token, "client_uuid": client_library.uuid, } def start_listening(self): + """Start listening for events.""" if self._listening: return @@ -92,6 +114,7 @@ def start_listening(self): self._thread.start() def stop_listening(self): + """Stop listening for events.""" if not self._listening: return @@ -132,7 +155,7 @@ async def _parse(self): await self._ws_close return event = Event(json.loads(queue_get.result())) - self._event_handler.parse_event(event) + self._event_handler.handle_event(event) self._queue.task_done() async def _ws_client(self): diff --git a/virl2_client/exceptions.py b/virl2_client/exceptions.py index 0d87291..5a6db8e 100644 --- a/virl2_client/exceptions.py +++ b/virl2_client/exceptions.py @@ -69,3 +69,15 @@ class InvalidProperty(VirlException): class MethodNotActive(VirlException): pass + + +class PyatsException(Exception): + pass + + +class PyatsNotInstalled(PyatsException): + pass + + +class PyatsDeviceNotFound(PyatsException): + pass diff --git a/virl2_client/models/auth_management.py b/virl2_client/models/auth_management.py index c565f2c..144c0d0 100644 --- a/virl2_client/models/auth_management.py +++ b/virl2_client/models/auth_management.py @@ -21,28 +21,29 @@ from __future__ import annotations import time -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from ..exceptions import MethodNotActive +from ..utils import get_url_from_template from .resource_pools import ResourcePool if TYPE_CHECKING: from httpx import Client -_BASE_URL = "system/auth" -_CONFIG_URL = _BASE_URL + "/config" -_TEST_URL = _BASE_URL + "/test" - - class AuthManagement: + _URL_TEMPLATES = { + "config": "system/auth/config", + "test": "system/auth/test", + } + def __init__(self, session: Client, auto_sync=True, auto_sync_interval=1.0): """ - Sync and modify authentication settings. + Manage authentication settings and synchronization. - :param session: parent client's httpx.Client object - :param auto_sync: whether to automatically synchronize resource pools - :param auto_sync_interval: how often to synchronize resource pools in seconds + :param session: The httpx-based HTTP client for this session with the server. + :param auto_sync: Whether to automatically synchronize resource pools. + :param auto_sync_interval: How often to synchronize resource pools in seconds. """ self.auto_sync = auto_sync self.auto_sync_interval = auto_sync_interval @@ -54,7 +55,21 @@ def __init__(self, session: Client, auto_sync=True, auto_sync_interval=1.0): "ldap": LDAPManager(self), } + def _url_for(self, endpoint, **kwargs): + """ + Generate the URL for a given API endpoint. + + :param endpoint: The desired endpoint. + :param **kwargs: Keyword arguments used to format the URL. + :returns: The formatted URL. + """ + return get_url_from_template(endpoint, self._URL_TEMPLATES, kwargs) + def sync_if_outdated(self) -> None: + """ + Synchronize local data with the server if the auto sync interval + has elapsed since the last synchronization. + """ timestamp = time.time() if ( self.auto_sync @@ -63,60 +78,71 @@ def sync_if_outdated(self) -> None: self.sync() def sync(self) -> None: - self._settings = self._session.get(_CONFIG_URL).json() + """Synchronize the authentication settings with the server.""" + url = self._url_for("config") + self._settings = self._session.get(url).json() self._last_sync_time = time.time() @property def method(self) -> str: - """Current authentication method.""" + """Return the current authentication method.""" self.sync_if_outdated() return self._settings["method"] @property def manager(self) -> AuthMethodManager: - """ - The property manager for the current authentication method. - """ + """Return the property manager for the current authentication method.""" self.sync_if_outdated() return self._managers[self._settings["method"]] def _get_current_settings(self) -> dict: - return self._session.get(_CONFIG_URL).json() + """ + Get the current authentication settings. + + :returns: The current authentication settings. + """ + url = self._url_for("config") + return self._session.get(url).json() def _get_setting(self, setting: str) -> Any: """ - Returns the value of a specific setting of the current authentication method. + Get the value of a specific setting of the current authentication method. - Note: consider using parameters instead. + :param setting: The setting to retrieve. + :returns: The value of the specified setting. """ self.sync_if_outdated() return self._settings[setting] def get_settings(self) -> dict: - """Returns a dictionary of the settings of the current authentication method.""" + """ + Get a dictionary of the settings of the current authentication method. + + :returns: A dictionary of the settings of the current authentication method. + """ self.sync_if_outdated() return self._settings.copy() def _update_setting(self, setting: str, value: Any) -> None: """ - Update a setting for the current auth method. + Update a setting for the current authentication method. - :param setting: the setting to update - :param value: the value to set the setting to + :param setting: The setting to update. + :param value: The value to set the setting to. """ - - self._session.put( - _CONFIG_URL, - json={setting: value, "method": self._settings["method"]}, - ) + url = self._url_for("config") + settings = {setting: value, "method": self._settings["method"]} + self._session.put(url, json=settings) if setting in self._settings: self._settings[setting] = value - def update_settings(self, settings_dict: Optional[dict] = None, **kwargs) -> None: + def update_settings(self, settings_dict: dict | None = None, **kwargs) -> None: """ - Update multiple auth settings at once. - More efficient in batches than setting manager properties. - If passed a dictionary, reads the dictionary for settings. + Update multiple authentication settings at once. + + This method is more efficient in batches than setting manager properties. + + If passed a dictionary, it reads the dictionary for settings. If passed multiple keyword arguments, each is taken as a setting. Can mix and match; keywords take precedence. @@ -128,8 +154,8 @@ def update_settings(self, settings_dict: Optional[dict] = None, **kwargs) -> Non # "verify_tls" key in dictionary overridden by keyword argument into False am.update_settings({"method": "ldap", "verify_tls": True}, verify_tls=False) - :param settings_dict: a dictionary of settings - :param kwargs: keywords of settings + :param settings_dict: A dictionary of settings. + :param kwargs: Setting keywords. """ settings = {} if settings_dict: @@ -137,37 +163,39 @@ def update_settings(self, settings_dict: Optional[dict] = None, **kwargs) -> Non settings.update(kwargs) if not settings: raise TypeError("No settings to update.") - self._session.put(_CONFIG_URL, json=settings) + url = self._url_for("config") + self._session.put(url, json=settings) self.sync() def test_auth(self, config: dict, username: str, password: str) -> dict: """ - Tests a set of credentials against the specified authentication configuration. + Test a set of credentials against the specified authentication configuration. - :param config: a dictionary of authentication settings to test against - (including manager password) - :param username: the username to test - :param password: the password to test - :returns: results of the test + :param config: A dictionary of authentication settings to test against + (including manager password). + :param username: The username to test. + :param password: The password to test. + :returns: Results of the test. """ body = { "auth-config": config, "auth-data": {"username": username, "password": password}, } - response = self._session.post(_TEST_URL, json=body) + url = self._url_for("test") + response = self._session.post(url, json=body) return response.json() def test_current_auth( self, manager_password: str, username: str, password: str ) -> dict: """ - Tests a set of credentials against the currently applied authentication + Test a set of credentials against the currently applied authentication configuration. - :param manager_password: the auth manager password to allow testing - :param username: the username to test - :param password: the password to test - :returns: results of the test + :param manager_password: The manager password to allow testing. + :param username: The username to test. + :param password: The password to test. + :returns: Results of the test. """ current = self.get_settings() current["manager_password"] = manager_password @@ -175,147 +203,214 @@ def test_current_auth( "auth-config": current, "auth-data": {"username": username, "password": password}, } - response = self._session.post(_TEST_URL, json=body) + url = self._url_for("test") + response = self._session.post(url, json=body) return response.json() class AuthMethodManager: + """ + Base class for managing authentication methods. + Provides a mechanism to retrieve and update settings + related to authentication methods. + """ + METHOD = "" def __init__(self, auth_management: AuthManagement): + """ + Initialize the AuthMethodManager with an instance of AuthManagement. + + :param auth_management: An instance of AuthManagement. + """ + self._auth_management = auth_management def _check_method(self): - """This makes sure that storing the manager in a variable won't allow the user - to accidentaly set settings of other methods using this.""" + """ + Internal method to ensure the current method is active. + + :raises MethodNotActive: If the current method is not active. + """ + if self._auth_management._get_setting("method") != self.METHOD: raise MethodNotActive(f"{self.METHOD} is not the currently active method.") def _get_setting(self, setting: str) -> Any: + """ + Internal method to get a setting value from the AuthManagement instance. + + :param setting: The name of the setting. + :returns: The value of the setting. + """ + self._check_method() return self._auth_management._get_setting(setting) def _update_setting(self, setting: str, value: Any) -> None: + """ + Internal method to update a setting value in the AuthManagement instance. + + :param setting: The name of the setting. + :param value: The new value for the setting. + """ + self._check_method() self._auth_management._update_setting(setting, value) class LDAPManager(AuthMethodManager): + """ + Manages LDAP authentication settings. + + Extends the AuthMethodManager class and provides properties + for retrieving and updating LDAP settings. + """ + METHOD = "ldap" @property def server_urls(self) -> str: + """Return the LDAP server URLs.""" return self._get_setting("server_urls") @server_urls.setter def server_urls(self, value: str) -> None: + """Set the LDAP server URLs.""" self._update_setting("server_urls", value) @property def verify_tls(self) -> bool: + """Return the flag indicating whether to verify the TLS certificate.""" return self._get_setting("verify_tls") @verify_tls.setter def verify_tls(self, value: bool) -> None: + """Set the flag indicating whether to verify the TLS certificate.""" self._update_setting("verify_tls", value) @property def cert_data_pem(self) -> str: + """Return the PEM-encoded certificate data.""" return self._get_setting("cert_data_pem") @cert_data_pem.setter def cert_data_pem(self, value: str) -> None: + """Set the PEM-encoded certificate data.""" self._update_setting("cert_data_pem", value) @property def use_ntlm(self) -> bool: + """Return the flag indicating whether to use NTLM authentication.""" return self._get_setting("use_ntlm") @use_ntlm.setter def use_ntlm(self, value: bool) -> None: + """Set the flag indicating whether to use NTLM authentication.""" self._update_setting("use_ntlm", value) @property def root_dn(self) -> str: + """Return the root DN.""" return self._get_setting("root_dn") @root_dn.setter def root_dn(self, value: str) -> None: + """Set the root DN.""" self._update_setting("root_dn", value) @property def user_search_base(self) -> str: + """Return the user search base.""" return self._get_setting("user_search_base") @user_search_base.setter def user_search_base(self, value: str) -> None: + """Set the user search base.""" self._update_setting("user_search_base", value) @property def user_search_filter(self) -> str: + """Return the user search filter.""" return self._get_setting("user_search_filter") @user_search_filter.setter def user_search_filter(self, value: str) -> None: + """Set the user search filter.""" self._update_setting("user_search_filter", value) @property def admin_search_filter(self) -> str: + """Return the admin search filter.""" return self._get_setting("admin_search_filter") @admin_search_filter.setter def admin_search_filter(self, value: str) -> None: + """Set the admin search filter.""" self._update_setting("admin_search_filter", value) @property def group_search_base(self) -> str: + """Return the group search base.""" return self._get_setting("group_search_base") @group_search_base.setter def group_search_base(self, value: str) -> None: + """Set the group search base.""" self._update_setting("group_search_base", value) @property def group_search_filter(self) -> str: + """Return the group search filter.""" return self._get_setting("group_search_filter") @group_search_filter.setter def group_search_filter(self, value: str) -> None: + """Set the group search filter.""" self._update_setting("group_search_filter", value) @property def group_via_user(self) -> bool: + """Return the flag indicating whether to use group via user.""" return self._get_setting("group_via_user") @group_via_user.setter def group_via_user(self, value: bool) -> None: + """Set the flag indicating whether to use group via user.""" self._update_setting("group_via_user", value) @property def group_user_attribute(self) -> str: + """Return the group user attribute.""" return self._get_setting("group_user_attribute") @group_user_attribute.setter def group_user_attribute(self, value: str) -> None: + """Set the group user attribute.""" self._update_setting("group_user_attribute", value) @property def group_membership_filter(self) -> str: + """Return the group membership filter.""" return self._get_setting("group_membership_filter") @group_membership_filter.setter def group_membership_filter(self, value: str) -> None: + """Set the group membership filter.""" self._update_setting("group_membership_filter", value) @property def manager_dn(self) -> str: + """Return the manager DN.""" return self._get_setting("manager_dn") @manager_dn.setter def manager_dn(self, value: str) -> None: + """Set the manager DN.""" self._update_setting("manager_dn", value) def manager_password(self, value: str) -> None: + """Set the manager password.""" self._update_setting("display_attribute", value) # manager password can't be retrieved, only set, so we only provide a setter @@ -323,26 +418,32 @@ def manager_password(self, value: str) -> None: @property def display_attribute(self) -> str: + """Return the display name LDAP attribute.""" return self._get_setting("display_attribute") @display_attribute.setter def display_attribute(self, value: str) -> None: + """Set the display name LDAP attribute.""" self._update_setting("display_attribute", value) @property def email_address_attribute(self) -> str: + """Return the email address LDAP attribute.""" return self._get_setting("email_address_attribute") @email_address_attribute.setter def email_address_attribute(self, value: str) -> None: + """Set the email address LDAP attribute.""" self._update_setting("email_address_attribute", value) @property def resource_pool(self) -> str: + """Return the resource pool a new user will be added to.""" return self._get_setting("resource_pool") @resource_pool.setter def resource_pool(self, value: str | ResourcePool) -> None: + """Set the resource pool a new user will be added to.""" if isinstance(value, ResourcePool): value = value.id self._update_setting("resource_pool", value) diff --git a/virl2_client/models/authentication.py b/virl2_client/models/authentication.py index 003cbf3..a4c93ec 100644 --- a/virl2_client/models/authentication.py +++ b/virl2_client/models/authentication.py @@ -22,7 +22,7 @@ import json import logging -from typing import TYPE_CHECKING, Generator, Optional +from typing import TYPE_CHECKING, Generator from uuid import uuid4 import httpx @@ -33,66 +33,98 @@ from ..virl2_client import ClientLibrary +_AUTH_URL = "authenticate" + + class TokenAuth(httpx.Auth): """ - Initially inspired by: + Token-based authentication for an httpx session. + + Inspired by: https://requests.readthedocs.io/en/v2.9.1/user/authentication/?highlight=AuthBase#new-forms-of-authentication - and modified for httpx based on: + Modified for httpx based on: https://www.python-httpx.org/advanced/#customizing-authentication """ requires_response_body = True def __init__(self, client_library: ClientLibrary): + """ + Initialize the TokenAuth object with a client library instance. + + :param client_library: A client library instance. + """ self.client_library = client_library - self._token: Optional[str] = None + self._token: str | None = None @property - def token(self) -> Optional[str]: + def token(self) -> str | None: + """ + Return the authentication token. If the token has not been set, it is obtained + from the server. + """ if self._token is not None: return self._token - auth_url = "authenticate" - base_url = self.client_library.session.base_url + base_url = self.client_library._session.base_url if base_url.port is not None and base_url.port != 443: - _LOGGER.warning("Not using SSL port of 443: %d", base_url.port) + _LOGGER.warning(f"Not using SSL port of 443: {base_url.port:d}") if base_url.scheme != "https": - _LOGGER.warning("Not using https scheme: %s", base_url.scheme) + _LOGGER.warning(f"Not using https scheme: {base_url.scheme}") data = { "username": self.client_library.username, "password": self.client_library.password, } - response = self.client_library.session.post( - auth_url, json=data, auth=None # type: ignore + response = self.client_library._session.post( + _AUTH_URL, json=data, auth=None # type: ignore ) # auth=None works but is missing from .post's type hint response_raise(response) self._token = response.json() return self._token @token.setter - def token(self, value: Optional[str]): + def token(self, value: str | None) -> None: + """ + Set the authentication token to the specified value. + + :param value: The value to set as the authentication token. + """ self._token = value def auth_flow( self, request: httpx.Request ) -> Generator[httpx.Request, httpx.Response, None]: - request.headers["Authorization"] = "Bearer {}".format(self.token) + """ + Implement the authentication flow for the token-based authentication. + + :param request: The request object to authenticate. + :returns: A generator of the authenticated request and response objects. + """ + request.headers["Authorization"] = f"Bearer {self.token}" response = yield request if response.status_code == 401: _LOGGER.warning("re-auth called on 401 unauthorized") self.token = None - request.headers["Authorization"] = "Bearer {}".format(self.token) + request.headers["Authorization"] = f"Bearer {self.token}" response = yield request response_raise(response) - def logout(self, clear_all_sessions=False): + def logout(self, clear_all_sessions=False) -> bool: + """ + Log out the user. Invalidates the current token. + + :param clear_all_sessions: Whether to clear all sessions. + :returns: Whether the logout succeeded. + """ url = "logout" + ("?clear_all_sessions=true" if clear_all_sessions else "") - return self.client_library.session.delete(url).json() + return self.client_library._session.delete(url).json() class BlankAuth(httpx.Auth): + """A class that implements an httpx Auth object that does nothing.""" + def auth_flow( self, request: httpx.Request ) -> Generator[httpx.Request, httpx.Response, None]: @@ -101,7 +133,13 @@ def auth_flow( def response_raise(response: httpx.Response) -> None: - """Replaces the useless link to httpstatuses.com with error description.""" + """ + Raise an exception if the response has an HTTP status error. + Replace the useless link to httpstatuses.com with error description. + + :param response: The response object to check. + :raises httpx.HTTPStatusError: If the response has an HTTP status error. + """ try: response.raise_for_status() except httpx.HTTPStatusError as error: @@ -117,6 +155,17 @@ def response_raise(response: httpx.Response) -> None: def make_session(base_url: str, ssl_verify: bool = True) -> httpx.Client: + """ + Create an httpx Client object with the specified base URL + and SSL verification setting. + + Note: The base URL is automatically prepended to all HTTP calls. This means you + should use ``_session.get("labs")`` rather than ``_session.get(base_url + "labs")``. + + :param base_url: The base URL for the client. + :param ssl_verify: Whether to perform SSL verification. + :returns: The created httpx Client object. + """ return httpx.Client( base_url=base_url, verify=ssl_verify, diff --git a/virl2_client/models/cl_pyats.py b/virl2_client/models/cl_pyats.py index c40df41..344f92f 100644 --- a/virl2_client/models/cl_pyats.py +++ b/virl2_client/models/cl_pyats.py @@ -21,33 +21,29 @@ from __future__ import annotations import io -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING + +from ..exceptions import PyatsDeviceNotFound, PyatsNotInstalled if TYPE_CHECKING: - from pyats.topology import Device + from genie.libs.conf.device import Device + from genie.libs.conf.testbed import Testbed from .lab import Lab -class PyatsNotInstalled(Exception): - pass - - -class PyatsDeviceNotFound(Exception): - pass - - class ClPyats: - def __init__(self, lab: Lab, hostname: Optional[str] = None) -> None: + def __init__(self, lab: Lab, hostname: str | None = None) -> None: """ - Creates a pyATS object that can be used to run commands + Create a pyATS object that can be used to run commands against a device either in exec mode ``show version`` or in configuration mode ``interface gi0/0 \\n no shut``. - :param lab: the lab which should be used with pyATS - :param hostname: force hostname/ip and port for console terminal server - :raises PyatsNotInstalled: when pyATS can not be found - :raises PyatsDeviceNotFound: when the device can not be found + :param lab: The lab object to be used with pyATS. + :param hostname: Forced hostname or IP address and port of the console + terminal server. + :raises PyatsNotInstalled: If pyATS is not installed. + :raises PyatsDeviceNotFound: If the device cannot be found. """ self._pyats_installed = False self._lab = lab @@ -58,31 +54,40 @@ def __init__(self, lab: Lab, hostname: Optional[str] = None) -> None: return else: self._pyats_installed = True - self._testbed: Any = None - self._connections: list[Any] = [] + self._testbed: Testbed | None = None + self._connections: set[Device] = set() @property - def hostname(self) -> Optional[str]: - """Return forced hostname/ip and port terminal server setting""" + def hostname(self) -> str | None: + """Return the forced hostname/IP and port terminal server setting.""" return self._hostname @hostname.setter - def hostname(self, hostname: Optional[str] = None) -> None: + def hostname(self, hostname: str | None = None) -> None: + """ + Set the forced hostname/IP and port terminal server setting. + + :param hostname: The hostname or IP address and port of the console terminal + server. + """ self._hostname = hostname def _check_pyats_installed(self) -> None: + """ + Check if pyATS is installed and raise an exception if not. + + :raises PyatsNotInstalled: If pyATS is not installed. + """ if not self._pyats_installed: raise PyatsNotInstalled def sync_testbed(self, username: str, password: str) -> None: """ - Syncs the testbed from the server. Note that this - will always fetch the latest topology data from the server. + Sync the testbed from the server. + This fetches the latest topology data from the server. - :param username: the username that will be inserted into - the testbed data - :param password: the password that will be inserted into - the testbed data + :param username: The username to be inserted into the testbed data. + :param password: The password to be inserted into the testbed data. """ self._check_pyats_installed() from pyats.topology import loader @@ -93,32 +98,19 @@ def sync_testbed(self, username: str, password: str) -> None: data.devices.terminal_server.credentials.default.password = password self._testbed = data - def _prepare_device(self, node_label: str) -> Device: - self._check_pyats_installed() - try: - pyats_device = self._testbed.devices[node_label] - except KeyError: - raise PyatsDeviceNotFound(node_label) - - # TODO: later check if connected - # TODO: later look at pooling connections - pyats_device.connect(log_stdout=False, learn_hostname=True) - self._connections.append(pyats_device) - return pyats_device - def _prepare_params( self, - init_exec_commands: Optional[list] = None, - init_config_commands: Optional[list] = None, + init_exec_commands: list[str] | None = None, + init_config_commands: list[str] | None = None, ) -> dict: """ Prepare a dictionary of optional parameters to be executed before a command. - None means that default commands will be executed. If you want no commands - to be executed pass an empty list instead. + None means that default commands will be executed. If you want no commands + to be executed, pass an empty list instead. - :param init_exec_commands: a list of exec commands to be executed - :param init_config_commangs: a list of config commands to be executed - :returns: a dictionary of optional parameters to be executed with a command + :param init_exec_commands: A list of exec commands to be executed. + :param init_config_commands: A list of config commands to be executed. + :returns: A dictionary of optional parameters to be executed with a command. """ params = {} if init_exec_commands: @@ -127,48 +119,108 @@ def _prepare_params( params["init_config_commands"] = init_config_commands return params - def run_command( + def _execute_command( self, node_label: str, command: str, - init_exec_commands: Optional[list] = None, - init_config_commands: Optional[list] = None, + configure_mode: bool = False, + init_exec_commands: list[str] | None = None, + init_config_commands: list[str] | None = None, ) -> str: """ - Run a command on the device in `exec` mode. - - :param node_label: the label / title of the device - :param command: the command to be run in exec mode - :param init_exec_commands: a list of exec commands to be executed - :param init_config_commangs: a list of config commands to be executed - :returns: the output from the device + Execute a command on the device. + + :param node_label: The label/title of the device. + :param command: The command to be executed. + :param configure_mode: True if the command is to be run in configure mode, + False for exec mode. + :param init_exec_commands: A list of exec commands to be executed + before the command. Default commands will be run if omitted. + Pass an empty list to run no commands. + :param init_config_commands: A list of config commands to be executed + before the command. Default commands will be run if omitted. + Pass an empty list to run no commands. + :returns: The output from the device. + :raises PyatsDeviceNotFound: If the device cannot be found. """ - pyats_device = self._prepare_device(node_label) + self._check_pyats_installed() + + try: + pyats_device: Device = self._testbed.devices[node_label] + except KeyError: + raise PyatsDeviceNotFound(node_label) + + if pyats_device not in self._connections or not pyats_device.is_connected(): + if pyats_device in self._connections: + pyats_device.destroy() + + pyats_device.connect(log_stdout=False, learn_hostname=True) + self._connections.add(pyats_device) params = self._prepare_params(init_exec_commands, init_config_commands) - return pyats_device.execute(command, log_stdout=False, **params) + if configure_mode: + return pyats_device.configure(command, log_stdout=False, **params) + else: + return pyats_device.execute(command, log_stdout=False, **params) - def run_config_command( + def run_command( self, node_label: str, command: str, - init_exec_commands: Optional[list] = None, - init_config_commands: Optional[list] = None, + init_exec_commands: list[str] | None = None, + init_config_commands: list[str] | None = None, ) -> str: """ - Run a command on the device in `configure` mode. pyATS - handles the change into `configure` mode automatically. + Run a command on the device in exec mode. + + :param node_label: The label/title of the device. + :param command: The command to be run in exec mode. + :param init_exec_commands: A list of exec commands to be executed + before the command. Default commands will be run if omitted. + Pass an empty list to run no commands. + :param init_config_commands: A list of config commands to be executed + before the command. Default commands will be run if omitted. + Pass an empty list to run no commands. + :returns: The output from the device. + """ + return self._execute_command( + node_label, + command, + configure_mode=False, + init_exec_commands=init_exec_commands, + init_config_commands=init_config_commands, + ) - :param node_label: the label / title of the device - :param command: the command to be run in exec mode - :param init_exec_commands: a list of exec commands to be executed - :param init_config_commangs: a list of config commands to be executed - :returns: the output from the device + def run_config_command( + self, + node_label: str, + command: str, + init_exec_commands: list[str] | None = None, + init_config_commands: list[str] | None = None, + ) -> str: """ - pyats_device = self._prepare_device(node_label) - params = self._prepare_params(init_exec_commands, init_config_commands) - return pyats_device.configure(command, log_stdout=False, **params) + Run a command on the device in configure mode. pyATS automatically handles the + change into configure mode. + + :param node_label: The label/title of the device. + :param command: The command to be run in configure mode. + :param init_exec_commands: A list of exec commands to be executed + before the command. Default commands will be run if omitted. + Pass an empty list to run no commands. + :param init_config_commands: A list of config commands to be executed + before the command. Default commands will be run if omitted. + Pass an empty list to run no commands. + :returns: The output from the device. + """ + return self._execute_command( + node_label, + command, + configure_mode=True, + init_exec_commands=init_exec_commands, + init_config_commands=init_config_commands, + ) def cleanup(self) -> None: - """Cleans up the pyATS connections.""" + """Clean up the pyATS connections.""" for pyats_device in self._connections: pyats_device.destroy() + self._connections.clear() diff --git a/virl2_client/models/configuration.py b/virl2_client/models/configuration.py index 79d8fd3..059af6e 100644 --- a/virl2_client/models/configuration.py +++ b/virl2_client/models/configuration.py @@ -1,7 +1,24 @@ +# +# This file is part of VIRL 2 +# Copyright (c) 2019-2023, Cisco Systems, Inc. +# All rights reserved. +# +# Python bindings for the Cisco VIRL 2 Network Simulation Platform +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# """ -https://github.com/CiscoDevNet/virlutils/blob/master/virl/api/credentials.py - -Collection of utility classes to make getting credentials and +A collection of utility classes to make getting credentials and configuration easier. """ from __future__ import annotations @@ -10,12 +27,19 @@ import os from pathlib import Path -from virl2_client.exceptions import InitializationError +from ..exceptions import InitializationError _CONFIG_FILE_NAME = ".virlrc" def _get_from_file(virlrc_parent: Path, prop_name: str) -> str | None: + """ + Retrieve a property value from a file. + + :param virlrc_parent: The parent directory of the file. + :param prop_name: The name of the property to retrieve. + :returns: The value of the property if found, otherwise None. + """ virlrc = virlrc_parent / _CONFIG_FILE_NAME if virlrc.is_file(): with virlrc.open() as fh: @@ -32,12 +56,17 @@ def _get_from_file(virlrc_parent: Path, prop_name: str) -> str | None: def _get_prop(prop_name: str) -> str: """ - Gets a variable using the following order: - * Check for .virlrc in current directory - * Recurse up directory tree for .virlrc - * Check environment variables - * Check ~/.virlrc + Get the value of a variable. + + The function follows the order: + + 1. Check for .virlrc in the current directory + 2. Recurse up the directory tree for .virlrc + 3. Check environment variables + 4. Check ~/.virlrc + :param prop_name: The name of the property to retrieve. + :returns: The value of the property if found, otherwise None. """ # check for .virlrc in current directory cwd = Path.cwd() @@ -64,20 +93,30 @@ def get_configuration( host: str, username: str, password: str, ssl_verify: bool | str ) -> tuple[str, str, str, str]: """ - Used to get login configuration - The login configuration is taken in the following order: - * Check for .virlrc in the current directory - * Recurse up directory tree for .virlrc - * Check environment variables - * Check ~/.virlrc - * Prompt user (except cert path) + Get the login configuration. + + The login configuration is retrieved in the following order: + + 1. Retrieve from function arguments + 2. Check for .virlrc in the current directory + 3. Recurse up the directory tree for .virlrc + 4. Check environment variables + 5. Check ~/.virlrc + 6. Prompt the user (except for the cert path) + + :param host: The host address of the CML server. + :param username: The username. + :param password: The password. + :param ssl_verify: The CA bundle path or boolean value indicating SSL verification. + :returns: A tuple containing the host, username, password, + and SSL verification information. """ if not ( host or (host := _get_prop("VIRL2_URL") or _get_prop("VIRL_HOST")) or (host := input("Please enter the IP / hostname of your virl server: ")) ): - message = "no URL provided" + message = "No URL provided." raise InitializationError(message) if not ( @@ -85,7 +124,7 @@ def get_configuration( or (username := _get_prop("VIRL2_USER") or _get_prop("VIRL_USERNAME")) or (username := input("Please enter your VIRL username: ")) ): - message = "no username provided" + message = "No username provided." raise InitializationError(message) if not ( @@ -93,7 +132,7 @@ def get_configuration( or (password := _get_prop("VIRL2_PASS") or _get_prop("VIRL_PASSWORD")) or (password := getpass.getpass("Please enter your password: ")) ): - message = "no password provided" + message = "No password provided." raise InitializationError(message) if ssl_verify is True: diff --git a/virl2_client/models/groups.py b/virl2_client/models/groups.py index 5394323..251aca4 100644 --- a/virl2_client/models/groups.py +++ b/virl2_client/models/groups.py @@ -20,65 +20,81 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING -from ..utils import UNCHANGED +from ..utils import UNCHANGED, get_url_from_template if TYPE_CHECKING: import httpx class GroupManagement: + """Manage groups.""" + + _URL_TEMPLATES = { + "groups": "groups", + "group": "groups/{group_id}", + "members": "groups/{group_id}/members", + "labs": "groups/{group_id}/labs", + "id": "groups/{group_name}/id", + } + def __init__(self, session: httpx.Client) -> None: self._session = session - @property - def base_url(self) -> str: - return "groups" + def _url_for(self, endpoint, **kwargs): + """ + Generate the URL for a given API endpoint. + + :param endpoint: The desired endpoint. + :param **kwargs: Keyword arguments used to format the URL. + :returns: The formatted URL. + """ + return get_url_from_template(endpoint, self._URL_TEMPLATES, kwargs) def groups(self) -> list: """ Get the list of available groups. - :return: list of group objects + :returns: A list of group objects. """ - return self._session.get(self.base_url).json() + url = self._url_for("groups") + return self._session.get(url).json() def get_group(self, group_id: str) -> dict: """ - Gets info for the specified group. + Get information for the specified group. - :param group_id: group uuid4 - :return: group object + :param group_id: The UUID4 of the group. + :returns: A group object. """ - url = self.base_url + "/{}".format(group_id) + url = self._url_for("group", group_id=group_id) return self._session.get(url).json() def delete_group(self, group_id: str) -> None: """ - Deletes a group. + Delete a group. - :param group_id: group uuid4 - :return: None + :param group_id: The UUID4 of the group. """ - url = self.base_url + "/{}".format(group_id) + url = self._url_for("group", group_id=group_id) self._session.delete(url) def create_group( self, name: str, description: str = "", - members: Optional[list[str]] = None, - labs: Optional[list[dict[str, str]]] = None, + members: list[str] | None = None, + labs: list[dict[str, str]] | None = None, ) -> dict: """ - Creates a group. + Create a group. - :param name: group name - :param description: group description - :param members: group members - :param labs: group labs - :return: created group object + :param name: The name of the group. + :param description: The description of the group. + :param members: The members of the group. + :param labs: The labs associated with the group. + :returns: The created group object. """ data = { "name": name, @@ -86,25 +102,26 @@ def create_group( "members": members or [], "labs": labs or [], } - return self._session.post(self.base_url, json=data).json() + url = self._url_for("groups") + return self._session.post(url, json=data).json() def update_group( self, group_id: str, - name: Optional[str] = None, - description: Optional[str] = UNCHANGED, - members: Optional[list[str]] = UNCHANGED, - labs: Optional[list[dict[str, str]]] = UNCHANGED, + name: str | None = None, + description: str | None = UNCHANGED, + members: list[str] | None = UNCHANGED, + labs: list[dict[str, str]] | None = UNCHANGED, ) -> dict: """ - Updates a group. + Update a group. - :param group_id: group uuid4 - :param name: new group name - :param description: group description - :param members: group members - :param labs: group labs - :return: updated group object + :param group_id: The UUID4 of the group. + :param name: The new name of the group. + :param description: The description of the group. + :param members: The members of the group. + :param labs: The labs associated with the group. + :returns: The updated group object. """ data: dict[str, str | list] = {} if name is not None: @@ -115,35 +132,35 @@ def update_group( data["members"] = members if labs is not UNCHANGED: data["labs"] = labs - url = self.base_url + "/{}".format(group_id) + url = self._url_for("group", group_id=group_id) return self._session.patch(url, json=data).json() def group_members(self, group_id: str) -> list[str]: """ - Gets group members. + Get the members of a group. - :param group_id: group uuid4 - :return: list of users associated with this group + :param group_id: The UUID4 of the group. + :returns: A list of users associated with this group. """ - url = self.base_url + "/{}/members".format(group_id) + url = self._url_for("members", group_id=group_id) return self._session.get(url).json() def group_labs(self, group_id: str) -> list[str]: """ - Get the list of labs that are associated with this group. + Get a list of labs associated with a group. - :param group_id: group uuid4 - :return: list of labs associated with this group + :param group_id: The UUID4 of the group. + :returns: A list of labs associated with this group. """ - url = self.base_url + "/{}/labs".format(group_id) + url = self._url_for("labs", group_id=group_id) return self._session.get(url).json() def group_id(self, group_name: str) -> str: """ - Get unique user uuid4. + Get the unique UUID4 of a group. - :param group_name: group name - :return: group unique identifier + :param group_name: The name of the group. + :returns: The unique identifier of the group. """ - url = self.base_url + "/{}/id".format(group_name) + url = self._url_for("id", group_name=group_name) return self._session.get(url).json() diff --git a/virl2_client/models/interface.py b/virl2_client/models/interface.py index 8c4eb54..32bc6a9 100644 --- a/virl2_client/models/interface.py +++ b/virl2_client/models/interface.py @@ -23,9 +23,9 @@ import logging import warnings from functools import total_ordering -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING -from ..utils import check_stale +from ..utils import check_stale, get_url_from_template from ..utils import property_s as property if TYPE_CHECKING: @@ -39,30 +39,37 @@ @total_ordering class Interface: + _URL_TEMPLATES = { + "interface": "{lab}/interfaces/{id}", + "state": "{lab}/interfaces/{id}/state", + "start": "{lab}/interfaces/{id}/state/start", + "stop": "{lab}/interfaces/{id}/state/stop", + } + def __init__( self, iid: str, node: Node, label: str, - slot: Optional[int], + slot: int | None, iface_type: str = "physical", ) -> None: """ - A VIRL2 network interface, part of a node. + A CML 2 network interface, part of a node. - :param iid: interface ID - :param node: node object - :param label: the label of the interface - :param slot: the slot of the interface - :param iface_type: the type of the interface, defaults to "physical" + :param iid: The interface ID. + :param node: The node object. + :param label: The label of the interface. + :param slot: The slot of the interface. + :param iface_type: The type of the interface. """ self._id = iid self._node = node self._type = iface_type self._label = label self._slot = slot - self._state: Optional[str] = None - self._session: httpx.Client = node.lab.session + self._state: str | None = None + self._session: httpx.Client = node.lab._session self._stale = False self.statistics = { "readbytes": 0, @@ -70,7 +77,7 @@ def __init__( "writebytes": 0, "writepackets": 0, } - self.ip_snooped_info: dict[str, Optional[str]] = { + self._ip_snooped_info: dict[str, str | None] = { "mac_address": None, "ipv4": None, "ipv6": None, @@ -102,56 +109,66 @@ def __repr__(self): def __hash__(self): return hash(self._id) - @property - def lab_base_url(self) -> str: - return self.node.lab_base_url + def _url_for(self, endpoint, **kwargs): + """ + Generate the URL for a given API endpoint. - @property - def _base_url(self) -> str: - return self.lab_base_url + "/interfaces/{}".format(self.id) + :param endpoint: The desired endpoint. + :param **kwargs: Keyword arguments used to format the URL. + :returns: The formatted URL. + """ + kwargs["lab"] = self.node.lab._url_for("lab") + kwargs["id"] = self.id + return get_url_from_template(endpoint, self._URL_TEMPLATES, kwargs) @property def id(self) -> str: + """Return the ID of the interface.""" return self._id @property def node(self) -> Node: + """Return the node object to which the interface belongs.""" return self._node @property def type(self) -> str: + """Return the type of the interface.""" return self._type @property def label(self) -> str: + """Return the label of the interface.""" return self._label @property - def slot(self) -> Optional[int]: + def slot(self) -> int | None: + """Return the slot of the interface.""" return self._slot @property def physical(self) -> bool: - """Whether the interface is physical.""" + """Check if the interface is physical.""" self.node.lab.sync_topology_if_outdated() return self.type == "physical" @property def connected(self) -> bool: - """Whether the interface is connected to a link.""" + """Check if the interface is connected to a link.""" return self.link is not None @property - def state(self) -> Optional[str]: + def state(self) -> str | None: + """Return the state of the interface.""" self.node.lab.sync_states_if_outdated() if self._state is None: - url = self._base_url + "/state" + url = self._url_for("state") self._state = self._session.get(url).json()["state"] return self._state @property - def link(self) -> Optional[Link]: - """Is link if connected, otherwise None.""" + def link(self) -> Link | None: + """Get the link if the interface is connected, otherwise None.""" self.node.lab.sync_topology_if_outdated() for link in self.node.lab.links(): if self in link.interfaces: @@ -159,7 +176,8 @@ def link(self) -> Optional[Link]: return None @property - def peer_interface(self) -> Optional[Interface]: + def peer_interface(self) -> Interface | None: + """Return the peer interface connected to the interface.""" link = self.link if link is None: return None @@ -169,60 +187,98 @@ def peer_interface(self) -> Optional[Interface]: return interfaces[0] @property - def peer_node(self) -> Optional[Node]: + def peer_node(self) -> Node | None: + """Return the node to which the peer interface belongs.""" peer_interface = self.peer_interface return peer_interface.node if peer_interface is not None else None @property def readbytes(self) -> int: + """Return the number of bytes read by the interface.""" self.node.lab.sync_statistics_if_outdated() return int(self.statistics["readbytes"]) @property def readpackets(self) -> int: + """Return the number of packets read by the interface.""" self.node.lab.sync_statistics_if_outdated() return int(self.statistics["readpackets"]) @property def writebytes(self) -> int: + """Return the number of bytes written by the interface.""" self.node.lab.sync_statistics_if_outdated() return int(self.statistics["writebytes"]) @property def writepackets(self) -> int: + """Return the number of packets written by the interface.""" self.node.lab.sync_statistics_if_outdated() return int(self.statistics["writepackets"]) @property - def discovered_mac_address(self) -> Optional[str]: - self.node.lab.sync_l3_addresses_if_outdated() - return self.ip_snooped_info["mac_address"] + def ip_snooped_info(self) -> dict[str, str | None]: + """ + Return the discovered MAC, IPv4 and IPv6 addresses + of the interface in a dictionary. + """ + self.node.sync_l3_addresses_if_outdated() + return self._ip_snooped_info.copy() @property - def discovered_ipv4(self) -> Optional[str]: - self.node.lab.sync_l3_addresses_if_outdated() - return self.ip_snooped_info["ipv4"] + def discovered_mac_address(self) -> str | None: + """Return the discovered MAC address of the interface.""" + self.node.sync_l3_addresses_if_outdated() + return self._ip_snooped_info["mac_address"] @property - def discovered_ipv6(self) -> Optional[str]: - self.node.lab.sync_l3_addresses_if_outdated() - return self.ip_snooped_info["ipv6"] + def discovered_ipv4(self) -> str | None: + """Return the discovered IPv4 address of the interface.""" + self.node.sync_l3_addresses_if_outdated() + return self._ip_snooped_info["ipv4"] @property - def is_physical(self): - warnings.warn("Deprecated, use .physical instead.", DeprecationWarning) + def discovered_ipv6(self) -> str | None: + """Return the discovered IPv6 address of the interface.""" + self.node.sync_l3_addresses_if_outdated() + return self._ip_snooped_info["ipv6"] + + @property + def is_physical(self) -> bool: + """ + DEPRECATED: Use `.physical` instead. + (Reason: renamed to match similar parameters) + + Check if the interface is physical. + """ + warnings.warn( + "'Interface.is_physical' is deprecated. Use '.physical' instead.", + DeprecationWarning, + ) return self.physical def as_dict(self) -> dict[str, str]: - # TODO what should be here in 'data' key? - return {"id": self.id, "node": self.node.id, "data": self.id} + """Convert the interface to a dictionary representation.""" + return { + "id": self.id, + "node": self.node.id, + "data": { + "lab_id": self.node.lab.id, + "label": self.label, + "slot": self.slot, + "type": self.type, + "mac_address": self.discovered_mac_address, + "is_connected": self.connected, + "state": self.state, + }, + } - def get_link_to(self, other_interface: Interface) -> Optional[Link]: + def get_link_to(self, other_interface: Interface) -> Link | None: """ - Returns the link between this interface and another. + Return the link between this interface and another. - :param other_interface: the other interface - :returns: A Link + :param other_interface: The other interface. + :returns: A Link object if a link exists between the interfaces, None otherwise. """ link = self.link if link is not None and other_interface in link.interfaces: @@ -230,51 +286,110 @@ def get_link_to(self, other_interface: Interface) -> Optional[Link]: return None def remove(self) -> None: + """Remove the interface from the node.""" self.node.lab.remove_interface(self) @check_stale def _remove_on_server(self) -> None: - _LOGGER.info("Removing interface %s", self) - url = self._base_url + """ + Remove the interface on the server. + + This method is internal and should not be used directly. Use 'remove' instead. + """ + _LOGGER.info(f"Removing interface {self}") + url = self._url_for("interface") self._session.delete(url) def remove_on_server(self) -> None: + """ + DEPRECATED: Use `.remove()` instead. + (Reason: was never meant to be public, removing only on server is not useful) + + Remove the interface on the server. + """ warnings.warn( - "'Interface.remove_on_server()' is deprecated, " - "use 'Interface.remove()' instead.", + "'Interface.remove_on_server()' is deprecated. Use '.remove()' instead.", DeprecationWarning, ) self._remove_on_server() @check_stale def bring_up(self) -> None: - url = self._base_url + "/state/start" + """Bring up the interface.""" + url = self._url_for("start") self._session.put(url) @check_stale def shutdown(self) -> None: - url = self._base_url + "/state/stop" + """Shutdown the interface.""" + url = self._url_for("stop") self._session.put(url) def peer_interfaces(self): - warnings.warn("Deprecated, use .peer_interface instead.", DeprecationWarning) + """ + DEPRECATED: Use `.peer_interface` instead. + (Reason: pointless plural, could have been a parameter) + + Return the peer interface connected to this interface in a set. + """ + warnings.warn( + "'Interface.peer_interfaces()' is deprecated, " + "use '.peer_interface' instead.", + DeprecationWarning, + ) return {self.peer_interface} def peer_nodes(self): - warnings.warn("Deprecated, use .peer_node instead.", DeprecationWarning) + """ + DEPRECATED: Use `.peer_node` instead. + (Reason: pointless plural, could have been a parameter) + + Return the node of the peer interface in a set. + """ + warnings.warn( + "'Interface.peer_nodes() is deprecated. Use '.peer_node' instead.", + DeprecationWarning, + ) return {self.peer_node} def links(self): - warnings.warn("Deprecated, use .link instead.", DeprecationWarning) + """ + DEPRECATED: Use `.link` instead. + (Reason: pointless plural, could have been a parameter) + + Return the link connected to this interface in a list. + """ + warnings.warn( + "'Interface.links()' is deprecated. Use '.link' instead.", + DeprecationWarning, + ) link = self.link if link is None: return [] return [link] def degree(self): - warnings.warn("Deprecated, use .connected instead.", DeprecationWarning) + """ + DEPRECATED: Use `.connected` instead. + (Reason: degree always 0 or 1) + + Return the degree of the interface. + """ + warnings.warn( + "'Interface.degree()' is deprecated. Use '.connected' instead.", + DeprecationWarning, + ) return int(self.connected) def is_connected(self): - warnings.warn("Deprecated, use .connected instead.", DeprecationWarning) + """ + DEPRECATED: Use `.connected` instead. + (Reason: should have been a parameter, renamed to match similar parameters) + + Check if the interface is connected to a link. + """ + warnings.warn( + "'Interface.is_connected()' is deprecated. Use '.connected' instead.", + DeprecationWarning, + ) return self.connected diff --git a/virl2_client/models/lab.py b/virl2_client/models/lab.py index 814c6a4..e2dbc75 100644 --- a/virl2_client/models/lab.py +++ b/virl2_client/models/lab.py @@ -24,7 +24,7 @@ import logging import time import warnings -from typing import TYPE_CHECKING, Any, Iterable, Optional +from typing import TYPE_CHECKING, Any, Iterable from httpx import HTTPStatusError @@ -36,7 +36,7 @@ NodeNotFound, VirlException, ) -from ..utils import check_stale, locked +from ..utils import check_stale, get_url_from_template, locked from ..utils import property_s as property from .cl_pyats import ClPyats from .interface import Interface @@ -53,37 +53,65 @@ class Lab: + _URL_TEMPLATES = { + "lab": "labs/{lab_id}", + "nodes": "labs/{lab_id}/nodes", + "nodes_populated": "labs/{lab_id}/nodes?populate_interfaces=true", + "links": "labs/{lab_id}/links", + "interfaces": "labs/{lab_id}/interfaces", + "simulation_stats": "labs/{lab_id}/simulation_stats", + "lab_element_state": "labs/{lab_id}/lab_element_state", + "check_if_converged": "labs/{lab_id}/check_if_converged", + "start": "labs/{lab_id}/start", + "stop": "labs/{lab_id}/stop", + "state": "labs/{lab_id}/state", + "wipe": "labs/{lab_id}/wipe", + "events": "labs/{lab_id}/events", + "build_configurations": "build_configurations?lab_id={lab_id}", + "topology": "labs/{lab_id}/topology", + "pyats_testbed": "labs/{lab_id}/pyats_testbed", + "layer3_addresses": "labs/{lab_id}/layer3_addresses", + "download": "labs/{lab_id}/download", + "groups": "labs/{lab_id}/groups", + "connector_mappings": "labs/{lab_id}/connector_mappings", + "resource_pools": "labs/{lab_id}/resource_pools", + } + def __init__( self, - title: Optional[str], + title: str | None, lab_id: str, session: httpx.Client, username: str, password: str, - auto_sync=True, - auto_sync_interval=1.0, - wait=True, - wait_max_iterations=500, - wait_time=5, - hostname: Optional[str] = None, - resource_pool_manager: Optional[ResourcePoolManagement] = None, + auto_sync: bool = True, + auto_sync_interval: float = 1.0, + wait: bool = True, + wait_max_iterations: int = 500, + wait_time: int = 5, + hostname: str | None = None, + resource_pool_manager: ResourcePoolManagement | None = None, ) -> None: """ - A VIRL2 lab network topology. Contains nodes, links and interfaces. - - :param title: Name / title of the lab - :param lab_id: A lab ID - :param session: httpx Client session - :param username: Username of the user to authenticate - :param password: Password of the user to authenticate - :param auto_sync: Should local changes sync to the server automatically - :param auto_sync_interval: Interval to auto sync in seconds - :param wait: Wait for convergence on backend - :param wait_max_iterations: Maximum number of tries or calls for convergence - :param wait_time: Time to sleep between calls for convergence on backend - :param hostname: Force hostname/ip and port for pyATS console terminal server - :param resource_pool_manager: a ResourcePoolManagement object - shared with parent ClientLibrary + A VIRL2 lab network topology. + + :param title: Name or title of the lab. + :param lab_id: Lab ID. + :param session: The httpx-based HTTP client for this session with the server. + :param username: Username of the user to authenticate. + :param password: Password of the user to authenticate. + :param auto_sync: A flag indicating whether local changes should be + automatically synced to the server. + :param auto_sync_interval: Interval in seconds for automatic syncing. + :param wait: A flag indicating whether to wait for convergence on the backend. + :param wait_max_iterations: Maximum number of iterations for convergence. + :param wait_time: Time in seconds to sleep between convergence calls + on the backend. + :param hostname: Hostname/IP and port for pyATS console terminal server. + :param resource_pool_manager: ResourcePoolManagement object shared + with parent ClientLibrary. + :raises VirlException: If the lab object is created without + a resource pool manager. """ self.username = username @@ -99,17 +127,17 @@ def __init__( self._nodes: dict[str, Node] = {} """ Dictionary containing all nodes in the lab. - It maps node identifier to `models.Node` + It maps node identifier to `models.Node`. """ self._links: dict[str, Link] = {} """ Dictionary containing all links in the lab. - It maps link identifier to `models.Link` + It maps link identifier to `models.Link`. """ self._interfaces: dict[str, Interface] = {} """ Dictionary containing all interfaces in the lab. - It maps interface identifier to `models.Interface` + It maps interface identifier to `models.Interface`. """ self.events: list = [] self.pyats = ClPyats(self, hostname) @@ -154,7 +182,28 @@ def __repr__(self): self.wait_for_convergence, ) - def need_to_wait(self, local_wait: Optional[bool]) -> bool: + def _url_for(self, endpoint, **kwargs): + """ + Generate the URL for a given API endpoint. + + :param endpoint: The desired endpoint. + :param **kwargs: Keyword arguments used to format the URL. + :returns: The formatted URL. + """ + kwargs["lab_id"] = self._id + return get_url_from_template(endpoint, self._URL_TEMPLATES, kwargs) + + def need_to_wait(self, local_wait: bool | None) -> bool: + """ + Check if waiting is required. + + If `local_wait` is `None`, return the value of `wait_for_convergence`. + If `local_wait` is not a boolean, raise a `ValueError`. + + :param local_wait: Local wait flag. + :returns: True if waiting is required, False otherwise. + :raises ValueError: If `local_wait` is not a boolean. + """ if local_wait is None: return self.wait_for_convergence if not isinstance(local_wait, bool): @@ -164,6 +213,7 @@ def need_to_wait(self, local_wait: Optional[bool]) -> bool: @check_stale @locked def sync_statistics_if_outdated(self) -> None: + """Sync statistics if they are outdated.""" timestamp = time.time() if ( self.auto_sync @@ -174,6 +224,7 @@ def sync_statistics_if_outdated(self) -> None: @check_stale @locked def sync_states_if_outdated(self) -> None: + """Sync states if they are outdated.""" timestamp = time.time() if ( self.auto_sync @@ -184,6 +235,7 @@ def sync_states_if_outdated(self) -> None: @check_stale @locked def sync_l3_addresses_if_outdated(self) -> None: + """Sync L3 addresses if they are outdated.""" timestamp = time.time() if ( self.auto_sync @@ -194,6 +246,7 @@ def sync_l3_addresses_if_outdated(self) -> None: @check_stale @locked def sync_topology_if_outdated(self) -> None: + """Sync the topology if it is outdated.""" timestamp = time.time() if ( self.auto_sync @@ -204,6 +257,7 @@ def sync_topology_if_outdated(self) -> None: @check_stale @locked def sync_operational_if_outdated(self) -> None: + """Sync the operational data if it is outdated.""" timestamp = time.time() if ( self.auto_sync @@ -213,140 +267,101 @@ def sync_operational_if_outdated(self) -> None: @property def id(self) -> str: - """ - Returns the ID of the lab (a 6 digit hex string). - - :returns: The Lab ID - """ - + """Return the ID of the lab (a 6 digit hex string).""" return self._id @property - def title(self) -> Optional[str]: - """ - Returns the title of the lab. - - :returns: The lab name - """ + def title(self) -> str | None: + """Return the title of the lab.""" self.sync_topology_if_outdated() return self._title @title.setter def title(self, value: str) -> None: - """ - Set the title of the lab. - - :param value: The new lab title - """ + """Set the title of the lab.""" self._set_property("title", value) @property def notes(self) -> str: - """ - Returns the notes of the lab. - - :returns: The lab name - """ + """Return the notes of the lab.""" self.sync_topology_if_outdated() return self._notes @notes.setter def notes(self, value: str) -> None: - """ - Set the notes of the lab. - - :param value: The new lab notes - """ + """Set the notes of the lab.""" self._set_property("notes", value) @property def description(self) -> str: - """ - Returns the description of the lab. - - :returns: The lab name - """ + """Return the description of the lab.""" self.sync_topology_if_outdated() return self._description @description.setter def description(self, value: str) -> None: - """ - Set the description of the lab. - - :param value: The new lab description - """ + """Set the description of the lab.""" self._set_property("description", value) @check_stale @locked def _set_property(self, prop: str, value: Any): - url = self.lab_base_url - self._session.patch(url, json={prop: value}) - setattr(self, f"_{prop}", value) - - @property - def session(self) -> httpx.Client: """ - Returns the API client session object. + Set the value of a lab property both locally and on the server. - :returns: The session object + :param prop: The name of the property. + :param value: The new value of the property. """ - return self._session + url = self._url_for("lab") + self._session.patch(url, json={prop: value}) + setattr(self, f"_{prop}", value) @property def owner(self) -> str: - """ - Returns the owner of the lab. - - :returns: A username - """ + """Return the owner of the lab.""" self.sync_topology_if_outdated() return self._owner @property def resource_pools(self) -> list[ResourcePool]: + """Return the list of resource pools this lab's nodes belong to.""" self.sync_operational_if_outdated() return self._resource_pools def nodes(self) -> list[Node]: """ - Returns the list of nodes in the lab. + Return the list of nodes in the lab. - :returns: A list of Node objects + :returns: A list of Node objects. """ self.sync_topology_if_outdated() return list(self._nodes.values()) def links(self) -> list[Link]: """ - Returns the list of links in the lab. + Return the list of links in the lab. - :returns: A list of Link objects + :returns: A list of Link objects. """ self.sync_topology_if_outdated() return list(self._links.values()) def interfaces(self) -> list[Interface]: """ - Returns the list of interfaces in the lab. + Return the list of interfaces in the lab. - :returns: A list of Interface objects + :returns: A list of Interface objects. """ self.sync_topology_if_outdated() return list(self._interfaces.values()) - @property - def lab_base_url(self) -> str: - return "labs/{}".format(self._id) - @property @locked def statistics(self) -> dict: """ - Returns some statistics about the lab. + Return lab statistics. - :returns: A dictionary with stats of the lab + :returns: A dictionary with stats of the lab. """ return { "nodes": len(self._nodes), @@ -356,11 +371,11 @@ def statistics(self) -> dict: def get_node_by_id(self, node_id: str) -> Node: """ - Returns the node identified by the ID. + Return the node identified by the ID. - :param node_id: ID of the node to be returned - :returns: A Node object - :raises NodeNotFound: If node not found + :param node_id: ID of the node to be returned. + :returns: A Node object. + :raises NodeNotFound: If the node is not found. """ self.sync_topology_if_outdated() try: @@ -370,11 +385,11 @@ def get_node_by_id(self, node_id: str) -> Node: def get_node_by_label(self, label: str) -> Node: """ - Returns the node identified by the label. + Return the node identified by the label. - :param label: label of the node to be returned - :returns: A Node object - :raises NodeNotFound: If node not found + :param label: Label of the node to be returned. + :returns: A Node object. + :raises NodeNotFound: If the node is not found. """ self.sync_topology_if_outdated() for node in self._nodes.values(): @@ -384,11 +399,11 @@ def get_node_by_label(self, label: str) -> Node: def get_interface_by_id(self, interface_id: str) -> Interface: """ - Returns the interface identified by the ID. + Return the interface identified by the ID. - :param interface_id: ID of the interface to be returned - :returns: An Interface object - :raises InterfaceNotFound: If interface not found + :param interface_id: ID of the interface to be returned. + :returns: An Interface object. + :raises InterfaceNotFound: If the interface is not found. """ self.sync_topology_if_outdated() try: @@ -398,11 +413,11 @@ def get_interface_by_id(self, interface_id: str) -> Interface: def get_link_by_id(self, link_id: str) -> Link: """ - Returns the link identified by the ID. + Return the link identified by the ID. - :param link_id: ID of the interface to be returned - :returns: A Link object - :raises LinkNotFound: If interface not found + :param link_id: ID of the link to be returned. + :returns: A Link object. + :raises LinkNotFound: If the link is not found. """ self.sync_topology_if_outdated() try: @@ -413,20 +428,20 @@ def get_link_by_id(self, link_id: str) -> Link: @staticmethod def get_link_by_nodes(node1: Node, node2: Node) -> Link: """ - DEPRECATED + DEPRECATED: Use `Node.get_link_to()` to get one link + or `Node.get_links_to()` to get all links. + (Reason: redundancy) - Returns ONE of the links identified by two node objects. + Return ONE of the links identified by two node objects. - Deprecated. Use `Node.get_link_to` to get one link - or `Node.get_links_to` to get all links. - - :param node1: the first node - :param node2: the second node - :returns: one of links between the nodes - :raises LinkNotFound: If no such link exists + :param node1: The first node. + :param node2: The second node. + :returns: One of links between the nodes. + :raises LinkNotFound: If no such link exists. """ warnings.warn( - "Deprecated, use Node.get_link_to or Node.get_links_to instead.", + "'Lab.get_link_by_nodes()' is deprecated. " + "Use 'Node.get_link_to()' or 'Node.get_links_to()' instead.", DeprecationWarning, ) if not (links := node1.get_links_to(node2)): @@ -434,21 +449,22 @@ def get_link_by_nodes(node1: Node, node2: Node) -> Link: return links[0] @staticmethod - def get_link_by_interfaces(iface1: Interface, iface2: Interface) -> Optional[Link]: + def get_link_by_interfaces(iface1: Interface, iface2: Interface) -> Link | None: """ - DEPRECATED - - Returns the link identified by two interface objects. + DEPRECATED: Use `Interface.get_link_to()` instead. + (Reason: redundancy) - Deprecated. Use `Interface.get_link_to` to get link. + Return the link identified by two interface objects. - :param iface1: the first interface - :param iface2: the second interface - :returns: the link between the interfaces - :raises LinkNotFound: If no such link exists + :param iface1: The first interface. + :param iface2: The second interface. + :returns: The link between the interfaces. + :raises LinkNotFound: If no such link exists. """ warnings.warn( - "Deprecated, use Interface.get_link_to instead.", DeprecationWarning + "'Lab.get_link_by_interfaces()' is deprecated. " + "Use 'Interface.get_link_to()' instead.", + DeprecationWarning, ) if (link := iface1.link) is not None and iface2 in link.interfaces: return link @@ -456,10 +472,10 @@ def get_link_by_interfaces(iface1: Interface, iface2: Interface) -> Optional[Lin def find_nodes_by_tag(self, tag: str) -> list[Node]: """ - Returns the nodes identified by the given tag. + Return the nodes identified by the given tag. - :param tag: tag of the nodes to be returned - :returns: a list of nodes + :param tag: The tag by which to search. + :returns: A list of nodes. """ self.sync_topology_if_outdated() return [node for node in self.nodes() if tag in node.tags()] @@ -472,27 +488,34 @@ def create_node( node_definition: str, x: int = 0, y: int = 0, - wait: Optional[bool] = None, + wait: bool | None = None, populate_interfaces: bool = False, **kwargs, ) -> Node: """ - Creates a node in the lab with the given parameters. + Create a node in the lab with the given parameters. - :param label: Label - :param node_definition: Node definition to use - :param x: x coordinate - :param y: y coordinate - :param wait: Wait for convergence (if left at default, - the lab wait property takes precedence) - :param populate_interfaces: automatically create pre-defined number - of interfaces on node creation - :returns: a Node object + :param label: Label. + :param node_definition: Node definition. + :param x: The X coordinate. + :param y: The Y coordinate. + :param wait: A flag indicating whether to wait for convergence. + If left at the default value, the lab's wait property is used instead. + :param populate_interfaces: Automatically create a pre-defined number + of interfaces on node creation. + :returns: A Node object. """ - # TODO: warn locally if label in use already? - url = self.lab_base_url + "/nodes" + try: + self.get_node_by_label(label) + _LOGGER.warning(f"Node with label {label} already exists.") + except NodeNotFound: + pass + if populate_interfaces: - url += "?populate_interfaces=true" + url = self._url_for("nodes_populated") + else: + url = self._url_for("nodes") + kwargs["label"] = label kwargs["node_definition"] = node_definition kwargs["x"] = x @@ -518,10 +541,17 @@ def create_node( return node def add_node_local(self, *args, **kwargs): + """ + DEPRECATED: Use `.create_node()` instead. + (Reason: only creates a node in the client, which is not useful; + if really needed, use `._create_node_local()`) + + Creates a node in the client, but not on the server. + """ warnings.warn( - "Deprecated. You probably want .create_node instead; " - "if you really want to create a node locally only, " - "use ._create_node_local.", + "'Lab.add_node_local' is deprecated. You probably want .create_node " + "instead. (If you really want to create a node locally only, " + "use '._create_node_local()'.)", DeprecationWarning, ) return self._create_node_local(*args, **kwargs) @@ -532,24 +562,22 @@ def _create_node_local( node_id: str, label: str, node_definition: str, - image_definition: Optional[str], - configuration: Optional[str], + image_definition: str | None, + configuration: str | None, x: int, y: int, - ram: Optional[int] = None, - cpus: Optional[int] = None, - cpu_limit: Optional[int] = None, - data_volume: Optional[int] = None, - boot_disk_size: Optional[int] = None, + ram: int | None = None, + cpus: int | None = None, + cpu_limit: int | None = None, + data_volume: int | None = None, + boot_disk_size: int | None = None, hide_links: bool = False, - tags: Optional[list] = None, - compute_id: Optional[str] = None, - resource_pool: Optional[ResourcePool] = None, + tags: list | None = None, + compute_id: str | None = None, + resource_pool: ResourcePool | None = None, ) -> Node: """Helper function to add a node to the client library.""" if tags is None: - # TODO: see if can deprecate now tags set automatically - # on server at creation tags = [] node = Node( @@ -577,18 +605,18 @@ def _create_node_local( @check_stale @locked - def remove_node(self, node: Node | str, wait: Optional[bool] = None) -> None: + def remove_node(self, node: Node | str, wait: bool | None = None) -> None: """ - Removes a node from the lab. + Remove a node from the lab. If you have a node object, you can also simply do:: node.remove() - :param node: the node object or ID - :param wait: Wait for convergence (if left at default, - the lab wait property takes precedence) + :param node: The node object or ID. + :param wait: A flag indicating whether to wait for convergence. + If left at the default value, the lab's wait property is used instead. """ if isinstance(node, str): node = self.get_node_by_id(node) @@ -597,7 +625,7 @@ def remove_node(self, node: Node | str, wait: Optional[bool] = None) -> None: if self.need_to_wait(wait): self.wait_until_lab_converged() - _LOGGER.debug("%s node removed from lab %s", node._id, self._id) + _LOGGER.debug(f"{node._id} node removed from lab {self._id}") @locked def _remove_node_local(self, node: Node) -> None: @@ -614,34 +642,35 @@ def _remove_node_local(self, node: Node) -> None: pass @locked - def remove_nodes(self, wait: Optional[bool] = None) -> None: + def remove_nodes(self, wait: bool | None = None) -> None: """ Remove all nodes from the lab. - :param wait: Wait for convergence (if left at default, - the lab wait property takes precedence) + :param wait: A flag indicating whether to wait for convergence. + If left at the default value, the lab's wait property is used instead. """ - # TODO: see if this is used - in testing? + # Use case - user was assigned one lab, wants to reset work; + # can't delete lab, so removing all nodes is the only option for node in list(self._nodes.values()): self.remove_node(node, wait=False) if self.need_to_wait(wait): self.wait_until_lab_converged() - _LOGGER.debug("all nodes removed from lab %s", self._id) + _LOGGER.debug(f"all nodes removed from lab {self._id}") @check_stale @locked - def remove_link(self, link: Link | str, wait: Optional[bool] = None) -> None: + def remove_link(self, link: Link | str, wait: bool | None = None) -> None: """ - Removes a link from the lab. + Remove a link from the lab. If you have a link object, you can also simply do:: link.remove() - :param link: the link object or ID - :param wait: Wait for convergence (if left at default, - the lab wait property takes precedence) + :param link: The link object or ID. + :param wait: A flag indicating whether to wait for convergence. + If left at the default value, the lab's wait property is used instead. """ if isinstance(link, str): link = self.get_link_by_id(link) @@ -650,7 +679,7 @@ def remove_link(self, link: Link | str, wait: Optional[bool] = None) -> None: if self.need_to_wait(wait): self.wait_until_lab_converged() - _LOGGER.debug("link %s removed from lab %s", link._id, self._id) + _LOGGER.debug(f"link {link._id} removed from lab {self._id}") @locked def _remove_link_local(self, link: Link) -> None: @@ -666,18 +695,18 @@ def _remove_link_local(self, link: Link) -> None: @check_stale @locked def remove_interface( - self, iface: Interface | str, wait: Optional[bool] = None + self, iface: Interface | str, wait: bool | None = None ) -> None: """ - Removes an interface from the lab. + Remove an interface from the lab. If you have an interface object, you can also simply do:: interface.remove() - :param iface: the interface object or ID - :param wait: Wait for convergence (if left at default, - the lab wait property takes precedence) + :param iface: The interface object or ID. + :param wait: A flag indicating whether to wait for convergence. + If left at the default value, the lab's wait property is used instead. """ if isinstance(iface, str): iface = self.get_interface_by_id(iface) @@ -687,7 +716,7 @@ def remove_interface( if self.need_to_wait(wait): self.wait_until_lab_converged() - _LOGGER.debug("interface %s removed from lab %s", iface._id, self._id) + _LOGGER.debug(f"interface {iface._id} removed from lab {self._id}") @locked def _remove_interface_local(self, iface: Interface) -> None: @@ -707,22 +736,22 @@ def _remove_interface_local(self, iface: Interface) -> None: @check_stale @locked def create_link( - self, i1: Interface | str, i2: Interface | str, wait: Optional[bool] = None + self, i1: Interface | str, i2: Interface | str, wait: bool | None = None ) -> Link: """ - Creates a link between two interfaces + Create a link between two interfaces. - :param i1: the first interface object or ID - :param i2: the second interface object or ID - :param wait: Wait for convergence (if left at default, - the lab wait property takes precedence) - :returns: the created link + :param i1: The first interface object or ID. + :param i2: The second interface object or ID. + :param wait: A flag indicating whether to wait for convergence. + If left at the default value, the lab's wait property is used instead. + :returns: The created link. """ if isinstance(i1, str): i1 = self.get_interface_by_id(i1) if isinstance(i2, str): i2 = self.get_interface_by_id(i2) - url = self.lab_base_url + "/links" + url = self._url_for("links") data = { "src_int": i1.id, "dst_int": i2.id, @@ -741,7 +770,7 @@ def create_link( @check_stale @locked def _create_link_local( - self, i1: Interface, i2: Interface, link_id: str, label: Optional[str] = None + self, i1: Interface, i2: Interface, link_id: str, label: str | None = None ) -> Link: """Helper function to create a link in the client library.""" link = Link(self, link_id, i1, i2, label) @@ -752,11 +781,11 @@ def _create_link_local( @locked def connect_two_nodes(self, node1: Node, node2: Node) -> Link: """ - Convenience method to connect two nodes within a lab. + Connect two nodes within a lab. - :param node1: the first node object - :param node2: the second node object - :returns: the created link + :param node1: The first node object. + :param node2: The second node object. + :returns: The created link. """ iface1 = node1.next_available_interface() or node1.create_interface() iface2 = node2.next_available_interface() or node2.create_interface() @@ -765,34 +794,33 @@ def connect_two_nodes(self, node1: Node, node2: Node) -> Link: @check_stale @locked def create_interface( - self, node: Node | str, slot: Optional[int] = None, wait: Optional[bool] = None + self, node: Node | str, slot: int | None = None, wait: bool | None = None ) -> Interface: """ Create an interface in the next available slot, or, if a slot is specified, in all available slots up to and including the specified slot. - :param node: The node on which the interface is created - :param slot: (optional) - :param wait: Wait for convergence (if left at default, - the lab wait property takes precedence) - :returns: The newly created interface + :param node: The node on which the interface is created. + :param slot: The slot number to create the interface in. + :param wait: A flag indicating whether to wait for convergence. + If left at the default value, the lab's wait property is used instead. + :returns: The newly created interface. """ if isinstance(node, str): node = self.get_node_by_id(node) - url = self.lab_base_url + "/interfaces" + url = self._url_for("interfaces") payload: dict[str, str | int] = {"node": node.id} if slot is not None: payload["slot"] = slot result = self._session.post(url, json=payload).json() if isinstance(result, dict): - # in case API returned just one interface, pack it into the list: + # in case API returned just one interface, pack it into a list: result = [result] if self.need_to_wait(wait): self.wait_until_lab_converged() - # TODO: need to import the topology then - desired_interface: Optional[Interface] = None + desired_interface: Interface | None = None for iface in result: lab_interface = self._create_interface_local( iface_id=iface["id"], @@ -817,7 +845,7 @@ def _create_interface_local( iface_id: str, label: str, node: Node, - slot: Optional[int], + slot: int | None, iface_type: str = "physical", ) -> Interface: """Helper function to create an interface in the client library.""" @@ -832,26 +860,24 @@ def _create_interface_local( iface._type = iface_type return iface - @staticmethod - def _get_element_from_data(data: dict, element: str) -> int: - try: - return int(data[element]) - except (TypeError, KeyError): - return 0 - @check_stale @locked def sync_statistics(self) -> None: """Retrieve the simulation statistic data from the back end server.""" - url = self.lab_base_url + "/simulation_stats" + + def _get_element_from_data(data: dict, element: str) -> int: + try: + return int(data[element]) + except (TypeError, KeyError): + return 0 + + url = self._url_for("simulation_stats") states: dict[str, dict[str, dict]] = self._session.get(url).json() node_statistics = states.get("nodes", {}) link_statistics = states.get("links", {}) for node_id, node_data in node_statistics.items(): - # TODO: standardise so if shutdown, then these are set to 0 on - # server side - disk_read = self._get_element_from_data(node_data, "block0_rd_bytes") - disk_write = self._get_element_from_data(node_data, "block0_wr_bytes") + disk_read = _get_element_from_data(node_data, "block0_rd_bytes") + disk_write = _get_element_from_data(node_data, "block0_wr_bytes") self._nodes[node_id].statistics = { "cpu_usage": float(node_data.get("cpu_usage", 0)), @@ -860,12 +886,10 @@ def sync_statistics(self) -> None: } for link_id, link_data in link_statistics.items(): - # TODO: standardise so if shutdown, then these are set to 0 on - # server side - readbytes = self._get_element_from_data(link_data, "readbytes") - readpackets = self._get_element_from_data(link_data, "readpackets") - writebytes = self._get_element_from_data(link_data, "writebytes") - writepackets = self._get_element_from_data(link_data, "writepackets") + readbytes = _get_element_from_data(link_data, "readbytes") + readpackets = _get_element_from_data(link_data, "readpackets") + writebytes = _get_element_from_data(link_data, "writebytes") + writepackets = _get_element_from_data(link_data, "writepackets") link = self._links[link_id] link.statistics = { @@ -890,10 +914,8 @@ def sync_statistics(self) -> None: @check_stale @locked def sync_states(self) -> None: - """ - Sync all the states of the various elements with the back end server. - """ - url = self.lab_base_url + "/lab_element_state" + """Sync all the states of the various elements with the backend server.""" + url = self._url_for("lab_element_state") states: dict[str, dict[str, str]] = self._session.get(url).json() for node_id, node_state in states["nodes"].items(): self._nodes[node_id]._state = node_state @@ -902,9 +924,6 @@ def sync_states(self) -> None: try: iface = ifaces.pop(interface_id) except KeyError: - # TODO: handle loopbacks created server-side - # check how the UI handles these created today - # - extra call after node created? pass else: iface._state = interface_state @@ -917,146 +936,144 @@ def sync_states(self) -> None: @check_stale def wait_until_lab_converged( - self, max_iterations: Optional[int] = None, wait_time: Optional[int] = None + self, max_iterations: int | None = None, wait_time: int | None = None ) -> None: - """Wait until lab converges.""" + """ + Wait until the lab converges. + + :param max_iterations: The maximum number of iterations to wait. + :param wait_time: The time to wait between iterations. + """ max_iter = ( self.wait_max_iterations if max_iterations is None else max_iterations ) wait_time = self.wait_time if wait_time is None else wait_time - _LOGGER.info("Waiting for lab %s to converge", self._id) + _LOGGER.info(f"Waiting for lab {self._id} to converge.") for index in range(max_iter): converged = self.has_converged() if converged: - _LOGGER.info("Lab %s has booted", self._id) + _LOGGER.info(f"Lab {self._id} has booted.") return if index % 10 == 0: _LOGGER.info( - "Lab has not converged, attempt %s/%s, waiting...", - index, - max_iter, + f"Lab has not converged, attempt {index}/{max_iter}, waiting..." ) time.sleep(wait_time) - msg = "Lab %s has not converged, maximum tries %s exceeded" % ( - self.id, - max_iter, - ) + msg = f"Lab {self.id} has not converged, maximum tries {max_iter} exceeded." _LOGGER.error(msg) - # after maximum retries are exceeded and lab has not converged - # error must be raised - it makes no sense to just log info - # and let client fail with something else if wait is explicitly - # specified + # After maximum retries are exceeded and lab has not converged, + # an error must be raised - it makes no sense to just log info + # and let the client fail with something else if wait is explicitly + # specified. raise RuntimeError(msg) @check_stale def has_converged(self) -> bool: - url = self.lab_base_url + "/check_if_converged" + """ + Check whether the lab has converged. + + :returns: True if the lab has converged, False otherwise. + """ + url = self._url_for("check_if_converged") converged = self._session.get(url).json() return converged @check_stale - def start(self, wait: Optional[bool] = None) -> None: + def start(self, wait: bool | None = None) -> None: """ Start all the nodes and links in the lab. - :param wait: Wait for convergence (if left at default, - the lab wait property takes precedence) + :param wait: A flag indicating whether to wait for convergence. + If left at the default value, the lab's wait property is used instead. """ - url = self.lab_base_url + "/start" + url = self._url_for("start") self._session.put(url) if self.need_to_wait(wait): self.wait_until_lab_converged() - _LOGGER.debug("started lab: %s", self._id) + _LOGGER.debug(f"Started lab: {self._id}") + + @check_stale + def stop(self, wait: bool | None = None) -> None: + """ + Stop all the nodes and links in the lab. + + :param wait: A flag indicating whether to wait for convergence. + If left at the default value, the lab's wait property is used instead. + """ + url = self._url_for("stop") + self._session.put(url) + if self.need_to_wait(wait): + self.wait_until_lab_converged() + _LOGGER.debug(f"Stopped lab: {self._id}") @check_stale def state(self) -> str: """ - Returns the state of the lab. + Return the state of the lab. - :returns: The state as text + :returns: The state as text. """ if self._state is None or getattr(self._session, "lock", None) is None: # no lock == event listening not enabled - url = self.lab_base_url + "/state" + url = self._url_for("state") response = self._session.get(url) self._state = response.json() - _LOGGER.debug("lab state: %s -> %s", self._id, self._state) + _LOGGER.debug(f"lab state: {self._id} -> {self._state}") return self._state def is_active(self) -> bool: """ - Returns whether the lab is started. + Check if the lab is active. - :returns: Whether the lab is started + :return: True if the lab is active, False otherwise """ return self.state() == "STARTED" @check_stale def details(self) -> dict[str, str | list | int]: """ - Returns the lab details (including state) of the lab. + Retrieve the details of the lab, including its state. - :returns: The detailed lab state + :return: A dictionary containing the detailed lab state """ - url = self.lab_base_url + url = self._url_for("lab") response = self._session.get(url) - _LOGGER.debug("lab state: %s -> %s", self._id, response.text) + _LOGGER.debug(f"lab state: {self._id} -> {response.text}") return response.json() @check_stale - def stop(self, wait: Optional[bool] = None) -> None: - """ - Stops all the nodes and links in the lab. - - :param wait: Wait for convergence (if left at default, - the lab wait property takes precedence) - """ - url = self.lab_base_url + "/stop" - self._session.put(url) - if self.need_to_wait(wait): - self.wait_until_lab_converged() - _LOGGER.debug("stopped lab: %s", self._id) - - @check_stale - def wipe(self, wait: Optional[bool] = None) -> None: + def wipe(self, wait: bool | None = None) -> None: """ Wipe all the nodes and links in the lab. - :param wait: Wait for convergence (if left at default, - the lab wait property takes precedence) + :param wait: A flag indicating whether to wait for convergence. + If left at the default value, the lab's wait property is used instead. """ - url = self.lab_base_url + "/wipe" + url = self._url_for("wipe") self._session.put(url) if self.need_to_wait(wait): self.wait_until_lab_converged() - _LOGGER.debug("wiped lab: %s", self._id) + _LOGGER.debug(f"wiped lab: {self._id}") def remove(self) -> None: - """ - Removes the lab from the server. The lab has to - be stopped and wiped for this to work. - - Use carefully, it removes current lab:: - - lab.remove() - - """ - self._remove_on_server() - self._stale = True - - @check_stale - def _remove_on_server(self): - url = self.lab_base_url + """Remove the lab from the server. The lab must be stopped and wiped first.""" + url = self._url_for("lab") response = self._session.delete(url) - _LOGGER.debug("removed lab: %s", response.text) + _LOGGER.debug(f"Removed lab: {response.text}") + self._stale = True @check_stale @locked def sync_events(self) -> bool: - url = self.lab_base_url + "/events" + """ + Synchronize the events in the lab. + + :returns: True if the events have changed, False otherwise + """ + url = self._url_for("events") result = self._session.get(url).json() changed = self.events != result self.events = result @@ -1065,11 +1082,10 @@ def sync_events(self) -> bool: @check_stale def build_configurations(self) -> None: """ - Build basic configurations for all nodes in the lab which - do not already have a configuration and also do support - configuration building. + Build basic configurations for all nodes in the lab that do not + already have a configuration and support configuration building. """ - url = "build_configurations?lab_id=" + self._id + url = self._url_for("build_configurations") self._session.get(url) # sync to get the updated configs self.sync_topology_if_outdated() @@ -1079,25 +1095,24 @@ def build_configurations(self) -> None: def sync( self, topology_only=True, - with_node_configurations: Optional[bool] = None, + with_node_configurations: bool | None = None, exclude_configurations=False, ) -> None: """ - Synchronize current lab applying the changes that - were done in UI or in another ClientLibrary session:: - - lab.sync() + Synchronize the current lab, locally applying changes made to the server. - :param topology_only: do not sync statistics and IP addresses - :param with_node_configurations: DEPRECATED, does the opposite of - what is expected: disables syncing node configuration when True - :param exclude_configurations: disable syncing node configuration + :param topology_only: Only sync the topology without statistics and IP + addresses. + :param with_node_configurations: DEPRECATED: does the opposite of what + is expected. Use exclude_configurations instead. + :param exclude_configurations: Whether to exclude configurations + from synchronization. """ - if with_node_configurations is not None: warnings.warn( - 'The argument "with_node_configurations" is deprecated, as it does ' - "the opposite of what is expected. Use exclude_configurations instead.", + "Lab.sync(): The argument 'with_node_configurations' is deprecated, " + "as it does the opposite of what is expected. " + "Use exclude_configurations instead.", DeprecationWarning, ) exclude_configurations = with_node_configurations @@ -1112,8 +1127,7 @@ def sync( @locked def _sync_topology(self, exclude_configurations=False) -> None: """Helper function to sync topologies from the backend server.""" - # TODO: check what happens if call twice - url = self.lab_base_url + "/topology" + url = self._url_for("topology") params = {"exclude_configurations": exclude_configurations} try: result = self._session.get(url, params=params) @@ -1131,32 +1145,44 @@ def _sync_topology(self, exclude_configurations=False) -> None: and f"Lab not found: {self._id}" in exc.response.text ): self._stale = True - raise LabNotFound("Error syncing lab: {}".format(error_msg)) - # TODO: get the error message from response/headers also? + raise LabNotFound(f"Error syncing lab: {error_msg}") + topology = result.json() + if self._initialized: self.update_lab(topology, exclude_configurations) else: self.import_lab(topology) self._initialized = True + self._last_sync_topology_time = time.time() @locked def import_lab(self, topology: dict) -> None: - self._import_lab(topology) + """ + Import a lab from a given topology. + :param topology: The topology to import. + """ + self._import_lab(topology) self._handle_import_nodes(topology) self._handle_import_interfaces(topology) self._handle_import_links(topology) @locked - def _import_lab(self, topology: dict[str, Any]) -> None: - """Replaces lab properties. Will raise KeyError if any property is missing.""" + def _import_lab(self, topology: dict) -> None: + """ + Replace lab properties with the given topology. + + :param topology: The topology to import. + :raises KeyError: If any property is missing in the topology. + """ lab_dict = topology.get("lab") + if lab_dict is None: warnings.warn( "Labs created in older CML releases (schema version 0.0.5 or lower) " - "are deprecated. Use labs with schema version 0.1.0 or higher instead.", + "are deprecated. Use labs with schema version 0.1.0 or higher.", DeprecationWarning, ) self._title = topology["lab_title"] @@ -1171,10 +1197,17 @@ def _import_lab(self, topology: dict[str, Any]) -> None: @locked def _handle_import_nodes(self, topology: dict) -> None: + """ + Handle the import of nodes from the given topology. + + :param topology: The topology to import nodes from. + """ for node in topology["nodes"]: node_id = node["id"] + if node_id in self._nodes: raise ElementAlreadyExists("Node already exists") + interfaces = node.pop("interfaces", []) self._import_node(node_id, node) @@ -1183,29 +1216,46 @@ def _handle_import_nodes(self, topology: dict) -> None: for iface in interfaces: iface_id = iface["id"] + if iface_id in self._interfaces: raise ElementAlreadyExists("Interface already exists") + self._import_interface(iface_id, node_id, iface) @locked def _handle_import_interfaces(self, topology: dict) -> None: + """ + Handle the import of interfaces from the given topology. + + :param topology: The topology to import interfaces from. + """ if "interfaces" in topology: for iface in topology["interfaces"]: iface_id = iface["id"] node_id = iface["node"] + if iface_id in self._interfaces: raise ElementAlreadyExists("Interface already exists") + self._import_interface(iface_id, node_id, iface) @locked def _handle_import_links(self, topology: dict) -> None: + """ + Handle the import of links from the given topology. + + :param topology: The topology to import links from. + """ for link in topology["links"]: link_id = link["id"] + if link_id in self._links: raise ElementAlreadyExists("Link already exists") + iface_a_id = link["interface_a"] iface_b_id = link["interface_b"] label = link.get("label") + self._import_link(link_id, iface_b_id, iface_a_id, label) @locked @@ -1214,8 +1264,17 @@ def _import_link( link_id: str, iface_b_id: str, iface_a_id: str, - label: Optional[str] = None, + label: str | None = None, ) -> Link: + """ + Import a link with the given parameters. + + :param link_id: The ID of the link. + :param iface_b_id: The ID of the second interface. + :param iface_a_id: The ID of the first interface. + :param label: The label of the link. + :returns: The imported Link object. + """ iface_a = self._interfaces[iface_a_id] iface_b = self._interfaces[iface_b_id] return self._create_link_local(iface_a, iface_b, link_id, label) @@ -1224,8 +1283,17 @@ def _import_link( def _import_interface( self, iface_id: str, node_id: str, iface_data: dict ) -> Interface: + """ + Import an interface with the given parameters. + + :param iface_id: The ID of the interface. + :param node_id: The ID of the node the interface belongs to. + :param iface_data: The data of the interface. + :returns: The imported Interface object. + """ if "data" in iface_data: iface_data = iface_data["data"] + label = iface_data["label"] slot = iface_data.get("slot") iface_type = iface_data["type"] @@ -1234,8 +1302,16 @@ def _import_interface( @locked def _import_node(self, node_id: str, node_data: dict) -> Node: + """ + Import a node with the given parameters. + + :param node_id: The ID of the node. + :param node_data: The data of the node. + :returns: The imported Node object. + """ if "data" in node_data: node_data = node_data["data"] + node_data.pop("id", None) state = node_data.pop("state", None) node_data.pop("lab_id", None) @@ -1244,12 +1320,19 @@ def _import_node(self, node_id: str, node_data: dict) -> Node: for key in ("image_definition", "configuration"): if key not in node_data: node_data[key] = None + node = self._create_node_local(node_id, **node_data) node._state = state return node @locked def update_lab(self, topology: dict, exclude_configurations: bool) -> None: + """ + Update the lab with the given topology. + + :param topology: The updated topology. + :param exclude_configurations: Whether to exclude configurations from updating. + """ self._import_lab(topology) # add in order: node -> interface -> link @@ -1260,6 +1343,7 @@ def update_lab(self, topology: dict, exclude_configurations: bool) -> None: update_node_keys = set(node["id"] for node in topology["nodes"]) update_link_keys = set(links["id"] for links in topology["links"]) + if "interfaces" in topology: update_interface_keys = set(iface["id"] for iface in topology["interfaces"]) else: @@ -1297,19 +1381,26 @@ def _remove_elements( removed_links: Iterable[str], removed_interfaces: Iterable[str], ) -> None: + """ + Remove elements from the lab. + + :param removed_nodes: Iterable of node IDs to be removed. + :param removed_links: Iterable of link IDs to be removed. + :param removed_interfaces: Iterable of interface IDs to be removed. + """ for link_id in removed_links: link = self._links.pop(link_id) - _LOGGER.info("Removed link %s", link) + _LOGGER.info(f"Removed link {link}") link._stale = True for interface_id in removed_interfaces: interface = self._interfaces.pop(interface_id) - _LOGGER.info("Removed interface %s", interface) + _LOGGER.info(f"Removed interface {interface}") interface._stale = True for node_id in removed_nodes: node = self._nodes.pop(node_id) - _LOGGER.info("Removed node %s", node) + _LOGGER.info(f"Removed node {node}") node._stale = True @locked @@ -1320,12 +1411,20 @@ def _add_elements( new_links: Iterable[str], new_interfaces: Iterable[str], ) -> None: + """ + Add elements to the lab. + + :param topology: Dictionary containing the lab topology. + :param new_nodes: Iterable of node IDs to be added. + :param new_links: Iterable of link IDs to be added. + :param new_interfaces: Iterable of interface IDs to be added. + """ for node in topology["nodes"]: node_id = node["id"] interfaces = node.pop("interfaces", []) if node_id in new_nodes: node = self._import_node(node_id, node) - _LOGGER.info("Added node %s", node) + _LOGGER.info(f"Added node {node}") if not interfaces: continue @@ -1334,7 +1433,7 @@ def _add_elements( interface_id = interface["id"] if interface_id in new_interfaces: interface = self._import_interface(interface_id, node_id, interface) - _LOGGER.info("Added interface %s", interface) + _LOGGER.info(f"Added interface {interface}") if "interfaces" in topology: for iface in topology["interfaces"]: @@ -1349,16 +1448,24 @@ def _add_elements( iface_b_id = link_data["interface_b"] label = link_data.get("label") link = self._import_link(link_id, iface_b_id, iface_a_id, label) - _LOGGER.info("Added link %s", link) + _LOGGER.info(f"Added link {link}") @locked def _update_elements( self, topology: dict, kept_nodes: Iterable[str], exclude_configurations: bool ) -> None: + """ + Update elements in the lab. + + :param topology: Dictionary containing the lab topology. + :param kept_nodes: Iterable of node IDs to be updated. + :param exclude_configurations: Boolean indicating whether to exclude + configurations during update. + """ for node_id in kept_nodes: node = self._find_node_in_topology(node_id, topology) lab_node = self._nodes[node_id] - lab_node.update(node, exclude_configurations) + lab_node.update(node, exclude_configurations, push_to_server=False) # For now, can't update interface data server-side, this will change with tags # for interface_id in kept_interfaces: @@ -1370,9 +1477,12 @@ def _update_elements( @locked def update_lab_properties(self, properties: dict[str, Any]): - """Updates lab properties. Will not modify unspecified properties. - Is not compatible with schema version 0.0.5.""" - # TODO: Should it be? + """ + Update lab properties. Will not modify unspecified properties. + Is not compatible with schema version 0.0.5. + + :param properties: Dictionary containing the updated lab properties. + """ self._title = properties.get("title", self._title) self._description = properties.get("description", self._description) self._notes = properties.get("notes", self._notes) @@ -1380,6 +1490,15 @@ def update_lab_properties(self, properties: dict[str, Any]): @staticmethod def _find_link_in_topology(link_id: str, topology: dict) -> dict: + """ + Find a link in the given topology. + + :param link_id: The ID of the link to find. + :param topology: Dictionary containing the lab topology. + :returns: The link with the specified ID. + :raises LinkNotFound: If the link cannot be found in the topology. + """ + for link in topology["links"]: if link["id"] == link_id: return link @@ -1397,6 +1516,15 @@ def _find_link_in_topology(link_id: str, topology: dict) -> dict: @staticmethod def _find_node_in_topology(node_id: str, topology: dict) -> dict: + """ + Find a node in the given topology. + + :param node_id: The ID of the node to find. + :param topology: Dictionary containing the lab topology. + :returns: The node with the specified ID. + :raises NodeNotFound: If the node cannot be found in the topology. + """ + for node in topology["nodes"]: if node["id"] == node_id: return node @@ -1404,7 +1532,7 @@ def _find_node_in_topology(node_id: str, topology: dict) -> dict: raise NodeNotFound @check_stale - def get_pyats_testbed(self, hostname: Optional[str] = None) -> str: + def get_pyats_testbed(self, hostname: str | None = None) -> str: """ Return lab's pyATS YAML testbed. Example usage:: @@ -1420,10 +1548,10 @@ def get_pyats_testbed(self, hostname: Optional[str] = None) -> str: aetest.main(testbed=testbed) - :param hostname: Force hostname/ip and port for console terminal server - :returns: The pyATS testbed for the lab in YAML format + :param hostname: Force hostname/ip and port for console terminal server. + :returns: The pyATS testbed for the lab in YAML format. """ - url = self.lab_base_url + "/pyats_testbed" + url = self._url_for("pyats_testbed") params = {} if hostname is not None: params["hostname"] = hostname @@ -1432,17 +1560,18 @@ def get_pyats_testbed(self, hostname: Optional[str] = None) -> str: @check_stale def sync_pyats(self) -> None: + """Sync the pyATS testbed.""" self.pyats.sync_testbed(self.username, self.password) def cleanup_pyats_connections(self) -> None: - """Closes and cleans up connection that pyATS might still hold.""" + """Close and clean up connection that pyATS might still hold.""" self.pyats.cleanup() @check_stale @locked def sync_layer3_addresses(self) -> None: - """Syncs all layer 3 IP addresses from the backend server.""" - url = self.lab_base_url + "/layer3_addresses" + """Sync all layer 3 IP addresses from the backend server.""" + url = self._url_for("layer3_addresses") result: dict[str, dict] = self._session.get(url).json() for node_id, node_data in result.items(): node = self.get_node_by_id(node_id) @@ -1455,19 +1584,15 @@ def download(self) -> str: """ Download the lab from the server in YAML format. - :returns: The lab in YAML format + :returns: The lab in YAML format. """ - url = self.lab_base_url + "/download" + url = self._url_for("download") return self._session.get(url).text @property def groups(self) -> list[dict[str, str]]: - """ - Returns the groups this lab is associated with. - - :return: associated groups - """ - url = self.lab_base_url + "/groups" + """Return the groups this lab is associated with.""" + url = self._url_for("groups") return self._session.get(url).json() @check_stale @@ -1475,30 +1600,32 @@ def update_lab_groups( self, group_list: list[dict[str, str]] ) -> list[dict[str, str]]: """ - Modifies lab / group association + Modify lab/group association. - :param group_list: list of objects consisting of group id and permission - :return: updated objects consisting of group id and permission + :param group_list: List of objects consisting of group ID and permission. + :returns: Updated objects consisting of group ID and permission. """ - url = self.lab_base_url + "/groups" + url = self._url_for("groups") return self._session.put(url, json=group_list).json() @property def connector_mappings(self) -> list[dict[str, Any]]: - """Retrieve lab's external connector mappings + """ + Retrieve lab's external connector mappings. Returns a list of mappings; each mapping has a key, which is used as the configuration of external connector nodes, and a device name, used to uniquely identify the controller host's bridge to use. If unset, the mapping is invalid and nodes using it cannot start. """ - url = self.lab_base_url + "/connector_mappings" + url = self._url_for("connector_mappings") return self._session.get(url).json() def update_connector_mappings( self, updates: list[dict[str, str]] ) -> list[dict[str, str]]: - """Update lab's external connector mappings + """ + Update lab's external connector mappings. Accepts a list of mappings, each with a key to add or modify, and the associated device name (bridge) of the controller host. @@ -1506,15 +1633,16 @@ def update_connector_mappings( value may be None to disassociate the bridge from the key; if no node uses this key in its configuration, the mapping entry is removed. - Returns all connector mappings after updates were applied. + :returns: All connector mappings after updates were applied. """ - url = self.lab_base_url + "/connector_mappings" + url = self._url_for("connector_mappings") return self._session.patch(url, json=updates).json() @check_stale @locked def sync_operational(self) -> None: - url = self.lab_base_url + "/resource_pools" + """Sync the operational status of the lab.""" + url = self._url_for("resource_pools") response = self._session.get(url).json() res_pools = self._resource_pool_manager.get_resource_pools_by_ids(response) self._resource_pools = list(res_pools.values()) diff --git a/virl2_client/models/licensing.py b/virl2_client/models/licensing.py index f90e1a4..99e01c8 100644 --- a/virl2_client/models/licensing.py +++ b/virl2_client/models/licensing.py @@ -22,7 +22,9 @@ import logging import time -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any + +from ..utils import get_url_from_template if TYPE_CHECKING: import httpx @@ -34,34 +36,53 @@ class Licensing: + _URL_TEMPLATES = { + "licensing": "licensing", + "tech_support": "licensing/tech_support", + "authorization_renew": "licensing/authorization/renew", + "transport": "licensing/transport", + "product_license": "licensing/product_license", + "certificate": "licensing/certificate", + "registration": "licensing/registration", + "registration_renew": "licensing/registration/renew", + "deregistration": "licensing/deregistration", + "features": "licensing/features", + "reservation_action": "licensing/reservation/{action}", + } max_wait = 30 wait_interval = 1.5 def __init__(self, session: httpx.Client) -> None: - self._session = session + """ + Manage licensing. - @property - def base_url(self) -> str: - return "licensing" + :param session: The httpx-based HTTP client for this session with the server. + """ + self._session = session - def status(self) -> dict[str, Any]: + def _url_for(self, endpoint, **kwargs): """ - Get current licensing configuration and status. + Generate the URL for a given API endpoint. + + :param endpoint: The desired endpoint. + :param **kwargs: Keyword arguments used to format the URL. + :returns: The formatted URL. """ - return self._session.get(self.base_url).json() + return get_url_from_template(endpoint, self._URL_TEMPLATES, kwargs) + + def status(self) -> dict[str, Any]: + """Return current licensing configuration and status.""" + url = self._url_for("licensing") + return self._session.get(url).json() def tech_support(self) -> str: - """ - Get current licensing tech support. - """ - url = self.base_url + "/tech_support" + """Return current licensing tech support.""" + url = self._url_for("tech_support") return self._session.get(url).text def renew_authorization(self) -> bool: - """ - Renew licensing authorization with the backend. - """ - url = self.base_url + "/authorization/renew" + """Renew licensing authorization with the backend.""" + url = self._url_for("authorization_renew") response = self._session.put(url) _LOGGER.info("The agent has scheduled an authorization renewal.") return response.status_code == 204 @@ -69,22 +90,18 @@ def renew_authorization(self) -> bool: def set_transport( self, ssms: str, - proxy_server: Optional[str] = None, - proxy_port: Optional[int] = None, + proxy_server: str | None = None, + proxy_port: int | None = None, ) -> bool: - """ - Setup licensing transport configuration. - """ - url = self.base_url + "/transport" + """Setup licensing transport configuration.""" + url = self._url_for("transport") data = {"ssms": ssms, "proxy": {"server": proxy_server, "port": proxy_port}} response = self._session.put(url, json=data) - _LOGGER.info("The transport configuration has been accepted. Config: %s.", data) + _LOGGER.info(f"The transport configuration has been accepted. Config: {data}.") return response.status_code == 204 def set_default_transport(self) -> bool: - """ - Setup licensing transport configuration to default values. - """ + """Setup licensing transport configuration to default values.""" default_ssms = self.status()["transport"]["default_ssms"] return self.set_transport( ssms=default_ssms, @@ -93,19 +110,15 @@ def set_default_transport(self) -> bool: ) def set_product_license(self, product_license: str) -> bool: - """ - Setup a product license. - """ - url = self.base_url + "/product_license" + """Setup a product license.""" + url = self._url_for("product_license") response = self._session.put(url, json=product_license) _LOGGER.info("Product license was accepted by the agent.") return response.status_code == 204 - def get_certificate(self) -> Optional[str]: - """ - Get the currently installed licensing public certificate. - """ - url = self.base_url + "/certificate" + def get_certificate(self) -> str | None: + """Get the currently installed licensing public certificate.""" + url = self._url_for("certificate") response = self._session.get(url) if response.is_success: _LOGGER.info("Certificate received.") @@ -117,7 +130,7 @@ def install_certificate(self, cert: str) -> bool: Set up a licensing public certificate for internal deployment of an unregistered product instance. """ - url = self.base_url + "/certificate" + url = self._url_for("certificate") response = self._session.post(url, content=cert) _LOGGER.info("Certificate was accepted by the agent.") return response.status_code == 204 @@ -127,16 +140,14 @@ def remove_certificate(self) -> bool: Clear any licensing public certificate for internal deployment of an unregistered product instance. """ - url = self.base_url + "/certificate" + url = self._url_for("certificate") response = self._session.delete(url) _LOGGER.info("Certificate was removed.") return response.status_code == 204 def register(self, token: str, reregister=False) -> bool: - """ - Setup licensing registration. - """ - url = self.base_url + "/registration" + """Setup licensing registration.""" + url = self._url_for("registration") response = self._session.post( url, json={"token": token, "reregister": reregister} ) @@ -144,10 +155,8 @@ def register(self, token: str, reregister=False) -> bool: return response.status_code == 204 def register_renew(self) -> bool: - """ - Request a renewal of licensing registration against current SSMS. - """ - url = self.base_url + "/registration/renew" + """Request a renewal of licensing registration against current SSMS.""" + url = self._url_for("registration_renew") response = self._session.put(url) _LOGGER.info("The renewal request was accepted by the agent.") return response.status_code == 204 @@ -163,10 +172,8 @@ def register_wait(self, token: str, reregister=False) -> bool: return res def deregister(self) -> int: - """ - Request deregistration from the current SSMS. - """ - url = self.base_url + "/deregistration" + """Request deregistration from the current SSMS.""" + url = self._url_for("deregistration") response = self._session.delete(url) if response.status_code == 202: _LOGGER.warning( @@ -174,7 +181,6 @@ def deregister(self) -> int: "unable to deregister from Smart Software Licensing due to a " "communication timeout." ) - # TODO try to register again and unregister if response.status_code == 204: _LOGGER.info( "The Product Instance was successfully deregistered from Smart " @@ -183,74 +189,54 @@ def deregister(self) -> int: return response.status_code def features(self) -> list[dict[str, str | int]]: - """ - Get current licensing features. - """ - url = self.base_url + "/features" + """Get current licensing features.""" + url = self._url_for("features") return self._session.get(url).json() def update_features(self, features: dict[str, int] | list[dict[str, int]]) -> None: - """ - Update licensing feature's explicit count in reservation mode. - """ - url = self.base_url + "/features" + """Update licensing feature's explicit count in reservation mode.""" + url = self._url_for("features") self._session.patch(url, json=features) def reservation_mode(self, data: bool) -> None: - """ - Enable or disable reservation mode in unregistered agent. - """ - url = self.base_url + "/reservation/mode" + """Enable or disable reservation mode in unregistered agent.""" + url = self._url_for("reservation_action", action="mode") self._session.put(url, json=data) msg = "enabled" if data else "disabled" - _LOGGER.info("The reservation mode has been %s.", msg) + _LOGGER.info(f"The reservation mode has been {msg}.") def enable_reservation_mode(self) -> None: - """ - Enable reservation mode in unregistered agent. - """ + """Enable reservation mode in unregistered agent.""" return self.reservation_mode(data=True) def disable_reservation_mode(self) -> None: - """ - Disable reservation mode in unregistered agent. - """ + """Disable reservation mode in unregistered agent.""" return self.reservation_mode(data=False) def request_reservation(self) -> str: - """ - Initiate reservation by generating request code and message to the user. - """ - url = self.base_url + "/reservation/request" + """Initiate reservation by generating request code and message to the user.""" + url = self._url_for("reservation_action", action="request") response = self._session.post(url) _LOGGER.info("Reservation request code received.") return response.json() def complete_reservation(self, authorization_code: str) -> str: - """ - Complete reservation by installing authorization code from SSMS. - """ - # TODO - url = self.base_url + "/reservation/complete" + """Complete reservation by installing authorization code from SSMS.""" + url = self._url_for("reservation_action", action="complete") response = self._session.post(url, json=authorization_code) _LOGGER.info("The confirmation code of completed reservation received.") return response.json() def cancel_reservation(self) -> bool: - """ - Cancel reservation request without completing it. - """ - url = self.base_url + "/reservation/cancel" + """Cancel reservation request without completing it.""" + url = self._url_for("reservation_action", action="cancel") response = self._session.delete(url) _LOGGER.info("The reservation request has been cancelled.") return response.status_code == 204 def release_reservation(self) -> str: - """ - Return a completed reservation. - """ - # TODO - url = self.base_url + "/reservation/release" + """Return a completed reservation.""" + url = self._url_for("reservation_action", action="release") response = self._session.delete(url) _LOGGER.info("The return code of the released reservation received.") return response.json() @@ -260,8 +246,7 @@ def discard_reservation(self, data: str) -> str: Discard a reservation authorization code for an already cancelled reservation request. """ - # TODO - url = self.base_url + "/reservation/discard" + url = self._url_for("reservation_action", action="discard") response = self._session.post(url, json=data) _LOGGER.info( "The discard code for an already cancelled reservation request received." @@ -269,37 +254,29 @@ def discard_reservation(self, data: str) -> str: return response.json() def get_reservation_confirmation_code(self) -> str: - """ - Get the confirmation code. - """ - url = self.base_url + "/reservation/confirmation_code" + """Get the reservation confirmation code.""" + url = self._url_for("reservation_action", action="confirmation_code") response = self._session.get(url) _LOGGER.info("The confirmation code of the completed reservation received.") return response.json() def delete_reservation_confirmation_code(self) -> bool: - """ - Remove the confirmation code. - """ - url = self.base_url + "/reservation/confirmation_code" + """Remove the reservation confirmation code.""" + url = self._url_for("reservation_action", action="confirmation_code") response = self._session.delete(url) _LOGGER.info("The confirmation code has been removed.") return response.status_code == 204 def get_reservation_return_code(self) -> str: - """ - Get the return code. - """ - url = self.base_url + "/reservation/return_code" + """Get the reservation return code.""" + url = self._url_for("reservation_action", action="return_code") response = self._session.get(url) _LOGGER.info("The return code of the released reservation received.") return response.json() def delete_reservation_return_code(self) -> bool: - """ - Remove the return code. - """ - url = self.base_url + "/reservation/return_code" + """Remove the reservation return code.""" + url = self._url_for("reservation_action", action="return_code") response = self._session.delete(url) _LOGGER.info("The return code has been removed.") return response.status_code == 204 @@ -319,9 +296,9 @@ def wait_for_status(self, what: str, *target_status: str) -> None: if count > self.max_wait: timeout = self.max_wait * self.wait_interval raise RuntimeError( - "Timeout: licensing {} did not reach {} status after {} secs. " - "Last status was {}".format(what, target_status, timeout, status) + f"Timeout: licensing {what} did not reach {target_status} status " + f"after {timeout} secs. Last status was {status}" ) status = self.status()[what]["status"] - _LOGGER.debug("%s status: %s", what, status) + _LOGGER.debug(f"{what} status: {status}") count += 1 diff --git a/virl2_client/models/link.py b/virl2_client/models/link.py index a5eb188..451b6eb 100644 --- a/virl2_client/models/link.py +++ b/virl2_client/models/link.py @@ -24,9 +24,9 @@ import time import warnings from functools import total_ordering -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING -from ..utils import check_stale, locked +from ..utils import check_stale, get_url_from_template, locked from ..utils import property_s as property if TYPE_CHECKING: @@ -41,31 +41,42 @@ @total_ordering class Link: + _URL_TEMPLATES = { + "link": "{lab}/links/{id}", + "check_if_converged": "{lab}/links/{id}/check_if_converged", + "state": "{lab}/links/{id}/state", + "start": "{lab}/links/{id}/state/start", + "stop": "{lab}/links/{id}/state/stop", + "condition": "{lab}/links/{id}/condition", + } + def __init__( self, lab: Lab, lid: str, iface_a: Interface, iface_b: Interface, - label: Optional[str] = None, + label: str | None = None, ) -> None: """ A VIRL2 network link between two nodes, connecting to two interfaces on these nodes. - :param lab: the lab object - :param lid: the link ID - :param iface_a: the first interface of the link - :param iface_b: the second interface of the link - :param label: the link label + :param lab: The lab object to which the link belongs. + :param lid: The ID of the link. + :param iface_a: The first interface of the link. + :param iface_b: The second interface of the link. + :param label: The label of the link. """ self._id = lid self._interface_a = iface_a self._interface_b = iface_b self._label = label self.lab = lab - self._session: httpx.Client = lab.session - self._state: Optional[str] = None + self._session: httpx.Client = lab._session + self._state: str | None = None + # When the link is removed on the server, this link object is marked stale + # and can no longer be interacted with - the user should discard it self._stale = False self.statistics = { "readbytes": 0, @@ -100,111 +111,147 @@ def __lt__(self, other: object): def __hash__(self): return hash(self._id) + def _url_for(self, endpoint, **kwargs): + """ + Generate the URL for a given API endpoint. + + :param endpoint: The desired endpoint. + :param **kwargs: Keyword arguments used to format the URL. + :returns: The formatted URL. + """ + kwargs["lab"] = self.lab._url_for("lab") + kwargs["id"] = self.id + return get_url_from_template(endpoint, self._URL_TEMPLATES, kwargs) + @property - def id(self): + def id(self) -> str: + """Return the ID of the link.""" return self._id @property - def interface_a(self): + def interface_a(self) -> Interface: + """Return the first interface of the link.""" return self._interface_a @property - def interface_b(self): + def interface_b(self) -> Interface: + """Return the second interface of the link.""" return self._interface_b @property @locked - def state(self) -> Optional[str]: + def state(self) -> str | None: + """Return the current state of the link.""" self.lab.sync_states_if_outdated() if self._state is None: - url = self.base_url + url = self._url_for("link") self._state = self._session.get(url).json()["state"] return self._state @property def readbytes(self) -> int: + """Return the number of read bytes on the link.""" self.lab.sync_statistics_if_outdated() return self.statistics["readbytes"] @property def readpackets(self) -> int: + """Return the number of read packets on the link.""" self.lab.sync_statistics_if_outdated() return self.statistics["readpackets"] @property def writebytes(self) -> int: + """Return the number of written bytes on the link.""" self.lab.sync_statistics_if_outdated() return self.statistics["writebytes"] @property def writepackets(self) -> int: + """Return the number of written packets on the link.""" self.lab.sync_statistics_if_outdated() return self.statistics["writepackets"] @property def node_a(self) -> Node: + """Return the first node connected to the link.""" self.lab.sync_topology_if_outdated() return self.interface_a.node @property def node_b(self) -> Node: + """Return the second node connected to the link.""" self.lab.sync_topology_if_outdated() return self.interface_b.node @property @locked def nodes(self) -> tuple[Node, Node]: - """Return nodes this link connects.""" + """Return the nodes connected by the link.""" self.lab.sync_topology_if_outdated() return self.node_a, self.node_b @property @locked def interfaces(self) -> tuple[Interface, Interface]: + """Return the interfaces connected by the link.""" self.lab.sync_topology_if_outdated() return self.interface_a, self.interface_b @property - def label(self) -> Optional[str]: + def label(self) -> str | None: + """Return the label of the link.""" self.lab.sync_topology_if_outdated() return self._label @locked def as_dict(self) -> dict[str, str]: + """ + Convert the link object to a dictionary representation. + + :returns: A dictionary representation of the link object. + """ return { "id": self.id, "interface_a": self.interface_a.id, "interface_b": self.interface_b.id, } - @property - def lab_base_url(self) -> str: - return self.lab.lab_base_url - - @property - def base_url(self) -> str: - return self.lab_base_url + "/links/{}".format(self.id) - def remove(self): + """Remove the link from the lab.""" self.lab.remove_link(self) @check_stale def _remove_on_server(self) -> None: - _LOGGER.info("Removing link %s", self) - url = self.base_url + _LOGGER.info(f"Removing link {self}") + url = self._url_for("link") self._session.delete(url) def remove_on_server(self) -> None: + """ + DEPRECATED: Use `.remove()` instead. + (Reason: was never meant to be public, removing only on server is not useful) + + Remove the link on the server. + """ warnings.warn( - "'Link.remove_on_server()' is deprecated, use 'Link.remove()' instead.", + "'Link.remove_on_server()' is deprecated. Use '.remove()' instead.", DeprecationWarning, ) self._remove_on_server() def wait_until_converged( - self, max_iterations: Optional[int] = None, wait_time: Optional[int] = None + self, max_iterations: int | None = None, wait_time: int | None = None ) -> None: - _LOGGER.info("Waiting for link %s to converge", self.id) + """ + Wait until the link has converged. + + :param max_iterations: The maximum number of iterations to wait for convergence. + :param wait_time: The time to wait between iterations. + :raises RuntimeError: If the link does not converge within the specified number + of iterations. + """ + _LOGGER.info(f"Waiting for link {self.id} to converge") max_iter = ( self.lab.wait_max_iterations if max_iterations is None else max_iterations ) @@ -212,21 +259,16 @@ def wait_until_converged( for index in range(max_iter): converged = self.has_converged() if converged: - _LOGGER.info("Link %s has converged", self.id) + _LOGGER.info(f"Link {self.id} has converged") return if index % 10 == 0: _LOGGER.info( - "Link has not converged, attempt %s/%s, waiting...", - index, - max_iter, + f"Link has not converged, attempt {index}/{max_iter}, waiting..." ) time.sleep(wait_time) - msg = "Link %s has not converged, maximum tries %s exceeded" % ( - self.id, - max_iter, - ) + msg = f"Link {self.id} has not converged, maximum tries {max_iter} exceeded" _LOGGER.error(msg) # after maximum retries are exceeded and link has not converged # error must be raised - it makes no sense to just log info @@ -236,20 +278,34 @@ def wait_until_converged( @check_stale def has_converged(self) -> bool: - url = self.base_url + "/check_if_converged" - converged = self._session.get(url).json() - return converged + """ + Check if the link has converged. + + :returns: True if the link has converged, False otherwise. + """ + url = self._url_for("check_if_converged") + return self._session.get(url).json() @check_stale - def start(self, wait: Optional[bool] = None) -> None: - url = self.base_url + "/state/start" + def start(self, wait: bool | None = None) -> None: + """ + Start the link. + + :param wait: Whether to wait for convergence after starting the link. + """ + url = self._url_for("start") self._session.put(url) if self.lab.need_to_wait(wait): self.wait_until_converged() @check_stale - def stop(self, wait: Optional[bool] = None) -> None: - url = self.base_url + "/state/stop" + def stop(self, wait: bool | None = None) -> None: + """ + Stop the link. + + :param wait: Whether to wait for convergence after stopping the link. + """ + url = self._url_for("stop") self._session.put(url) if self.lab.need_to_wait(wait): self.wait_until_converged() @@ -259,14 +315,14 @@ def set_condition( self, bandwidth: int, latency: int, jitter: int, loss: float ) -> None: """ - Applies conditioning to this link. + Set the conditioning parameters for the link. - :param bandwidth: desired bandwidth, 0-10000000 kbps - :param latency: desired latency, 0-10000 ms - :param jitter: desired jitter, 0-10000 ms - :param loss: desired loss, 0-100% + :param bandwidth: The desired bandwidth in kbps (0-10000000). + :param latency: The desired latency in ms (0-10000). + :param jitter: The desired jitter in ms (0-10000). + :param loss: The desired packet loss percentage (0-100). """ - url = self.base_url + "/condition" + url = self._url_for("condition") data = { "bandwidth": bandwidth, "latency": latency, @@ -278,27 +334,19 @@ def set_condition( @check_stale def get_condition(self) -> dict: """ - Retrieves the current condition on this link. - If there is no link condition applied, an empty dictionary is returned. + Get the current conditioning parameters for the link. - (Note: this used to (erroneously) say None would be returned - when no condition is applied, but that was never the case.) - - :return: the applied link condition + :returns: A dictionary containing the current conditioning parameters. """ - url = self.base_url + "/condition" + url = self._url_for("condition") condition = self._session.get(url).json() keys = ["bandwidth", "latency", "jitter", "loss"] - result = {k: v for (k, v) in condition.items() if k in keys} - return result + return {k: v for k, v in condition.items() if k in keys} @check_stale def remove_condition(self) -> None: - """ - Removes link conditioning. - If there's no condition applied then this is a no-op for the controller. - """ - url = self.base_url + "/condition" + """Remove the link conditioning.""" + url = self._url_for("condition") self._session.delete(url) def set_condition_by_name(self, name: str) -> None: @@ -323,8 +371,8 @@ def set_condition_by_name(self, name: str) -> None: satellite 1500 1 mbps 0.2 ========= ============ ========= ======== - :param name: the predefined condition name as outlined in the table above - :raises ValueError: if the given name isn't known + :param name: The name of the predefined link condition. + :raises ValueError: If the given name is not a known predefined condition. """ options = { "gprs": (500, 50, 2.0), @@ -340,9 +388,9 @@ def set_condition_by_name(self, name: str) -> None: } if name not in options: - msg = "unknown condition name '{}', known values: '{}'".format( - name, - ", ".join(sorted(options)), + msg = ( + f"Unknown condition name: '{name}', " + f"known values: '{', '.join(sorted(options))}'" ) _LOGGER.error(msg) raise ValueError(msg) diff --git a/virl2_client/models/node.py b/virl2_client/models/node.py index 5aed0ab..c375b39 100644 --- a/virl2_client/models/node.py +++ b/virl2_client/models/node.py @@ -25,10 +25,10 @@ import time import warnings from functools import total_ordering -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from ..exceptions import InterfaceNotFound -from ..utils import check_stale, locked +from ..utils import check_stale, get_url_from_template, locked from ..utils import property_s as property if TYPE_CHECKING: @@ -40,50 +40,67 @@ _LOGGER = logging.getLogger(__name__) -CONFIG_WARNING = 'The property "config" is deprecated in favor of "configuration"' - @total_ordering class Node: + _URL_TEMPLATES = { + "node": "{lab}/nodes/{id}", + "state": "{lab}/nodes/{id}/state", + "check_if_converged": "{lab}/nodes/{id}/check_if_converged", + "start": "{lab}/nodes/{id}/state/start", + "stop": "{lab}/nodes/{id}/state/stop", + "wipe_disks": "{lab}/nodes/{id}/wipe_disks", + "extract_configuration": "{lab}/nodes/{id}/extract_configuration", + "console_log": "{lab}/nodes/{id}/consoles/{console_id}/log", + "console_log_lines": "{lab}/nodes/{id}/consoles/{console_id}/log?lines={lines}", + "console_key": "{lab}/nodes/{id}/keys/console", + "vnc_key": "{lab}/nodes/{id}/keys/vnc", + "layer3_addresses": "{lab}/nodes/{id}/layer3_addresses", + "operational": "{lab}/nodes/{id}?operational=true", + } + def __init__( self, lab: Lab, nid: str, label: str, node_definition: str, - image_definition: Optional[str], - configuration: Optional[str], + image_definition: str | None, + configuration: str | None, x: int, y: int, - ram: Optional[int], - cpus: Optional[int], - cpu_limit: Optional[int], - data_volume: Optional[int], - boot_disk_size: Optional[int], + ram: int | None, + cpus: int | None, + cpu_limit: int | None, + data_volume: int | None, + boot_disk_size: int | None, hide_links: bool, tags: list[str], - resource_pool: Optional[str], + resource_pool: str | None, ) -> None: """ - A VIRL2 Node object. Typically, a virtual machine representing a router, - switch or server. - - :param lab: the Lab this node belongs to - :param nid: the Node ID - :param label: node label - :param node_definition: The node definition of this node - :param image_definition: The image definition of this node - :param configuration: The initial configuration of this node - :param x: X coordinate on topology canvas - :param y: Y coordinate on topology canvas - :param ram: memory of node in MiB (if applicable) - :param cpus: Amount of CPUs in this node (if applicable) - :param cpu_limit: CPU limit (default at 100%) - :param data_volume: Size in GiB of 2nd HDD (if > 0) - :param boot_disk_size: Size in GiB of boot disk (will expand to this size) - :param hide_links: Whether node's links should be hidden in UI visualization - :param tags: List of tags - :param resource_pool: A resource pool ID if the node is in a resource pool + A VIRL2 node object representing a virtual machine that serves + as a router, switch, or server. + + :param lab: The Lab instance to which this node belongs. + :param nid: The ID of the node. + :param label: The label of the node. + :param node_definition: The definition of this node. + :param image_definition: The definition of the image used by this node. + :param configuration: The initial configuration of this node. + :param x: The X coordinate of the node on the topology canvas. + :param y: The Y coordinate of the node on the topology canvas. + :param ram: The memory of the node in MiB (if applicable). + :param cpus: The number of CPUs in this node (if applicable). + :param cpu_limit: The CPU limit of the node (default is 100%). + :param data_volume: The size in GiB of the second HDD (if > 0). + :param boot_disk_size: The size in GiB of the boot disk + (will expand to this size). + :param hide_links: A flag indicating whether the node's links should be hidden + in UI visualization. + :param tags: A list of tags associated with the node. + :param resource_pool: The ID of the resource pool if the node is part + of a resource pool. """ self.lab = lab self._id = nid @@ -91,8 +108,8 @@ def __init__( self._node_definition = node_definition self._x = x self._y = y - self._state: Optional[str] = None - self._session: httpx.Client = lab.session + self._state: str | None = None + self._session: httpx.Client = lab._session self._image_definition = image_definition self._ram = ram self._configuration = configuration @@ -102,9 +119,10 @@ def __init__( self._boot_disk_size = boot_disk_size self._hide_links = hide_links self._tags = tags - self._compute_id: Optional[str] = None + self._compute_id: str | None = None self._resource_pool = resource_pool self._stale = False + self._last_sync_l3_address_time = 0.0 self.statistics: dict[str, int | float] = { "cpu_usage": 0, @@ -151,51 +169,78 @@ def __lt__(self, other): def __hash__(self): return hash(self._id) + def _url_for(self, endpoint, **kwargs): + """ + Generate the URL for a given API endpoint. + + :param endpoint: The desired endpoint. + :param **kwargs: Keyword arguments used to format the URL. + :returns: The formatted URL. + """ + kwargs["lab"] = self.lab._url_for("lab") + kwargs["id"] = self.id + return get_url_from_template(endpoint, self._URL_TEMPLATES, kwargs) + + @check_stale + @locked + def sync_l3_addresses_if_outdated(self) -> None: + timestamp = time.time() + if ( + self.lab.auto_sync + and timestamp - self._last_sync_l3_address_time + > self.lab.auto_sync_interval + ): + self.sync_layer3_addresses() + @property @locked - def state(self) -> Optional[str]: + def state(self) -> str | None: + """Return the state of the node.""" self.lab.sync_states_if_outdated() if self._state is None: - url = self._base_url + "/state" + url = self._url_for("state") self._state = self._session.get(url).json()["state"] return self._state @check_stale @locked def interfaces(self) -> list[Interface]: - # self.lab.sync_topology_if_outdated() + """Return a list of interfaces on the node.""" + self.lab.sync_topology_if_outdated() return [iface for iface in self.lab.interfaces() if iface.node is self] @locked def physical_interfaces(self) -> list[Interface]: - # self.lab.sync_topology_if_outdated() + """Return a list of physical interfaces on the node.""" + self.lab.sync_topology_if_outdated() return [iface for iface in self.interfaces() if iface.physical] @check_stale @locked def create_interface( - self, slot: Optional[int] = None, wait: bool = False + self, slot: int | None = None, wait: bool = False ) -> Interface: """ Create an interface in the specified slot or, if no slot is given, in the next available slot. - :param slot: (optional) - :param wait: Wait for the creation - :returns: The newly created interface + :param slot: The slot in which the interface will be created. + :param wait: Wait for the creation to complete. + :returns: The newly created interface. """ return self.lab.create_interface(self, slot, wait=wait) @locked - def next_available_interface(self) -> Optional[Interface]: + def next_available_interface(self) -> Interface | None: """ - Returns the next available physical interface on this node. + Return the next available physical interface on this node. Note: On XR 9000v, the first two physical interfaces are marked as "do not use"... Only the third physical interface can be used to connect to other nodes! - :returns: an interface or None, if all existing ones are connected + :returns: An available physical interface or None if all existing + ones are connected. """ for iface in self.interfaces(): if not iface.connected and iface.physical: @@ -204,6 +249,7 @@ def next_available_interface(self) -> Optional[Interface]: @locked def peer_interfaces(self) -> list[Interface]: + """Return a list of interfaces connected to this node.""" peer_ifaces = [] for iface in self.interfaces(): peer_iface = iface.peer_interface @@ -213,212 +259,274 @@ def peer_interfaces(self) -> list[Interface]: @locked def peer_nodes(self) -> list[Node]: + """Return a list of nodes connected to this node.""" return list({iface.node for iface in self.peer_interfaces()}) @locked def links(self) -> list[Link]: + """Return a list of links connected to this node.""" return list( {link for iface in self.interfaces() if (link := iface.link) is not None} ) @locked def degree(self) -> int: + """Return the degree of the node.""" self.lab.sync_topology_if_outdated() return len(self.links()) @property def id(self) -> str: + """Return the ID of the node.""" return self._id @property def label(self) -> str: + """Return the label of the node.""" self.lab.sync_topology_if_outdated() return self._label @label.setter @locked def label(self, value: str) -> None: + """Set the label of the node to the given value.""" self._set_node_property("label", value) self._label = value @property def x(self) -> int: + """Return the X coordinate of the node.""" self.lab.sync_topology_if_outdated() return self._x @x.setter @locked def x(self, value: int) -> None: + """Set the X coordinate of the node to the given value.""" self._set_node_property("x", value) self._x = value @property def y(self) -> int: + """Return the Y coordinate of the node.""" self.lab.sync_topology_if_outdated() return self._y @y.setter @locked def y(self, value: int) -> None: + """Set the Y coordinate of the node to the given value.""" self._set_node_property("y", value) self._y = value @property def ram(self) -> int: + """Return the RAM size of the node in bytes.""" self.lab.sync_topology_if_outdated() return self._ram @ram.setter @locked def ram(self, value: int) -> None: + """Set the RAM size of the node to the given value in bytes.""" self._set_node_property("ram", value) self._ram = value @property def cpus(self) -> int: + """Return the number of CPUs assigned to the node.""" self.lab.sync_topology_if_outdated() return self._cpus @cpus.setter @locked def cpus(self, value: int) -> None: + """Set the number of CPUs assigned to the node.""" self._set_node_property("cpus", value) self._cpus = value @property def cpu_limit(self) -> int: + """Return the CPU limit of the node.""" self.lab.sync_topology_if_outdated() return self._cpu_limit @cpu_limit.setter @locked def cpu_limit(self, value: int) -> None: + """Set the CPU limit of the node.""" self._set_node_property("cpu_limit", value) self._cpu_limit = value @property def data_volume(self) -> int: + """Return the size (in GiB) of the second HDD.""" self.lab.sync_topology_if_outdated() return self._data_volume @data_volume.setter @locked def data_volume(self, value: int) -> None: + """Set the size (in GiB) of the second HDD.""" self._set_node_property("data_volume", value) self._data_volume = value @property def hide_links(self) -> bool: + """ + Return a flag indicating whether the node's links should be hidden + in UI visualization. + """ self.lab.sync_topology_if_outdated() return self._hide_links @hide_links.setter def hide_links(self, value: bool) -> None: + """ + Set the flag indicating whether the node's links should be hidden + in UI visualization. + """ self._set_node_property("hide_links", value) self._hide_links = value @property def boot_disk_size(self) -> int: + """Return the size of the boot disk in GiB.""" self.lab.sync_topology_if_outdated() return self._boot_disk_size @boot_disk_size.setter @locked def boot_disk_size(self, value: int) -> None: + """Set the size of the boot disk in GiB (will expand to this size).""" self._set_node_property("boot_disk_size", value) self._boot_disk_size = value @property - def configuration(self) -> Optional[str]: - # TODO: auto sync if out of date + def configuration(self) -> str | None: + """Return the initial configuration of this node.""" + self.lab._sync_topology(exclude_configurations=False) return self._configuration @configuration.setter def configuration(self, value) -> None: + """Set the initial configuration of this node.""" self._set_node_property("configuration", value) self._configuration = value @property - def config(self) -> Optional[str]: - warnings.warn(CONFIG_WARNING, DeprecationWarning) + def config(self) -> str | None: + """ + DEPRECATED: Use `.configuration` instead. + (Reason: consistency with API) + + Return the initial configuration of this node. + """ + warnings.warn( + "'Node.config' is deprecated. Use '.configuration' instead.", + DeprecationWarning, + ) return self.configuration @config.setter @locked def config(self, value: str) -> None: - warnings.warn(CONFIG_WARNING, DeprecationWarning) + """ + DEPRECATED: Use `.configuration` instead. + (Reason: consistency with API) + + Set the initial configuration of this node. + """ + warnings.warn( + "'Node.config' is deprecated. Use '.configuration' instead.", + DeprecationWarning, + ) self.configuration = value @property - def image_definition(self) -> Optional[str]: + def image_definition(self) -> str | None: + """Return the definition of the image used by this node.""" self.lab.sync_topology_if_outdated() return self._image_definition @image_definition.setter @locked def image_definition(self, value: str) -> None: + """Set the definition of the image used by this node.""" self.lab.sync_topology_if_outdated() self._set_node_property("image_definition", value) self._image_definition = value @property def node_definition(self) -> str: + """Return the definition of this node.""" self.lab.sync_topology_if_outdated() return self._node_definition @property def compute_id(self): + """Return the ID of the compute this node is assigned to.""" self.lab.sync_operational_if_outdated() return self._compute_id @property def resource_pool(self) -> str: + """Return the ID of the resource pool if the node is part of a resource pool.""" self.lab.sync_operational_if_outdated() return self._resource_pool - @property - def lab_base_url(self) -> str: - return self.lab.lab_base_url - - @property - def _base_url(self) -> str: - return self.lab_base_url + "/nodes/{}".format(self.id) - @property def cpu_usage(self) -> int | float: + """Return the CPU usage of this node.""" self.lab.sync_statistics_if_outdated() return min(self.statistics["cpu_usage"], 100) @property def disk_read(self) -> int: + """Return the amount of disk read by this node.""" self.lab.sync_statistics_if_outdated() return round(self.statistics["disk_read"] / 1048576) @property def disk_write(self) -> int: + """Return the amount of disk write by this node.""" self.lab.sync_statistics_if_outdated() return round(self.statistics["disk_write"] / 1048576) @locked def get_interface_by_label(self, label: str) -> Interface: + """ + Get the interface of the node with the specified label. + + :param label: The label of the interface. + :returns: The interface with the specified label. + :raises InterfaceNotFound: If no interface with the specified label is found. + """ for iface in self.interfaces(): if iface.label == label: return iface - raise InterfaceNotFound("{}:{}".format(label, self)) + raise InterfaceNotFound(f"{label}:{self}") @locked def get_interface_by_slot(self, slot: int) -> Interface: + """ + Get the interface of the node with the specified slot. + + :param slot: The slot number of the interface. + :returns: The interface with the specified slot. + :raises InterfaceNotFound: If no interface with the specified slot is found. + """ for iface in self.interfaces(): if iface.slot == slot: return iface - raise InterfaceNotFound("{}:{}".format(slot, self)) + raise InterfaceNotFound(f"{slot}:{self}") def get_links_to(self, other_node: Node) -> list[Link]: """ - Returns all links between this node and another. + Return all links between this node and another. - :param other_node: the other node - :returns: a list of links + :param other_node: The other node. + :returns: A list of links between this node and the other node. """ links = [] for link in self.links(): @@ -426,12 +534,12 @@ def get_links_to(self, other_node: Node) -> list[Link]: links.append(link) return links - def get_link_to(self, other_node: Node) -> Optional[Link]: + def get_link_to(self, other_node: Node) -> Link | None: """ - Returns one link between this node and another. + Return one link between this node and another. - :param other_node: the other node - :returns: a link, if one exists + :param other_node: The other node. + :returns: A link between this node and the other node, if one exists. """ for link in self.links(): if other_node in link.nodes: @@ -440,9 +548,17 @@ def get_link_to(self, other_node: Node) -> Optional[Link]: @check_stale def wait_until_converged( - self, max_iterations: Optional[int] = None, wait_time: Optional[int] = None + self, max_iterations: int | None = None, wait_time: int | None = None ) -> None: - _LOGGER.info("Waiting for node %s to converge", self.id) + """ + Wait until the node has converged. + + :param max_iterations: The maximum number of iterations to wait for convergence. + :param wait_time: The time to wait between iterations. + :raises RuntimeError: If the node does not converge within the specified number + of iterations. + """ + _LOGGER.info(f"Waiting for node {self.id} to converge.") max_iter = ( self.lab.wait_max_iterations if max_iterations is None else max_iterations ) @@ -450,21 +566,16 @@ def wait_until_converged( for index in range(max_iter): converged = self.has_converged() if converged: - _LOGGER.info("Node %s has converged", self.id) + _LOGGER.info(f"Node {self.id} has converged.") return if index % 10 == 0: _LOGGER.info( - "Node has not converged, attempt %s/%s, waiting...", - index, - max_iter, + f"Node has not converged, attempt {index}/{max_iter}, waiting..." ) time.sleep(wait_time) - msg = "Node %s has not converged, maximum tries %s exceeded" % ( - self.id, - max_iter, - ) + msg = f"Node {self.id} has not converged, maximum tries {max_iter} exceeded." _LOGGER.error(msg) # after maximum retries are exceeded and node has not converged # error must be raised - it makes no sense to just log info @@ -474,75 +585,134 @@ def wait_until_converged( @check_stale def has_converged(self) -> bool: - url = self._base_url + "/check_if_converged" + """ + Check if the node has converged. + + :returns: True if the node has converged, False otherwise. + """ + url = self._url_for("check_if_converged") return self._session.get(url).json() @check_stale def start(self, wait=False) -> None: - url = self._base_url + "/state/start" + """ + Start the node. + + :param wait: Whether to wait until the node has converged. + """ + url = self._url_for("start") self._session.put(url) if self.lab.need_to_wait(wait): self.wait_until_converged() @check_stale def stop(self, wait=False) -> None: - url = self._base_url + "/state/stop" + """ + Stop the node. + + :param wait: Whether to wait until the node has converged. + """ + url = self._url_for("stop") self._session.put(url) if self.lab.need_to_wait(wait): self.wait_until_converged() @check_stale def wipe(self, wait=False) -> None: - url = self._base_url + "/wipe_disks" + """ + Wipe the node's disks. + + :param wait: Whether to wait until the node has converged. + """ + url = self._url_for("wipe_disks") self._session.put(url) if self.lab.need_to_wait(wait): self.wait_until_converged() @check_stale def extract_configuration(self) -> None: - url = self._base_url + "/extract_configuration" + """Update the configuration from the running node.""" + url = self._url_for("extract_configuration") self._session.put(url) @check_stale - def console_logs(self, console_id: int, lines: Optional[int] = None) -> dict: - query = "?lines=%d" % lines if lines else "" - url = self._base_url + "/consoles/%d/log%s" % (console_id, query) + def console_logs(self, console_id: int, lines: int | None = None) -> dict: + """ + Get the console logs of the node. + + :param console_id: The ID of the console. + :param lines: Limit the number of lines to retrieve. + :returns: A dictionary containing the console logs. + """ + if lines: + url = self._url_for("console_log_lines", console_id=console_id, lines=lines) + else: + url = self._url_for("console_log", console_id=console_id) return self._session.get(url).json() @check_stale def console_key(self) -> str: - url = self._base_url + "/keys/console" + """ + Get the console key of the node. + + :returns: The console key. + """ + url = self._url_for("console_key") return self._session.get(url).json() @check_stale def vnc_key(self) -> str: - url = self._base_url + "/keys/vnc" + """ + Get the VNC key of the node. + + :returns: The VNC key. + """ + url = self._url_for("vnc_key") return self._session.get(url).json() + @check_stale def remove(self) -> None: + """Remove the node from the system.""" self.lab.remove_node(self) @check_stale def _remove_on_server(self) -> None: - _LOGGER.info("Removing node %s", self) - url = self._base_url + """Helper function to remove the node from the server.""" + _LOGGER.info(f"Removing node {self}") + url = self._url_for("node") self._session.delete(url) def remove_on_server(self) -> None: + """ + DEPRECATED: Use `.remove()` instead. + (Reason: was never meant to be public, removing only on server is not useful) + + Remove the node on the server. + """ warnings.warn( - "'Node.remove_on_server()' is deprecated, use 'Node.remove()' instead.", + "'Node.remove_on_server()' is deprecated. Use '.remove()' instead.", DeprecationWarning, ) + # To not change behavior of scripts, this will still remove on server only. self._remove_on_server() @check_stale def tags(self) -> list[str]: - """Returns the tags set on this node""" + """ + Get the tags set on this node. + + :returns: A list of tags. + """ self.lab.sync_topology_if_outdated() return self._tags @locked def add_tag(self, tag: str) -> None: + """ + Add a tag to this node. + + :param tag: The tag to add. + """ current = self.tags() if tag not in current: current.append(tag) @@ -550,24 +720,31 @@ def add_tag(self, tag: str) -> None: @locked def remove_tag(self, tag: str) -> None: + """ + Remove a tag from this node. + + :param tag: The tag to remove. + """ current = self.tags() current.remove(tag) self._set_node_property("tags", current) def run_pyats_command(self, command: str) -> str: - """Run a pyATS command in exec mode. + """ + Run a pyATS command in exec mode on the node. - :param command: the command (like "show version") - :returns: the output from the device + :param command: The command to run (e.g. "show version"). + :returns: The output from the device. """ label = self.label return self.lab.pyats.run_command(label, command) def run_pyats_config_command(self, command: str) -> str: - """Run a pyATS command in config mode. + """ + Run a pyATS command in config mode on the node. - :param command: the command (like "interface gi0") - :returns: the output from the device + :param command: The command to run (e.g. "interface gi0"). + :returns: The output from the device. """ label = self.label return self.lab.pyats.run_config_command(label, command) @@ -575,13 +752,13 @@ def run_pyats_config_command(self, command: str) -> str: @check_stale @locked def sync_layer3_addresses(self) -> None: - """Acquire all L3 addresses from the controller. For this - to work, the device has to be attached to the external network + """ + Acquire all layer 3 addresses from the controller. + + For this to work, the device has to be attached to the external network in bridge mode and must run DHCP to acquire an IP address. """ - # TODO: can optimise the sync of l3 to only be for the node - # rather than whole lab - url = self._base_url + "/layer3_addresses" + url = self._url_for("layer3_addresses") result = self._session.get(url).json() interfaces = result.get("interfaces", {}) self.map_l3_addresses_to_interfaces(interfaces) @@ -589,8 +766,13 @@ def sync_layer3_addresses(self) -> None: @check_stale @locked def sync_operational(self, response: dict[str, Any] = None): + """ + Synchronize the operational state of the node. + + :param response: The response from the server. + """ if response is None: - url = self._base_url + "?operational=true" + url = self._url_for("operational") response = self._session.get(url).json() operational = response.get("operational", {}) self._compute_id = operational.get("compute_id") @@ -601,21 +783,26 @@ def sync_operational(self, response: dict[str, Any] = None): def map_l3_addresses_to_interfaces( self, mapping: dict[str, dict[str, str]] ) -> None: + """ + Map layer 3 addresses to interfaces. + + :param mapping: A dictionary mapping MAC addresses to interface information. + """ for mac_address, entry in mapping.items(): - label = entry.get("label") - if not label: + if not (label := entry.get("label")): continue - ipv4 = entry.get("ip4") - ipv6 = entry.get("ip6") try: iface = self.get_interface_by_label(label) except InterfaceNotFound: continue - iface.ip_snooped_info = { + ipv4 = entry.get("ip4") + ipv6 = entry.get("ip6") + iface._ip_snooped_info = { "mac_address": mac_address, "ipv4": ipv4, "ipv6": ipv6, } + self._last_sync_l3_address_time = time.time() @check_stale @locked @@ -623,8 +810,16 @@ def update( self, node_data: dict[str, Any], exclude_configurations: bool, - push_to_server: bool = False, + push_to_server: bool = True, ) -> None: + """ + Update the node with the provided data. + + :param node_data: The data to update the node with. + :param exclude_configurations: Whether to exclude configuration updates. + :param push_to_server: Whether to push the changes to the server. + Defaults to True; should only be False when used by internal methods. + """ if push_to_server: self._set_node_properties(node_data) if "data" in node_data: @@ -639,16 +834,38 @@ def update( setattr(self, f"_{key}", value) def is_active(self) -> bool: + """ + Check if the node is in an active state. + + :returns: True if the node is in an active state, False otherwise. + """ active_states = {"STARTED", "QUEUED", "BOOTED"} return self.state in active_states def is_booted(self) -> bool: + """ + Check if the node is booted. + + :returns: True if the node is booted, False otherwise. + """ return self.state == "BOOTED" def _set_node_property(self, key: str, val: Any) -> None: - _LOGGER.debug("Setting node property %s %s: %s", self, key, val) + """ + Set a property of the node. + + :param key: The key of the property to set. + :param val: The value to set. + """ + _LOGGER.debug(f"Setting node property {self} {key}: {val}") self._set_node_properties({key: val}) @check_stale def _set_node_properties(self, node_data: dict[str, Any]) -> None: - self._session.patch(url=self._base_url, json=node_data) + """ + Set multiple properties of the node. + + :param node_data: A dictionary containing the properties to set. + """ + url = self._url_for("node") + self._session.patch(url, json=node_data) diff --git a/virl2_client/models/node_image_definitions.py b/virl2_client/models/node_image_definitions.py index 8a1b525..e4a6d39 100644 --- a/virl2_client/models/node_image_definitions.py +++ b/virl2_client/models/node_image_definitions.py @@ -24,62 +24,75 @@ import pathlib import time import warnings -from typing import TYPE_CHECKING, BinaryIO, Callable, Optional +from typing import TYPE_CHECKING, BinaryIO, Callable -from virl2_client.exceptions import InvalidContentType, InvalidImageFile +from ..exceptions import InvalidContentType, InvalidImageFile +from ..utils import get_url_from_template if TYPE_CHECKING: import httpx class NodeImageDefinitions: + _URL_TEMPLATES = { + "node_defs": "node_definitions/", + "image_defs": "image_definitions/", + "node_def": "node_definitions/{definition_id}", + "image_def": "image_definitions/{definition_id}", + "node_image_defs": "node_definitions/{definition_id}/image_definitions", + "upload": "images/upload", + "image_list": "list_image_definition_drop_folder/", + "image_manage": "images/manage/{filename}", + } + def __init__(self, session: httpx.Client) -> None: """ - VIRL2 Definition classes to specify a node VM and associated disk images. + Manage node and image definitions. + + Node definitions define the properties of a virtual network node. + Image definitions define disk images that are required to boot a network node. + Together, they define a complete virtual network node. - :param session: httpx Client session + :param session: The httpx-based HTTP client for this session with the server. """ self._session = session - @property - def session(self) -> httpx.Client: + def _url_for(self, endpoint, **kwargs): """ - Returns the used httpx client session object. + Generate the URL for a given API endpoint. - :returns: The session object + :param endpoint: The desired endpoint. + :param **kwargs: Keyword arguments used to format the URL. + :returns: The formatted URL. """ - return self._session + return get_url_from_template(endpoint, self._URL_TEMPLATES, kwargs) def node_definitions(self) -> list[dict]: """ - Returns all node definitions. + Return all node definitions. - :return: list of node definitions + :returns: A list of node definitions. """ - url = "node_definitions/" + url = self._url_for("node_defs") return self._session.get(url).json() def image_definitions(self) -> list[dict]: """ - Returns all image definitions. + Return all image definitions. - :return: list of image definitions + :returns: A list of image definitions. """ - url = "image_definitions/" + url = self._url_for("image_defs") return self._session.get(url).json() def image_definitions_for_node_definition(self, definition_id: str) -> list[dict]: """ - Returns all image definitions for a given node definition - - example:: + Return all image definitions for a given node definition. - client_library.definitions.image_definitions_for_node_definition("iosv") - - :param definition_id: node definition id - :returns: list of image definition objects + :param definition_id: The ID of the node definition. + :returns: A list of image definition objects. """ - url = "node_definitions/" + definition_id + "/image_definitions" + url = self._url_for("node_image_defs", definition_id=definition_id) return self._session.get(url).json() def set_image_definition_read_only( @@ -110,21 +123,22 @@ def set_node_definition_read_only( def upload_node_definition(self, body: str | dict, json: bool | None = None) -> str: """ - Uploads a new node definition. + Upload a new node definition. - :param body: node definition (yaml or json) - :param json: DEPRECATED, replaced with type check - :return: "Success" + :param body: The node definition (yaml or json). + :param json: DEPRECATED: Replaced with type check. + :returns: "Success". """ is_json = _is_json_content(body) if json is not None: warnings.warn( - 'The argument "json" is deprecated as the content type is determined ' - "from the provided body", + "'NodeImageDefinitions.upload_node_definition()': " + "The argument 'json' is deprecated as the content type " + "is determined from the provided 'body'.", DeprecationWarning, ) is_json = True - url = "node_definitions/" + url = self._url_for("node_defs") if is_json: return self._session.post(url, json=body).json() else: @@ -135,21 +149,22 @@ def upload_image_definition( self, body: str | dict, json: bool | None = None ) -> str: """ - Uploads a new image definition. + Upload a new image definition. - :param body: image definition (yaml or json) - :param json: DEPRECATED, replaced with type check - :return: "Success" + :param body: The image definition (yaml or json). + :param json: DEPRECATED: Replaced with type check. + :returns: "Success". """ is_json = _is_json_content(body) if json is not None: warnings.warn( - 'The argument "json" is deprecated as the content type is determined ' - "from the provided body", + "'NodeImageDefinitions.upload_image_definition()': " + "The argument 'json' is deprecated as the content type " + "is determined from the provided 'body'.", DeprecationWarning, ) is_json = True - url = "image_definitions/" + url = self._url_for("image_defs") if is_json: return self._session.post(url, json=body).json() else: @@ -158,51 +173,45 @@ def upload_image_definition( def download_node_definition(self, definition_id: str) -> str: """ - Returns the node definition for a given definition ID + Return the node definition for a given definition ID. Example:: client_library.definitions.download_node_definition("iosv") - :param definition_id: the node definition ID - :returns: the node definition as YAML + :param definition_id: The ID of the node definition. + :returns: The node definition as YAML. """ - url = "node_definitions/" + definition_id + url = self._url_for("node_def", definition_id=definition_id) return self._session.get(url).json() def download_image_definition(self, definition_id: str) -> str: """ + Return the image definition for a given definition ID. + Example:: client_library.definitions.download_image_definition("iosv-158-3") - :param definition_id: the image definition ID - :returns: the image definition as YAML + :param definition_id: The ID of the image definition. + :returns: The image definition as YAML. """ - - url = "image_definitions/" + definition_id + url = self._url_for("image_def", definition_id=definition_id) return self._session.get(url).json() def upload_image_file( self, filename: str, - rename: Optional[str] = None, - chunk_size_mb: Optional[int] = None, + rename: str | None = None, ) -> None: """ - :param filename: the path of the image to upload - :param rename: Optional filename to rename to - :param chunk_size_mb: Optional size of upload chunk (mb) - (deprecated since 2.2.0) - """ - if chunk_size_mb is not None: - warnings.warn( - 'The argument "chunk_size_mb" is deprecated as it never worked', - DeprecationWarning, - ) + Upload an image file. + :param filename: The path of the image to upload. + :param rename: Optional filename to rename to. + """ extension_list = [".qcow", ".qcow2"] - url = "images/upload" + url = self._url_for("upload") path = pathlib.Path(filename) extension = "".join(path.suffixes) @@ -245,55 +254,69 @@ def callback_read(__n): return callback_read - file = open(filename, "rb") - file.read = callback_read_factory(file, print_progress_bar) - files = {"field0": (name, file)} + _file = open(filename, "rb") + _file.read = callback_read_factory(_file, print_progress_bar) + files = {"field0": (name, _file)} self._session.post(url, files=files, headers=headers) print("Upload completed") def download_image_file_list(self) -> list[str]: - url = "list_image_definition_drop_folder/" + """ + Return a list of image files. + + :returns: A list of image file names. + """ + url = self._url_for("image_list") return self._session.get(url).json() def remove_dropfolder_image(self, filename: str) -> str: """ - :returns: "Success" + Remove an image file from the drop folder. + + :param filename: The name of the image file to remove. + :returns: "Success". """ - url = "images/manage/" + filename + url = self._url_for("image_manage", filename=filename) return self._session.delete(url).json() def remove_node_definition(self, definition_id: str) -> None: """ - Removes the node definition with the given ID. + Remove the node definition with the given ID. Example:: client_library.definitions.remove_node_definition("iosv-custom") - :param definition_id: the definition ID to delete + :param definition_id: The ID of the node definition to remove. """ - - url = "node_definitions/" + definition_id + url = self._url_for("node_def", definition_id=definition_id) self._session.delete(url) def remove_image_definition(self, definition_id: str) -> None: """ - Removes the image definition with the given ID. + Remove the image definition with the given ID. Example:: client_library.definitions.remove_image_definition("iosv-158-3-custom") - :param definition_id: the image definition ID to remove + :param definition_id: The ID of the image definition to remove. """ - - url = "image_definitions/" + definition_id + url = self._url_for("image_def", definition_id=definition_id) self._session.delete(url) def print_progress_bar(cur: int, total: int, start_time: float, length=50) -> None: - percent = "{0:.1f}".format(100 * (cur / float(total))) + """ + Print a progress bar. + + :param cur: The current progress value. + :param total: The total progress value. + :param start_time: The start time of the progress. + :param length: The length of the progress bar. + """ + percent = f"{100 * (cur / float(total)):.1f}" filled_len = int(round(length * cur / float(total))) bar = "#" * filled_len + "-" * (length - filled_len) raw_elapsed = time.time() - start_time @@ -308,6 +331,13 @@ def print_progress_bar(cur: int, total: int, start_time: float, length=50) -> No def _is_json_content(content: dict | str) -> bool: + """ + Check if the content is JSON. + + :param content: The content to check. + :returns: True if the content is JSON, False otherwise. + :raises InvalidContentType: If the content type is invalid. + """ if isinstance(content, dict): return True elif isinstance(content, str): diff --git a/virl2_client/models/resource_pools.py b/virl2_client/models/resource_pools.py index 27a2d7f..481f325 100644 --- a/virl2_client/models/resource_pools.py +++ b/virl2_client/models/resource_pools.py @@ -23,9 +23,10 @@ import logging import time from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional +from typing import TYPE_CHECKING, Any, Dict, Iterable -from virl2_client.exceptions import InvalidProperty +from ..exceptions import InvalidProperty +from ..utils import get_url_from_template if TYPE_CHECKING: import httpx @@ -38,13 +39,18 @@ class ResourcePoolManagement: # updatable in both the client itself and its lab objects, as otherwise # the lab might receive the ID of a resource pool that doesn't exist yet # in the client, because the client wasn't updated recently + _URL_TEMPLATES = { + "resource_pools": "resource_pools", + "resource_pools_data": "resource_pools?data=true", + } + def __init__(self, session: httpx.Client, auto_sync=True, auto_sync_interval=1.0): """ - Sync and modify resource pools. + Manage and synchronize resource pools. - :param session: parent client's httpx.Client object - :param auto_sync: whether to automatically synchronize resource pools - :param auto_sync_interval: how often to synchronize resource pools in seconds + :param session: The httpx-based HTTP client for this session with the server. + :param auto_sync: Whether to automatically synchronize resource pools. + :param auto_sync_interval: How often to synchronize resource pools in seconds. """ self._session = session self.auto_sync = auto_sync @@ -52,16 +58,24 @@ def __init__(self, session: httpx.Client, auto_sync=True, auto_sync_interval=1.0 self._last_sync_resource_pool_time = 0.0 self._resource_pools: ResourcePools = {} - @property - def session(self) -> httpx.Client: - return self._session + def _url_for(self, endpoint, **kwargs): + """ + Generate the URL for a given API endpoint. + + :param endpoint: The desired endpoint. + :param **kwargs: Keyword arguments used to format the URL. + :returns: The formatted URL. + """ + return get_url_from_template(endpoint, self._URL_TEMPLATES, kwargs) @property def resource_pools(self) -> ResourcePools: + """Return all resource pools.""" self.sync_resource_pools_if_outdated() return self._resource_pools def sync_resource_pools_if_outdated(self) -> None: + """Synchronize resource pools if they are outdated.""" timestamp = time.time() if ( self.auto_sync @@ -70,7 +84,8 @@ def sync_resource_pools_if_outdated(self) -> None: self.sync_resource_pools() def sync_resource_pools(self) -> None: - url = "resource_pools?data=true" + """Synchronize the resource pools with the server.""" + url = self._url_for("resource_pools_data") res_pools = self._session.get(url).json() res_pool_ids = [] # update existing pools and add entries for new pools @@ -78,7 +93,7 @@ def sync_resource_pools(self) -> None: pool_id = res_pool.pop("id") res_pool["pool_id"] = pool_id if pool_id in self._resource_pools: - self._resource_pools[pool_id].update(res_pool) + self._resource_pools[pool_id].update(res_pool, push_to_server=False) else: self._add_resource_pool_local(**res_pool) res_pool_ids.append(pool_id) @@ -92,14 +107,12 @@ def get_resource_pools_by_ids( self, resource_pool_ids: str | Iterable[str] ) -> ResourcePool | ResourcePools: """ - Converts either one resource pool ID into a `ResourcePool` object - or multiple IDs in an iterable into a dict of IDs and `ResourcePool` objects. - - :param resource_pool_ids: a resource pool ID or an iterable of IDs - :returns: a `ResourcePool` when one ID is passed or, when an iterable of pools - is passed, a dictionary of IDs to `ResourcePool`s or `None`s when - resource pool doesn't exist - :raises KeyError: when one ID is passed and it doesn't exist + Get resource pools by their IDs. + + :param resource_pool_ids: A resource pool ID or an iterable of IDs. + :returns: A ResourcePool object when one ID is passed, or a dictionary of IDs to + either ResourcePools when a resource pool exists or None when it doesn't. + :raises KeyError: When one ID is passed and it doesn't exist. """ self.sync_resource_pools_if_outdated() if isinstance(resource_pool_ids, str): @@ -118,14 +131,15 @@ def create_resource_pool( **kwargs, ) -> ResourcePool: """ - Creates a resource pool with the given parameters. + Create a resource pool with the given parameters. - :param label: required pool label - :param kwargs: other resource pool parameters as accepted by ResourcePool object - :returns: resource pool object + :param label: The label for the resource pools. + :param kwargs: Additional resource pool parameters as accepted by the + ResourcePool object. + :returns: The created ResourcePool object. """ kwargs["label"] = label - url = "resource_pools" + url = self._url_for("resource_pools") response: dict = self._session.post(url, json=kwargs).json() response["pool_id"] = response.pop("id") return self._add_resource_pool_local(**response) @@ -137,20 +151,22 @@ def create_resource_pools( **kwargs, ) -> list[ResourcePool]: """ - Creates a list of resource pools with the given parameters. + Create a list of resource pools with the given parameters. If no template is supplied, a new template pool is created with the specified parameters, and each user is assigned a new pool with no additional limits. If a template pool is supplied, then parameters are applied to each user pool. - :param label: pool label - :param users: list of user IDs for which to create resource pools - :param kwargs: other resource pool parameters as accepted by ResourcePool object - :returns: resource pool objects, with template pool first if created + :param label: The label for the resource pools. + :param users: The list of user IDs for which to create resource pools. + :param kwargs: Additional resource pool parameters as accepted by the + ResourcePool object. + :returns: A list of created resource pool objects, with the template pool first + if created. """ kwargs["label"] = label kwargs["users"] = users kwargs["shared"] = False - url = "resource_pools" + url = self._url_for("resource_pools") response: dict = self._session.post(url, json=kwargs).json() result = [] @@ -163,16 +179,17 @@ def _add_resource_pool_local( self, pool_id: str, label: str, - description: Optional[str] = None, - template: Optional[str] = None, - licenses: Optional[int] = None, - ram: Optional[int] = None, - cpus: Optional[int] = None, - disk_space: Optional[int] = None, - external_connectors: Optional[list[str]] = None, - users: Optional[list[str]] = None, - user_pools: Optional[list[str]] = None, + description: str | None = None, + template: str | None = None, + licenses: int | None = None, + ram: int | None = None, + cpus: int | None = None, + disk_space: int | None = None, + external_connectors: list[str] | None = None, + users: list[str] | None = None, + user_pools: list[str] | None = None, ) -> ResourcePool: + """Helper method to add resource pool locally.""" new_resource_pool = ResourcePool( self, pool_id, @@ -192,23 +209,44 @@ def _add_resource_pool_local( class ResourcePool: + _URL_TEMPLATES = { + "resource_pool": "resource_pools/{pool_id}", + "resource_pool_usage": "resource_pool_usage/{pool_id}", + } + def __init__( self, resource_pools: ResourcePoolManagement, pool_id: str, label: str, - description: Optional[str], - template: Optional[str], - licenses: Optional[int], - ram: Optional[int], - cpus: Optional[int], - disk_space: Optional[int], - external_connectors: Optional[list[str]], - users: Optional[list[str]], - user_pools: Optional[list[str]], + description: str | None, + template: str | None, + licenses: int | None, + ram: int | None, + cpus: int | None, + disk_space: int | None, + external_connectors: list[str] | None, + users: list[str] | None, + user_pools: list[str] | None, ): + """ + Initialize a resource pool. + + :param resource_pools: The parent ResourcePoolManagement object. + :param pool_id: The ID of the resource pool. + :param label: The label of the resource pool. + :param description: The description of the resource pool. + :param template: The template of the resource pool. + :param licenses: The number of licenses in the resource pool. + :param ram: The amount of RAM in the resource pool. + :param cpus: The number of CPUs in the resource pool. + :param disk_space: The amount of disk space in the resource pool. + :param external_connectors: A list of external connectors in the resource pool. + :param users: A list of users in the resource pool. + :param user_pools: A list of user pools in the resource pool. + """ self._resource_pools = resource_pools - self._session: httpx.Client = resource_pools.session + self._session: httpx.Client = resource_pools._session self._id = pool_id self._label = label self._description = description @@ -220,8 +258,6 @@ def __init__( self._external_connectors = external_connectors self._users = users if users is not None else [] self._user_pools = user_pools if user_pools is not None else [] - self._base_url = f"{self._session.base_url}/resource_pools/{pool_id}" - self._usage_base_url = f"{self._session.base_url}/resource_pool_usage/{pool_id}" def __str__(self): return f"Resource pool: {self._label}" @@ -235,113 +271,160 @@ def __repr__(self): self._template, ) + def _url_for(self, endpoint, **kwargs): + """ + Generate the URL for a given API endpoint. + + :param endpoint: The desired endpoint. + :param **kwargs: Keyword arguments used to format the URL. + :returns: The formatted URL. + """ + kwargs["pool_id"] = self.id + return get_url_from_template(endpoint, self._URL_TEMPLATES, kwargs) + @property def id(self) -> str: + """Return the ID of the resource pool.""" return self._id @property def label(self) -> str: + """Set the label of the resource pool.""" self._resource_pools.sync_resource_pools_if_outdated() return self._label @label.setter def label(self, value: str): + """Set the label of the resource pool.""" self._set_resource_pool_property("label", value) self._label = value @property def description(self) -> str: + """Return the description of the resource pool.""" self._resource_pools.sync_resource_pools_if_outdated() return self._description @description.setter def description(self, value: str): + """Set the description of the resource pool.""" self._set_resource_pool_property("description", value) self._description = value @property def template(self) -> str: + """Return the template of the resource pool.""" return self._template @property def is_template(self) -> bool: + """Return whether the resource pool is a template.""" return self._template is None @property def licenses(self) -> int: + """Return the number of licenses in the resource pool.""" self._resource_pools.sync_resource_pools_if_outdated() return self._licenses @licenses.setter def licenses(self, value: int): + """Set the number of licenses in the resource pool.""" self._set_resource_pool_property("licenses", value) self._licenses = value @property def ram(self) -> int: + """Return the amount of RAM in the resource pool.""" self._resource_pools.sync_resource_pools_if_outdated() return self._ram @ram.setter def ram(self, value: int): + """Set the amount of RAM in the resource pool.""" self._set_resource_pool_property("ram", value) self._ram = value @property def cpus(self) -> int: + """Return the number of CPUs in the resource pool.""" self._resource_pools.sync_resource_pools_if_outdated() return self._cpus @cpus.setter def cpus(self, value: int): + """Set the number of CPUs in the resource pool.""" self._set_resource_pool_property("cpus", value) self._cpus = value @property def disk_space(self) -> int: + """Return the amount of disk space in the resource pool.""" self._resource_pools.sync_resource_pools_if_outdated() return self._disk_space @disk_space.setter def disk_space(self, value: int): + """Set the amount of disk space in the resource pool.""" self._set_resource_pool_property("disk_space", value) self._disk_space = value @property - def external_connectors(self) -> Optional[list[str]]: + def external_connectors(self) -> list[str] | None: + """Return a list of external connectors in the resource pool.""" self._resource_pools.sync_resource_pools_if_outdated() - return self._external_connectors + return ( + self._external_connectors.copy() + if self._external_connectors is not None + else None + ) @external_connectors.setter - def external_connectors(self, value: Optional[list[str]]): + def external_connectors(self, value: list[str] | None): + """Set the external connectors in the resource pool.""" self._set_resource_pool_property("external_connectors", value) self._external_connectors = value @property def users(self) -> list[str]: + """Return the list of users in the resource pool.""" if self.is_template: - raise InvalidProperty("A template does not have 'users' property") + raise InvalidProperty("A template does not have the 'users' property.") self._resource_pools.sync_resource_pools_if_outdated() return self._users @property def user_pools(self) -> list[str]: + """Return the list of user pools in the template.""" if not self.is_template: - raise InvalidProperty("A resource pool does not have 'user_pools' property") + raise InvalidProperty( + "A resource pool does not have the 'user_pools' property." + ) self._resource_pools.sync_resource_pools_if_outdated() return self._user_pools def get_usage(self) -> ResourcePoolUsage: - result = self._session.get(self._usage_base_url).json() + """Get the usage stats of the resource pool.""" + url = self._url_for("resource_pool_usage") + result = self._session.get(url).json() limit = ResourcePoolUsageBase(**result["limit"]) usage = ResourcePoolUsageBase(**result["usage"]) return ResourcePoolUsage(limit, usage) def remove(self) -> None: - _LOGGER.info("Removing resource pool %s", self) - self._session.delete(self._base_url) + """Remove the resource pool.""" + _LOGGER.info(f"Removing resource pool {self}") + url = self._url_for("resource_pool") + self._session.delete(url) + + def update(self, pool_data: dict[str, Any], push_to_server: bool = True): + """ + Update multiple properties of the pool at once. - def update(self, pool_data: dict[str, Any], push_to_server: bool = False): + :param pool_data: A dictionary of the properties to update. + :param push_to_server: Whether to push the changes to the server. + Defaults to True; should only be False when used by internal methods. + """ if push_to_server: self._set_resource_pool_properties(pool_data) @@ -349,15 +432,18 @@ def update(self, pool_data: dict[str, Any], push_to_server: bool = False): setattr(self, f"_{key}", value) def _set_resource_pool_property(self, key: str, val: Any) -> None: - _LOGGER.debug("Setting resource pool property %s %s: %s", self, key, val) + """Helper method to set a property on the server.""" + _LOGGER.debug(f"Setting resource pool property {self} {key}: {val}") self._set_resource_pool_properties({key: val}) def _set_resource_pool_properties(self, resource_pool_data: dict[str, Any]) -> None: + """Helper method to set multiple properties on the server.""" for key in list(resource_pool_data): # drop unmodifiable properties if key in ("id", "template", "users", "user_pools"): resource_pool_data.pop(key) - self._session.patch(url=self._base_url, json=resource_pool_data) + url = self._url_for("resource_pool") + self._session.patch(url, json=resource_pool_data) ResourcePools = Dict[str, ResourcePool] diff --git a/virl2_client/models/system.py b/virl2_client/models/system.py index 1d5fe83..a3047bf 100644 --- a/virl2_client/models/system.py +++ b/virl2_client/models/system.py @@ -22,7 +22,9 @@ import logging import time -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any + +from ..utils import get_url_from_template if TYPE_CHECKING: import httpx @@ -31,7 +33,31 @@ class SystemManagement: - def __init__(self, session: httpx.Client, auto_sync=True, auto_sync_interval=1.0): + _URL_TEMPLATES = { + "maintenance_mode": "system/maintenance_mode", + "compute_hosts": "system/compute_hosts", + "notices": "system/notices", + "external_connectors": "system/external_connectors", + "external_connector": "system/external_connectors/{connector_id}", + "web_session_timeout": "web_session_timeout/{timeout}", + "mac_address_block": "mac_address_block/{block}", + "host_configuration": "system/compute_hosts/configuration", + } + + def __init__( + self, + session: httpx.Client, + auto_sync: bool = True, + auto_sync_interval: float = 1.0, + ): + """ + Manage the underlying controller software and the host system where it runs. + + :param session: The httpx-based HTTP client for this session with the server. + :param auto_sync: A boolean indicating whether auto synchronization is enabled. + :param auto_sync_interval: The interval in seconds between auto + synchronizations. + """ self._session = session self.auto_sync = auto_sync self.auto_sync_interval = auto_sync_interval @@ -40,53 +66,66 @@ def __init__(self, session: httpx.Client, auto_sync=True, auto_sync_interval=1.0 self._compute_hosts: dict[str, ComputeHost] = {} self._system_notices: dict[str, SystemNotice] = {} self._maintenance_mode = False - self._maintenance_notice: Optional[SystemNotice] = None + self._maintenance_notice: SystemNotice | None = None - @property - def session(self) -> httpx.Client: - return self._session + def _url_for(self, endpoint, **kwargs): + """ + Generate the URL for a given API endpoint. + + :param endpoint: The desired endpoint. + :param **kwargs: Keyword arguments used to format the URL. + :returns: The formatted URL. + """ + return get_url_from_template(endpoint, self._URL_TEMPLATES, kwargs) @property def compute_hosts(self) -> dict[str, ComputeHost]: + """Return a dictionary of compute hosts.""" self.sync_compute_hosts_if_outdated() - return self._compute_hosts + return self._compute_hosts.copy() @property def system_notices(self) -> dict[str, SystemNotice]: + """Return a dictionary of system notices.""" self.sync_system_notices_if_outdated() - return self._system_notices + return self._system_notices.copy() @property def maintenance_mode(self) -> bool: + """Return the maintenance mode status.""" self.sync_system_notices_if_outdated() return self._maintenance_mode @maintenance_mode.setter def maintenance_mode(self, value: bool) -> None: - url = "system/maintenance_mode" + """Set the maintenance mode status.""" + url = self._url_for("maintenance_mode") self._session.patch(url, json={"maintenance_mode": value}) self._maintenance_mode = value @property - def maintenance_notice(self) -> Optional[SystemNotice]: + def maintenance_notice(self) -> SystemNotice | None: + """Return the current maintenance notice.""" self.sync_system_notices_if_outdated() return self._maintenance_notice @maintenance_notice.setter - def maintenance_notice(self, notice: Optional[SystemNotice]) -> None: - url = "system/maintenance_mode" + def maintenance_notice(self, notice: SystemNotice | None) -> None: + """Set the maintenance notice.""" + url = self._url_for("maintenance_mode") notice_id = None if notice is None else notice.id - result = self._session.patch(url, json={"notice": notice_id}) + result: dict = self._session.patch(url, json={"notice": notice_id}).json() resolved = result["resolved_notice"] if resolved is None: notice = None else: notice = self._system_notices.get(resolved["id"]) if notice is not None and resolved is not None: - notice.update(resolved) + notice.update(resolved, push_to_server=False) self._maintenance_notice = notice def sync_compute_hosts_if_outdated(self) -> None: + """Synchronize compute hosts if they are outdated.""" timestamp = time.time() if ( self.auto_sync @@ -95,6 +134,7 @@ def sync_compute_hosts_if_outdated(self) -> None: self.sync_compute_hosts() def sync_system_notices_if_outdated(self) -> None: + """Synchronize system notices if they are outdated.""" timestamp = time.time() if ( self.auto_sync @@ -103,7 +143,8 @@ def sync_system_notices_if_outdated(self) -> None: self.sync_system_notices() def sync_compute_hosts(self) -> None: - url = "system/compute_hosts" + """Synchronize compute hosts from the server.""" + url = self._url_for("compute_hosts") compute_hosts = self._session.get(url).json() compute_host_ids = [] @@ -111,7 +152,9 @@ def sync_compute_hosts(self) -> None: compute_id = compute_host.pop("id") compute_host["compute_id"] = compute_id if compute_id in self._compute_hosts: - self._compute_hosts[compute_id].update(compute_host) + self._compute_hosts[compute_id].update( + compute_host, push_to_server=False + ) else: self.add_compute_host_local(**compute_host) compute_host_ids.append(compute_id) @@ -122,14 +165,17 @@ def sync_compute_hosts(self) -> None: self._last_sync_compute_host_time = time.time() def sync_system_notices(self) -> None: - url = "system/notices" + """Synchronize system notices from the server.""" + url = self._url_for("notices") system_notices = self._session.get(url).json() system_notice_ids = [] for system_notice in system_notices: notice_id = system_notice.get("id") if notice_id in self._system_notices: - self._system_notices[notice_id].update(system_notice) + self._system_notices[notice_id].update( + system_notice, push_to_server=False + ) else: self.add_system_notice_local(**system_notice) system_notice_ids.append(notice_id) @@ -138,8 +184,8 @@ def sync_system_notices(self) -> None: if notice_id not in system_notice_ids: self._system_notices.pop(notice_id) - url = "system/maintenance_mode" - maintenance = self.session.get(url).json() + url = self._url_for("maintenance_mode") + maintenance = self._session.get(url).json() self._maintenance_mode = maintenance["maintenance_mode"] notice_id = maintenance["notice"] if notice_id is None: @@ -151,14 +197,14 @@ def sync_system_notices(self) -> None: def get_external_connectors(self, sync: bool | None = None) -> list[dict[str, str]]: """ Get the list of external connectors present on the controller. - Admin users may enable sync to refresh the cached list from host state. - If sync is False, the state is retrieved; if True, configuration is applied - back into the controller host. - Device names or tags are used as External Connector nodes' configuration. - Returns a list of objects with the device name and label. + + :param sync: Admin only. A boolean indicating whether to refresh the cached list + from host state. If sync is False, the state is retrieved; + if True, configuration is applied back into the controller host. + :return: A list of objects with the device name and label. """ - url = "/system/external_connectors" + url = self._url_for("external_connectors") if sync is None: return self._session.get(url).json() else: @@ -168,61 +214,87 @@ def get_external_connectors(self, sync: bool | None = None) -> list[dict[str, st def update_external_connector( self, connector_id: str, data: dict[str, Any] ) -> dict[str, Any]: - url = f"/system/external_connectors/{connector_id}" + """ + Update an external connector. + + :param connector_id: The ID of the connector to update. + :param data: The data to update. + :return: The updated data. + """ + url = self._url_for("external_connector", connector_id=connector_id) return self._session.patch(url, json=data).json() - def delete_external_connector(self, connector_id: str): - url = f"/system/external_connectors/{connector_id}" + def delete_external_connector(self, connector_id: str) -> None: + """ + Delete an external connector. + + :param connector_id: The ID of the connector to delete. + """ + url = self._url_for("external_connector", connector_id=connector_id) return self._session.delete(url) def get_web_session_timeout(self) -> int: """ Get the web session timeout in seconds. - :return: web session timeout + :return: The web session timeout. """ - url = "/web_session_timeout" + url = self._url_for("web_session_timeout", timeout="") return self._session.get(url).json() def set_web_session_timeout(self, timeout: int) -> str: """ Set the web session timeout in seconds. + :param timeout: The timeout value in seconds. :return: 'OK' """ - - url = "/web_session_timeout/{}".format(timeout) + url = self._url_for("web_session_timeout", timeout=timeout) return self._session.patch(url).json() def get_mac_address_block(self) -> int: """ - Get mac address block. + Get the MAC address block. - :return: mac address block + :return: The MAC address block. """ - url = "/mac_address_block" + url = self._url_for("mac_address_block", block="") return self._session.get(url).json() - def _set_mac_address_block(self, block: int) -> str: - url = "/mac_address_block/{}".format(block) - return self._session.patch(url).json() - def set_mac_address_block(self, block: int) -> str: """ - Set mac address block. + Set the MAC address block. + :param block: The MAC address block. :return: 'OK' + :raises ValueError: If the MAC address block is not in the range 0-7. """ if block < 0 or block > 7: raise ValueError("MAC address block has to be in range 0-7.") return self._set_mac_address_block(block=block) + def _set_mac_address_block(self, block: int) -> str: + """Helper method to set the MAC address block.""" + url = self._url_for("mac_address_block", block=block) + return self._session.patch(url).json() + def get_new_compute_host_state(self) -> str: - url = "/system/compute_hosts/configuration" + """ + Get the admission state of the new compute host. + + :return: The admission state of the new compute host. + """ + url = self._url_for("host_configuration") return self._session.get(url).json()["admission_state"] def set_new_compute_host_state(self, admission_state: str) -> str: - url = "/system/compute_hosts/configuration" + """ + Set the admission state of the new compute host. + + :param admission_state: The admission state to set. + :return: The updated admission state. + """ + url = self._url_for("host_configuration") return self._session.patch( url, json={"admission_state": admission_state} ).json()["admission_state"] @@ -237,8 +309,22 @@ def add_compute_host_local( is_connected: bool, is_synced: bool, admission_state: str, - nodes: Optional[list[str]] = None, + nodes: list[str] | None = None, ) -> ComputeHost: + """ + Add a compute host locally. + + :param compute_id: The ID of the compute host. + :param hostname: The hostname of the compute host. + :param server_address: The server address of the compute host. + :param is_connector: A boolean indicating if the compute host is a connector. + :param is_simulator: A boolean indicating if the compute host is a simulator. + :param is_connected: A boolean indicating if the compute host is connected. + :param is_synced: A boolean indicating if the compute host is synced. + :param admission_state: The admission state of the compute host. + :param nodes: A list of node IDs associated with the compute host. + :return: The added compute host. + """ new_compute_host = ComputeHost( self, compute_id, @@ -262,8 +348,22 @@ def add_system_notice_local( content: str, enabled: bool, acknowledged: dict[str, bool], - groups: Optional[list[str]] = None, + groups: list[str] | None = None, ) -> SystemNotice: + """ + Add a system notice locally. + + :param id: The unique identifier of the system notice. + :param level: The level of the system notice. + :param label: The label or title of the system notice. + :param content: The content or description of the system notice. + :param enabled: A flag indicating whether the system notice is enabled or not. + :param acknowledged: A dictionary mapping user IDs to their acknowledgment + status. + :param groups: A list of group names to associate with the system notice. + (Optional) + :return: The newly created system notice object. + """ new_system_notice = SystemNotice( self, id, @@ -279,6 +379,8 @@ def add_system_notice_local( class ComputeHost: + _URL_TEMPLATES = {"compute_host": "system/compute_hosts/{compute_id}"} + def __init__( self, system: SystemManagement, @@ -290,10 +392,24 @@ def __init__( is_connected: bool, is_synced: bool, admission_state: str, - nodes: Optional[list[str]] = None, + nodes: list[str] | None = None, ): + """ + A compute host, which hosts some of the nodes of the simulation. + + :param system: The SystemManagement instance. + :param compute_id: The ID of the compute host. + :param hostname: The hostname of the compute host. + :param server_address: The server address of the compute host. + :param is_connector: Whether the compute host is a connector. + :param is_simulator: Whether the compute host is a simulator. + :param is_connected: Whether the compute host is connected. + :param is_synced: Whether the compute host is synced. + :param admission_state: The admission state of the compute host. + :param nodes: The list of nodes associated with the compute host. + """ self._system = system - self._session: httpx.Client = system.session + self._session: httpx.Client = system._session self._compute_id = compute_id self._hostname = hostname self._server_address = server_address @@ -303,63 +419,92 @@ def __init__( self._is_synced = is_synced self._admission_state = admission_state self._nodes = nodes if nodes is not None else [] - self._base_url = f"{self._session.base_url}/system/compute_hosts/{compute_id}" def __str__(self): return f"Compute host: {self._hostname}" + def _url_for(self, endpoint, **kwargs) -> str: + """ + Generate the URL for a given API endpoint. + + :param endpoint: The desired endpoint. + :param **kwargs: Keyword arguments used to format the URL. + :returns: The formatted URL. + """ + kwargs["compute_id"] = self._compute_id + return get_url_from_template(endpoint, self._URL_TEMPLATES, kwargs) + @property def compute_id(self) -> str: + """Return the ID of the compute host.""" return self._compute_id @property def hostname(self) -> str: + """Return the hostname of the compute host.""" self._system.sync_compute_hosts_if_outdated() return self._hostname @property def server_address(self) -> str: + """Return the server address of the compute host.""" self._system.sync_compute_hosts_if_outdated() return self._server_address @property def is_connector(self) -> bool: + """Return whether the compute host is a connector.""" return self._is_connector @property def is_simulator(self) -> bool: + """Return whether the compute host is a simulator.""" return self._is_simulator @property def is_connected(self) -> bool: + """Return whether the compute host is connected.""" self._system.sync_compute_hosts_if_outdated() return self._is_connected @property def is_synced(self) -> bool: + """Return whether the compute host is synced.""" self._system.sync_compute_hosts_if_outdated() return self._is_synced @property def nodes(self) -> list[str]: + """Return the list of nodes associated with the compute host.""" self._system.sync_compute_hosts_if_outdated() return self._nodes @property def admission_state(self) -> str: + """Return the admission state of the compute host.""" self._system.sync_compute_hosts_if_outdated() return self._admission_state @admission_state.setter - def admission_state(self, value: str): + def admission_state(self, value: str) -> None: + """Set the admission state of the compute host.""" self._set_compute_host_property("admission_state", value) self._admission_state = value def remove(self) -> None: - _LOGGER.info("Removing compute host %s", self) - self._session.delete(self._base_url) + """Remove the compute host.""" + _LOGGER.info(f"Removing compute host {self}") + url = self._url_for("compute_host") + self._session.delete(url) - def update(self, host_data: dict[str, Any], push_to_server: bool = False): + def update(self, host_data: dict[str, Any], push_to_server: bool = True) -> None: + """ + Update the compute host with the given data. + + :param host_data: The data to update the compute host. + :param push_to_server: Whether to push the changes to the server. + Defaults to True; should only be False when used by internal methods. + """ if push_to_server: self._set_compute_host_properties(host_data) return @@ -368,28 +513,54 @@ def update(self, host_data: dict[str, Any], push_to_server: bool = False): setattr(self, f"_{key}", value) def _set_compute_host_property(self, key: str, val: Any) -> None: - _LOGGER.debug("Setting compute host property %s %s: %s", self, key, val) + """ + Set a specific property of the compute host. + + :param key: The property key. + :param val: The new value for the property. + """ + _LOGGER.debug(f"Setting compute host property {self} {key}: {val}") self._set_compute_host_properties({key: val}) def _set_compute_host_properties(self, host_data: dict[str, Any]) -> None: - new_data = self._session.patch(url=self._base_url, json=host_data).json() - self.update(new_data) + """ + Set multiple properties of the compute host. + + :param host_data: The data to set as properties of the compute host. + """ + url = self._url_for("compute_host") + new_data = self._session.patch(url, json=host_data).json() + self.update(new_data, push_to_server=False) class SystemNotice: + _URL_TEMPLATES = {"notice": "system/notices/{notice_id}"} + def __init__( self, system: SystemManagement, id: str, level: str, label: str, - content: bool, + content: str, enabled: bool, acknowledged: dict[str, bool], - groups: Optional[list[str]] = None, + groups: list[str] | None = None, ): + """ + A system notice, which notifies users of maintenance or other events. + + :param system: The SystemManagement instance. + :param id: The ID of the system notice. + :param level: The level of the system notice. + :param label: The label of the system notice. + :param content: The content of the system notice. + :param enabled: Whether the system notice is enabled. + :param acknowledged: The acknowledgement status of the system notice. + :param groups: The groups associated with the system notice. + """ self._system = system - self._session: httpx.Client = system.session + self._session: httpx.Client = system._session self._id = id self._level = level self._label = label @@ -397,47 +568,73 @@ def __init__( self._enabled = enabled self._acknowledged = acknowledged self._groups = groups - self._base_url = f"{self._session.base_url}/system/notices/{id}" + + def _url_for(self, endpoint, **kwargs) -> str: + """ + Generate the URL for a given API endpoint. + + :param endpoint: The desired endpoint. + :param **kwargs: Keyword arguments used to format the URL. + :returns: The formatted URL. + """ + kwargs["notice_id"] = self._id + return get_url_from_template(endpoint, self._URL_TEMPLATES, kwargs) @property def id(self) -> str: + """Return the ID of the system notice.""" return self._id @property def level(self) -> str: + """Return the level of the system notice.""" self._system.sync_system_notices_if_outdated() return self._level @property def label(self) -> str: + """Return the label of the system notice.""" self._system.sync_system_notices_if_outdated() return self._label @property def content(self) -> str: + """Return the content of the system notice.""" self._system.sync_system_notices_if_outdated() return self._content @property def enabled(self) -> bool: + """Return whether the system notice is enabled.""" self._system.sync_system_notices_if_outdated() return self._enabled @property def acknowledged(self) -> dict[str, bool]: + """Return the acknowledgement status of the system notice.""" self._system.sync_system_notices_if_outdated() return self._acknowledged @property - def groups(self) -> Optional[list[str]]: + def groups(self) -> list[str] | None: + """Return the groups associated with the system notice.""" self._system.sync_system_notices_if_outdated() return self._groups def remove(self) -> None: - _LOGGER.info("Removing system notice %s", self) - self._session.delete(self._base_url) + """Remove the system notice.""" + _LOGGER.info(f"Removing system notice {self}") + url = self._url_for("notice") + self._session.delete(url) - def update(self, notice_data: dict[str, Any], push_to_server: bool = False): + def update(self, notice_data: dict[str, Any], push_to_server: bool = True) -> None: + """ + Update the system notice with the given data. + + :param notice_data: The data to update the system notice with. + :param push_to_server: Whether to push the changes to the server. + Defaults to True; should only be False when used by internal methods. + """ if push_to_server: self._set_notice_properties(notice_data) return @@ -446,9 +643,21 @@ def update(self, notice_data: dict[str, Any], push_to_server: bool = False): setattr(self, f"_{key}", value) def _set_notice_property(self, key: str, val: Any) -> None: - _LOGGER.debug("Setting system notice property %s %s: %s", self, key, val) + """ + Set a specific property of the system notice. + + :param key: The property key. + :param val: The new value for the property. + """ + _LOGGER.debug(f"Setting system notice property {self} {key}: {val}") self._set_notice_properties({key: val}) def _set_notice_properties(self, notice_data: dict[str, Any]) -> None: - new_data = self._session.patch(url=self._base_url, json=notice_data).json() - self.update(new_data) + """ + Set multiple properties of the system notice. + + :param notice_data: The data to set as properties of the system notice. + """ + url = self._url_for("notice") + new_data = self._session.patch(url, json=notice_data).json() + self.update(new_data, push_to_server=False) diff --git a/virl2_client/models/users.py b/virl2_client/models/users.py index 7e591b6..591954c 100644 --- a/virl2_client/models/users.py +++ b/virl2_client/models/users.py @@ -20,47 +20,61 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any -from ..utils import UNCHANGED +from ..utils import UNCHANGED, get_url_from_template if TYPE_CHECKING: import httpx class UserManagement: + _URL_TEMPLATES = { + "users": "users", + "user": "users/{user_id}", + "user_groups": "users/{user_id}/groups", + "user_id": "users/{username}/id", + } + def __init__(self, session: httpx.Client) -> None: self._session = session - @property - def base_url(self) -> str: - return "users" + def _url_for(self, endpoint, **kwargs): + """ + Generate the URL for a given API endpoint. + + :param endpoint: The desired endpoint. + :param **kwargs: Keyword arguments used to format the URL. + :returns: The formatted URL. + """ + return get_url_from_template(endpoint, self._URL_TEMPLATES, kwargs) - def users(self) -> list: + def users(self) -> list[str]: """ Get the list of available users. - :return: list of user objects + :returns: List of user IDs. """ - return self._session.get(self.base_url).json() + url = self._url_for("users") + return self._session.get(url).json() def get_user(self, user_id: str) -> dict: """ - Gets user info. + Get user information. - :param user_id: user uuid4 - :return: user object + :param user_id: User UUID4. + :returns: User object. """ - url = self.base_url + "/{}".format(user_id) + url = self._url_for("user", user_id=user_id) return self._session.get(url).json() def delete_user(self, user_id: str) -> None: """ - Deletes user. + Delete a user. - :param user_id: user uuid4 + :param user_id: User UUID4. """ - url = self.base_url + "/{}".format(user_id) + url = self._url_for("user", user_id=user_id) self._session.delete(url) def create_user( @@ -70,20 +84,20 @@ def create_user( fullname: str = "", description: str = "", admin: bool = False, - groups: Optional[list[str]] = None, - resource_pool: Optional[str] = None, + groups: list[str] | None = None, + resource_pool: str | None = None, ) -> dict: """ - Creates user. - - :param username: username - :param pwd: desired password - :param fullname: full name - :param description: description - :param admin: whether to create admin user - :param groups: adds user to groups specified - :param resource_pool: adds user to resource pool specified - :return: user object + Create a new user. + + :param username: Username. + :param pwd: Desired password. + :param fullname: Full name. + :param description: Description. + :param admin: Whether to create an admin user. + :param groups: List of groups to which the user should be added. + :param resource_pool: Resource pool to which the user should be added. + :returns: User object. """ data = { "username": username, @@ -94,30 +108,30 @@ def create_user( "groups": groups or [], "resource_pool": resource_pool, } - url = self.base_url + url = self._url_for("users") return self._session.post(url, json=data).json() def update_user( self, user_id: str, - fullname: Optional[str] = UNCHANGED, - description: Optional[str] = UNCHANGED, - groups: Optional[list[str]] = UNCHANGED, - admin: Optional[bool] = None, - password_dict: Optional[dict[str, str]] = None, - resource_pool: Optional[str] = UNCHANGED, + fullname: str | None = UNCHANGED, + description: str | None = UNCHANGED, + groups: list[str] | None = UNCHANGED, + admin: bool | None = None, + password_dict: dict[str, str] | None = None, + resource_pool: str | None = UNCHANGED, ) -> dict: """ - Updates user. - - :param user_id: user uuid4 - :param fullname: full name - :param description: description - :param admin: whether to create admin user - :param groups: adds user to groups specified - :param password_dict: dict containing old and new passwords - :param resource_pool: adds user to resource pool specified - :return: user object + Update an existing user. + + :param user_id: User UUID4. + :param fullname: Full name. + :param description: Description. + :param admin: Whether to create an admin user. + :param groups: List of groups to which the user should be added. + :param password_dict: Dictionary containing old and new passwords. + :param resource_pool: Resource pool to which the user should be added. + :returns: User object. """ data: dict[str, Any] = {} if fullname is not UNCHANGED: @@ -133,24 +147,25 @@ def update_user( if resource_pool is not UNCHANGED: data["resource_pool"] = resource_pool - url = self.base_url + "/{}".format(user_id) + url = self._url_for("user", user_id=user_id) return self._session.patch(url, json=data).json() def user_groups(self, user_id: str) -> list[str]: """ - Get the groups the user is member of. + Get the groups that a user is a member of. - :param user_id: user uuid4 + :param user_id: User UUID4. + :returns: List of group names. """ - url = self.base_url + "/{}/groups".format(user_id) + url = self._url_for("user_groups", user_id=user_id) return self._session.get(url).json() def user_id(self, username: str) -> str: """ - Get unique user uuid4. + Get the unique UUID4 of a user. - :param username: user name - :return: user unique identifier + :param username: User name. + :returns: User unique identifier. """ - url = self.base_url + "/{}/id".format(username) + url = self._url_for("user_id", username=username) return self._session.get(url).json() diff --git a/virl2_client/utils.py b/virl2_client/utils.py index c614357..80ffd2f 100644 --- a/virl2_client/utils.py +++ b/virl2_client/utils.py @@ -22,7 +22,7 @@ from contextlib import nullcontext from functools import wraps -from typing import TYPE_CHECKING, Callable, Optional, Type, TypeVar, Union, cast +from typing import TYPE_CHECKING, Callable, Type, TypeVar, Union, cast import httpx @@ -32,6 +32,7 @@ LabNotFound, LinkNotFound, NodeNotFound, + VirlException, ) if TYPE_CHECKING: @@ -71,12 +72,12 @@ def _make_not_found(instance: Element, owner: Type[Element]) -> ElementNotFound: def _check_and_mark_stale( func: Callable, instance: Element, - owner: Optional[Type[Element]] = None, + owner: Type[Element] | None = None, *args, **kwargs, ): """ - Checks staleness before and after calling `func` + Check staleness before and after calling `func` and updates staleness if a 404 is raised. :param func: the function to be called if not stale @@ -112,9 +113,7 @@ def _check_and_mark_stale( def check_stale(func: TCallable) -> TCallable: - """ - A decorator that will make the wrapped function check staleness. - """ + """A decorator that will make the wrapped function check staleness.""" @wraps(func) def wrapper_stale(*args, **kwargs): @@ -148,3 +147,22 @@ def wrapper_locked(*args, **kwargs): return func(*args, **kwargs) return cast(TCallable, wrapper_locked) + + +def get_url_from_template( + endpoint: str, url_templates: dict[str, str], values: dict | None = None +): + """ + Generate the URL for a given API endpoint from given templates. + + :param endpoint: The desired endpoint. + :param url_templates: The templates to map values to. + :param values: Keyword arguments used to format the URL. + :returns: The formatted URL. + """ + endpoint_url_template = url_templates.get(endpoint) + if endpoint_url_template is None: + raise VirlException(f"Invalid endpoint: {endpoint}") + if values is None: + return endpoint_url_template + return endpoint_url_template.format(**values) diff --git a/virl2_client/virl2_client.py b/virl2_client/virl2_client.py index 5323f68..d96a58e 100644 --- a/virl2_client/virl2_client.py +++ b/virl2_client/virl2_client.py @@ -28,7 +28,7 @@ from functools import lru_cache from pathlib import Path from threading import RLock -from typing import Any, NamedTuple, Optional +from typing import Any, NamedTuple from urllib.parse import urljoin, urlsplit, urlunsplit import httpx @@ -47,7 +47,7 @@ ) from .models.authentication import make_session from .models.configuration import get_configuration -from .utils import locked +from .utils import get_url_from_template, locked _LOGGER = logging.getLogger(__name__) cached = lru_cache(maxsize=None) # cache results forever @@ -72,7 +72,7 @@ def parse_version_str(version_str: str) -> str: return res[0] def __repr__(self): - return "{}".format(self.version_str) + return self.version_str def __eq__(self, other): return ( @@ -132,9 +132,9 @@ class ClientConfig(NamedTuple): """Stores client library configuration, which can be used to create any number of identically configured instances of ClientLibrary.""" - url: Optional[str] = None - username: Optional[str] = None - password: Optional[str] = None + url: str | None = None + username: str | None = None + password: str | None = None ssl_verify: bool | str = True allow_http: bool = False auto_sync: float = 1.0 @@ -176,12 +176,27 @@ class ClientLibrary: Version("2.2.2"), Version("2.2.3"), ] + _URL_TEMPLATES = { + "auth_test": "authok", + "system_info": "system_information", + "import": "import", + "import_1x": "import/virl-1x", + "sample_labs": "sample/labs", + "sample_lab": "sample/labs/{lab_title}", + "labs": "labs", + "lab": "labs/{lab_id}", + "lab_topology": "labs/{lab_id}/topology", + "diagnostics": "diagnostics", + "system_health": "system_health", + "system_stats": "system_stats", + "populate_lab_tiles": "populate_lab_tiles", + } def __init__( self, - url: Optional[str] = None, - username: Optional[str] = None, - password: Optional[str] = None, + url: str | None = None, + username: str | None = None, + password: str | None = None, ssl_verify: bool | str = True, raise_for_auth_failure: bool = False, allow_http: bool = False, @@ -190,7 +205,7 @@ def __init__( events: bool = False, ) -> None: """ - Initializes a ClientLibrary instance. Note that ssl_verify can + Initialize a ClientLibrary instance. Note that ssl_verify can also be a string that points to a cert (see class documentation). :param url: URL of controller. It's also possible to pass the @@ -200,20 +215,21 @@ def __init__( to pass the username via ``VIRL2_USER`` or ``VIRL_USERNAME`` variable. :param password: Password of the user to authenticate. It's also possible to pass the password via ``VIRL2_PASS`` or ``VIRL_PASSWORD`` variable. - :param ssl_verify: Path SSL controller certificate, or True to load + :param ssl_verify: Path of the SSL controller certificate, or True to load from ``CA_BUNDLE`` or ``CML_VERIFY_CERT`` environment variable, or False to disable. - :param raise_for_auth_failure: Raise an exception if unable - to connect to controller (use for scripting scenarios). - :param allow_http: If set, a https URL will not be enforced - :param convergence_wait_max_iter: maximum number of iterations to wait for - lab to converge. - :param convergence_wait_time: wait interval in seconds to wait during one - iteration during lab convergence wait. - :param events: whether to use events - event listeners for state updates and - event handlers for custom handling. - :raises InitializationError: If no URL is provided, - authentication fails or host can't be reached. + :param raise_for_auth_failure: Raise an exception if unable to connect to + controller. (Use for scripting scenarios.) + :param allow_http: If set, a https URL will not be enforced. + :param convergence_wait_max_iter: Maximum number of iterations for convergence. + :param convergence_wait_time: Time in seconds to sleep between convergence calls + on the backend. + :param events: A flag indicating whether to enable event-based data + synchronization from the server. When enabled, utilizes a mechanism for + receiving real-time updates from the server, instead of periodically + requesting the data. + :raises InitializationError: If no URL is provided, authentication fails or host + can't be reached. """ url, username, password, cert = get_configuration( url, username, password, ssl_verify @@ -221,7 +237,7 @@ def __init__( if cert is not None: ssl_verify = cert - url, base_url = prepare_url(url, allow_http) + url, base_url = _prepare_url(url, allow_http) self.username: str = username self.password: str = password @@ -286,7 +302,7 @@ def __init__( return self.event_listener = None - self.session.lock = None + self._session.lock = None if events: # http-based auto sync should be off by default when using events self.auto_sync = False @@ -304,11 +320,25 @@ def __repr__(self): ) def __str__(self): - return "{} URL: {}".format(self.__class__.__name__, self._session.base_url) + return f"{self.__class__.__name__} URL: {self._session.base_url}" + + def _url_for(self, endpoint, **kwargs): + """ + Generate the URL for a given API endpoint. + + :param endpoint: The desired endpoint. + :param **kwargs: Keyword arguments used to format the URL. + :returns: The formatted URL. + """ + return get_url_from_template(endpoint, self._URL_TEMPLATES, kwargs) def _make_test_auth_call(self) -> None: - """Make a call to confirm auth works by using the "authok" endpoint.""" - url = "authok" + """ + Make a call to confirm that authentication works. + + :raises InitializationError: If authentication fails. + """ + url = self._url_for("auth_test") try: self._session.get(url) except httpx.HTTPStatusError as exc: @@ -320,73 +350,68 @@ def _make_test_auth_call(self) -> None: raise InitializationError(message) raise except httpx.HTTPError as exc: - # TODO: subclass InitializationError raise InitializationError(exc) @staticmethod def _environ_get( - key: str, value: Optional[Any] = None, default: Optional[Any] = None - ) -> Optional[Any]: - """If the value is not set yet, fetch it from environment or return default.""" + key: str, value: Any | None = None, default: Any | None = None + ) -> Any | None: + """ + If the value is not yet set, fetch it from the environment or return the default + value. + + :param key: The key to fetch the value for. + :param value: The value to use if it is already set. + :param default: The default value to use if the key is not set + and value is None. + :returns: The fetched value or the default value. + """ if value is None: value = os.environ.get(key) if value: - _LOGGER.info("Using value %s from environment", key) + _LOGGER.info(f"Using value {key} from environment") else: value = default return value - @property - def session(self) -> httpx.Client: - """ - Returns the used `httpx` client session object. - - :returns: The session object - """ - return self._session - @property def uuid(self) -> str: - """ - Returns the UUID4 that identifies this client to the server. - - :returns: the UUID4 string - """ + """Return the UUID4 that identifies this client to the server.""" return self._session.headers["X-Client-UUID"] - def logout(self, clear_all_sessions: bool = False) -> None: + def logout(self, clear_all_sessions: bool = False) -> bool: """ - Invalidate current token. + Invalidate the current token. + + :param clear_all_sessions: Whether to clear all user sessions as well. """ - return self._session.auth.logout( # type: ignore - clear_all_sessions=clear_all_sessions - ) + return self._session.auth.logout(clear_all_sessions=clear_all_sessions) def get_host(self) -> str: """ - Returns the hostname of the session to the controller. + Return the hostname of the session to the controller. - :returns: A hostname + :returns: The hostname. """ - return self.session.base_url.host + return self._session.base_url.host def system_info(self) -> dict: """ Get information about the system where the application runs. Can be called without authentication. - :returns: system information json + :returns: The system information as a JSON object. """ - url = "system_information" + url = self._url_for("system_info") return self._session.get(url).json() def check_controller_version(self) -> None: """ - Checks remote controller version against current client version - (specified in self.VERSION) and against controller version - blacklist (specified in self.INCOMPATIBLE_CONTROLLER_VERSIONS) and - raises exception if versions are incompatible or prints warning - if client minor version is lower than controller minor version. + Check remote controller version against current client version + (specified in `self.VERSION`) and against controller version + blacklist (specified in `self.INCOMPATIBLE_CONTROLLER_VERSIONS`). + Raise exception if versions are incompatible, or print warning + if the client minor version is lower than the controller minor version. """ controller_version = self.system_info().get("version", "").split("-")[0] @@ -398,37 +423,33 @@ def check_controller_version(self) -> None: controller_version_obj = Version(controller_version) if controller_version_obj in self.INCOMPATIBLE_CONTROLLER_VERSIONS: raise InitializationError( - "Controller version {} is marked incompatible! " - "List of versions marked explicitly as incompatible: {}.".format( - controller_version_obj, self.INCOMPATIBLE_CONTROLLER_VERSIONS - ) + f"Controller version {controller_version_obj} is marked incompatible! " + f"List of versions marked explicitly as incompatible: " + f"{self.INCOMPATIBLE_CONTROLLER_VERSIONS}." ) if self.VERSION.major_differs(controller_version_obj): raise InitializationError( - "Major version mismatch. Client {}, controller {}.".format( - self.VERSION, controller_version_obj - ) + f"Major version mismatch. Client {self.VERSION}, " + f"controller {controller_version_obj}." ) if self.VERSION.minor_differs(controller_version_obj) and self.VERSION.minor_lt( controller_version_obj ): _LOGGER.warning( - "Please ensure the client version is compatible with the controller" - " version. Client %s, controller %s.", - self.VERSION, - controller_version_obj, + f"Please ensure the client version is compatible with the controller " + f"version. Client {self.VERSION}, controller {controller_version_obj}." ) def is_system_ready( self, wait: bool = False, max_wait: int = 60, sleep: int = 5 ) -> bool: """ - Reports whether the system is ready or not. + Report whether the system is ready or not. - :param wait: if this is true, the call will block until it's ready - :param max_wait: maximum time to wait, in seconds - :param sleep: time to wait between tries, in seconds - :returns: ready state + :param wait: Whether to block until the system is ready. + :param max_wait: The maximum time to wait in seconds. + :param sleep: The time to wait between tries in seconds. + :returns: The ready state of the system. """ loops = 1 if wait: @@ -453,6 +474,12 @@ def is_system_ready( @staticmethod def is_virl_1x(path: Path) -> bool: + """ + Check if the given file is of VIRL version 1.x. + + :param path: The path to check. + :returns: Whether the file is of VIRL version 1.x. + """ if path.suffix == ".virl": return True return False @@ -460,13 +487,17 @@ def is_virl_1x(path: Path) -> bool: @locked def start_event_listening(self): """ - Starts listening for and parsing websocket events. + Start listening for and parsing websocket events. To replace the default event handling mechanism, subclass `event_handling.EventHandler` (or `EventHandlerBase` if necessary), - then do: - self.websockets = event_listening.EventListener(self) - self.websockets._event_handler = handler + then do:: + from .event_listening import EventListener + custom_listener = EventListener() + custom_listener._event_handler = CustomHandler(client_library) + client_library.event_listener = custom_listener + + :returns: """ from .event_listening import EventListener @@ -478,44 +509,49 @@ def start_event_listening(self): @locked def stop_event_listening(self): + """Stop listening for and parsing websocket events.""" if self.event_listener: - self.session.lock = None + self._session.lock = None self.event_listener.stop_listening() @locked def import_lab( self, topology: str, - title: Optional[str] = None, + title: str | None = None, offline: bool = False, virl_1x: bool = False, ) -> Lab: """ - Imports an existing topology from a string. - - :param topology: Topology representation as a string - :param title: Title of the lab (optional) - :param offline: whether the ClientLibrary should import the - lab locally. The topology parameter has to be a JSON string - in this case. This can not be XML or YAML representation of - the lab. Compatible JSON from `GET /labs/{lab_id}/topology`. - :param virl_1x: Is this old virl-1x topology format (default=False) - :returns: A Lab instance - :raises ValueError: if there's no lab ID in the API response - :raises httpx.HTTPError: if there was a transport error - """ - # TODO: refactor to return the local lab, and sync it, - # if already exists in self._labs + Import an existing topology from a string. + + :param topology: The topology representation as a string. + :param title: The title of the lab. + :param offline: Whether to import the lab locally. + :param virl_1x: Whether the topology format is the old, VIRL 1.x format. + :returns: The imported Lab instance. + :raises ValueError: If no lab ID is returned in the API response. + :raises httpx.HTTPError: If there was a transport error. + """ + if title is not None: + for lab_id in self._labs: + if (lab := self._labs[lab_id]).title == title: + # Lab of this title already exists, sync and return it + lab.sync() + return lab if offline: lab_id = "offline_lab" else: if virl_1x: - url = "import/virl-1x" + url = self._url_for("import_1x") else: - url = "import" + url = self._url_for("import") if title is not None: - url = "{}?title={}".format(url, title) - result = self._session.post(url, content=topology).json() + result = self._session.post( + url, params={"title": title}, content=topology + ).json() + else: + result = self._session.post(url, content=topology).json() lab_id = result.get("id") if lab_id is None: raise ValueError("No lab ID returned!") @@ -543,17 +579,18 @@ def import_lab( return lab @locked - def import_lab_from_path(self, path: str, title: Optional[str] = None) -> Lab: + def import_lab_from_path(self, path: str, title: str | None = None) -> Lab: """ - Imports an existing topology from a file / path. + Import an existing topology from a file or path. - :param path: Topology filename / path - :param title: Title of the lab - :returns: A Lab instance + :param path: The topology filename or path. + :param title: The title of the lab. + :returns: The imported Lab instance. + :raises FileNotFoundError: If the specified path does not exist. """ topology_path = Path(path) if not topology_path.exists(): - message = "{} can not be found".format(path) + message = f"{path} can not be found" raise FileNotFoundError(message) topology = topology_path.read_text() @@ -563,38 +600,40 @@ def import_lab_from_path(self, path: str, title: Optional[str] = None) -> Lab: def get_sample_labs(self) -> dict[str, dict]: """ - Returns a dictionary with information about all sample labs available on host. + Return a dictionary with information about all sample labs + available on the host. - :returns: A dictionary of sample lab information, where keys are titles + :returns: A dictionary of sample lab information where the keys are the titles. """ - url = "sample/labs" + url = self._url_for("sample_labs") return self._session.get(url).json() @locked def import_sample_lab(self, title: str) -> Lab: """ - Imports a built-in sample lab. + Import a built-in sample lab. - :param title: sample lab name, as returned by get_sample_labs - :returns: a Lab instance + :param title: The sample lab name. + :returns: The imported Lab instance. """ - url = "sample/labs/" + title + url = self._url_for("sample_lab", lab_title=title) lab_id = self._session.put(url).json() return self.join_existing_lab(lab_id) @locked def all_labs(self, show_all: bool = False) -> list[Lab]: """ - Joins all labs owned by this user (or all labs period) and returns their list. + Join all labs owned by this user (or all labs) and return their list. - :param show_all: Whether to get only labs owned by the admin or all user labs - :returns: A list of lab objects + :param show_all: Whether to get only labs owned by the admin or all user labs. + :returns: A list of Lab objects. """ - url = "labs" + url = {"url": self._url_for("labs")} if show_all: - url += "?show_all=true" - lab_ids = self._session.get(url).json() + url["params"] = {"show_all": True} + lab_ids = self._session.get(**url).json() + result = [] for lab_id in lab_ids: lab = self.join_existing_lab(lab_id) @@ -604,17 +643,30 @@ def all_labs(self, show_all: bool = False) -> list[Lab]: @locked def _remove_stale_labs(self): + """Remove stale labs from the client library.""" for lab in list(self._labs.values()): if lab._stale: self._remove_lab_local(lab) @locked def local_labs(self) -> list[Lab]: + """ + Return a list of local labs. + + :returns: A list of local labs. + """ self._remove_stale_labs() return list(self._labs.values()) @locked def get_local_lab(self, lab_id: str) -> Lab: + """ + Get a local lab by its ID. + + :param lab_id: The ID of the lab. + :returns: The Lab object with the specified ID. + :raises LabNotFound: If the lab with the given ID does not exist. + """ self._remove_stale_labs() try: return self._labs[lab_id] @@ -622,14 +674,21 @@ def get_local_lab(self, lab_id: str) -> Lab: raise LabNotFound(lab_id) @locked - def create_lab(self, title: Optional[str] = None) -> Lab: + def create_lab( + self, + title: str | None = None, + description: str | None = None, + notes: str | None = None, + ) -> Lab: """ - Creates an empty lab with the given name. If no name is given, then - the created lab ID is set as the name. + Create a new lab with optional title, description, and notes. - Note that the lab will be auto-syncing according to the Client Library's - auto-sync setting when created. The lab has a property to override this - on a per-lab basis. + If no title, description, or notes are provided, the server will generate a + default title in the format "Lab at Mon 13:30 PM" and leave the description + and notes blank. + + The lab will automatically sync based on the Client Library's auto-sync setting + when created, but this behavior can be overridden on a per-lab basis. Example:: @@ -637,16 +696,19 @@ def create_lab(self, title: Optional[str] = None) -> Lab: print(lab.id) lab.create_node("r1", "iosv", 50, 100) - :param title: The Lab name (or title) - :returns: A Lab instance + :param title: The title of the lab. + :param description: The description of the lab. + :param notes: The notes of the lab. + :returns: A Lab instance representing the created lab. """ - url = "labs" - params = {"title": title} - result = self._session.post(url, params=params).json() - # TODO: server generated title not loaded + url = self._url_for("labs") + body = {"title": title, "description": description, "notes": notes} + # exclude values left at None + body = {k: v for k, v in body.items() if v is not None} + result = self._session.post(url, json=body).json() lab_id = result["id"] lab = Lab( - title or lab_id, + result["lab_title"], lab_id, self._session, self.username, @@ -657,21 +719,20 @@ def create_lab(self, title: Optional[str] = None) -> Lab: wait_time=self.convergence_wait_time, resource_pool_manager=self.resource_pool_management, ) + lab._import_lab(result) self._labs[lab_id] = lab return lab @locked def remove_lab(self, lab_id: str | Lab) -> None: """ - Use carefully, it removes a given lab:: - - client_library.remove_lab("1f6cd7") + Remove a lab identified by its ID or Lab object. - If you have a lab object, you can also simply do: + Use this method with caution as it permanently deletes the specified lab. - lab.remove() + If you have the lab object, you can also do ``lab.remove()``. - :param lab_id: the lab ID or object + :param lab_id: The ID or Lab object representing the lab to be removed. """ self._remove_stale_labs() if isinstance(lab_id, Lab): @@ -696,27 +757,26 @@ def _remove_lab_local(self, lab: Lab) -> None: pass def _remove_unjoined_lab(self, lab_id: str): - url = "labs/{}".format(lab_id) + """Helper function to remove an unjoined lab from the server.""" + url = self._url_for("lab", lab_id=lab_id) response = self._session.delete(url) - _LOGGER.debug("removed lab: %s", response.text) + _LOGGER.debug(f"Removed lab: {response.text}") @locked def join_existing_lab(self, lab_id: str, sync_lab: bool = True) -> Lab: """ - Creates a new ClientLibrary lab instance that is retrieved - from the controller. - - If `sync_lab` is set to true, then synchronize current lab - by applying the changes that were done in UI or in another - ClientLibrary session. + Join a lab that exists on the server and make it accessible locally. - Join preexisting lab:: + If `sync_lab` is set to True, the current lab will be synchronized by applying + any changes that were made in the UI or in another ClientLibrary session. + Example:: lab = client_library.join_existing_lab("2e6a18") - :param lab_id: The lab ID to be joined. - :param sync_lab: Synchronize changes. - :returns: A Lab instance + :param lab_id: The ID of the lab to be joined. + :param sync_lab: Whether to synchronize the lab. + :returns: A Lab instance representing the joined lab. + :raises LabNotFound: If no lab with the given ID exists on the host. """ self._remove_stale_labs() if lab_id in self._labs: @@ -725,7 +785,8 @@ def join_existing_lab(self, lab_id: str, sync_lab: bool = True) -> Lab: if sync_lab: try: # check if lab exists through REST call - topology = self.session.get(f"labs/{lab_id}/topology").json() + url = self._url_for("lab_topology", lab_id=lab_id) + topology = self._session.get(url).json() except httpx.HTTPStatusError as exc: if exc.response.status_code == 404: raise LabNotFound("No lab with the given ID exists on the host.") @@ -753,29 +814,29 @@ def join_existing_lab(self, lab_id: str, sync_lab: bool = True) -> Lab: def get_diagnostics(self) -> dict: """ - Returns the controller diagnostic data as JSON + Return the controller diagnostic data as a JSON object. - :returns: diagnostic data + :returns: The diagnostic data. """ - url = "diagnostics" + url = self._url_for("diagnostics") return self._session.get(url).json() def get_system_health(self) -> dict: """ - Returns the controller system health data as JSON + Return the controller system health data as a JSON object. - :returns: system health data + :returns: The system health data. """ - url = "system_health" + url = self._url_for("system_health") return self._session.get(url).json() def get_system_stats(self) -> dict: """ - Returns the controller resource statistics as JSON + Return the controller resource statistics as a JSON object. - :returns: system resource statistics + :returns: The system resource statistics. """ - url = "system_stats" + url = self._url_for("system_stats") return self._session.get(url).json() @locked @@ -783,10 +844,10 @@ def find_labs_by_title(self, title: str) -> list[Lab]: """ Return a list of labs which match the given title. - :param title: The title to search for - :returns: A list of lab objects which match the title specified + :param title: The title to search for. + :returns: A list of Lab objects matching the specified title. """ - url = "populate_lab_tiles" + url = self._url_for("populate_lab_tiles") resp = self._session.get(url).json() # populate_lab_tiles response has been changed in 2.1 @@ -810,18 +871,19 @@ def find_labs_by_title(self, title: str) -> list[Lab]: def get_lab_list(self, show_all: bool = False) -> list[str]: """ - Returns list of all lab IDs. + Get the list of all lab IDs. - :param show_all: Whether to get only labs owned by the admin or all user labs - :returns: a list of Lab IDs + :param show_all: Whether to include labs owned by all users (True) or only labs + owned by the admin (False). + :returns: A list of lab IDs. """ - url = "labs" + url = {"url": self._url_for("labs")} if show_all: - url += "?show_all=true" - return self._session.get(url).json() + url["params"] = {"show_all": True} + return self._session.get(**url).json() -def prepare_url(url: str, allow_http: bool) -> tuple[str, str]: +def _prepare_url(url: str, allow_http: bool) -> tuple[str, str]: # prepare the URL try: url_parts = urlsplit(url, "https")