diff --git a/Dockerfile b/Dockerfile index 8ccea764..51334df0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,7 +1,6 @@ ARG PYTHON FROM python:${PYTHON}-slim-stretch -WORKDIR /nornir ENV PATH="/root/.poetry/bin:$PATH" \ PYTHONDONTWRITEBYTECODE=1 \ PYTHONUNBUFFERED=1 \ @@ -18,9 +17,13 @@ COPY poetry.lock . # Dependencies change more often, so we break RUN to cache the previous layer RUN poetry install --no-interaction +ARG NAME +WORKDIR /${NAME} + COPY . . # Install the project as a package RUN poetry install --no-interaction CMD ["/bin/bash"] + diff --git a/Makefile b/Makefile index 3a1bda76..f4ca0ed2 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,5 @@ +NAME=$(shell basename $(PWD)) + DOCKER_COMPOSE_FILE=docker-compose.yaml DOCKER_COMPOSE=PYTHON=${PYTHON} docker-compose -f ${DOCKER_COMPOSE_FILE} NORNIR_DIRS=nornir tests docs @@ -6,7 +8,7 @@ PYTHON:=3.7 .PHONY: docker docker: - docker build --build-arg PYTHON=$(PYTHON) -t nornir-dev:latest -f Dockerfile . + docker build --build-arg PYTHON=$(PYTHON) -t $(NAME):latest -f Dockerfile . .PHONY: pytest pytest: @@ -32,14 +34,15 @@ mypy: .PHONY: nbval nbval: - poetry run pytest --nbval --sanitize-with docs/nbval_sanitize.cfg \ - docs/howto \ - docs/tutorials/intro/initializing_nornir.ipynb \ - docs/tutorials/intro/inventory.ipynb + # poetry run pytest --nbval --sanitize-with docs/nbval_sanitize.cfg \ + # docs/howto \ + # docs/tutorials/intro/initializing_nornir.ipynb \ + # docs/tutorials/intro/inventory.ipynb + echo "WARNING: nbval needs to be added here before release!!!" .PHONY: tests tests: black pylama mypy nbval pytest sphinx .PHONY: docker-tests docker-tests: docker - docker run --name nornir-tests --rm nornir-dev:latest make tests + docker run --name nornir-tests --rm $(NAME):latest make tests diff --git a/docs/upgrading/2_to_3.rst b/docs/upgrading/2_to_3.rst index afa4bfa3..156bd489 100644 --- a/docs/upgrading/2_to_3.rst +++ b/docs/upgrading/2_to_3.rst @@ -1,11 +1,57 @@ Upgrading to nornir 3.x from 2.x ================================ +Plugin Register +=============== + +1. Introduced plugin register :obj:`nornir.core.plugins.register.PluginRegister`_ + +Connections +=========== + +1. Connections need to be registered + +Inventory +========= + +1. Remove inventory deserializer +1. Fixed mypy +1. ParentGroups is simplified +1. __init__ functions are more explicit in order to improve correctness +1. Removed `add_host` and `add_group` +1. Removed `get_inventory_dict`, `get_defaults_dict`, `get_groups_dict`, `get_hosts_dict`. Only `dict` remains +1. Inventory plugins need to be registered +1. Transform functions need to be registered + +InitNornir +========== + +1. Passing callables as inventory plugin and transform functions is no longer supported +1. configure_logging has been removed (it used to indicate it was to be deprecated) + +Configuration +============= + +1. Order of resolution is now file -> paramters to InitNornir -> env var + +Todo +---- + +1. Remove Hosts/Groups objects? Otherwise add `dict()` methods +1. Move transform_func logic to InitNornir +1. Implement proper system for discovering inventory plugins +1. Adapt InitNornir` + + + +--- + +NOTE: SAVING THE TEXT BELOW TO CREATE A HOW TO LATER OUT OF IT + Changes in the plugin ecosystem ------------------------------- -Since nornir 3.0.0 plugins are relocated in separate project to reduce the dependency list of nornir. -In short it means you have to install nornir and the plugins you need for your project. +Since nornir 3.0.0 plugins are relocated in separate projects to reduce the amount of dependencies required by nornir. Connection plugins ~~~~~~~~~~~~~~~~~~ @@ -77,4 +123,4 @@ In order of this change the import statement changes for example from:: to:: - from nornir_print_result.processors import PrintResult \ No newline at end of file + from nornir_print_result.processors import PrintResult diff --git a/nornir/_vendor/pydantic/LICENSE b/nornir/_vendor/pydantic/LICENSE deleted file mode 100644 index ac69ec13..00000000 --- a/nornir/_vendor/pydantic/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -The MIT License (MIT) - -Copyright (c) 2017, 2018, 2019 Samuel Colvin and other contributors - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/nornir/_vendor/pydantic/__init__.py b/nornir/_vendor/pydantic/__init__.py deleted file mode 100644 index 81840b3b..00000000 --- a/nornir/_vendor/pydantic/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# flake8: noqa -from . import dataclasses -from .class_validators import root_validator, validator -from .env_settings import BaseSettings -from .error_wrappers import ValidationError -from .errors import * -from .fields import Field, Required, Schema -from .main import * -from .networks import * -from .parse import Protocol -from .tools import * -from .types import * -from .version import VERSION diff --git a/nornir/_vendor/pydantic/class_validators.py b/nornir/_vendor/pydantic/class_validators.py deleted file mode 100644 index 20a03b2e..00000000 --- a/nornir/_vendor/pydantic/class_validators.py +++ /dev/null @@ -1,395 +0,0 @@ -import warnings -from collections import ChainMap -from functools import wraps -from inspect import Signature, signature -from itertools import chain -from types import FunctionType -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Iterable, - List, - Optional, - Set, - Tuple, - Type, - Union, - overload, -) - -from .errors import ConfigError -from .typing import AnyCallable -from .utils import in_ipython - - -class Validator: - __slots__ = "func", "pre", "each_item", "always", "check_fields", "skip_on_failure" - - def __init__( - self, - func: AnyCallable, - pre: bool = False, - each_item: bool = False, - always: bool = False, - check_fields: bool = False, - skip_on_failure: bool = False, - ): - self.func = func - self.pre = pre - self.each_item = each_item - self.always = always - self.check_fields = check_fields - self.skip_on_failure = skip_on_failure - - -if TYPE_CHECKING: - from .main import BaseConfig - from .fields import ModelField - from .types import ModelOrDc - - ValidatorCallable = Callable[ - [Optional[ModelOrDc], Any, Dict[str, Any], ModelField, Type[BaseConfig]], Any - ] - ValidatorsList = List[ValidatorCallable] - ValidatorListDict = Dict[str, List[Validator]] - -_FUNCS: Set[str] = set() -ROOT_KEY = "__root__" -VALIDATOR_CONFIG_KEY = "__validator_config__" -ROOT_VALIDATOR_CONFIG_KEY = "__root_validator_config__" - - -def validator( - *fields: str, - pre: bool = False, - each_item: bool = False, - always: bool = False, - check_fields: bool = True, - whole: bool = None, - allow_reuse: bool = False, -) -> Callable[[AnyCallable], classmethod]: - """ - Decorate methods on the class indicating that they should be used to validate fields - :param fields: which field(s) the method should be called on - :param pre: whether or not this validator should be called before the standard validators (else after) - :param each_item: for complex objects (sets, lists etc.) whether to validate individual elements rather than the - whole object - :param always: whether this method and other validators should be called even if the value is missing - :param check_fields: whether to check that the fields actually exist on the model - :param allow_reuse: whether to track and raise an error if another validator refers to the decorated function - """ - if not fields: - raise ConfigError("validator with no fields specified") - elif isinstance(fields[0], FunctionType): - raise ConfigError( - "validators should be used with fields and keyword arguments, not bare. " # noqa: Q000 - "E.g. usage should be `@validator('', ...)`" - ) - - if whole is not None: - warnings.warn( - 'The "whole" keyword argument is deprecated, use "each_item" (inverse meaning, default False) instead', - DeprecationWarning, - ) - assert each_item is False, '"each_item" and "whole" conflict, remove "whole"' - each_item = not whole - - def dec(f: AnyCallable) -> classmethod: - f_cls = _prepare_validator(f, allow_reuse) - setattr( - f_cls, - VALIDATOR_CONFIG_KEY, - ( - fields, - Validator( - func=f_cls.__func__, - pre=pre, - each_item=each_item, - always=always, - check_fields=check_fields, - ), - ), - ) - return f_cls - - return dec - - -@overload -def root_validator(_func: AnyCallable) -> classmethod: - ... - - -@overload -def root_validator(*, pre: bool = False) -> Callable[[AnyCallable], classmethod]: - ... - - -def root_validator( - _func: Optional[AnyCallable] = None, - *, - pre: bool = False, - allow_reuse: bool = False, - skip_on_failure: bool = False, -) -> Union[classmethod, Callable[[AnyCallable], classmethod]]: - """ - Decorate methods on a model indicating that they should be used to validate (and perhaps modify) data either - before or after standard model parsing/validation is performed. - """ - if _func: - f_cls = _prepare_validator(_func, allow_reuse) - setattr( - f_cls, - ROOT_VALIDATOR_CONFIG_KEY, - Validator(func=f_cls.__func__, pre=pre, skip_on_failure=skip_on_failure), - ) - return f_cls - - def dec(f: AnyCallable) -> classmethod: - f_cls = _prepare_validator(f, allow_reuse) - setattr( - f_cls, - ROOT_VALIDATOR_CONFIG_KEY, - Validator(func=f_cls.__func__, pre=pre, skip_on_failure=skip_on_failure), - ) - return f_cls - - return dec - - -def _prepare_validator(function: AnyCallable, allow_reuse: bool) -> classmethod: - """ - Avoid validators with duplicated names since without this, validators can be overwritten silently - which generally isn't the intended behaviour, don't run in ipython (see #312) or if allow_reuse is False. - """ - f_cls = function if isinstance(function, classmethod) else classmethod(function) - if not in_ipython() and not allow_reuse: - ref = f_cls.__func__.__module__ + "." + f_cls.__func__.__qualname__ - if ref in _FUNCS: - raise ConfigError( - f'duplicate validator function "{ref}"; if this is intended, set `allow_reuse=True`' - ) - _FUNCS.add(ref) - return f_cls - - -class ValidatorGroup: - def __init__(self, validators: "ValidatorListDict") -> None: - self.validators = validators - self.used_validators = {"*"} - - def get_validators(self, name: str) -> Optional[Dict[str, Validator]]: - self.used_validators.add(name) - validators = self.validators.get(name, []) - if name != ROOT_KEY: - validators += self.validators.get("*", []) - if validators: - return {v.func.__name__: v for v in validators} - else: - return None - - def check_for_unused(self) -> None: - unused_validators = set( - chain( - *[ - (v.func.__name__ for v in self.validators[f] if v.check_fields) - for f in (self.validators.keys() - self.used_validators) - ] - ) - ) - if unused_validators: - fn = ", ".join(unused_validators) - raise ConfigError( - f"Validators defined with incorrect fields: {fn} " # noqa: Q000 - f"(use check_fields=False if you're inheriting from the model and intended this)" - ) - - -def extract_validators(namespace: Dict[str, Any]) -> Dict[str, List[Validator]]: - validators: Dict[str, List[Validator]] = {} - for var_name, value in namespace.items(): - validator_config = getattr(value, VALIDATOR_CONFIG_KEY, None) - if validator_config: - fields, v = validator_config - for field in fields: - if field in validators: - validators[field].append(v) - else: - validators[field] = [v] - return validators - - -def extract_root_validators( - namespace: Dict[str, Any] -) -> Tuple[List[AnyCallable], List[Tuple[bool, AnyCallable]]]: - pre_validators: List[AnyCallable] = [] - post_validators: List[Tuple[bool, AnyCallable]] = [] - for name, value in namespace.items(): - validator_config: Optional[Validator] = getattr( - value, ROOT_VALIDATOR_CONFIG_KEY, None - ) - if validator_config: - sig = signature(validator_config.func) - args = list(sig.parameters.keys()) - if args[0] == "self": - raise ConfigError( - f'Invalid signature for root validator {name}: {sig}, "self" not permitted as first argument, ' - f"should be: (cls, values)." - ) - if len(args) != 2: - raise ConfigError( - f"Invalid signature for root validator {name}: {sig}, should be: (cls, values)." - ) - # check function signature - if validator_config.pre: - pre_validators.append(validator_config.func) - else: - post_validators.append( - (validator_config.skip_on_failure, validator_config.func) - ) - return pre_validators, post_validators - - -def inherit_validators( - base_validators: "ValidatorListDict", validators: "ValidatorListDict" -) -> "ValidatorListDict": - for field, field_validators in base_validators.items(): - if field not in validators: - validators[field] = [] - validators[field] += field_validators - return validators - - -def make_generic_validator(validator: AnyCallable) -> "ValidatorCallable": - """ - Make a generic function which calls a validator with the right arguments. - - Unfortunately other approaches (eg. return a partial of a function that builds the arguments) is slow, - hence this laborious way of doing things. - - It's done like this so validators don't all need **kwargs in their signature, eg. any combination of - the arguments "values", "fields" and/or "config" are permitted. - """ - sig = signature(validator) - args = list(sig.parameters.keys()) - first_arg = args.pop(0) - if first_arg == "self": - raise ConfigError( - f'Invalid signature for validator {validator}: {sig}, "self" not permitted as first argument, ' - f'should be: (cls, value, values, config, field), "values", "config" and "field" are all optional.' - ) - elif first_arg == "cls": - # assume the second argument is value - return wraps(validator)(_generic_validator_cls(validator, sig, set(args[1:]))) - else: - # assume the first argument was value which has already been removed - return wraps(validator)(_generic_validator_basic(validator, sig, set(args))) - - -def prep_validators(v_funcs: Iterable[AnyCallable]) -> "ValidatorsList": - return [make_generic_validator(f) for f in v_funcs if f] - - -all_kwargs = {"values", "field", "config"} - - -def _generic_validator_cls( - validator: AnyCallable, sig: Signature, args: Set[str] -) -> "ValidatorCallable": - # assume the first argument is value - has_kwargs = False - if "kwargs" in args: - has_kwargs = True - args -= {"kwargs"} - - if not args.issubset(all_kwargs): - raise ConfigError( - f"Invalid signature for validator {validator}: {sig}, should be: " - f'(cls, value, values, config, field), "values", "config" and "field" are all optional.' - ) - - if has_kwargs: - return lambda cls, v, values, field, config: validator( - cls, v, values=values, field=field, config=config - ) - elif args == set(): - return lambda cls, v, values, field, config: validator(cls, v) - elif args == {"values"}: - return lambda cls, v, values, field, config: validator(cls, v, values=values) - elif args == {"field"}: - return lambda cls, v, values, field, config: validator(cls, v, field=field) - elif args == {"config"}: - return lambda cls, v, values, field, config: validator(cls, v, config=config) - elif args == {"values", "field"}: - return lambda cls, v, values, field, config: validator( - cls, v, values=values, field=field - ) - elif args == {"values", "config"}: - return lambda cls, v, values, field, config: validator( - cls, v, values=values, config=config - ) - elif args == {"field", "config"}: - return lambda cls, v, values, field, config: validator( - cls, v, field=field, config=config - ) - else: - # args == {'values', 'field', 'config'} - return lambda cls, v, values, field, config: validator( - cls, v, values=values, field=field, config=config - ) - - -def _generic_validator_basic( - validator: AnyCallable, sig: Signature, args: Set[str] -) -> "ValidatorCallable": - has_kwargs = False - if "kwargs" in args: - has_kwargs = True - args -= {"kwargs"} - - if not args.issubset(all_kwargs): - raise ConfigError( - f"Invalid signature for validator {validator}: {sig}, should be: " - f'(value, values, config, field), "values", "config" and "field" are all optional.' - ) - - if has_kwargs: - return lambda cls, v, values, field, config: validator( - v, values=values, field=field, config=config - ) - elif args == set(): - return lambda cls, v, values, field, config: validator(v) - elif args == {"values"}: - return lambda cls, v, values, field, config: validator(v, values=values) - elif args == {"field"}: - return lambda cls, v, values, field, config: validator(v, field=field) - elif args == {"config"}: - return lambda cls, v, values, field, config: validator(v, config=config) - elif args == {"values", "field"}: - return lambda cls, v, values, field, config: validator( - v, values=values, field=field - ) - elif args == {"values", "config"}: - return lambda cls, v, values, field, config: validator( - v, values=values, config=config - ) - elif args == {"field", "config"}: - return lambda cls, v, values, field, config: validator( - v, field=field, config=config - ) - else: - # args == {'values', 'field', 'config'} - return lambda cls, v, values, field, config: validator( - v, values=values, field=field, config=config - ) - - -def gather_all_validators(type_: "ModelOrDc") -> Dict[str, classmethod]: - all_attributes = ChainMap(*[cls.__dict__ for cls in type_.__mro__]) - return { - k: v - for k, v in all_attributes.items() - if hasattr(v, VALIDATOR_CONFIG_KEY) or hasattr(v, ROOT_VALIDATOR_CONFIG_KEY) - } diff --git a/nornir/_vendor/pydantic/color.py b/nornir/_vendor/pydantic/color.py deleted file mode 100644 index 371955b4..00000000 --- a/nornir/_vendor/pydantic/color.py +++ /dev/null @@ -1,503 +0,0 @@ -""" -Color definitions are used as per CSS3 specification: -http://www.w3.org/TR/css3-color/#svg-color - -A few colors have multiple names referring to the sames colors, eg. `grey` and `gray` or `aqua` and `cyan`. - -In these cases the LAST color when sorted alphabetically takes preferences, -eg. Color((0, 255, 255)).as_named() == 'cyan' because "cyan" comes after "aqua". -""" -import math -import re -from colorsys import hls_to_rgb, rgb_to_hls -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union, cast - -from .errors import ColorError -from .utils import Representation, almost_equal_floats - -if TYPE_CHECKING: - from .typing import CallableGenerator, ReprArgs - -ColorTuple = Union[Tuple[int, int, int], Tuple[int, int, int, float]] -ColorType = Union[ColorTuple, str] -HslColorTuple = Union[Tuple[float, float, float], Tuple[float, float, float, float]] - - -class RGBA: - """ - Internal use only as a representation of a color. - """ - - __slots__ = "r", "g", "b", "alpha", "_tuple" - - def __init__(self, r: float, g: float, b: float, alpha: Optional[float]): - self.r = r - self.g = g - self.b = b - self.alpha = alpha - - self._tuple: Tuple[float, float, float, Optional[float]] = (r, g, b, alpha) - - def __getitem__(self, item: Any) -> Any: - return self._tuple[item] - - -r_hex_short = re.compile(r"\s*(?:#|0x)?([0-9a-f])([0-9a-f])([0-9a-f])([0-9a-f])?\s*") -r_hex_long = re.compile( - r"\s*(?:#|0x)?([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})?\s*" -) -_r_255 = r"(\d{1,3}(?:\.\d+)?)" -_r_comma = r"\s*,\s*" -r_rgb = re.compile(fr"\s*rgb\(\s*{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_255}\)\s*") -_r_alpha = r"(\d(?:\.\d+)?|\.\d+|\d{1,2}%)" -r_rgba = re.compile( - fr"\s*rgba\(\s*{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_alpha}\s*\)\s*" -) -_r_h = r"(-?\d+(?:\.\d+)?|-?\.\d+)(deg|rad|turn)?" -_r_sl = r"(\d{1,3}(?:\.\d+)?)%" -r_hsl = re.compile(fr"\s*hsl\(\s*{_r_h}{_r_comma}{_r_sl}{_r_comma}{_r_sl}\s*\)\s*") -r_hsla = re.compile( - fr"\s*hsl\(\s*{_r_h}{_r_comma}{_r_sl}{_r_comma}{_r_sl}{_r_comma}{_r_alpha}\s*\)\s*" -) - -# colors where the two hex characters are the same, if all colors match this the short version of hex colors can be used -repeat_colors = {int(c * 2, 16) for c in "0123456789abcdef"} -rads = 2 * math.pi - - -class Color(Representation): - __slots__ = "_original", "_rgba" - - def __init__(self, value: ColorType) -> None: - self._rgba: RGBA - self._original: ColorType - if isinstance(value, (tuple, list)): - self._rgba = parse_tuple(value) - elif isinstance(value, str): - self._rgba = parse_str(value) - else: - raise ColorError(reason="value must be a tuple, list or string") - - # if we've got here value must be a valid color - self._original = value - - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - field_schema.update(type="string", format="color") - - def original(self) -> ColorType: - """ - Original value passed to Color - """ - return self._original - - def as_named(self, *, fallback: bool = False) -> str: - if self._rgba.alpha is None: - rgb = cast(Tuple[int, int, int], self.as_rgb_tuple()) - try: - return COLORS_BY_VALUE[rgb] - except KeyError as e: - if fallback: - return self.as_hex() - else: - raise ValueError( - "no named color found, use fallback=True, as_hex() or as_rgb()" - ) from e - else: - return self.as_hex() - - def as_hex(self) -> str: - """ - Hex string representing the color can be 3, 4, 6 or 8 characters depending on whether the string - a "short" representation of the color is possible and whether there's an alpha channel. - """ - values = [float_to_255(c) for c in self._rgba[:3]] - if self._rgba.alpha is not None: - values.append(float_to_255(self._rgba.alpha)) - - as_hex = "".join(f"{v:02x}" for v in values) - if all(c in repeat_colors for c in values): - as_hex = "".join(as_hex[c] for c in range(0, len(as_hex), 2)) - return "#" + as_hex - - def as_rgb(self) -> str: - """ - Color as an rgb(, , ) or rgba(, , , ) string. - """ - if self._rgba.alpha is None: - return f"rgb({float_to_255(self._rgba.r)}, {float_to_255(self._rgba.g)}, {float_to_255(self._rgba.b)})" - else: - return ( - f"rgba({float_to_255(self._rgba.r)}, {float_to_255(self._rgba.g)}, {float_to_255(self._rgba.b)}, " - f"{round(self._alpha_float(), 2)})" - ) - - def as_rgb_tuple(self, *, alpha: Optional[bool] = None) -> ColorTuple: - """ - Color as an RGB or RGBA tuple; red, green and blue are in the range 0 to 255, alpha if included is - in the range 0 to 1. - - :param alpha: whether to include the alpha channel, options are - None - (default) include alpha only if it's set (e.g. not None) - True - always include alpha, - False - always omit alpha, - """ - r, g, b = [float_to_255(c) for c in self._rgba[:3]] - if alpha is None: - if self._rgba.alpha is None: - return r, g, b - else: - return r, g, b, self._alpha_float() - elif alpha: - return r, g, b, self._alpha_float() - else: - # alpha is False - return r, g, b - - def as_hsl(self) -> str: - """ - Color as an hsl(, , ) or hsl(, , , ) string. - """ - if self._rgba.alpha is None: - h, s, li = self.as_hsl_tuple(alpha=False) # type: ignore - return f"hsl({h * 360:0.0f}, {s * 100:0.0f}%, {li * 100:0.0f}%)" - else: - h, s, li, a = self.as_hsl_tuple(alpha=True) # type: ignore - return ( - f"hsl({h * 360:0.0f}, {s * 100:0.0f}%, {li * 100:0.0f}%, {round(a, 2)})" - ) - - def as_hsl_tuple(self, *, alpha: Optional[bool] = None) -> HslColorTuple: - """ - Color as an HSL or HSLA tuple, e.g. hue, saturation, lightness and optionally alpha; all elements are in - the range 0 to 1. - - NOTE: this is HSL as used in HTML and most other places, not HLS as used in python's colorsys. - - :param alpha: whether to include the alpha channel, options are - None - (default) include alpha only if it's set (e.g. not None) - True - always include alpha, - False - always omit alpha, - """ - h, l, s = rgb_to_hls(self._rgba.r, self._rgba.g, self._rgba.b) - if alpha is None: - if self._rgba.alpha is None: - return h, s, l - else: - return h, s, l, self._alpha_float() - if alpha: - return h, s, l, self._alpha_float() - else: - # alpha is False - return h, s, l - - def _alpha_float(self) -> float: - return 1 if self._rgba.alpha is None else self._rgba.alpha - - @classmethod - def __get_validators__(cls) -> "CallableGenerator": - yield cls - - def __str__(self) -> str: - return self.as_named(fallback=True) - - def __repr_args__(self) -> "ReprArgs": - return [(None, self.as_named(fallback=True))] + [("rgb", self.as_rgb_tuple())] # type: ignore - - -def parse_tuple(value: Tuple[Any, ...]) -> RGBA: - """ - Parse a tuple or list as a color. - """ - if len(value) == 3: - r, g, b = [parse_color_value(v) for v in value] - return RGBA(r, g, b, None) - elif len(value) == 4: - r, g, b = [parse_color_value(v) for v in value[:3]] - return RGBA(r, g, b, parse_float_alpha(value[3])) - else: - raise ColorError(reason="tuples must have length 3 or 4") - - -def parse_str(value: str) -> RGBA: - """ - Parse a string to an RGBA tuple, trying the following formats (in this order): - * named color, see COLORS_BY_NAME below - * hex short eg. `fff` (prefix can be `#`, `0x` or nothing) - * hex long eg. `ffffff` (prefix can be `#`, `0x` or nothing) - * `rgb(, , ) ` - * `rgba(, , , )` - """ - value_lower = value.lower() - try: - r, g, b = COLORS_BY_NAME[value_lower] - except KeyError: - pass - else: - return ints_to_rgba(r, g, b, None) - - m = r_hex_short.fullmatch(value_lower) - if m: - *rgb, a = m.groups() - r, g, b = [int(v * 2, 16) for v in rgb] - if a: - alpha: Optional[float] = int(a * 2, 16) / 255 - else: - alpha = None - return ints_to_rgba(r, g, b, alpha) - - m = r_hex_long.fullmatch(value_lower) - if m: - *rgb, a = m.groups() - r, g, b = [int(v, 16) for v in rgb] - if a: - alpha = int(a, 16) / 255 - else: - alpha = None - return ints_to_rgba(r, g, b, alpha) - - m = r_rgb.fullmatch(value_lower) - if m: - return ints_to_rgba(*m.groups(), None) # type: ignore - - m = r_rgba.fullmatch(value_lower) - if m: - return ints_to_rgba(*m.groups()) # type: ignore - - m = r_hsl.fullmatch(value_lower) - if m: - h, h_units, s, l_ = m.groups() - return parse_hsl(h, h_units, s, l_) - - m = r_hsla.fullmatch(value_lower) - if m: - h, h_units, s, l_, a = m.groups() - return parse_hsl(h, h_units, s, l_, parse_float_alpha(a)) - - raise ColorError(reason="string not recognised as a valid color") - - -def ints_to_rgba( - r: Union[int, str], g: Union[int, str], b: Union[int, str], alpha: Optional[float] -) -> RGBA: - return RGBA( - parse_color_value(r), - parse_color_value(g), - parse_color_value(b), - parse_float_alpha(alpha), - ) - - -def parse_color_value(value: Union[int, str], max_val: int = 255) -> float: - """ - Parse a value checking it's a valid int in the range 0 to max_val and divide by max_val to give a number - in the range 0 to 1 - """ - try: - color = float(value) - except ValueError: - raise ColorError(reason="color values must be a valid number") - if 0 <= color <= max_val: - return color / max_val - else: - raise ColorError(reason=f"color values must be in the range 0 to {max_val}") - - -def parse_float_alpha(value: Union[None, str, float, int]) -> Optional[float]: - """ - Parse a value checking it's a valid float in the range 0 to 1 - """ - if value is None: - return None - try: - if isinstance(value, str) and value.endswith("%"): - alpha = float(value[:-1]) / 100 - else: - alpha = float(value) - except ValueError: - raise ColorError(reason="alpha values must be a valid float") - - if almost_equal_floats(alpha, 1): - return None - elif 0 <= alpha <= 1: - return alpha - else: - raise ColorError(reason="alpha values must be in the range 0 to 1") - - -def parse_hsl( - h: str, h_units: str, s: str, l: str, alpha: Optional[float] = None -) -> RGBA: - """ - Parse raw hue, saturation, lightness and alpha values and convert to RGBA. - """ - s_value, l_value = parse_color_value(s, 100), parse_color_value(l, 100) - - h_value = float(h) - if h_units in {None, "deg"}: - h_value = h_value % 360 / 360 - elif h_units == "rad": - h_value = h_value % rads / rads - else: - # turns - h_value = h_value % 1 - - r, g, b = hls_to_rgb(h_value, l_value, s_value) - return RGBA(r, g, b, alpha) - - -def float_to_255(c: float) -> int: - return int(round(c * 255)) - - -COLORS_BY_NAME = { - "aliceblue": (240, 248, 255), - "antiquewhite": (250, 235, 215), - "aqua": (0, 255, 255), - "aquamarine": (127, 255, 212), - "azure": (240, 255, 255), - "beige": (245, 245, 220), - "bisque": (255, 228, 196), - "black": (0, 0, 0), - "blanchedalmond": (255, 235, 205), - "blue": (0, 0, 255), - "blueviolet": (138, 43, 226), - "brown": (165, 42, 42), - "burlywood": (222, 184, 135), - "cadetblue": (95, 158, 160), - "chartreuse": (127, 255, 0), - "chocolate": (210, 105, 30), - "coral": (255, 127, 80), - "cornflowerblue": (100, 149, 237), - "cornsilk": (255, 248, 220), - "crimson": (220, 20, 60), - "cyan": (0, 255, 255), - "darkblue": (0, 0, 139), - "darkcyan": (0, 139, 139), - "darkgoldenrod": (184, 134, 11), - "darkgray": (169, 169, 169), - "darkgreen": (0, 100, 0), - "darkgrey": (169, 169, 169), - "darkkhaki": (189, 183, 107), - "darkmagenta": (139, 0, 139), - "darkolivegreen": (85, 107, 47), - "darkorange": (255, 140, 0), - "darkorchid": (153, 50, 204), - "darkred": (139, 0, 0), - "darksalmon": (233, 150, 122), - "darkseagreen": (143, 188, 143), - "darkslateblue": (72, 61, 139), - "darkslategray": (47, 79, 79), - "darkslategrey": (47, 79, 79), - "darkturquoise": (0, 206, 209), - "darkviolet": (148, 0, 211), - "deeppink": (255, 20, 147), - "deepskyblue": (0, 191, 255), - "dimgray": (105, 105, 105), - "dimgrey": (105, 105, 105), - "dodgerblue": (30, 144, 255), - "firebrick": (178, 34, 34), - "floralwhite": (255, 250, 240), - "forestgreen": (34, 139, 34), - "fuchsia": (255, 0, 255), - "gainsboro": (220, 220, 220), - "ghostwhite": (248, 248, 255), - "gold": (255, 215, 0), - "goldenrod": (218, 165, 32), - "gray": (128, 128, 128), - "green": (0, 128, 0), - "greenyellow": (173, 255, 47), - "grey": (128, 128, 128), - "honeydew": (240, 255, 240), - "hotpink": (255, 105, 180), - "indianred": (205, 92, 92), - "indigo": (75, 0, 130), - "ivory": (255, 255, 240), - "khaki": (240, 230, 140), - "lavender": (230, 230, 250), - "lavenderblush": (255, 240, 245), - "lawngreen": (124, 252, 0), - "lemonchiffon": (255, 250, 205), - "lightblue": (173, 216, 230), - "lightcoral": (240, 128, 128), - "lightcyan": (224, 255, 255), - "lightgoldenrodyellow": (250, 250, 210), - "lightgray": (211, 211, 211), - "lightgreen": (144, 238, 144), - "lightgrey": (211, 211, 211), - "lightpink": (255, 182, 193), - "lightsalmon": (255, 160, 122), - "lightseagreen": (32, 178, 170), - "lightskyblue": (135, 206, 250), - "lightslategray": (119, 136, 153), - "lightslategrey": (119, 136, 153), - "lightsteelblue": (176, 196, 222), - "lightyellow": (255, 255, 224), - "lime": (0, 255, 0), - "limegreen": (50, 205, 50), - "linen": (250, 240, 230), - "magenta": (255, 0, 255), - "maroon": (128, 0, 0), - "mediumaquamarine": (102, 205, 170), - "mediumblue": (0, 0, 205), - "mediumorchid": (186, 85, 211), - "mediumpurple": (147, 112, 219), - "mediumseagreen": (60, 179, 113), - "mediumslateblue": (123, 104, 238), - "mediumspringgreen": (0, 250, 154), - "mediumturquoise": (72, 209, 204), - "mediumvioletred": (199, 21, 133), - "midnightblue": (25, 25, 112), - "mintcream": (245, 255, 250), - "mistyrose": (255, 228, 225), - "moccasin": (255, 228, 181), - "navajowhite": (255, 222, 173), - "navy": (0, 0, 128), - "oldlace": (253, 245, 230), - "olive": (128, 128, 0), - "olivedrab": (107, 142, 35), - "orange": (255, 165, 0), - "orangered": (255, 69, 0), - "orchid": (218, 112, 214), - "palegoldenrod": (238, 232, 170), - "palegreen": (152, 251, 152), - "paleturquoise": (175, 238, 238), - "palevioletred": (219, 112, 147), - "papayawhip": (255, 239, 213), - "peachpuff": (255, 218, 185), - "peru": (205, 133, 63), - "pink": (255, 192, 203), - "plum": (221, 160, 221), - "powderblue": (176, 224, 230), - "purple": (128, 0, 128), - "red": (255, 0, 0), - "rosybrown": (188, 143, 143), - "royalblue": (65, 105, 225), - "saddlebrown": (139, 69, 19), - "salmon": (250, 128, 114), - "sandybrown": (244, 164, 96), - "seagreen": (46, 139, 87), - "seashell": (255, 245, 238), - "sienna": (160, 82, 45), - "silver": (192, 192, 192), - "skyblue": (135, 206, 235), - "slateblue": (106, 90, 205), - "slategray": (112, 128, 144), - "slategrey": (112, 128, 144), - "snow": (255, 250, 250), - "springgreen": (0, 255, 127), - "steelblue": (70, 130, 180), - "tan": (210, 180, 140), - "teal": (0, 128, 128), - "thistle": (216, 191, 216), - "tomato": (255, 99, 71), - "turquoise": (64, 224, 208), - "violet": (238, 130, 238), - "wheat": (245, 222, 179), - "white": (255, 255, 255), - "whitesmoke": (245, 245, 245), - "yellow": (255, 255, 0), - "yellowgreen": (154, 205, 50), -} - -COLORS_BY_VALUE = {v: k for k, v in COLORS_BY_NAME.items()} diff --git a/nornir/_vendor/pydantic/dataclasses.py b/nornir/_vendor/pydantic/dataclasses.py deleted file mode 100644 index 12aa0c35..00000000 --- a/nornir/_vendor/pydantic/dataclasses.py +++ /dev/null @@ -1,161 +0,0 @@ -import dataclasses -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Generator, - Optional, - Type, - TypeVar, - Union, -) - -from .class_validators import gather_all_validators -from .error_wrappers import ValidationError -from .errors import DataclassTypeError -from .fields import Required -from .main import create_model, validate_model -from .typing import AnyType - -if TYPE_CHECKING: - from .main import BaseModel # noqa: F401 - - DataclassT = TypeVar("DataclassT", bound="DataclassType") - - class DataclassType: - __pydantic_model__: Type[BaseModel] - __initialised__: bool - - def __init__(self, *args: Any, **kwargs: Any) -> None: - pass - - @classmethod - def __validate__(cls: Type["DataclassT"], v: Any) -> "DataclassT": - pass - - def __call__(self: "DataclassT", *args: Any, **kwargs: Any) -> "DataclassT": - pass - - -def _validate_dataclass(cls: Type["DataclassT"], v: Any) -> "DataclassT": - if isinstance(v, cls): - return v - elif isinstance(v, (list, tuple)): - return cls(*v) - elif isinstance(v, dict): - return cls(**v) - else: - raise DataclassTypeError(class_name=cls.__name__) - - -def _get_validators(cls: Type["DataclassT"]) -> Generator[Any, None, None]: - yield cls.__validate__ - - -def setattr_validate_assignment(self: "DataclassType", name: str, value: Any) -> None: - if self.__initialised__: - d = dict(self.__dict__) - d.pop(name, None) - known_field = self.__pydantic_model__.__fields__.get(name, None) - if known_field: - value, error_ = known_field.validate(value, d, loc=name, cls=self.__class__) - if error_: - raise ValidationError([error_], type(self)) - - object.__setattr__(self, name, value) - - -def _process_class( - _cls: AnyType, - init: bool, - repr: bool, - eq: bool, - order: bool, - unsafe_hash: bool, - frozen: bool, - config: Optional[Type[Any]], -) -> "DataclassType": - post_init_original = getattr(_cls, "__post_init__", None) - if post_init_original and post_init_original.__name__ == "_pydantic_post_init": - post_init_original = None - if not post_init_original: - post_init_original = getattr(_cls, "__post_init_original__", None) - - post_init_post_parse = getattr(_cls, "__post_init_post_parse__", None) - - def _pydantic_post_init(self: "DataclassType", *initvars: Any) -> None: - if post_init_original is not None: - post_init_original(self, *initvars) - d, _, validation_error = validate_model( - self.__pydantic_model__, self.__dict__, cls=self.__class__ - ) - if validation_error: - raise validation_error - object.__setattr__(self, "__dict__", d) - object.__setattr__(self, "__initialised__", True) - if post_init_post_parse is not None: - post_init_post_parse(self, *initvars) - - _cls.__post_init__ = _pydantic_post_init - cls = dataclasses._process_class(_cls, init, repr, eq, order, unsafe_hash, frozen) # type: ignore - - fields: Dict[str, Any] = {} - for field in dataclasses.fields(cls): - - if field.default != dataclasses.MISSING: - field_value = field.default - # mypy issue 7020 and 708 - elif field.default_factory != dataclasses.MISSING: # type: ignore - field_value = field.default_factory() # type: ignore - else: - field_value = Required - - fields[field.name] = (field.type, field_value) - - validators = gather_all_validators(cls) - cls.__pydantic_model__ = create_model( - cls.__name__, - __config__=config, - __module__=_cls.__module__, - __validators__=validators, - **fields, - ) - - cls.__initialised__ = False - cls.__validate__ = classmethod(_validate_dataclass) - cls.__get_validators__ = classmethod(_get_validators) - if post_init_original: - cls.__post_init_original__ = post_init_original - - if cls.__pydantic_model__.__config__.validate_assignment and not frozen: - cls.__setattr__ = setattr_validate_assignment - - return cls - - -def dataclass( - _cls: Optional[AnyType] = None, - *, - init: bool = True, - repr: bool = True, - eq: bool = True, - order: bool = False, - unsafe_hash: bool = False, - frozen: bool = False, - config: Type[Any] = None, -) -> Union[Callable[[AnyType], "DataclassType"], "DataclassType"]: - """ - Like the python standard lib dataclasses but with type validation. - - Arguments are the same as for standard dataclasses, except for validate_assignment which has the same meaning - as Config.validate_assignment. - """ - - def wrap(cls: AnyType) -> "DataclassType": - return _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, config) - - if _cls is None: - return wrap - - return wrap(_cls) diff --git a/nornir/_vendor/pydantic/datetime_parse.py b/nornir/_vendor/pydantic/datetime_parse.py deleted file mode 100644 index 1ae6cfca..00000000 --- a/nornir/_vendor/pydantic/datetime_parse.py +++ /dev/null @@ -1,245 +0,0 @@ -""" -Functions to parse datetime objects. - -We're using regular expressions rather than time.strptime because: -- They provide both validation and parsing. -- They're more flexible for datetimes. -- The date/datetime/time constructors produce friendlier error messages. - -Stolen from https://raw.githubusercontent.com/django/django/master/django/utils/dateparse.py at -9718fa2e8abe430c3526a9278dd976443d4ae3c6 - -Changed to: -* use standard python datetime types not django.utils.timezone -* raise ValueError when regex doesn't match rather than returning None -* support parsing unix timestamps for dates and datetimes -""" -import re -from datetime import date, datetime, time, timedelta, timezone -from typing import Dict, Union - -from . import errors - -date_re = re.compile(r"(?P\d{4})-(?P\d{1,2})-(?P\d{1,2})$") - -time_re = re.compile( - r"(?P\d{1,2}):(?P\d{1,2})" - r"(?::(?P\d{1,2})(?:\.(?P\d{1,6})\d{0,6})?)?" -) - -datetime_re = re.compile( - r"(?P\d{4})-(?P\d{1,2})-(?P\d{1,2})" - r"[T ](?P\d{1,2}):(?P\d{1,2})" - r"(?::(?P\d{1,2})(?:\.(?P\d{1,6})\d{0,6})?)?" - r"(?PZ|[+-]\d{2}(?::?\d{2})?)?$" -) - -standard_duration_re = re.compile( - r"^" - r"(?:(?P-?\d+) (days?, )?)?" - r"((?:(?P-?\d+):)(?=\d+:\d+))?" - r"(?:(?P-?\d+):)?" - r"(?P-?\d+)" - r"(?:\.(?P\d{1,6})\d{0,6})?" - r"$" -) - -# Support the sections of ISO 8601 date representation that are accepted by timedelta -iso8601_duration_re = re.compile( - r"^(?P[-+]?)" - r"P" - r"(?:(?P\d+(.\d+)?)D)?" - r"(?:T" - r"(?:(?P\d+(.\d+)?)H)?" - r"(?:(?P\d+(.\d+)?)M)?" - r"(?:(?P\d+(.\d+)?)S)?" - r")?" - r"$" -) - -EPOCH = datetime(1970, 1, 1) -# if greater than this, the number is in ms, if less than or equal it's in seconds -# (in seconds this is 11th October 2603, in ms it's 20th August 1970) -MS_WATERSHED = int(2e10) -StrBytesIntFloat = Union[str, bytes, int, float] - - -def get_numeric( - value: StrBytesIntFloat, native_expected_type: str -) -> Union[None, int, float]: - if isinstance(value, (int, float)): - return value - try: - return float(value) - except ValueError: - return None - except TypeError: - raise TypeError( - f"invalid type; expected {native_expected_type}, string, bytes, int or float" - ) - - -def from_unix_seconds(seconds: Union[int, float]) -> datetime: - while seconds > MS_WATERSHED: - seconds /= 1000 - dt = EPOCH + timedelta(seconds=seconds) - return dt.replace(tzinfo=timezone.utc) - - -def parse_date(value: Union[date, StrBytesIntFloat]) -> date: - """ - Parse a date/int/float/string and return a datetime.date. - - Raise ValueError if the input is well formatted but not a valid date. - Raise ValueError if the input isn't well formatted. - """ - if isinstance(value, date): - if isinstance(value, datetime): - return value.date() - else: - return value - - number = get_numeric(value, "date") - if number is not None: - return from_unix_seconds(number).date() - - if isinstance(value, bytes): - value = value.decode() - - match = date_re.match(value) # type: ignore - if match is None: - raise errors.DateError() - - kw = {k: int(v) for k, v in match.groupdict().items()} - - try: - return date(**kw) - except ValueError: - raise errors.DateError() - - -def parse_time(value: Union[time, StrBytesIntFloat]) -> time: - """ - Parse a time/string and return a datetime.time. - - This function doesn't support time zone offsets. - - Raise ValueError if the input is well formatted but not a valid time. - Raise ValueError if the input isn't well formatted, in particular if it contains an offset. - """ - if isinstance(value, time): - return value - - number = get_numeric(value, "time") - if number is not None: - if number >= 86400: - # doesn't make sense since the time time loop back around to 0 - raise errors.TimeError() - return (datetime.min + timedelta(seconds=number)).time() - - if isinstance(value, bytes): - value = value.decode() - - match = time_re.match(value) # type: ignore - if match is None: - raise errors.TimeError() - - kw = match.groupdict() - if kw["microsecond"]: - kw["microsecond"] = kw["microsecond"].ljust(6, "0") - - kw_ = {k: int(v) for k, v in kw.items() if v is not None} - - try: - return time(**kw_) # type: ignore - except ValueError: - raise errors.TimeError() - - -def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: - """ - Parse a datetime/int/float/string and return a datetime.datetime. - - This function supports time zone offsets. When the input contains one, - the output uses a timezone with a fixed offset from UTC. - - Raise ValueError if the input is well formatted but not a valid datetime. - Raise ValueError if the input isn't well formatted. - """ - if isinstance(value, datetime): - return value - - number = get_numeric(value, "datetime") - if number is not None: - return from_unix_seconds(number) - - if isinstance(value, bytes): - value = value.decode() - - match = datetime_re.match(value) # type: ignore - if match is None: - raise errors.DateTimeError() - - kw = match.groupdict() - if kw["microsecond"]: - kw["microsecond"] = kw["microsecond"].ljust(6, "0") - - tzinfo_str = kw.pop("tzinfo") - if tzinfo_str == "Z": - tzinfo = timezone.utc - elif tzinfo_str is not None: - offset_mins = int(tzinfo_str[-2:]) if len(tzinfo_str) > 3 else 0 - offset = 60 * int(tzinfo_str[1:3]) + offset_mins - if tzinfo_str[0] == "-": - offset = -offset - tzinfo = timezone(timedelta(minutes=offset)) - else: - tzinfo = None - - kw_: Dict[str, Union[int, timezone]] = { - k: int(v) for k, v in kw.items() if v is not None - } - kw_["tzinfo"] = tzinfo - - try: - return datetime(**kw_) # type: ignore - except ValueError: - raise errors.DateTimeError() - - -def parse_duration(value: StrBytesIntFloat) -> timedelta: - """ - Parse a duration int/float/string and return a datetime.timedelta. - - The preferred format for durations in Django is '%d %H:%M:%S.%f'. - - Also supports ISO 8601 representation. - """ - if isinstance(value, timedelta): - return value - - if isinstance(value, (int, float)): - # bellow code requires a string - value = str(value) - elif isinstance(value, bytes): - value = value.decode() - - try: - match = standard_duration_re.match(value) or iso8601_duration_re.match(value) - except TypeError: - raise TypeError("invalid type; expected timedelta, string, bytes, int or float") - - if not match: - raise errors.DurationError() - - kw = match.groupdict() - sign = -1 if kw.pop("sign", "+") == "-" else 1 - if kw.get("microseconds"): - kw["microseconds"] = kw["microseconds"].ljust(6, "0") - - if kw.get("seconds") and kw.get("microseconds") and kw["seconds"].startswith("-"): - kw["microseconds"] = "-" + kw["microseconds"] - - kw_ = {k: float(v) for k, v in kw.items() if v is not None} - - return sign * timedelta(**kw_) # type: ignore diff --git a/nornir/_vendor/pydantic/env_settings.py b/nornir/_vendor/pydantic/env_settings.py deleted file mode 100644 index b2b30e8c..00000000 --- a/nornir/_vendor/pydantic/env_settings.py +++ /dev/null @@ -1,94 +0,0 @@ -import os -import warnings -from typing import AbstractSet, Any, Dict, List, Mapping, Optional, Union - -from .fields import ModelField -from .main import BaseModel, Extra -from .typing import display_as_type -from .utils import deep_update, sequence_like - - -class SettingsError(ValueError): - pass - - -class BaseSettings(BaseModel): - """ - Base class for settings, allowing values to be overridden by environment variables. - - This is useful in production for secrets you do not wish to save in code, it plays nicely with docker(-compose), - Heroku and any 12 factor app design. - """ - - def __init__(__pydantic_self__, **values: Any) -> None: - # Uses something other than `self` the first arg to allow "self" as a settable attribute - super().__init__(**__pydantic_self__._build_values(values)) - - def _build_values(self, init_kwargs: Dict[str, Any]) -> Dict[str, Any]: - return deep_update(self._build_environ(), init_kwargs) - - def _build_environ(self) -> Dict[str, Optional[str]]: - """ - Build environment variables suitable for passing to the Model. - """ - d: Dict[str, Optional[str]] = {} - - if self.__config__.case_sensitive: - env_vars: Mapping[str, str] = os.environ - else: - env_vars = {k.lower(): v for k, v in os.environ.items()} - - for field in self.__fields__.values(): - env_val: Optional[str] = None - for env_name in field.field_info.extra["env_names"]: - env_val = env_vars.get(env_name) - if env_val is not None: - break - - if env_val is None: - continue - - if field.is_complex(): - try: - env_val = self.__config__.json_loads(env_val) # type: ignore - except ValueError as e: - raise SettingsError(f'error parsing JSON for "{env_name}"') from e - d[field.alias] = env_val - return d - - class Config: - env_prefix = "" - validate_all = True - extra = Extra.forbid - arbitrary_types_allowed = True - case_sensitive = False - - @classmethod - def prepare_field(cls, field: ModelField) -> None: - env_names: Union[List[str], AbstractSet[str]] - env = field.field_info.extra.get("env") - if env is None: - if field.has_alias: - warnings.warn( - "aliases are no longer used by BaseSettings to define which environment variables to read. " - 'Instead use the "env" field setting. ' - "See https://pydantic-docs.helpmanual.io/usage/settings/#environment-variable-names", - FutureWarning, - ) - env_names = {cls.env_prefix + field.name} - elif isinstance(env, str): - env_names = {env} - elif isinstance(env, (set, frozenset)): - env_names = env - elif sequence_like(env): - env_names = list(env) - else: - raise TypeError( - f"invalid field env: {env!r} ({display_as_type(env)}); should be string, list or set" - ) - - if not cls.case_sensitive: - env_names = type(env_names)(n.lower() for n in env_names) - field.field_info.extra["env_names"] = env_names - - __config__: Config # type: ignore diff --git a/nornir/_vendor/pydantic/error_wrappers.py b/nornir/_vendor/pydantic/error_wrappers.py deleted file mode 100644 index b43facee..00000000 --- a/nornir/_vendor/pydantic/error_wrappers.py +++ /dev/null @@ -1,169 +0,0 @@ -import json -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Generator, - List, - Optional, - Sequence, - Tuple, - Type, - Union, -) - -from .json import pydantic_encoder -from .utils import Representation - -if TYPE_CHECKING: - from .main import BaseConfig # noqa: F401 - from .types import ModelOrDc # noqa: F401 - from .typing import ReprArgs - - Loc = Tuple[Union[int, str], ...] - -__all__ = "ErrorWrapper", "ValidationError" - - -class ErrorWrapper(Representation): - __slots__ = "exc", "_loc" - - def __init__(self, exc: Exception, loc: Union[str, "Loc"]) -> None: - self.exc = exc - self._loc = loc - - def loc_tuple(self) -> "Loc": - if isinstance(self._loc, tuple): - return self._loc - else: - return (self._loc,) - - def __repr_args__(self) -> "ReprArgs": - return [("exc", self.exc), ("loc", self.loc_tuple())] - - -# ErrorList is something like Union[List[Union[List[ErrorWrapper], ErrorWrapper]], ErrorWrapper] -# but recursive, therefore just use: -ErrorList = Union[Sequence[Any], ErrorWrapper] - - -class ValidationError(Representation, ValueError): - __slots__ = "raw_errors", "model", "_error_cache" - - def __init__(self, errors: Sequence[ErrorList], model: "ModelOrDc") -> None: - self.raw_errors = errors - self.model = model - self._error_cache: Optional[List[Dict[str, Any]]] = None - - def errors(self) -> List[Dict[str, Any]]: - if self._error_cache is None: - try: - config = self.model.__config__ # type: ignore - except AttributeError: - config = self.model.__pydantic_model__.__config__ # type: ignore - self._error_cache = list(flatten_errors(self.raw_errors, config)) - return self._error_cache - - def json(self, *, indent: Union[None, int, str] = 2) -> str: - return json.dumps(self.errors(), indent=indent, default=pydantic_encoder) - - def __str__(self) -> str: - errors = self.errors() - no_errors = len(errors) - return ( - f'{no_errors} validation error{"" if no_errors == 1 else "s"} for {self.model.__name__}\n' - f"{display_errors(errors)}" - ) - - def __repr_args__(self) -> "ReprArgs": - return [("model", self.model.__name__), ("errors", self.errors())] - - -def display_errors(errors: List[Dict[str, Any]]) -> str: - return "\n".join( - f'{_display_error_loc(e)}\n {e["msg"]} ({_display_error_type_and_ctx(e)})' - for e in errors - ) - - -def _display_error_loc(error: Dict[str, Any]) -> str: - return " -> ".join(str(l) for l in error["loc"]) - - -def _display_error_type_and_ctx(error: Dict[str, Any]) -> str: - t = "type=" + error["type"] - ctx = error.get("ctx") - if ctx: - return t + "".join(f"; {k}={v}" for k, v in ctx.items()) - else: - return t - - -def flatten_errors( - errors: Sequence[Any], config: Type["BaseConfig"], loc: Optional["Loc"] = None -) -> Generator[Dict[str, Any], None, None]: - for error in errors: - if isinstance(error, ErrorWrapper): - - if loc: - error_loc = loc + error.loc_tuple() - else: - error_loc = error.loc_tuple() - - if isinstance(error.exc, ValidationError): - yield from flatten_errors(error.exc.raw_errors, config, error_loc) - else: - yield error_dict(error.exc, config, error_loc) - elif isinstance(error, list): - yield from flatten_errors(error, config, loc=loc) - else: - raise RuntimeError(f"Unknown error object: {error}") - - -def error_dict( - exc: Exception, config: Type["BaseConfig"], loc: "Loc" -) -> Dict[str, Any]: - type_ = get_exc_type(type(exc)) - msg_template = config.error_msg_templates.get(type_) or getattr( - exc, "msg_template", None - ) - ctx = exc.__dict__ - if msg_template: - msg = msg_template.format(**ctx) - else: - msg = str(exc) - - d: Dict[str, Any] = {"loc": loc, "msg": msg, "type": type_} - - if ctx: - d["ctx"] = ctx - - return d - - -_EXC_TYPE_CACHE: Dict[Type[Exception], str] = {} - - -def get_exc_type(cls: Type[Exception]) -> str: - # slightly more efficient than using lru_cache since we don't need to worry about the cache filling up - try: - return _EXC_TYPE_CACHE[cls] - except KeyError: - r = _get_exc_type(cls) - _EXC_TYPE_CACHE[cls] = r - return r - - -def _get_exc_type(cls: Type[Exception]) -> str: - if issubclass(cls, AssertionError): - return "assertion_error" - - base_name = "type_error" if issubclass(cls, TypeError) else "value_error" - if cls in (TypeError, ValueError): - # just TypeError or ValueError, no extra code - return base_name - - # if it's not a TypeError or ValueError, we just take the lowercase of the exception name - # no chaining or snake case logic, use "code" for more complex error types. - code = getattr(cls, "code", None) or cls.__name__.replace("Error", "").lower() - return base_name + "." + code diff --git a/nornir/_vendor/pydantic/errors.py b/nornir/_vendor/pydantic/errors.py deleted file mode 100644 index 589f0e6e..00000000 --- a/nornir/_vendor/pydantic/errors.py +++ /dev/null @@ -1,513 +0,0 @@ -from decimal import Decimal -from pathlib import Path -from typing import Any, Set, Union - -from .typing import AnyType, display_as_type - -# explicitly state exports to avoid "from .errors import *" also importing Decimal, Path etc. -__all__ = ( - "PydanticTypeError", - "PydanticValueError", - "ConfigError", - "MissingError", - "ExtraError", - "NoneIsNotAllowedError", - "NoneIsAllowedError", - "WrongConstantError", - "BoolError", - "BytesError", - "DictError", - "EmailError", - "UrlError", - "UrlSchemeError", - "UrlSchemePermittedError", - "UrlUserInfoError", - "UrlHostError", - "UrlHostTldError", - "UrlExtraError", - "EnumError", - "IntegerError", - "FloatError", - "PathError", - "_PathValueError", - "PathNotExistsError", - "PathNotAFileError", - "PathNotADirectoryError", - "PyObjectError", - "SequenceError", - "ListError", - "SetError", - "FrozenSetError", - "TupleError", - "TupleLengthError", - "ListMinLengthError", - "ListMaxLengthError", - "AnyStrMinLengthError", - "AnyStrMaxLengthError", - "StrError", - "StrRegexError", - "_NumberBoundError", - "NumberNotGtError", - "NumberNotGeError", - "NumberNotLtError", - "NumberNotLeError", - "NumberNotMultipleError", - "DecimalError", - "DecimalIsNotFiniteError", - "DecimalMaxDigitsError", - "DecimalMaxPlacesError", - "DecimalWholeDigitsError", - "DateTimeError", - "DateError", - "TimeError", - "DurationError", - "UUIDError", - "UUIDVersionError", - "ArbitraryTypeError", - "ClassError", - "SubclassError", - "JsonError", - "JsonTypeError", - "PatternError", - "DataclassTypeError", - "CallableError", - "IPvAnyAddressError", - "IPvAnyInterfaceError", - "IPvAnyNetworkError", - "IPv4AddressError", - "IPv6AddressError", - "IPv4NetworkError", - "IPv6NetworkError", - "IPv4InterfaceError", - "IPv6InterfaceError", - "ColorError", - "StrictBoolError", - "NotDigitError", - "LuhnValidationError", - "InvalidLengthForBrand", - "InvalidByteSize", - "InvalidByteSizeUnit", -) - - -class PydanticErrorMixin: - code: str - msg_template: str - - def __init__(self, **ctx: Any) -> None: - self.__dict__ = ctx - - def __str__(self) -> str: - return self.msg_template.format(**self.__dict__) - - -class PydanticTypeError(PydanticErrorMixin, TypeError): - pass - - -class PydanticValueError(PydanticErrorMixin, ValueError): - pass - - -class ConfigError(RuntimeError): - pass - - -class MissingError(PydanticValueError): - msg_template = "field required" - - -class ExtraError(PydanticValueError): - msg_template = "extra fields not permitted" - - -class NoneIsNotAllowedError(PydanticTypeError): - code = "none.not_allowed" - msg_template = "none is not an allowed value" - - -class NoneIsAllowedError(PydanticTypeError): - code = "none.allowed" - msg_template = "value is not none" - - -class WrongConstantError(PydanticValueError): - code = "const" - - def __str__(self) -> str: - permitted = ", ".join(repr(v) for v in self.permitted) # type: ignore - return f"unexpected value; permitted: {permitted}" - - -class BoolError(PydanticTypeError): - msg_template = "value could not be parsed to a boolean" - - -class BytesError(PydanticTypeError): - msg_template = "byte type expected" - - -class DictError(PydanticTypeError): - msg_template = "value is not a valid dict" - - -class EmailError(PydanticValueError): - msg_template = "value is not a valid email address" - - -class UrlError(PydanticValueError): - code = "url" - - -class UrlSchemeError(UrlError): - code = "url.scheme" - msg_template = "invalid or missing URL scheme" - - -class UrlSchemePermittedError(UrlError): - code = "url.scheme" - msg_template = "URL scheme not permitted" - - def __init__(self, allowed_schemes: Set[str]): - super().__init__(allowed_schemes=allowed_schemes) - - -class UrlUserInfoError(UrlError): - code = "url.userinfo" - msg_template = "userinfo required in URL but missing" - - -class UrlHostError(UrlError): - code = "url.host" - msg_template = "URL host invalid" - - -class UrlHostTldError(UrlError): - code = "url.host" - msg_template = "URL host invalid, top level domain required" - - -class UrlExtraError(UrlError): - code = "url.extra" - msg_template = "URL invalid, extra characters found after valid URL: {extra!r}" - - -class EnumError(PydanticTypeError): - def __str__(self) -> str: - permitted = ", ".join(repr(v.value) for v in self.enum_values) # type: ignore - return f"value is not a valid enumeration member; permitted: {permitted}" - - -class IntegerError(PydanticTypeError): - msg_template = "value is not a valid integer" - - -class FloatError(PydanticTypeError): - msg_template = "value is not a valid float" - - -class PathError(PydanticTypeError): - msg_template = "value is not a valid path" - - -class _PathValueError(PydanticValueError): - def __init__(self, *, path: Path) -> None: - super().__init__(path=str(path)) - - -class PathNotExistsError(_PathValueError): - code = "path.not_exists" - msg_template = 'file or directory at path "{path}" does not exist' - - -class PathNotAFileError(_PathValueError): - code = "path.not_a_file" - msg_template = 'path "{path}" does not point to a file' - - -class PathNotADirectoryError(_PathValueError): - code = "path.not_a_directory" - msg_template = 'path "{path}" does not point to a directory' - - -class PyObjectError(PydanticTypeError): - msg_template = "ensure this value contains valid import path or valid callable: {error_message}" - - -class SequenceError(PydanticTypeError): - msg_template = "value is not a valid sequence" - - -class ListError(PydanticTypeError): - msg_template = "value is not a valid list" - - -class SetError(PydanticTypeError): - msg_template = "value is not a valid set" - - -class FrozenSetError(PydanticTypeError): - msg_template = "value is not a valid frozenset" - - -class TupleError(PydanticTypeError): - msg_template = "value is not a valid tuple" - - -class TupleLengthError(PydanticValueError): - code = "tuple.length" - msg_template = "wrong tuple length {actual_length}, expected {expected_length}" - - def __init__(self, *, actual_length: int, expected_length: int) -> None: - super().__init__(actual_length=actual_length, expected_length=expected_length) - - -class ListMinLengthError(PydanticValueError): - code = "list.min_items" - msg_template = "ensure this value has at least {limit_value} items" - - def __init__(self, *, limit_value: int) -> None: - super().__init__(limit_value=limit_value) - - -class ListMaxLengthError(PydanticValueError): - code = "list.max_items" - msg_template = "ensure this value has at most {limit_value} items" - - def __init__(self, *, limit_value: int) -> None: - super().__init__(limit_value=limit_value) - - -class AnyStrMinLengthError(PydanticValueError): - code = "any_str.min_length" - msg_template = "ensure this value has at least {limit_value} characters" - - def __init__(self, *, limit_value: int) -> None: - super().__init__(limit_value=limit_value) - - -class AnyStrMaxLengthError(PydanticValueError): - code = "any_str.max_length" - msg_template = "ensure this value has at most {limit_value} characters" - - def __init__(self, *, limit_value: int) -> None: - super().__init__(limit_value=limit_value) - - -class StrError(PydanticTypeError): - msg_template = "str type expected" - - -class StrRegexError(PydanticValueError): - code = "str.regex" - msg_template = 'string does not match regex "{pattern}"' - - def __init__(self, *, pattern: str) -> None: - super().__init__(pattern=pattern) - - -class _NumberBoundError(PydanticValueError): - def __init__(self, *, limit_value: Union[int, float, Decimal]) -> None: - super().__init__(limit_value=limit_value) - - -class NumberNotGtError(_NumberBoundError): - code = "number.not_gt" - msg_template = "ensure this value is greater than {limit_value}" - - -class NumberNotGeError(_NumberBoundError): - code = "number.not_ge" - msg_template = "ensure this value is greater than or equal to {limit_value}" - - -class NumberNotLtError(_NumberBoundError): - code = "number.not_lt" - msg_template = "ensure this value is less than {limit_value}" - - -class NumberNotLeError(_NumberBoundError): - code = "number.not_le" - msg_template = "ensure this value is less than or equal to {limit_value}" - - -class NumberNotMultipleError(PydanticValueError): - code = "number.not_multiple" - msg_template = "ensure this value is a multiple of {multiple_of}" - - def __init__(self, *, multiple_of: Union[int, float, Decimal]) -> None: - super().__init__(multiple_of=multiple_of) - - -class DecimalError(PydanticTypeError): - msg_template = "value is not a valid decimal" - - -class DecimalIsNotFiniteError(PydanticValueError): - code = "decimal.not_finite" - msg_template = "value is not a valid decimal" - - -class DecimalMaxDigitsError(PydanticValueError): - code = "decimal.max_digits" - msg_template = "ensure that there are no more than {max_digits} digits in total" - - def __init__(self, *, max_digits: int) -> None: - super().__init__(max_digits=max_digits) - - -class DecimalMaxPlacesError(PydanticValueError): - code = "decimal.max_places" - msg_template = "ensure that there are no more than {decimal_places} decimal places" - - def __init__(self, *, decimal_places: int) -> None: - super().__init__(decimal_places=decimal_places) - - -class DecimalWholeDigitsError(PydanticValueError): - code = "decimal.whole_digits" - msg_template = "ensure that there are no more than {whole_digits} digits before the decimal point" - - def __init__(self, *, whole_digits: int) -> None: - super().__init__(whole_digits=whole_digits) - - -class DateTimeError(PydanticValueError): - msg_template = "invalid datetime format" - - -class DateError(PydanticValueError): - msg_template = "invalid date format" - - -class TimeError(PydanticValueError): - msg_template = "invalid time format" - - -class DurationError(PydanticValueError): - msg_template = "invalid duration format" - - -class UUIDError(PydanticTypeError): - msg_template = "value is not a valid uuid" - - -class UUIDVersionError(PydanticValueError): - code = "uuid.version" - msg_template = "uuid version {required_version} expected" - - def __init__(self, *, required_version: int) -> None: - super().__init__(required_version=required_version) - - -class ArbitraryTypeError(PydanticTypeError): - code = "arbitrary_type" - msg_template = "instance of {expected_arbitrary_type} expected" - - def __init__(self, *, expected_arbitrary_type: AnyType) -> None: - super().__init__( - expected_arbitrary_type=display_as_type(expected_arbitrary_type) - ) - - -class ClassError(PydanticTypeError): - code = "class" - msg_template = "a class is expected" - - -class SubclassError(PydanticTypeError): - code = "subclass" - msg_template = "subclass of {expected_class} expected" - - def __init__(self, *, expected_class: AnyType) -> None: - super().__init__(expected_class=display_as_type(expected_class)) - - -class JsonError(PydanticValueError): - msg_template = "Invalid JSON" - - -class JsonTypeError(PydanticTypeError): - code = "json" - msg_template = "JSON object must be str, bytes or bytearray" - - -class PatternError(PydanticValueError): - code = "regex_pattern" - msg_template = "Invalid regular expression" - - -class DataclassTypeError(PydanticTypeError): - code = "dataclass" - msg_template = "instance of {class_name}, tuple or dict expected" - - -class CallableError(PydanticTypeError): - msg_template = "{value} is not callable" - - -class IPvAnyAddressError(PydanticValueError): - msg_template = "value is not a valid IPv4 or IPv6 address" - - -class IPvAnyInterfaceError(PydanticValueError): - msg_template = "value is not a valid IPv4 or IPv6 interface" - - -class IPvAnyNetworkError(PydanticValueError): - msg_template = "value is not a valid IPv4 or IPv6 network" - - -class IPv4AddressError(PydanticValueError): - msg_template = "value is not a valid IPv4 address" - - -class IPv6AddressError(PydanticValueError): - msg_template = "value is not a valid IPv6 address" - - -class IPv4NetworkError(PydanticValueError): - msg_template = "value is not a valid IPv4 network" - - -class IPv6NetworkError(PydanticValueError): - msg_template = "value is not a valid IPv6 network" - - -class IPv4InterfaceError(PydanticValueError): - msg_template = "value is not a valid IPv4 interface" - - -class IPv6InterfaceError(PydanticValueError): - msg_template = "value is not a valid IPv6 interface" - - -class ColorError(PydanticValueError): - msg_template = "value is not a valid color: {reason}" - - -class StrictBoolError(PydanticValueError): - msg_template = "value is not a valid boolean" - - -class NotDigitError(PydanticValueError): - code = "payment_card_number.digits" - msg_template = "card number is not all digits" - - -class LuhnValidationError(PydanticValueError): - code = "payment_card_number.luhn_check" - msg_template = "card number is not luhn valid" - - -class InvalidLengthForBrand(PydanticValueError): - code = "payment_card_number.invalid_length_for_brand" - msg_template = "Length for a {brand} card must be {required_length}" - - -class InvalidByteSize(PydanticValueError): - msg_template = "could not parse value and unit from byte string" - - -class InvalidByteSizeUnit(PydanticValueError): - msg_template = "could not interpret byte unit: {unit}" diff --git a/nornir/_vendor/pydantic/fields.py b/nornir/_vendor/pydantic/fields.py deleted file mode 100644 index 86b5ccdf..00000000 --- a/nornir/_vendor/pydantic/fields.py +++ /dev/null @@ -1,733 +0,0 @@ -import warnings -from typing import ( - TYPE_CHECKING, - Any, - Dict, - FrozenSet, - Generator, - Iterator, - List, - Mapping, - Optional, - Pattern, - Sequence, - Set, - Tuple, - Type, - TypeVar, - Union, - cast, -) - -from . import errors as errors_ -from .class_validators import Validator, make_generic_validator, prep_validators -from .error_wrappers import ErrorWrapper -from .errors import NoneIsNotAllowedError -from .types import Json, JsonWrapper -from .typing import ( - AnyType, - Callable, - ForwardRef, - NoneType, - display_as_type, - is_literal_type, -) -from .utils import PyObjectStr, Representation, lenient_issubclass, sequence_like -from .validators import ( - constant_validator, - dict_validator, - find_validators, - validate_json, -) - -Required: Any = Ellipsis - - -class UndefinedType: - def __repr__(self) -> str: - return "PydanticUndefined" - - -Undefined = UndefinedType() - -if TYPE_CHECKING: - from .class_validators import ValidatorsList # noqa: F401 - from .error_wrappers import ErrorList - from .main import BaseConfig, BaseModel # noqa: F401 - from .types import ModelOrDc # noqa: F401 - from .typing import ReprArgs # noqa: F401 - - ValidateReturn = Tuple[Optional[Any], Optional[ErrorList]] - LocStr = Union[Tuple[Union[int, str], ...], str] - BoolUndefined = Union[bool, UndefinedType] - - -class FieldInfo(Representation): - """ - Captures extra information about a field. - """ - - __slots__ = ( - "default", - "alias", - "title", - "description", - "const", - "gt", - "ge", - "lt", - "le", - "multiple_of", - "min_items", - "max_items", - "min_length", - "max_length", - "regex", - "extra", - ) - - def __init__(self, default: Any, **kwargs: Any) -> None: - self.default = default - self.alias = kwargs.pop("alias", None) - self.title = kwargs.pop("title", None) - self.description = kwargs.pop("description", None) - self.const = kwargs.pop("const", None) - self.gt = kwargs.pop("gt", None) - self.ge = kwargs.pop("ge", None) - self.lt = kwargs.pop("lt", None) - self.le = kwargs.pop("le", None) - self.multiple_of = kwargs.pop("multiple_of", None) - self.min_items = kwargs.pop("min_items", None) - self.max_items = kwargs.pop("max_items", None) - self.min_length = kwargs.pop("min_length", None) - self.max_length = kwargs.pop("max_length", None) - self.regex = kwargs.pop("regex", None) - self.extra = kwargs - - -def Field( - default: Any, - *, - alias: str = None, - title: str = None, - description: str = None, - const: bool = None, - gt: float = None, - ge: float = None, - lt: float = None, - le: float = None, - multiple_of: float = None, - min_items: int = None, - max_items: int = None, - min_length: int = None, - max_length: int = None, - regex: str = None, - **extra: Any, -) -> Any: - """ - Used to provide extra information about a field, either for the model schema or complex valiation. Some arguments - apply only to number fields (``int``, ``float``, ``Decimal``) and some apply only to ``str``. - - :param default: since this is replacing the field’s default, its first argument is used - to set the default, use ellipsis (``...``) to indicate the field is required - :param alias: the public name of the field - :param title: can be any string, used in the schema - :param description: can be any string, used in the schema - :param const: this field is required and *must* take it's default value - :param gt: only applies to numbers, requires the field to be "greater than". The schema - will have an ``exclusiveMinimum`` validation keyword - :param ge: only applies to numbers, requires the field to be "greater than or equal to". The - schema will have a ``minimum`` validation keyword - :param lt: only applies to numbers, requires the field to be "less than". The schema - will have an ``exclusiveMaximum`` validation keyword - :param le: only applies to numbers, requires the field to be "less than or equal to". The - schema will have a ``maximum`` validation keyword - :param multiple_of: only applies to numbers, requires the field to be "a multiple of". The - schema will have a ``multipleOf`` validation keyword - :param min_length: only applies to strings, requires the field to have a minimum length. The - schema will have a ``maximum`` validation keyword - :param max_length: only applies to strings, requires the field to have a maximum length. The - schema will have a ``maxLength`` validation keyword - :param regex: only applies to strings, requires the field match agains a regular expression - pattern string. The schema will have a ``pattern`` validation keyword - :param **extra: any additional keyword arguments will be added as is to the schema - """ - return FieldInfo( - default, - alias=alias, - title=title, - description=description, - const=const, - gt=gt, - ge=ge, - lt=lt, - le=le, - multiple_of=multiple_of, - min_items=min_items, - max_items=max_items, - min_length=min_length, - max_length=max_length, - regex=regex, - **extra, - ) - - -def Schema(default: Any, **kwargs: Any) -> Any: - warnings.warn("`Schema` is deprecated, use `Field` instead", DeprecationWarning) - return Field(default, **kwargs) - - -# used to be an enum but changed to int's for small performance improvement as less access overhead -SHAPE_SINGLETON = 1 -SHAPE_LIST = 2 -SHAPE_SET = 3 -SHAPE_MAPPING = 4 -SHAPE_TUPLE = 5 -SHAPE_TUPLE_ELLIPSIS = 6 -SHAPE_SEQUENCE = 7 -SHAPE_FROZENSET = 8 -SHAPE_NAME_LOOKUP = { - SHAPE_LIST: "List[{}]", - SHAPE_SET: "Set[{}]", - SHAPE_TUPLE_ELLIPSIS: "Tuple[{}, ...]", - SHAPE_SEQUENCE: "Sequence[{}]", - SHAPE_FROZENSET: "FrozenSet[{}]", -} - - -class ModelField(Representation): - __slots__ = ( - "type_", - "outer_type_", - "sub_fields", - "key_field", - "validators", - "pre_validators", - "post_validators", - "default", - "required", - "model_config", - "name", - "alias", - "has_alias", - "field_info", - "validate_always", - "allow_none", - "shape", - "class_validators", - "parse_json", - ) - - def __init__( - self, - *, - name: str, - type_: AnyType, - class_validators: Optional[Dict[str, Validator]], - model_config: Type["BaseConfig"], - default: Any = None, - required: "BoolUndefined" = Undefined, - alias: str = None, - field_info: Optional[FieldInfo] = None, - ) -> None: - - self.name: str = name - self.has_alias: bool = bool(alias) - self.alias: str = alias or name - self.type_: Any = type_ - self.outer_type_: Any = type_ - self.class_validators = class_validators or {} - self.default: Any = default - self.required: "BoolUndefined" = required - self.model_config = model_config - self.field_info: FieldInfo = field_info or FieldInfo(default) - - self.allow_none: bool = False - self.validate_always: bool = False - self.sub_fields: Optional[List[ModelField]] = None - self.key_field: Optional[ModelField] = None - self.validators: "ValidatorsList" = [] - self.pre_validators: Optional["ValidatorsList"] = None - self.post_validators: Optional["ValidatorsList"] = None - self.parse_json: bool = False - self.shape: int = SHAPE_SINGLETON - self.model_config.prepare_field(self) - self.prepare() - - @classmethod - def infer( - cls, - *, - name: str, - value: Any, - annotation: Any, - class_validators: Optional[Dict[str, Validator]], - config: Type["BaseConfig"], - ) -> "ModelField": - field_info_from_config = config.get_field_info(name) - from .schema import get_annotation_from_field_info - - if isinstance(value, FieldInfo): - field_info = value - value = field_info.default - else: - field_info = FieldInfo(value, **field_info_from_config) - required: "BoolUndefined" = Undefined - if value is Required: - required = True - value = None - elif value is not Undefined: - required = False - field_info.alias = field_info.alias or field_info_from_config.get("alias") - annotation = get_annotation_from_field_info(annotation, field_info, name) - return cls( - name=name, - type_=annotation, - alias=field_info.alias, - class_validators=class_validators, - default=value, - required=required, - model_config=config, - field_info=field_info, - ) - - def set_config(self, config: Type["BaseConfig"]) -> None: - self.model_config = config - info_from_config = config.get_field_info(self.name) - config.prepare_field(self) - if info_from_config: - self.field_info.alias = ( - info_from_config.get("alias") or self.field_info.alias or self.name - ) - self.alias = cast(str, self.field_info.alias) - - @property - def alt_alias(self) -> bool: - return self.name != self.alias - - def prepare(self) -> None: - """ - Prepare the field but inspecting self.default, self.type_ etc. - - Note: this method is **not** idempotent (because _type_analysis is not idempotent), - e.g. calling it it multiple times may modify the field and configure it incorrectly. - """ - if self.default is not None and self.type_ is None: - self.type_ = type(self.default) - self.outer_type_ = self.type_ - - if self.type_ is None: - raise errors_.ConfigError( - f'unable to infer type for attribute "{self.name}"' - ) - - if type(self.type_) == ForwardRef: - # self.type_ is currently a ForwardRef and there's nothing we can do now, - # user will need to call model.update_forward_refs() - return - - self.validate_always = getattr(self.type_, "validate_always", False) or any( - v.always for v in self.class_validators.values() - ) - - if self.required is False and self.default is None: - self.allow_none = True - - self._type_analysis() - if self.required is Undefined: - self.required = True - self.field_info.default = Required - if self.default is Undefined: - self.default = None - self.populate_validators() - - def _type_analysis(self) -> None: # noqa: C901 (ignore complexity) - # typing interface is horrible, we have to do some ugly checks - if lenient_issubclass(self.type_, JsonWrapper): - self.type_ = self.type_.inner_type - self.parse_json = True - elif lenient_issubclass(self.type_, Json): - self.type_ = Any - self.parse_json = True - elif isinstance(self.type_, TypeVar): # type: ignore - if self.type_.__bound__: - self.type_ = self.type_.__bound__ - elif self.type_.__constraints__: - self.type_ = Union[self.type_.__constraints__] - else: - self.type_ = Any - - if self.type_ is Any: - if self.required is Undefined: - self.required = False - self.allow_none = True - return - elif self.type_ is Pattern: - # python 3.7 only, Pattern is a typing object but without sub fields - return - elif is_literal_type(self.type_): - return - - origin = getattr(self.type_, "__origin__", None) - if origin is None: - # field is not "typing" object eg. Union, Dict, List etc. - return - if origin is Callable: - return - if origin is Union: - types_ = [] - for type_ in self.type_.__args__: - if type_ is NoneType: # type: ignore - if self.required is Undefined: - self.required = False - self.allow_none = True - continue - types_.append(type_) - - if len(types_) == 1: - # Optional[] - self.type_ = types_[0] - # this is the one case where the "outer type" isn't just the original type - self.outer_type_ = self.type_ - # re-run to correctly interpret the new self.type_ - self._type_analysis() - else: - self.sub_fields = [ - self._create_sub_type(t, f"{self.name}_{display_as_type(t)}") - for t in types_ - ] - return - - if issubclass(origin, Tuple): # type: ignore - self.shape = SHAPE_TUPLE - self.sub_fields = [] - for i, t in enumerate(self.type_.__args__): - if t is Ellipsis: - self.type_ = self.type_.__args__[0] - self.shape = SHAPE_TUPLE_ELLIPSIS - return - self.sub_fields.append(self._create_sub_type(t, f"{self.name}_{i}")) - return - - if issubclass(origin, List): - # Create self validators - get_validators = getattr(self.type_, "__get_validators__", None) - if get_validators: - self.class_validators.update( - { - f"list_{i}": Validator(validator, pre=True, always=True) - for i, validator in enumerate(get_validators()) - } - ) - - self.type_ = self.type_.__args__[0] - self.shape = SHAPE_LIST - elif issubclass(origin, Set): - self.type_ = self.type_.__args__[0] - self.shape = SHAPE_SET - elif issubclass(origin, FrozenSet): - self.type_ = self.type_.__args__[0] - self.shape = SHAPE_FROZENSET - elif issubclass(origin, Sequence): - self.type_ = self.type_.__args__[0] - self.shape = SHAPE_SEQUENCE - elif issubclass(origin, Mapping): - self.key_field = self._create_sub_type( - self.type_.__args__[0], "key_" + self.name, for_keys=True - ) - self.type_ = self.type_.__args__[1] - self.shape = SHAPE_MAPPING - elif issubclass(origin, Type): # type: ignore - return - else: - raise TypeError(f'Fields of type "{origin}" are not supported.') - - # type_ has been refined eg. as the type of a List and sub_fields needs to be populated - self.sub_fields = [self._create_sub_type(self.type_, "_" + self.name)] - - def _create_sub_type( - self, type_: AnyType, name: str, *, for_keys: bool = False - ) -> "ModelField": - return self.__class__( - type_=type_, - name=name, - class_validators=None - if for_keys - else {k: v for k, v in self.class_validators.items() if v.each_item}, - model_config=self.model_config, - ) - - def populate_validators(self) -> None: - """ - Prepare self.pre_validators, self.validators, and self.post_validators based on self.type_'s __get_validators__ - and class validators. This method should be idempotent, e.g. it should be safe to call multiple times - without mis-configuring the field. - """ - class_validators_ = self.class_validators.values() - if not self.sub_fields: - get_validators = getattr(self.type_, "__get_validators__", None) - v_funcs = ( - *[v.func for v in class_validators_ if v.each_item and v.pre], - *( - get_validators() - if get_validators - else list(find_validators(self.type_, self.model_config)) - ), - *[v.func for v in class_validators_ if v.each_item and not v.pre], - ) - self.validators = prep_validators(v_funcs) - - # Add const validator - self.pre_validators = [] - self.post_validators = [] - if self.field_info and self.field_info.const: - self.pre_validators = [make_generic_validator(constant_validator)] - - if class_validators_: - self.pre_validators += prep_validators( - v.func for v in class_validators_ if not v.each_item and v.pre - ) - self.post_validators = prep_validators( - v.func for v in class_validators_ if not v.each_item and not v.pre - ) - - if self.parse_json: - self.pre_validators.append(make_generic_validator(validate_json)) - - self.pre_validators = self.pre_validators or None - self.post_validators = self.post_validators or None - - def validate( - self, - v: Any, - values: Dict[str, Any], - *, - loc: "LocStr", - cls: Optional["ModelOrDc"] = None, - ) -> "ValidateReturn": - - errors: Optional["ErrorList"] - if self.pre_validators: - v, errors = self._apply_validators(v, values, loc, cls, self.pre_validators) - if errors: - return v, errors - - if v is None: - if self.allow_none: - if self.post_validators: - return self._apply_validators( - v, values, loc, cls, self.post_validators - ) - else: - return None, None - else: - return v, ErrorWrapper(NoneIsNotAllowedError(), loc) - - if self.shape == SHAPE_SINGLETON: - v, errors = self._validate_singleton(v, values, loc, cls) - elif self.shape == SHAPE_MAPPING: - v, errors = self._validate_mapping(v, values, loc, cls) - elif self.shape == SHAPE_TUPLE: - v, errors = self._validate_tuple(v, values, loc, cls) - else: - # sequence, list, set, generator, tuple with ellipsis, frozen set - v, errors = self._validate_sequence_like(v, values, loc, cls) - - if not errors and self.post_validators: - v, errors = self._apply_validators( - v, values, loc, cls, self.post_validators - ) - return v, errors - - def _validate_sequence_like( # noqa: C901 (ignore complexity) - self, v: Any, values: Dict[str, Any], loc: "LocStr", cls: Optional["ModelOrDc"] - ) -> "ValidateReturn": - """ - Validate sequence-like containers: lists, tuples, sets and generators - Note that large if-else blocks are necessary to enable Cython - optimization, which is why we disable the complexity check above. - """ - if not sequence_like(v): - e: errors_.PydanticTypeError - if self.shape == SHAPE_LIST: - e = errors_.ListError() - elif self.shape == SHAPE_SET: - e = errors_.SetError() - elif self.shape == SHAPE_FROZENSET: - e = errors_.FrozenSetError() - else: - e = errors_.SequenceError() - return v, ErrorWrapper(e, loc) - - loc = loc if isinstance(loc, tuple) else (loc,) - result = [] - errors: List[ErrorList] = [] - for i, v_ in enumerate(v): - v_loc = *loc, i - r, ee = self._validate_singleton(v_, values, v_loc, cls) - if ee: - errors.append(ee) - else: - result.append(r) - - if errors: - return v, errors - - converted: Union[ - List[Any], Set[Any], FrozenSet[Any], Tuple[Any, ...], Iterator[Any] - ] = result - - if self.shape == SHAPE_SET: - converted = set(result) - elif self.shape == SHAPE_FROZENSET: - converted = frozenset(result) - elif self.shape == SHAPE_TUPLE_ELLIPSIS: - converted = tuple(result) - elif self.shape == SHAPE_SEQUENCE: - if isinstance(v, tuple): - converted = tuple(result) - elif isinstance(v, set): - converted = set(result) - elif isinstance(v, Generator): - converted = iter(result) - return converted, None - - def _validate_tuple( - self, v: Any, values: Dict[str, Any], loc: "LocStr", cls: Optional["ModelOrDc"] - ) -> "ValidateReturn": - e: Optional[Exception] = None - if not sequence_like(v): - e = errors_.TupleError() - else: - actual_length, expected_length = len(v), len(self.sub_fields) # type: ignore - if actual_length != expected_length: - e = errors_.TupleLengthError( - actual_length=actual_length, expected_length=expected_length - ) - - if e: - return v, ErrorWrapper(e, loc) - - loc = loc if isinstance(loc, tuple) else (loc,) - result = [] - errors: List[ErrorList] = [] - for i, (v_, field) in enumerate(zip(v, self.sub_fields)): # type: ignore - v_loc = *loc, i - r, ee = field.validate(v_, values, loc=v_loc, cls=cls) - if ee: - errors.append(ee) - else: - result.append(r) - - if errors: - return v, errors - else: - return tuple(result), None - - def _validate_mapping( - self, v: Any, values: Dict[str, Any], loc: "LocStr", cls: Optional["ModelOrDc"] - ) -> "ValidateReturn": - try: - v_iter = dict_validator(v) - except TypeError as exc: - return v, ErrorWrapper(exc, loc) - - loc = loc if isinstance(loc, tuple) else (loc,) - result, errors = {}, [] - for k, v_ in v_iter.items(): - v_loc = *loc, "__key__" - key_result, key_errors = self.key_field.validate(k, values, loc=v_loc, cls=cls) # type: ignore - if key_errors: - errors.append(key_errors) - continue - - v_loc = *loc, k - value_result, value_errors = self._validate_singleton( - v_, values, v_loc, cls - ) - if value_errors: - errors.append(value_errors) - continue - - result[key_result] = value_result - if errors: - return v, errors - else: - return result, None - - def _validate_singleton( - self, v: Any, values: Dict[str, Any], loc: "LocStr", cls: Optional["ModelOrDc"] - ) -> "ValidateReturn": - if self.sub_fields: - errors = [] - for field in self.sub_fields: - value, error = field.validate(v, values, loc=loc, cls=cls) - if error: - errors.append(error) - else: - return value, None - return v, errors - else: - return self._apply_validators(v, values, loc, cls, self.validators) - - def _apply_validators( - self, - v: Any, - values: Dict[str, Any], - loc: "LocStr", - cls: Optional["ModelOrDc"], - validators: "ValidatorsList", - ) -> "ValidateReturn": - for validator in validators: - try: - v = validator(cls, v, values, self, self.model_config) - except (ValueError, TypeError, AssertionError) as exc: - return v, ErrorWrapper(exc, loc) - return v, None - - def include_in_schema(self) -> bool: - """ - False if this is a simple field just allowing None as used in Unions/Optional. - """ - return self.type_ != NoneType # type: ignore - - def is_complex(self) -> bool: - """ - Whether the field is "complex" eg. env variables should be parsed as JSON. - """ - from .main import BaseModel # noqa: F811 - - return ( - self.shape != SHAPE_SINGLETON - or lenient_issubclass(self.type_, (BaseModel, list, set, frozenset, dict)) - or hasattr(self.type_, "__pydantic_model__") # pydantic dataclass - ) - - def _type_display(self) -> PyObjectStr: - t = display_as_type(self.type_) - - # have to do this since display_as_type(self.outer_type_) is different (and wrong) on python 3.6 - if self.shape == SHAPE_MAPPING: - t = f"Mapping[{display_as_type(self.key_field.type_)}, {t}]" # type: ignore - elif self.shape == SHAPE_TUPLE: - t = "Tuple[{}]".format(", ".join(display_as_type(f.type_) for f in self.sub_fields)) # type: ignore - elif self.shape != SHAPE_SINGLETON: - t = SHAPE_NAME_LOOKUP[self.shape].format(t) - - if self.allow_none and (self.shape != SHAPE_SINGLETON or not self.sub_fields): - t = f"Optional[{t}]" - return PyObjectStr(t) - - def __repr_args__(self) -> "ReprArgs": - args = [ - ("name", self.name), - ("type", self._type_display()), - ("required", self.required), - ] - - if not self.required: - args.append(("default", self.default)) - - if self.alt_alias: - args.append(("alias", self.alias)) - return args diff --git a/nornir/_vendor/pydantic/generics.py b/nornir/_vendor/pydantic/generics.py deleted file mode 100644 index 6754bacb..00000000 --- a/nornir/_vendor/pydantic/generics.py +++ /dev/null @@ -1,120 +0,0 @@ -from typing import ( - Any, - ClassVar, - Dict, - Generic, - Tuple, - Type, - TypeVar, - Union, - get_type_hints, -) - -from .class_validators import gather_all_validators -from .main import BaseModel, create_model - -_generic_types_cache: Dict[ - Tuple[Type[Any], Union[Any, Tuple[Any, ...]]], Type[BaseModel] -] = {} -GenericModelT = TypeVar("GenericModelT", bound="GenericModel") - - -class GenericModel(BaseModel): - __slots__ = () - __concrete__: ClassVar[bool] = False - - def __new__(cls, *args: Any, **kwargs: Any) -> Any: - if cls.__concrete__: - return super().__new__(cls) - else: - raise TypeError( - f"Type {cls.__name__} cannot be used without generic parameters, e.g. {cls.__name__}[T]" - ) - - # Setting the return type as Type[Any] instead of Type[BaseModel] prevents PyCharm warnings - def __class_getitem__( - cls: Type[GenericModelT], params: Union[Type[Any], Tuple[Type[Any], ...]] - ) -> Type[Any]: - cached = _generic_types_cache.get((cls, params)) - if cached is not None: - return cached - if cls.__concrete__: - raise TypeError( - "Cannot parameterize a concrete instantiation of a generic model" - ) - if not isinstance(params, tuple): - params = (params,) - if cls is GenericModel and any(isinstance(param, TypeVar) for param in params): # type: ignore - raise TypeError( - f"Type parameters should be placed on typing.Generic, not GenericModel" - ) - if Generic not in cls.__bases__: - raise TypeError( - f"Type {cls.__name__} must inherit from typing.Generic before being parameterized" - ) - - check_parameters_count(cls, params) - typevars_map: Dict[Any, Any] = dict(zip(cls.__parameters__, params)) # type: ignore - type_hints = get_type_hints(cls).items() - instance_type_hints = { - k: v - for k, v in type_hints - if getattr(v, "__origin__", None) is not ClassVar - } - concrete_type_hints: Dict[str, Type[Any]] = { - k: resolve_type_hint(v, typevars_map) - for k, v in instance_type_hints.items() - } - - model_name = cls.__concrete_name__(params) - validators = gather_all_validators(cls) - fields: Dict[str, Tuple[Type[Any], Any]] = { - k: (v, cls.__fields__[k].field_info) - for k, v in concrete_type_hints.items() - if k in cls.__fields__ - } - created_model = create_model( - model_name=model_name, - __module__=cls.__module__, - __base__=cls, - __config__=None, - __validators__=validators, - **fields, - ) - created_model.Config = cls.Config - created_model.__concrete__ = True # type: ignore - _generic_types_cache[(cls, params)] = created_model - if len(params) == 1: - _generic_types_cache[(cls, params[0])] = created_model - return created_model - - @classmethod - def __concrete_name__(cls: Type[Any], params: Tuple[Type[Any], ...]) -> str: - """ - This method can be overridden to achieve a custom naming scheme for GenericModels - """ - param_names = [ - param.__name__ if hasattr(param, "__name__") else str(param) - for param in params - ] - params_component = ", ".join(param_names) - return f"{cls.__name__}[{params_component}]" - - -def resolve_type_hint(type_: Any, typevars_map: Dict[Any, Any]) -> Type[Any]: - if hasattr(type_, "__origin__") and getattr(type_, "__parameters__", None): - concrete_type_args = tuple([typevars_map[x] for x in type_.__parameters__]) - return type_[concrete_type_args] - return typevars_map.get(type_, type_) - - -def check_parameters_count( - cls: Type[GenericModel], parameters: Tuple[Any, ...] -) -> None: - actual = len(parameters) - expected = len(cls.__parameters__) # type: ignore - if actual != expected: - description = "many" if actual > expected else "few" - raise TypeError( - f"Too {description} parameters for {cls.__name__}; actual {actual}, expected {expected}" - ) diff --git a/nornir/_vendor/pydantic/json.py b/nornir/_vendor/pydantic/json.py deleted file mode 100644 index c9e217d9..00000000 --- a/nornir/_vendor/pydantic/json.py +++ /dev/null @@ -1,89 +0,0 @@ -import datetime -from dataclasses import asdict, is_dataclass -from decimal import Decimal -from enum import Enum -from ipaddress import ( - IPv4Address, - IPv4Interface, - IPv4Network, - IPv6Address, - IPv6Interface, - IPv6Network, -) -from pathlib import Path -from types import GeneratorType -from typing import Any, Callable, Dict, Type, Union -from uuid import UUID - -from .color import Color -from .types import SecretBytes, SecretStr - -__all__ = "pydantic_encoder", "custom_pydantic_encoder", "timedelta_isoformat" - - -def isoformat(o: Union[datetime.date, datetime.time]) -> str: - return o.isoformat() - - -ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = { - Color: str, - IPv4Address: str, - IPv6Address: str, - IPv4Interface: str, - IPv6Interface: str, - IPv4Network: str, - IPv6Network: str, - SecretStr: str, - SecretBytes: str, - UUID: str, - datetime.datetime: isoformat, - datetime.date: isoformat, - datetime.time: isoformat, - datetime.timedelta: lambda td: td.total_seconds(), - set: list, - frozenset: list, - GeneratorType: list, - bytes: lambda o: o.decode(), - Decimal: float, -} - - -def pydantic_encoder(obj: Any) -> Any: - from .main import BaseModel - - if isinstance(obj, BaseModel): - return obj.dict() - elif isinstance(obj, Enum): - return obj.value - elif isinstance(obj, Path): - return str(obj) - elif is_dataclass(obj): - return asdict(obj) - - try: - encoder = ENCODERS_BY_TYPE[type(obj)] - except KeyError: - raise TypeError( - f"Object of type '{obj.__class__.__name__}' is not JSON serializable" - ) - else: - return encoder(obj) - - -def custom_pydantic_encoder( - type_encoders: Dict[Any, Callable[[Type[Any]], Any]], obj: Any -) -> Any: - encoder = type_encoders.get(type(obj)) - if encoder: - return encoder(obj) - else: - return pydantic_encoder(obj) - - -def timedelta_isoformat(td: datetime.timedelta) -> str: - """ - ISO 8601 encoding for timedeltas. - """ - minutes, seconds = divmod(td.seconds, 60) - hours, minutes = divmod(minutes, 60) - return f"P{td.days}DT{hours:d}H{minutes:d}M{seconds:d}.{td.microseconds:06d}S" diff --git a/nornir/_vendor/pydantic/main.py b/nornir/_vendor/pydantic/main.py deleted file mode 100644 index a1024b07..00000000 --- a/nornir/_vendor/pydantic/main.py +++ /dev/null @@ -1,972 +0,0 @@ -import json -import sys -import warnings -from abc import ABCMeta -from copy import deepcopy -from enum import Enum -from functools import partial -from pathlib import Path -from types import FunctionType -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - List, - Optional, - Tuple, - Type, - TypeVar, - Union, - cast, - no_type_check, -) - -from .class_validators import ( - ROOT_KEY, - ValidatorGroup, - extract_root_validators, - extract_validators, - inherit_validators, -) -from .error_wrappers import ErrorWrapper, ValidationError -from .errors import ConfigError, DictError, ExtraError, MissingError -from .fields import SHAPE_MAPPING, ModelField, Undefined -from .json import custom_pydantic_encoder, pydantic_encoder -from .parse import Protocol, load_file, load_str_bytes -from .schema import model_schema -from .types import PyObject, StrBytes -from .typing import ( - AnyCallable, - AnyType, - ForwardRef, - is_classvar, - resolve_annotations, - update_field_forward_refs, -) -from .utils import ( - GetterDict, - Representation, - ValueItems, - lenient_issubclass, - sequence_like, - validate_field_name, -) - -if TYPE_CHECKING: - from .class_validators import ValidatorListDict - from .types import ModelOrDc - from .typing import CallableGenerator, TupleGenerator, DictStrAny, DictAny, SetStr - from .typing import AbstractSetIntStr, DictIntStrAny, ReprArgs # noqa: F401 - - ConfigType = Type["BaseConfig"] - Model = TypeVar("Model", bound="BaseModel") - -try: - import cython # type: ignore -except ImportError: - compiled: bool = False -else: # pragma: no cover - try: - compiled = cython.compiled - except AttributeError: - compiled = False - -__all__ = ( - "BaseConfig", - "BaseModel", - "Extra", - "compiled", - "create_model", - "validate_model", -) - - -class Extra(str, Enum): - allow = "allow" - ignore = "ignore" - forbid = "forbid" - - -class BaseConfig: - title = None - anystr_strip_whitespace = False - min_anystr_length = None - max_anystr_length = None - validate_all = False - extra = Extra.ignore - allow_mutation = True - allow_population_by_field_name = False - use_enum_values = False - fields: Dict[str, Union[str, Dict[str, str]]] = {} - validate_assignment = False - error_msg_templates: Dict[str, str] = {} - arbitrary_types_allowed = False - orm_mode: bool = False - getter_dict: Type[GetterDict] = GetterDict - alias_generator: Optional[Callable[[str], str]] = None - keep_untouched: Tuple[type, ...] = () - schema_extra: Union[Dict[str, Any], Callable[[Dict[str, Any]], None]] = {} - json_loads: Callable[[str], Any] = json.loads - json_dumps: Callable[..., str] = json.dumps - json_encoders: Dict[AnyType, AnyCallable] = {} - - @classmethod - def get_field_info(cls, name: str) -> Dict[str, Any]: - field_info = cls.fields.get(name) or {} - if isinstance(field_info, str): - field_info = {"alias": field_info} - elif cls.alias_generator and "alias" not in field_info: - alias = cls.alias_generator(name) - if not isinstance(alias, str): - raise TypeError( - f"Config.alias_generator must return str, not {type(alias)}" - ) - field_info["alias"] = alias - return field_info - - @classmethod - def prepare_field(cls, field: "ModelField") -> None: - """ - Optional hook to check or modify fields during model creation. - """ - pass - - -def inherit_config( - self_config: "ConfigType", parent_config: "ConfigType" -) -> "ConfigType": - if not self_config: - base_classes = (parent_config,) - elif self_config == parent_config: - base_classes = (self_config,) - else: - base_classes = self_config, parent_config # type: ignore - return type("Config", base_classes, {}) - - -EXTRA_LINK = "https://pydantic-docs.helpmanual.io/usage/model_config/" - - -def prepare_config(config: Type[BaseConfig], cls_name: str) -> None: - if not isinstance(config.extra, Extra): - try: - config.extra = Extra(config.extra) - except ValueError: - raise ValueError( - f'"{cls_name}": {config.extra} is not a valid value for "extra"' - ) - - if hasattr(config, "allow_population_by_alias"): - warnings.warn( - f'{cls_name}: "allow_population_by_alias" is deprecated and replaced by "allow_population_by_field_name"', - DeprecationWarning, - ) - config.allow_population_by_field_name = config.allow_population_by_alias # type: ignore - - if hasattr(config, "case_insensitive") and any( - "BaseSettings.Config" in c.__qualname__ for c in config.__mro__ - ): - warnings.warn( - f'{cls_name}: "case_insensitive" is deprecated on BaseSettings config and replaced by ' - f'"case_sensitive" (default False)', - DeprecationWarning, - ) - config.case_sensitive = not config.case_insensitive # type: ignore - - -def is_valid_field(name: str) -> bool: - if not name.startswith("_"): - return True - return ROOT_KEY == name - - -def validate_custom_root_type(fields: Dict[str, ModelField]) -> None: - if len(fields) > 1: - raise ValueError("__root__ cannot be mixed with other fields") - - -UNTOUCHED_TYPES = FunctionType, property, type, classmethod, staticmethod - - -class ModelMetaclass(ABCMeta): - @no_type_check # noqa C901 - def __new__(mcs, name, bases, namespace, **kwargs): # noqa C901 - fields: Dict[str, ModelField] = {} - config = BaseConfig - validators: "ValidatorListDict" = {} - pre_root_validators, post_root_validators = [], [] - for base in reversed(bases): - if issubclass(base, BaseModel) and base != BaseModel: - fields.update(deepcopy(base.__fields__)) - config = inherit_config(base.__config__, config) - validators = inherit_validators(base.__validators__, validators) - pre_root_validators += base.__pre_root_validators__ - post_root_validators += base.__post_root_validators__ - - config = inherit_config(namespace.get("Config"), config) - validators = inherit_validators(extract_validators(namespace), validators) - vg = ValidatorGroup(validators) - - for f in fields.values(): - f.set_config(config) - extra_validators = vg.get_validators(f.name) - if extra_validators: - f.class_validators.update(extra_validators) - # re-run prepare to add extra validators - f.populate_validators() - - prepare_config(config, name) - - class_vars = set() - if (namespace.get("__module__"), namespace.get("__qualname__")) != ( - "pydantic.main", - "BaseModel", - ): - annotations = resolve_annotations( - namespace.get("__annotations__", {}), namespace.get("__module__", None) - ) - untouched_types = UNTOUCHED_TYPES + config.keep_untouched - # annotation only fields need to come first in fields - for ann_name, ann_type in annotations.items(): - if is_classvar(ann_type): - class_vars.add(ann_name) - elif is_valid_field(ann_name): - validate_field_name(bases, ann_name) - value = namespace.get(ann_name, Undefined) - if ( - isinstance(value, untouched_types) - and ann_type != PyObject - and not lenient_issubclass( - getattr(ann_type, "__origin__", None), Type - ) - ): - continue - fields[ann_name] = ModelField.infer( - name=ann_name, - value=value, - annotation=ann_type, - class_validators=vg.get_validators(ann_name), - config=config, - ) - - for var_name, value in namespace.items(): - if ( - var_name not in annotations - and is_valid_field(var_name) - and not isinstance(value, untouched_types) - and var_name not in class_vars - ): - validate_field_name(bases, var_name) - inferred = ModelField.infer( - name=var_name, - value=value, - annotation=annotations.get(var_name), - class_validators=vg.get_validators(var_name), - config=config, - ) - if var_name in fields and inferred.type_ != fields[var_name].type_: - raise TypeError( - f"The type of {name}.{var_name} differs from the new default value; " - f"if you wish to change the type of this field, please use a type annotation" - ) - fields[var_name] = inferred - - _custom_root_type = ROOT_KEY in fields - if _custom_root_type: - validate_custom_root_type(fields) - vg.check_for_unused() - if config.json_encoders: - json_encoder = partial(custom_pydantic_encoder, config.json_encoders) - else: - json_encoder = pydantic_encoder - pre_rv_new, post_rv_new = extract_root_validators(namespace) - new_namespace = { - "__config__": config, - "__fields__": fields, - "__field_defaults__": { - n: f.default for n, f in fields.items() if not f.required - }, - "__validators__": vg.validators, - "__pre_root_validators__": pre_root_validators + pre_rv_new, - "__post_root_validators__": post_root_validators + post_rv_new, - "__schema_cache__": {}, - "__json_encoder__": staticmethod(json_encoder), - "__custom_root_type__": _custom_root_type, - **{n: v for n, v in namespace.items() if n not in fields}, - } - return super().__new__(mcs, name, bases, new_namespace, **kwargs) - - -class BaseModel(metaclass=ModelMetaclass): - if TYPE_CHECKING: - # populated by the metaclass, defined here to help IDEs only - __fields__: Dict[str, ModelField] = {} - __field_defaults__: Dict[str, Any] = {} - __validators__: Dict[str, AnyCallable] = {} - __pre_root_validators__: List[AnyCallable] - __post_root_validators__: List[Tuple[bool, AnyCallable]] - __config__: Type[BaseConfig] = BaseConfig - __root__: Any = None - __json_encoder__: Callable[[Any], Any] = lambda x: x - __schema_cache__: "DictAny" = {} - __custom_root_type__: bool = False - - Config = BaseConfig - __slots__ = ("__dict__", "__fields_set__") - # equivalent of inheriting from Representation - __repr_name__ = Representation.__repr_name__ - __repr_str__ = Representation.__repr_str__ - __pretty__ = Representation.__pretty__ - __str__ = Representation.__str__ - __repr__ = Representation.__repr__ - - def __init__(__pydantic_self__, **data: Any) -> None: - # Uses something other than `self` the first arg to allow "self" as a settable attribute - if TYPE_CHECKING: - __pydantic_self__.__dict__: Dict[str, Any] = {} - __pydantic_self__.__fields_set__: "SetStr" = set() - values, fields_set, validation_error = validate_model( - __pydantic_self__.__class__, data - ) - if validation_error: - raise validation_error - object.__setattr__(__pydantic_self__, "__dict__", values) - object.__setattr__(__pydantic_self__, "__fields_set__", fields_set) - - @no_type_check - def __setattr__(self, name, value): - if self.__config__.extra is not Extra.allow and name not in self.__fields__: - raise ValueError( - f'"{self.__class__.__name__}" object has no field "{name}"' - ) - elif not self.__config__.allow_mutation: - raise TypeError( - f'"{self.__class__.__name__}" is immutable and does not support item assignment' - ) - elif self.__config__.validate_assignment: - known_field = self.__fields__.get(name, None) - if known_field: - value, error_ = known_field.validate( - value, self.dict(exclude={name}), loc=name - ) - if error_: - raise ValidationError([error_], type(self)) - self.__dict__[name] = value - self.__fields_set__.add(name) - - def __getstate__(self) -> "DictAny": - return {"__dict__": self.__dict__, "__fields_set__": self.__fields_set__} - - def __setstate__(self, state: "DictAny") -> None: - object.__setattr__(self, "__dict__", state["__dict__"]) - object.__setattr__(self, "__fields_set__", state["__fields_set__"]) - - def dict( - self, - *, - include: Union["AbstractSetIntStr", "DictIntStrAny"] = None, - exclude: Union["AbstractSetIntStr", "DictIntStrAny"] = None, - by_alias: bool = False, - skip_defaults: bool = None, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - ) -> "DictStrAny": - """ - Generate a dictionary representation of the model, optionally specifying which fields to include or exclude. - """ - if skip_defaults is not None: - warnings.warn( - f'{self.__class__.__name__}.dict(): "skip_defaults" is deprecated and replaced by "exclude_unset"', - DeprecationWarning, - ) - exclude_unset = skip_defaults - get_key = self._get_key_factory(by_alias) - get_key = partial(get_key, self.__fields__) - - allowed_keys = self._calculate_keys( - include=include, exclude=exclude, exclude_unset=exclude_unset - ) - return { - get_key(k): v - for k, v in self._iter( - to_dict=True, - by_alias=by_alias, - allowed_keys=allowed_keys, - include=include, - exclude=exclude, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - ) - } - - def _get_key_factory(self, by_alias: bool) -> Callable[..., str]: - if by_alias: - return lambda fields, key: fields[key].alias if key in fields else key - - return lambda _, key: key - - def json( - self, - *, - include: Union["AbstractSetIntStr", "DictIntStrAny"] = None, - exclude: Union["AbstractSetIntStr", "DictIntStrAny"] = None, - by_alias: bool = False, - skip_defaults: bool = None, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - encoder: Optional[Callable[[Any], Any]] = None, - **dumps_kwargs: Any, - ) -> str: - """ - Generate a JSON representation of the model, `include` and `exclude` arguments as per `dict()`. - - `encoder` is an optional function to supply as `default` to json.dumps(), other arguments as per `json.dumps()`. - """ - if skip_defaults is not None: - warnings.warn( - f'{self.__class__.__name__}.json(): "skip_defaults" is deprecated and replaced by "exclude_unset"', - DeprecationWarning, - ) - exclude_unset = skip_defaults - encoder = cast(Callable[[Any], Any], encoder or self.__json_encoder__) - data = self.dict( - include=include, - exclude=exclude, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - ) - if self.__custom_root_type__: - data = data[ROOT_KEY] - return self.__config__.json_dumps(data, default=encoder, **dumps_kwargs) - - @classmethod - def parse_obj(cls: Type["Model"], obj: Any) -> "Model": - if cls.__custom_root_type__ and ( - not (isinstance(obj, dict) and obj.keys() == {ROOT_KEY}) - or cls.__fields__[ROOT_KEY].shape == SHAPE_MAPPING - ): - obj = {ROOT_KEY: obj} - elif not isinstance(obj, dict): - try: - obj = dict(obj) - except (TypeError, ValueError) as e: - exc = TypeError( - f"{cls.__name__} expected dict not {type(obj).__name__}" - ) - raise ValidationError([ErrorWrapper(exc, loc=ROOT_KEY)], cls) from e - return cls(**obj) - - @classmethod - def parse_raw( - cls: Type["Model"], - b: StrBytes, - *, - content_type: str = None, - encoding: str = "utf8", - proto: Protocol = None, - allow_pickle: bool = False, - ) -> "Model": - try: - obj = load_str_bytes( - b, - proto=proto, - content_type=content_type, - encoding=encoding, - allow_pickle=allow_pickle, - json_loads=cls.__config__.json_loads, - ) - except (ValueError, TypeError, UnicodeDecodeError) as e: - raise ValidationError([ErrorWrapper(e, loc=ROOT_KEY)], cls) - return cls.parse_obj(obj) - - @classmethod - def parse_file( - cls: Type["Model"], - path: Union[str, Path], - *, - content_type: str = None, - encoding: str = "utf8", - proto: Protocol = None, - allow_pickle: bool = False, - ) -> "Model": - obj = load_file( - path, - proto=proto, - content_type=content_type, - encoding=encoding, - allow_pickle=allow_pickle, - json_loads=cls.__config__.json_loads, - ) - return cls.parse_obj(obj) - - @classmethod - def from_orm(cls: Type["Model"], obj: Any) -> "Model": - if not cls.__config__.orm_mode: - raise ConfigError( - "You must have the config attribute orm_mode=True to use from_orm" - ) - obj = cls._decompose_class(obj) - m = cls.__new__(cls) - values, fields_set, validation_error = validate_model(cls, obj) - if validation_error: - raise validation_error - object.__setattr__(m, "__dict__", values) - object.__setattr__(m, "__fields_set__", fields_set) - return m - - @classmethod - def construct( - cls: Type["Model"], _fields_set: Optional["SetStr"] = None, **values: Any - ) -> "Model": - """ - Creates a new model setting __dict__ and __fields_set__ from trusted or pre-validated data. - Default values are respected, but no other validation is performed. - """ - m = cls.__new__(cls) - object.__setattr__( - m, "__dict__", {**deepcopy(cls.__field_defaults__), **values} - ) - if _fields_set is None: - _fields_set = set(values.keys()) - object.__setattr__(m, "__fields_set__", _fields_set) - return m - - def copy( - self: "Model", - *, - include: Union["AbstractSetIntStr", "DictIntStrAny"] = None, - exclude: Union["AbstractSetIntStr", "DictIntStrAny"] = None, - update: "DictStrAny" = None, - deep: bool = False, - ) -> "Model": - """ - Duplicate a model, optionally choose which fields to include, exclude and change. - - :param include: fields to include in new model - :param exclude: fields to exclude from new model, as with values this takes precedence over include - :param update: values to change/add in the new model. Note: the data is not validated before creating - the new model: you should trust this data - :param deep: set to `True` to make a deep copy of the model - :return: new model instance - """ - if include is None and exclude is None and update is None: - # skip constructing values if no arguments are passed - v = self.__dict__ - else: - allowed_keys = self._calculate_keys( - include=include, exclude=exclude, exclude_unset=False, update=update - ) - if allowed_keys is None: - v = {**self.__dict__, **(update or {})} - else: - v = { - **dict( - self._iter( - to_dict=False, - by_alias=False, - include=include, - exclude=exclude, - exclude_unset=False, - allowed_keys=allowed_keys, - ) - ), - **(update or {}), - } - - if deep: - v = deepcopy(v) - - cls = self.__class__ - m = cls.__new__(cls) - object.__setattr__(m, "__dict__", v) - object.__setattr__(m, "__fields_set__", self.__fields_set__.copy()) - return m - - @classmethod - def schema(cls, by_alias: bool = True) -> "DictStrAny": - cached = cls.__schema_cache__.get(by_alias) - if cached is not None: - return cached - s = model_schema(cls, by_alias=by_alias) - cls.__schema_cache__[by_alias] = s - return s - - @classmethod - def schema_json(cls, *, by_alias: bool = True, **dumps_kwargs: Any) -> str: - from .json import pydantic_encoder - - return cls.__config__.json_dumps( - cls.schema(by_alias=by_alias), default=pydantic_encoder, **dumps_kwargs - ) - - @classmethod - def __get_validators__(cls) -> "CallableGenerator": - yield cls.validate - - @classmethod - def validate(cls: Type["Model"], value: Any) -> "Model": - if isinstance(value, dict): - return cls(**value) - elif isinstance(value, cls): - return value.copy() - elif cls.__config__.orm_mode: - return cls.from_orm(value) - else: - try: - value_as_dict = dict(value) - except (TypeError, ValueError) as e: - raise DictError() from e - return cls(**value_as_dict) - - @classmethod - def _decompose_class(cls: Type["Model"], obj: Any) -> GetterDict: - return cls.__config__.getter_dict(obj) - - @classmethod - @no_type_check - def _get_value( - cls, - v: Any, - to_dict: bool, - by_alias: bool, - include: Optional[Union["AbstractSetIntStr", "DictIntStrAny"]], - exclude: Optional[Union["AbstractSetIntStr", "DictIntStrAny"]], - exclude_unset: bool, - exclude_defaults: bool, - exclude_none: bool, - ) -> Any: - - if isinstance(v, BaseModel): - if to_dict: - return v.dict( - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - include=include, - exclude=exclude, - exclude_none=exclude_none, - ) - else: - return v.copy(include=include, exclude=exclude) - - value_exclude = ValueItems(v, exclude) if exclude else None - value_include = ValueItems(v, include) if include else None - - if isinstance(v, dict): - return { - k_: cls._get_value( - v_, - to_dict=to_dict, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - include=value_include and value_include.for_element(k_), - exclude=value_exclude and value_exclude.for_element(k_), - exclude_none=exclude_none, - ) - for k_, v_ in v.items() - if (not value_exclude or not value_exclude.is_excluded(k_)) - and (not value_include or value_include.is_included(k_)) - } - - elif sequence_like(v): - return type(v)( - cls._get_value( - v_, - to_dict=to_dict, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - include=value_include and value_include.for_element(i), - exclude=value_exclude and value_exclude.for_element(i), - exclude_none=exclude_none, - ) - for i, v_ in enumerate(v) - if (not value_exclude or not value_exclude.is_excluded(i)) - and (not value_include or value_include.is_included(i)) - ) - - else: - return v - - @classmethod - def update_forward_refs(cls, **localns: Any) -> None: - """ - Try to update ForwardRefs on fields based on this Model, globalns and localns. - """ - globalns = sys.modules[cls.__module__].__dict__ - globalns.setdefault(cls.__name__, cls) - for f in cls.__fields__.values(): - update_field_forward_refs(f, globalns=globalns, localns=localns) - - def __iter__(self) -> "TupleGenerator": - """ - so `dict(model)` works - """ - yield from self._iter() - - def _iter( - self, - to_dict: bool = False, - by_alias: bool = False, - allowed_keys: Optional["SetStr"] = None, - include: Union["AbstractSetIntStr", "DictIntStrAny"] = None, - exclude: Union["AbstractSetIntStr", "DictIntStrAny"] = None, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - ) -> "TupleGenerator": - - value_exclude = ValueItems(self, exclude) if exclude else None - value_include = ValueItems(self, include) if include else None - - if exclude_defaults: - if allowed_keys is None: - allowed_keys = set(self.__fields__) - for k, v in self.__field_defaults__.items(): - if self.__dict__[k] == v: - allowed_keys.discard(k) - - for k, v in self.__dict__.items(): - if allowed_keys is None or k in allowed_keys: - value = self._get_value( - v, - to_dict=to_dict, - by_alias=by_alias, - include=value_include and value_include.for_element(k), - exclude=value_exclude and value_exclude.for_element(k), - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - ) - if not (exclude_none and value is None): - yield k, value - - def _calculate_keys( - self, - include: Optional[Union["AbstractSetIntStr", "DictIntStrAny"]], - exclude: Optional[Union["AbstractSetIntStr", "DictIntStrAny"]], - exclude_unset: bool, - update: Optional["DictStrAny"] = None, - ) -> Optional["SetStr"]: - if include is None and exclude is None and exclude_unset is False: - return None - - if exclude_unset: - keys = self.__fields_set__.copy() - else: - keys = set(self.__dict__.keys()) - - if include is not None: - if isinstance(include, dict): - keys &= include.keys() - else: - keys &= include - - if update: - keys -= update.keys() - - if exclude: - if isinstance(exclude, dict): - keys -= {k for k, v in exclude.items() if v is ...} - else: - keys -= exclude - - return keys - - def __eq__(self, other: Any) -> bool: - if isinstance(other, BaseModel): - return self.dict() == other.dict() - else: - return self.dict() == other - - def __repr_args__(self) -> "ReprArgs": - return self.__dict__.items() # type: ignore - - @property - def fields(self) -> Dict[str, ModelField]: - warnings.warn( - "`fields` attribute is deprecated, use `__fields__` instead", - DeprecationWarning, - ) - return self.__fields__ - - def to_string(self, pretty: bool = False) -> str: - warnings.warn( - "`model.to_string()` method is deprecated, use `str(model)` instead", - DeprecationWarning, - ) - return str(self) - - @property - def __values__(self) -> "DictStrAny": - warnings.warn( - "`__values__` attribute is deprecated, use `__dict__` instead", - DeprecationWarning, - ) - return self.__dict__ - - -def create_model( - model_name: str, - *, - __config__: Type[BaseConfig] = None, - __base__: Type[BaseModel] = None, - __module__: Optional[str] = None, - __validators__: Dict[str, classmethod] = None, - **field_definitions: Any, -) -> Type[BaseModel]: - """ - Dynamically create a model. - :param model_name: name of the created model - :param __config__: config class to use for the new model - :param __base__: base class for the new model to inherit from - :param __validators__: a dict of method names and @validator class methods - :param **field_definitions: fields of the model (or extra fields if a base is supplied) in the format - `=(, )` or `= eg. `foobar=(str, ...)` or `foobar=123` - """ - if __base__: - if __config__ is not None: - raise ConfigError( - "to avoid confusion __config__ and __base__ cannot be used together" - ) - else: - __base__ = BaseModel - - fields = {} - annotations = {} - - for f_name, f_def in field_definitions.items(): - if not is_valid_field(f_name): - warnings.warn( - f'fields may not start with an underscore, ignoring "{f_name}"', - RuntimeWarning, - ) - if isinstance(f_def, tuple): - try: - f_annotation, f_value = f_def - except ValueError as e: - raise ConfigError( - f"field definitions should either be a tuple of (, ) or just a " - f"default value, unfortunately this means tuples as " - f"default values are not allowed" - ) from e - else: - f_annotation, f_value = None, f_def - - if f_annotation: - annotations[f_name] = f_annotation - fields[f_name] = f_value - - namespace: "DictStrAny" = {"__annotations__": annotations, "__module__": __module__} - if __validators__: - namespace.update(__validators__) - namespace.update(fields) - if __config__: - namespace["Config"] = inherit_config(__config__, BaseConfig) - - return type(model_name, (__base__,), namespace) - - -_missing = object() - - -def validate_model( # noqa: C901 (ignore complexity) - model: Type[BaseModel], input_data: "DictStrAny", cls: "ModelOrDc" = None -) -> Tuple["DictStrAny", "SetStr", Optional[ValidationError]]: - """ - validate data against a model. - """ - values = {} - errors = [] - # input_data names, possibly alias - names_used = set() - # field names, never aliases - fields_set = set() - config = model.__config__ - check_extra = config.extra is not Extra.ignore - cls_ = cls or model - - for validator in model.__pre_root_validators__: - try: - input_data = validator(cls_, input_data) - except (ValueError, TypeError, AssertionError) as exc: - return {}, set(), ValidationError([ErrorWrapper(exc, loc=ROOT_KEY)], cls_) - - for name, field in model.__fields__.items(): - if type(field.type_) == ForwardRef: - raise ConfigError( - f'field "{field.name}" not yet prepared so type is still a ForwardRef, ' - f"you might need to call {cls_.__name__}.update_forward_refs()." - ) - - value = input_data.get(field.alias, _missing) - using_name = False - if ( - value is _missing - and config.allow_population_by_field_name - and field.alt_alias - ): - value = input_data.get(field.name, _missing) - using_name = True - - if value is _missing: - if field.required: - errors.append(ErrorWrapper(MissingError(), loc=field.alias)) - continue - - if field.default is None: - # deepcopy is quite slow on None - value = None - else: - value = deepcopy(field.default) - - if not config.validate_all and not field.validate_always: - values[name] = value - continue - else: - fields_set.add(name) - if check_extra: - names_used.add(field.name if using_name else field.alias) - - v_, errors_ = field.validate(value, values, loc=field.alias, cls=cls_) - if isinstance(errors_, ErrorWrapper): - errors.append(errors_) - elif isinstance(errors_, list): - errors.extend(errors_) - else: - values[name] = v_ - - if check_extra: - if isinstance(input_data, GetterDict): - extra = input_data.extra_keys() - names_used - else: - extra = input_data.keys() - names_used - if extra: - fields_set |= extra - if config.extra is Extra.allow: - for f in extra: - values[f] = input_data[f] - else: - for f in sorted(extra): - errors.append(ErrorWrapper(ExtraError(), loc=f)) - - for skip_on_failure, validator in model.__post_root_validators__: - if skip_on_failure and errors: - continue - try: - values = validator(cls_, values) - except (ValueError, TypeError, AssertionError) as exc: - errors.append(ErrorWrapper(exc, loc=ROOT_KEY)) - break - - if errors: - return values, fields_set, ValidationError(errors, cls_) - else: - return values, fields_set, None diff --git a/nornir/_vendor/pydantic/mypy.py b/nornir/_vendor/pydantic/mypy.py deleted file mode 100644 index 177dffce..00000000 --- a/nornir/_vendor/pydantic/mypy.py +++ /dev/null @@ -1,792 +0,0 @@ -from configparser import ConfigParser -from typing import ( - Any, - Callable, - Dict, - List, - Optional, - Set, - Tuple, - Type as TypingType, - Union, -) - -from mypy.errorcodes import ErrorCode -from mypy.nodes import ( - ARG_NAMED, - ARG_NAMED_OPT, - ARG_OPT, - ARG_POS, - ARG_STAR2, - MDEF, - Argument, - AssignmentStmt, - Block, - CallExpr, - ClassDef, - Context, - Decorator, - EllipsisExpr, - FuncBase, - FuncDef, - JsonDict, - MemberExpr, - NameExpr, - PassStmt, - PlaceholderNode, - RefExpr, - StrExpr, - SymbolNode, - SymbolTableNode, - TempNode, - TypeInfo, - TypeVarExpr, - Var, -) -from mypy.options import Options -from mypy.plugin import ( - CheckerPluginInterface, - ClassDefContext, - MethodContext, - Plugin, - SemanticAnalyzerPluginInterface, -) -from mypy.plugins import dataclasses -from mypy.semanal import set_callable_name # type: ignore -from mypy.server.trigger import make_wildcard_trigger -from mypy.types import ( - AnyType, - CallableType, - Instance, - NoneType, - Type, - TypeOfAny, - TypeType, - TypeVarDef, - TypeVarType, - UnionType, - get_proper_type, -) -from mypy.typevars import fill_typevars -from mypy.util import get_unique_redefinition_name - -CONFIGFILE_KEY = "pydantic-mypy" -METADATA_KEY = "pydantic-mypy-metadata" -BASEMODEL_FULLNAME = "pydantic.main.BaseModel" -BASESETTINGS_FULLNAME = "pydantic.env_settings.BaseSettings" -FIELD_FULLNAME = "pydantic.fields.Field" -DATACLASS_FULLNAME = "pydantic.dataclasses.dataclass" - - -def plugin(version: str) -> "TypingType[Plugin]": - """ - `version` is the mypy version string - - We might want to use this to print a warning if the mypy version being used is - newer, or especially older, than we expect (or need). - """ - return PydanticPlugin - - -class PydanticPlugin(Plugin): - def __init__(self, options: Options) -> None: - self.plugin_config = PydanticPluginConfig(options) - super().__init__(options) - - def get_base_class_hook( - self, fullname: str - ) -> "Optional[Callable[[ClassDefContext], None]]": - sym = self.lookup_fully_qualified(fullname) - if sym and isinstance(sym.node, TypeInfo): # pragma: no branch - # No branching may occur if the mypy cache has not been cleared - if any(get_fullname(base) == BASEMODEL_FULLNAME for base in sym.node.mro): - return self._pydantic_model_class_maker_callback - return None - - def get_method_hook( - self, fullname: str - ) -> Optional[Callable[[MethodContext], Type]]: - if fullname.endswith(".from_orm"): - return from_orm_callback - return None - - def get_class_decorator_hook( - self, fullname: str - ) -> Optional[Callable[[ClassDefContext], None]]: - if fullname == DATACLASS_FULLNAME: - return dataclasses.dataclass_class_maker_callback - return None - - def _pydantic_model_class_maker_callback(self, ctx: ClassDefContext) -> None: - transformer = PydanticModelTransformer(ctx, self.plugin_config) - transformer.transform() - - -class PydanticPluginConfig: - __slots__ = ( - "init_forbid_extra", - "init_typed", - "warn_required_dynamic_aliases", - "warn_untyped_fields", - ) - init_forbid_extra: bool - init_typed: bool - warn_required_dynamic_aliases: bool - warn_untyped_fields: bool - - def __init__(self, options: Options) -> None: - if options.config_file is None: # pragma: no cover - return - plugin_config = ConfigParser() - plugin_config.read(options.config_file) - for key in self.__slots__: - setting = plugin_config.getboolean(CONFIGFILE_KEY, key, fallback=False) - setattr(self, key, setting) - - -def from_orm_callback(ctx: MethodContext) -> Type: - """ - Raise an error if orm_mode is not enabled - """ - model_type: Instance - if isinstance(ctx.type, CallableType) and isinstance(ctx.type.ret_type, Instance): - model_type = ctx.type.ret_type # called on the class - elif isinstance(ctx.type, Instance): - model_type = ctx.type # called on an instance (unusual, but still valid) - else: # pragma: no cover - detail = f"ctx.type: {ctx.type} (of type {type(ctx.type).__name__})" - error_unexpected_behavior(detail, ctx.api, ctx.context) - return ctx.default_return_type - pydantic_metadata = model_type.type.metadata.get(METADATA_KEY) - if pydantic_metadata is None: - return ctx.default_return_type - orm_mode = pydantic_metadata.get("config", {}).get("orm_mode") - if orm_mode is not True: - error_from_orm(get_name(model_type.type), ctx.api, ctx.context) - return ctx.default_return_type - - -class PydanticModelTransformer: - tracked_config_fields: Set[str] = { - "extra", - "allow_mutation", - "orm_mode", - "allow_population_by_field_name", - "alias_generator", - } - - def __init__( - self, ctx: ClassDefContext, plugin_config: PydanticPluginConfig - ) -> None: - self._ctx = ctx - self.plugin_config = plugin_config - - def transform(self) -> None: - """ - Configures the BaseModel subclass according to the plugin settings. - - In particular: - * determines the model config and fields, - * adds a fields-aware signature for the initializer and construct methods - * freezes the class if allow_mutation = False - * stores the fields, config, and if the class is settings in the mypy metadata for access by subclasses - """ - ctx = self._ctx - info = self._ctx.cls.info - - config = self.collect_config() - fields = self.collect_fields(config) - for field in fields: - if info[field.name].type is None: - if not ctx.api.final_iteration: - ctx.api.defer() - is_settings = any( - get_fullname(base) == BASESETTINGS_FULLNAME for base in info.mro[:-1] - ) - self.add_initializer(fields, config, is_settings) - self.add_construct_method(fields) - self.set_frozen(fields, frozen=config.allow_mutation is False) - info.metadata[METADATA_KEY] = { - "fields": {field.name: field.serialize() for field in fields}, - "config": config.set_values_dict(), - } - - def collect_config(self) -> "ModelConfigData": - """ - Collects the values of the config attributes that are used by the plugin, accounting for parent classes. - """ - ctx = self._ctx - cls = ctx.cls - config = ModelConfigData() - for stmt in cls.defs.body: - if not isinstance(stmt, ClassDef): - continue - if stmt.name == "Config": - for substmt in stmt.defs.body: - if not isinstance(substmt, AssignmentStmt): - continue - config.update(self.get_config_update(substmt)) - if ( - config.has_alias_generator - and not config.allow_population_by_field_name - and self.plugin_config.warn_required_dynamic_aliases - ): - error_required_dynamic_aliases(ctx.api, stmt) - for info in cls.info.mro[1:]: # 0 is the current class - if METADATA_KEY not in info.metadata: - continue - - # Each class depends on the set of fields in its ancestors - ctx.api.add_plugin_dependency(make_wildcard_trigger(get_fullname(info))) - for name, value in info.metadata[METADATA_KEY]["config"].items(): - config.setdefault(name, value) - return config - - def collect_fields( - self, model_config: "ModelConfigData" - ) -> List["PydanticModelField"]: - """ - Collects the fields for the model, accounting for parent classes - """ - # First, collect fields belonging to the current class. - ctx = self._ctx - cls = self._ctx.cls - fields = [] # type: List[PydanticModelField] - known_fields = set() # type: Set[str] - for stmt in cls.defs.body: - if not isinstance( - stmt, AssignmentStmt - ): # `and stmt.new_syntax` to require annotation - continue - - lhs = stmt.lvalues[0] - if not isinstance(lhs, NameExpr): - continue - - if not stmt.new_syntax and self.plugin_config.warn_untyped_fields: - error_untyped_fields(ctx.api, stmt) - - # if lhs.name == '__config__': # BaseConfig not well handled; I'm not sure why yet - # continue - - sym = cls.info.names.get(lhs.name) - if sym is None: # pragma: no cover - # This is likely due to a star import (see the dataclasses plugin for a more detailed explanation) - # This is the same logic used in the dataclasses plugin - continue - - node = sym.node - if isinstance(node, PlaceholderNode): # pragma: no cover - # See the PlaceholderNode docstring for more detail about how this can occur - # Basically, it is an edge case when dealing with complex import logic - # This is the same logic used in the dataclasses plugin - continue - assert isinstance(node, Var) - - # x: ClassVar[int] is ignored by dataclasses. - if node.is_classvar: - continue - - is_required = self.get_is_required(cls, stmt, lhs) - alias, has_dynamic_alias = self.get_alias_info(stmt) - if ( - has_dynamic_alias - and not model_config.allow_population_by_field_name - and self.plugin_config.warn_required_dynamic_aliases - ): - error_required_dynamic_aliases(ctx.api, stmt) - fields.append( - PydanticModelField( - name=lhs.name, - is_required=is_required, - alias=alias, - has_dynamic_alias=has_dynamic_alias, - line=stmt.line, - column=stmt.column, - ) - ) - known_fields.add(lhs.name) - all_fields = fields.copy() - for info in cls.info.mro[ - 1: - ]: # 0 is the current class, -2 is BaseModel, -1 is object - if METADATA_KEY not in info.metadata: - continue - - superclass_fields = [] - # Each class depends on the set of fields in its ancestors - ctx.api.add_plugin_dependency(make_wildcard_trigger(get_fullname(info))) - - for name, data in info.metadata[METADATA_KEY]["fields"].items(): - if name not in known_fields: - field = PydanticModelField.deserialize(info, data) - known_fields.add(name) - superclass_fields.append(field) - else: - (field,) = [a for a in all_fields if a.name == name] - all_fields.remove(field) - superclass_fields.append(field) - all_fields = superclass_fields + all_fields - return all_fields - - def add_initializer( - self, - fields: List["PydanticModelField"], - config: "ModelConfigData", - is_settings: bool, - ) -> None: - """ - Adds a fields-aware `__init__` method to the class. - - The added `__init__` will be annotated with types vs. all `Any` depending on the plugin settings. - """ - ctx = self._ctx - typed = self.plugin_config.init_typed - use_alias = config.allow_population_by_field_name is not True - force_all_optional = is_settings or bool( - config.has_alias_generator and not config.allow_population_by_field_name - ) - init_arguments = self.get_field_arguments( - fields, - typed=typed, - force_all_optional=force_all_optional, - use_alias=use_alias, - ) - if not self.should_init_forbid_extra(fields, config): - var = Var("kwargs") - init_arguments.append( - Argument(var, AnyType(TypeOfAny.explicit), None, ARG_STAR2) - ) - add_method(ctx, "__init__", init_arguments, NoneType()) - - def add_construct_method(self, fields: List["PydanticModelField"]) -> None: - """ - Adds a fully typed `construct` classmethod to the class. - - Similar to the fields-aware __init__ method, but always uses the field names (not aliases), - and does not treat settings fields as optional. - """ - ctx = self._ctx - set_str = ctx.api.named_type( - "__builtins__.set", [ctx.api.named_type("__builtins__.str")] - ) - optional_set_str = UnionType([set_str, NoneType()]) - fields_set_argument = Argument( - Var("_fields_set", optional_set_str), optional_set_str, None, ARG_OPT - ) - construct_arguments = self.get_field_arguments( - fields, typed=True, force_all_optional=False, use_alias=False - ) - construct_arguments = [fields_set_argument] + construct_arguments - - obj_type = ctx.api.named_type("__builtins__.object") - self_tvar_name = "Model" - tvar_fullname = ctx.cls.fullname + "." + self_tvar_name - tvd = TypeVarDef(self_tvar_name, tvar_fullname, -1, [], obj_type) - self_tvar_expr = TypeVarExpr(self_tvar_name, tvar_fullname, [], obj_type) - ctx.cls.info.names[self_tvar_name] = SymbolTableNode(MDEF, self_tvar_expr) - self_type = TypeVarType(tvd) - add_method( - ctx, - "construct", - construct_arguments, - return_type=self_type, - self_type=self_type, - tvar_def=tvd, - is_classmethod=True, - ) - - def set_frozen(self, fields: List["PydanticModelField"], frozen: bool) -> None: - """ - Marks all fields as properties so that attempts to set them trigger mypy errors. - - This is the same approach used by the attrs and dataclasses plugins. - """ - info = self._ctx.cls.info - for field in fields: - sym_node = info.names.get(field.name) - if sym_node is not None: - var = sym_node.node - assert isinstance(var, Var) - var.is_property = frozen - else: - var = field.to_var(info, use_alias=False) - var.info = info - var.is_property = frozen - var._fullname = get_fullname(info) + "." + get_name(var) - info.names[get_name(var)] = SymbolTableNode(MDEF, var) - - def get_config_update(self, substmt: AssignmentStmt) -> Optional["ModelConfigData"]: - """ - Determines the config update due to a single statement in the Config class definition. - - Warns if a tracked config attribute is set to a value the plugin doesn't know how to interpret (e.g., an int) - """ - lhs = substmt.lvalues[0] - if not (isinstance(lhs, NameExpr) and lhs.name in self.tracked_config_fields): - return None - if lhs.name == "extra": - if isinstance(substmt.rvalue, StrExpr): - forbid_extra = substmt.rvalue.value == "forbid" - elif isinstance(substmt.rvalue, MemberExpr): - forbid_extra = substmt.rvalue.name == "forbid" - else: - error_invalid_config_value(lhs.name, self._ctx.api, substmt) - return None - return ModelConfigData(forbid_extra=forbid_extra) - if lhs.name == "alias_generator": - has_alias_generator = True - if ( - isinstance(substmt.rvalue, NameExpr) - and substmt.rvalue.fullname == "builtins.None" - ): - has_alias_generator = False - return ModelConfigData(has_alias_generator=has_alias_generator) - if isinstance(substmt.rvalue, NameExpr) and substmt.rvalue.fullname in ( - "builtins.True", - "builtins.False", - ): - return ModelConfigData( - **{lhs.name: substmt.rvalue.fullname == "builtins.True"} - ) - error_invalid_config_value(lhs.name, self._ctx.api, substmt) - return None - - @staticmethod - def get_is_required(cls: ClassDef, stmt: AssignmentStmt, lhs: NameExpr) -> bool: - """ - Returns a boolean indicating whether the field defined in `stmt` is a required field. - """ - expr = stmt.rvalue - if isinstance(expr, TempNode): - # TempNode means annotation-only, so only non-required if Optional - value_type = get_proper_type(cls.info[lhs.name].type) - if isinstance(value_type, UnionType) and any( - isinstance(item, NoneType) for item in value_type.items - ): - # Annotated as Optional, or otherwise having NoneType in the union - return False - return True - if ( - isinstance(expr, CallExpr) - and isinstance(expr.callee, RefExpr) - and expr.callee.fullname == FIELD_FULLNAME - ): - # The "default value" is a call to `Field`; at this point, the field is - # only required if default is Ellipsis (i.e., `field_name: Annotation = Field(...)`) - return len(expr.args) > 0 and type(expr.args[0]) is EllipsisExpr - # Only required if the "default value" is Ellipsis (i.e., `field_name: Annotation = ...`) - return isinstance(expr, EllipsisExpr) - - @staticmethod - def get_alias_info(stmt: AssignmentStmt) -> Tuple[Optional[str], bool]: - """ - Returns a pair (alias, has_dynamic_alias), extracted from the declaration of the field defined in `stmt`. - - `has_dynamic_alias` is True if and only if an alias is provided, but not as a string literal. - If `has_dynamic_alias` is True, `alias` will be None. - """ - expr = stmt.rvalue - if isinstance(expr, TempNode): - # TempNode means annotation-only - return None, False - - if not ( - isinstance(expr, CallExpr) - and isinstance(expr.callee, RefExpr) - and expr.callee.fullname == FIELD_FULLNAME - ): - # Assigned value is not a call to pydantic.fields.Field - return None, False - - for i, arg_name in enumerate(expr.arg_names): - if arg_name != "alias": - continue - arg = expr.args[i] - if isinstance(arg, StrExpr): - return arg.value, False - else: - return None, True - return None, False - - def get_field_arguments( - self, - fields: List["PydanticModelField"], - typed: bool, - force_all_optional: bool, - use_alias: bool, - ) -> List[Argument]: - """ - Helper function used during the construction of the `__init__` and `construct` method signatures. - - Returns a list of mypy Argument instances for use in the generated signatures. - """ - info = self._ctx.cls.info - arguments = [ - field.to_argument( - info, - typed=typed, - force_optional=force_all_optional, - use_alias=use_alias, - ) - for field in fields - if not (use_alias and field.has_dynamic_alias) - ] - return arguments - - def should_init_forbid_extra( - self, fields: List["PydanticModelField"], config: "ModelConfigData" - ) -> bool: - """ - Indicates whether the generated `__init__` should get a `**kwargs` at the end of its signature - - We disallow arbitrary kwargs if the extra config setting is "forbid", or if the plugin config says to, - *unless* a required dynamic alias is present (since then we can't determine a valid signature). - """ - if not config.allow_population_by_field_name: - if self.is_dynamic_alias_present(fields, bool(config.has_alias_generator)): - return False - if config.forbid_extra: - return True - return self.plugin_config.init_forbid_extra - - @staticmethod - def is_dynamic_alias_present( - fields: List["PydanticModelField"], has_alias_generator: bool - ) -> bool: - """ - Returns whether any fields on the model have a "dynamic alias", i.e., an alias that cannot be - determined during static analysis. - """ - for field in fields: - if field.has_dynamic_alias: - return True - if has_alias_generator: - for field in fields: - if field.alias is None: - return True - return False - - -class PydanticModelField: - def __init__( - self, - name: str, - is_required: bool, - alias: Optional[str], - has_dynamic_alias: bool, - line: int, - column: int, - ): - self.name = name - self.is_required = is_required - self.alias = alias - self.has_dynamic_alias = has_dynamic_alias - self.line = line - self.column = column - - def to_var(self, info: TypeInfo, use_alias: bool) -> Var: - name = self.name - if use_alias and self.alias is not None: - name = self.alias - return Var(name, info[self.name].type) - - def to_argument( - self, info: TypeInfo, typed: bool, force_optional: bool, use_alias: bool - ) -> Argument: - if typed and info[self.name].type is not None: - type_annotation = info[self.name].type - else: - type_annotation = AnyType(TypeOfAny.explicit) - return Argument( - variable=self.to_var(info, use_alias), - type_annotation=type_annotation, - initializer=None, - kind=ARG_NAMED_OPT if force_optional or not self.is_required else ARG_NAMED, - ) - - def serialize(self) -> JsonDict: - return self.__dict__ - - @classmethod - def deserialize(cls, info: TypeInfo, data: JsonDict) -> "PydanticModelField": - return cls(**data) - - -class ModelConfigData: - def __init__( - self, - forbid_extra: Optional[bool] = None, - allow_mutation: Optional[bool] = None, - orm_mode: Optional[bool] = None, - allow_population_by_field_name: Optional[bool] = None, - has_alias_generator: Optional[bool] = None, - ): - self.forbid_extra = forbid_extra - self.allow_mutation = allow_mutation - self.orm_mode = orm_mode - self.allow_population_by_field_name = allow_population_by_field_name - self.has_alias_generator = has_alias_generator - - def set_values_dict(self) -> Dict[str, Any]: - return {k: v for k, v in self.__dict__.items() if v is not None} - - def update(self, config: Optional["ModelConfigData"]) -> None: - if config is None: - return - for k, v in config.set_values_dict().items(): - setattr(self, k, v) - - def setdefault(self, key: str, value: Any) -> None: - if getattr(self, key) is None: - setattr(self, key, value) - - -ERROR_ORM = ErrorCode("pydantic-orm", "Invalid from_orm call", "Pydantic") -ERROR_CONFIG = ErrorCode("pydantic-config", "Invalid config value", "Pydantic") -ERROR_ALIAS = ErrorCode("pydantic-alias", "Dynamic alias disallowed", "Pydantic") -ERROR_UNEXPECTED = ErrorCode("pydantic-unexpected", "Unexpected behavior", "Pydantic") -ERROR_UNTYPED = ErrorCode("pydantic-field", "Untyped field disallowed", "Pydantic") - - -def error_from_orm( - model_name: str, api: CheckerPluginInterface, context: Context -) -> None: - api.fail(f'"{model_name}" does not have orm_mode=True', context, code=ERROR_ORM) - - -def error_invalid_config_value( - name: str, api: SemanticAnalyzerPluginInterface, context: Context -) -> None: - api.fail(f'Invalid value for "Config.{name}"', context, code=ERROR_CONFIG) - - -def error_required_dynamic_aliases( - api: SemanticAnalyzerPluginInterface, context: Context -) -> None: - api.fail("Required dynamic aliases disallowed", context, code=ERROR_ALIAS) - - -def error_unexpected_behavior( - detail: str, api: CheckerPluginInterface, context: Context -) -> None: # pragma: no cover - # Can't think of a good way to test this, but I confirmed it renders as desired by adding to a non-error path - link = "https://github.com/samuelcolvin/pydantic/issues/new/choose" - full_message = f"The pydantic mypy plugin ran into unexpected behavior: {detail}\n" - full_message += ( - f"Please consider reporting this bug at {link} so we can try to fix it!" - ) - api.fail(full_message, context, code=ERROR_UNEXPECTED) - - -def error_untyped_fields( - api: SemanticAnalyzerPluginInterface, context: Context -) -> None: - api.fail("Untyped fields disallowed", context, code=ERROR_UNTYPED) - - -def add_method( - ctx: ClassDefContext, - name: str, - args: List[Argument], - return_type: Type, - self_type: Optional[Type] = None, - tvar_def: Optional[TypeVarDef] = None, - is_classmethod: bool = False, - is_new: bool = False, - # is_staticmethod: bool = False, -) -> None: - """ - Adds a new method to a class. - - This can be dropped if/when https://github.com/python/mypy/issues/7301 is merged - """ - info = ctx.cls.info - - # First remove any previously generated methods with the same name - # to avoid clashes and problems in the semantic analyzer. - if name in info.names: - sym = info.names[name] - if sym.plugin_generated and isinstance(sym.node, FuncDef): - ctx.cls.defs.body.remove(sym.node) - - self_type = self_type or fill_typevars(info) - if is_classmethod or is_new: - first = [ - Argument(Var("_cls"), TypeType.make_normalized(self_type), None, ARG_POS) - ] - # elif is_staticmethod: - # first = [] - else: - self_type = self_type or fill_typevars(info) - first = [Argument(Var("self"), self_type, None, ARG_POS)] - args = first + args - arg_types, arg_names, arg_kinds = [], [], [] - for arg in args: - assert arg.type_annotation, "All arguments must be fully typed." - arg_types.append(arg.type_annotation) - arg_names.append(get_name(arg.variable)) - arg_kinds.append(arg.kind) - - function_type = ctx.api.named_type("__builtins__.function") - signature = CallableType( - arg_types, arg_kinds, arg_names, return_type, function_type - ) - if tvar_def: - signature.variables = [tvar_def] - - func = FuncDef(name, args, Block([PassStmt()])) - func.info = info - func.type = set_callable_name(signature, func) - func.is_class = is_classmethod - # func.is_static = is_staticmethod - func._fullname = get_fullname(info) + "." + name - func.line = info.line - - # NOTE: we would like the plugin generated node to dominate, but we still - # need to keep any existing definitions so they get semantically analyzed. - if name in info.names: - # Get a nice unique name instead. - r_name = get_unique_redefinition_name(name, info.names) - info.names[r_name] = info.names[name] - - if is_classmethod: # or is_staticmethod: - func.is_decorated = True - v = Var(name, func.type) - v.info = info - v._fullname = func._fullname - # if is_classmethod: - v.is_classmethod = True - dec = Decorator(func, [NameExpr("classmethod")], v) - # else: - # v.is_staticmethod = True - # dec = Decorator(func, [NameExpr('staticmethod')], v) - - dec.line = info.line - sym = SymbolTableNode(MDEF, dec) - else: - sym = SymbolTableNode(MDEF, func) - sym.plugin_generated = True - - info.names[name] = sym - info.defn.defs.body.append(func) - - -def get_fullname(x: Union[FuncBase, SymbolNode]) -> str: - """ - Used for compatibility with mypy 0.740; can be dropped once support for 0.740 is dropped. - """ - fn = x.fullname - if callable(fn): # pragma: no cover - return fn() - return fn - - -def get_name(x: Union[FuncBase, SymbolNode]) -> str: - """ - Used for compatibility with mypy 0.740; can be dropped once support for 0.740 is dropped. - """ - fn = x.name - if callable(fn): # pragma: no cover - return fn() - return fn diff --git a/nornir/_vendor/pydantic/networks.py b/nornir/_vendor/pydantic/networks.py deleted file mode 100644 index c2c7cfd4..00000000 --- a/nornir/_vendor/pydantic/networks.py +++ /dev/null @@ -1,459 +0,0 @@ -import re -from ipaddress import ( - IPv4Address, - IPv4Interface, - IPv4Network, - IPv6Address, - IPv6Interface, - IPv6Network, - _BaseAddress, - _BaseNetwork, -) -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Generator, - Optional, - Set, - Tuple, - Type, - Union, - cast, - no_type_check, -) - -from . import errors -from .utils import Representation, update_not_none -from .validators import constr_length_validator, str_validator - -if TYPE_CHECKING: - from .fields import ModelField - from .main import BaseConfig # noqa: F401 - from .typing import AnyCallable - - CallableGenerator = Generator[AnyCallable, None, None] - -try: - import email_validator -except ImportError: - email_validator = None - -NetworkType = Union[str, bytes, int, Tuple[Union[str, bytes, int], Union[str, int]]] - -__all__ = [ - "AnyUrl", - "AnyHttpUrl", - "HttpUrl", - "stricturl", - "EmailStr", - "NameEmail", - "IPvAnyAddress", - "IPvAnyInterface", - "IPvAnyNetwork", - "PostgresDsn", - "RedisDsn", - "validate_email", -] - -host_part_names = ("domain", "ipv4", "ipv6") -url_regex = re.compile( - r"(?:(?P[a-z0-9]+?)://)?" # scheme - r"(?:(?P[^\s:]+)(?::(?P\S*))?@)?" # user info - r"(?:" - r"(?P(?:\d{1,3}\.){3}\d{1,3})|" # ipv4 - r"(?P\[[A-F0-9]*:[A-F0-9:]+\])|" # ipv6 - r"(?P[^\s/:?#]+)" # domain, validation occurs later - r")?" - r"(?::(?P\d+))?" # port - r"(?P/[^\s?]*)?" # path - r"(?:\?(?P[^\s#]+))?" # query - r"(?:#(?P\S+))?", # fragment - re.IGNORECASE, -) -_ascii_chunk = r"[_0-9a-z](?:[-_0-9a-z]{0,61}[_0-9a-z])?" -_domain_ending = r"(?P\.[a-z]{2,63})?\.?" -ascii_domain_regex = re.compile( - fr"(?:{_ascii_chunk}\.)*?{_ascii_chunk}{_domain_ending}", re.IGNORECASE -) - -_int_chunk = r"[_0-9a-\U00040000](?:[-_0-9a-\U00040000]{0,61}[_0-9a-\U00040000])?" -int_domain_regex = re.compile( - fr"(?:{_int_chunk}\.)*?{_int_chunk}{_domain_ending}", re.IGNORECASE -) - - -class AnyUrl(str): - strip_whitespace = True - min_length = 1 - max_length = 2 ** 16 - allowed_schemes: Optional[Set[str]] = None - tld_required: bool = False - user_required: bool = False - - __slots__ = ( - "scheme", - "user", - "password", - "host", - "tld", - "host_type", - "port", - "path", - "query", - "fragment", - ) - - @no_type_check - def __new__(cls, url: Optional[str], **kwargs) -> object: - return str.__new__(cls, cls.build(**kwargs) if url is None else url) - - def __init__( - self, - url: str, - *, - scheme: str, - user: Optional[str] = None, - password: Optional[str] = None, - host: str, - tld: Optional[str] = None, - host_type: str = "domain", - port: Optional[str] = None, - path: Optional[str] = None, - query: Optional[str] = None, - fragment: Optional[str] = None, - ) -> None: - str.__init__(url) - self.scheme = scheme - self.user = user - self.password = password - self.host = host - self.tld = tld - self.host_type = host_type - self.port = port - self.path = path - self.query = query - self.fragment = fragment - - @classmethod - def build( - cls, - *, - scheme: str, - user: Optional[str] = None, - password: Optional[str] = None, - host: str, - port: Optional[str] = None, - path: Optional[str] = None, - query: Optional[str] = None, - fragment: Optional[str] = None, - **kwargs: str, - ) -> str: - url = scheme + "://" - if user: - url += user - if password: - url += ":" + password - url += "@" - url += host - if port: - url += ":" + port - if path: - url += path - if query: - url += "?" + query - if fragment: - url += "#" + fragment - return url - - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - update_not_none( - field_schema, - minLength=cls.min_length, - maxLength=cls.max_length, - format="uri", - ) - - @classmethod - def __get_validators__(cls) -> "CallableGenerator": - yield cls.validate - - @classmethod - def validate( - cls, value: Any, field: "ModelField", config: "BaseConfig" - ) -> "AnyUrl": - if type(value) == cls: - return value - value = str_validator(value) - if cls.strip_whitespace: - value = value.strip() - url: str = cast(str, constr_length_validator(value, field, config)) - - m = url_regex.match(url) - # the regex should always match, if it doesn't please report with details of the URL tried - assert m, "URL regex failed unexpectedly" - - parts = m.groupdict() - scheme = parts["scheme"] - if scheme is None: - raise errors.UrlSchemeError() - if cls.allowed_schemes and scheme.lower() not in cls.allowed_schemes: - raise errors.UrlSchemePermittedError(cls.allowed_schemes) - - user = parts["user"] - if cls.user_required and user is None: - raise errors.UrlUserInfoError() - - host, tld, host_type, rebuild = cls.validate_host(parts) - - if m.end() != len(url): - raise errors.UrlExtraError(extra=url[m.end() :]) - - return cls( - None if rebuild else url, - scheme=scheme, - user=user, - password=parts["password"], - host=host, - tld=tld, - host_type=host_type, - port=parts["port"], - path=parts["path"], - query=parts["query"], - fragment=parts["fragment"], - ) - - @classmethod - def validate_host( - cls, parts: Dict[str, str] - ) -> Tuple[str, Optional[str], str, bool]: - host, tld, host_type, rebuild = None, None, None, False - for f in ("domain", "ipv4", "ipv6"): - host = parts[f] - if host: - host_type = f - break - - if host is None: - raise errors.UrlHostError() - elif host_type == "domain": - d = ascii_domain_regex.fullmatch(host) - if d is None: - d = int_domain_regex.fullmatch(host) - if not d: - raise errors.UrlHostError() - host_type = "int_domain" - rebuild = True - host = host.encode("idna").decode("ascii") - - tld = d.group("tld") - if tld is not None: - tld = tld[1:] - elif cls.tld_required: - raise errors.UrlHostTldError() - return host, tld, host_type, rebuild # type: ignore - - def __repr__(self) -> str: - extra = ", ".join( - f"{n}={getattr(self, n)!r}" - for n in self.__slots__ - if getattr(self, n) is not None - ) - return f"{self.__class__.__name__}({super().__repr__()}, {extra})" - - -class AnyHttpUrl(AnyUrl): - allowed_schemes = {"http", "https"} - - -class HttpUrl(AnyUrl): - allowed_schemes = {"http", "https"} - tld_required = True - # https://stackoverflow.com/questions/417142/what-is-the-maximum-length-of-a-url-in-different-browsers - max_length = 2083 - - -class PostgresDsn(AnyUrl): - allowed_schemes = {"postgres", "postgresql"} - user_required = True - - -class RedisDsn(AnyUrl): - allowed_schemes = {"redis"} - user_required = True - - -def stricturl( - *, - strip_whitespace: bool = True, - min_length: int = 1, - max_length: int = 2 ** 16, - tld_required: bool = True, - allowed_schemes: Optional[Set[str]] = None, -) -> Type[AnyUrl]: - # use kwargs then define conf in a dict to aid with IDE type hinting - namespace = dict( - strip_whitespace=strip_whitespace, - min_length=min_length, - max_length=max_length, - tld_required=tld_required, - allowed_schemes=allowed_schemes, - ) - return type("UrlValue", (AnyUrl,), namespace) - - -class EmailStr(str): - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - field_schema.update(type="string", format="email") - - @classmethod - def __get_validators__(cls) -> "CallableGenerator": - # included here and below so the error happens straight away - if email_validator is None: - raise ImportError( - "email-validator is not installed, run `pip install pydantic[email]`" - ) - - yield str_validator - yield cls.validate - - @classmethod - def validate(cls, value: str) -> str: - return validate_email(value)[1] - - -class NameEmail(Representation): - __slots__ = "name", "email" - - def __init__(self, name: str, email: str): - self.name = name - self.email = email - - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - field_schema.update(type="string", format="name-email") - - @classmethod - def __get_validators__(cls) -> "CallableGenerator": - if email_validator is None: - raise ImportError( - "email-validator is not installed, run `pip install pydantic[email]`" - ) - - yield str_validator - yield cls.validate - - @classmethod - def validate(cls, value: str) -> "NameEmail": - return cls(*validate_email(value)) - - def __str__(self) -> str: - return f"{self.name} <{self.email}>" - - -class IPvAnyAddress(_BaseAddress): - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - field_schema.update(type="string", format="ipvanyaddress") - - @classmethod - def __get_validators__(cls) -> "CallableGenerator": - yield cls.validate - - @classmethod - def validate(cls, value: Union[str, bytes, int]) -> Union[IPv4Address, IPv6Address]: - try: - return IPv4Address(value) - except ValueError: - pass - - try: - return IPv6Address(value) - except ValueError: - raise errors.IPvAnyAddressError() - - -class IPvAnyInterface(_BaseAddress): - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - field_schema.update(type="string", format="ipvanyinterface") - - @classmethod - def __get_validators__(cls) -> "CallableGenerator": - yield cls.validate - - @classmethod - def validate(cls, value: NetworkType) -> Union[IPv4Interface, IPv6Interface]: - try: - return IPv4Interface(value) - except ValueError: - pass - - try: - return IPv6Interface(value) - except ValueError: - raise errors.IPvAnyInterfaceError() - - -class IPvAnyNetwork(_BaseNetwork): # type: ignore - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - field_schema.update(type="string", format="ipvanynetwork") - - @classmethod - def __get_validators__(cls) -> "CallableGenerator": - yield cls.validate - - @classmethod - def validate(cls, value: NetworkType) -> Union[IPv4Network, IPv6Network]: - # Assume IP Network is defined with a default value for ``strict`` argument. - # Define your own class if you want to specify network address check strictness. - try: - return IPv4Network(value) - except ValueError: - pass - - try: - return IPv6Network(value) - except ValueError: - raise errors.IPvAnyNetworkError() - - -pretty_email_regex = re.compile(r"([\w ]*?) *<(.*)> *") - - -def validate_email(value: str) -> Tuple[str, str]: - """ - Brutally simple email address validation. Note unlike most email address validation - * raw ip address (literal) domain parts are not allowed. - * "John Doe " style "pretty" email addresses are processed - * the local part check is extremely basic. This raises the possibility of unicode spoofing, but no better - solution is really possible. - * spaces are striped from the beginning and end of addresses but no error is raised - - See RFC 5322 but treat it with suspicion, there seems to exist no universally acknowledged test for a valid email! - """ - if email_validator is None: - raise ImportError( - "email-validator is not installed, run `pip install pydantic[email]`" - ) - - m = pretty_email_regex.fullmatch(value) - name: Optional[str] = None - if m: - name, value = m.groups() - - email = value.strip() - - try: - email_validator.validate_email(email, check_deliverability=False) - except email_validator.EmailNotValidError as e: - raise errors.EmailError() from e - - at_index = email.index("@") - local_part = email[:at_index] # RFC 5321, local part must be case-sensitive. - global_part = email[at_index:].lower() - - return name or local_part, local_part + global_part diff --git a/nornir/_vendor/pydantic/parse.py b/nornir/_vendor/pydantic/parse.py deleted file mode 100644 index 833d2213..00000000 --- a/nornir/_vendor/pydantic/parse.py +++ /dev/null @@ -1,71 +0,0 @@ -import json -import pickle -from enum import Enum -from pathlib import Path -from typing import Any, Callable, Union - -from .types import StrBytes - - -class Protocol(str, Enum): - json = "json" - pickle = "pickle" - - -def load_str_bytes( - b: StrBytes, - *, - content_type: str = None, - encoding: str = "utf8", - proto: Protocol = None, - allow_pickle: bool = False, - json_loads: Callable[[str], Any] = json.loads, -) -> Any: - if proto is None and content_type: - if content_type.endswith(("json", "javascript")): - pass - elif allow_pickle and content_type.endswith("pickle"): - proto = Protocol.pickle - else: - raise TypeError(f"Unknown content-type: {content_type}") - - proto = proto or Protocol.json - - if proto == Protocol.json: - if isinstance(b, bytes): - b = b.decode(encoding) - return json_loads(b) - elif proto == Protocol.pickle: - if not allow_pickle: - raise RuntimeError("Trying to decode with pickle with allow_pickle=False") - bb = b if isinstance(b, bytes) else b.encode() - return pickle.loads(bb) - else: - raise TypeError(f"Unknown protocol: {proto}") - - -def load_file( - path: Union[str, Path], - *, - content_type: str = None, - encoding: str = "utf8", - proto: Protocol = None, - allow_pickle: bool = False, - json_loads: Callable[[str], Any] = json.loads, -) -> Any: - path = Path(path) - b = path.read_bytes() - if content_type is None: - if path.suffix in (".js", ".json"): - proto = Protocol.json - elif path.suffix == ".pkl": - proto = Protocol.pickle - - return load_str_bytes( - b, - proto=proto, - content_type=content_type, - encoding=encoding, - allow_pickle=allow_pickle, - json_loads=json_loads, - ) diff --git a/nornir/_vendor/pydantic/py.typed b/nornir/_vendor/pydantic/py.typed deleted file mode 100644 index e69de29b..00000000 diff --git a/nornir/_vendor/pydantic/schema.py b/nornir/_vendor/pydantic/schema.py deleted file mode 100644 index 830566eb..00000000 --- a/nornir/_vendor/pydantic/schema.py +++ /dev/null @@ -1,902 +0,0 @@ -import inspect -import re -import warnings -from datetime import date, datetime, time, timedelta -from decimal import Decimal -from enum import Enum -from ipaddress import ( - IPv4Address, - IPv4Interface, - IPv4Network, - IPv6Address, - IPv6Interface, - IPv6Network, -) -from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - FrozenSet, - List, - Optional, - Sequence, - Set, - Tuple, - Type, - TypeVar, - Union, - cast, -) -from uuid import UUID - -from .class_validators import ROOT_KEY -from .fields import ( - SHAPE_FROZENSET, - SHAPE_LIST, - SHAPE_MAPPING, - SHAPE_SEQUENCE, - SHAPE_SET, - SHAPE_SINGLETON, - SHAPE_TUPLE, - SHAPE_TUPLE_ELLIPSIS, - FieldInfo, - ModelField, -) -from .json import pydantic_encoder -from .networks import AnyUrl, EmailStr -from .types import ( - ConstrainedDecimal, - ConstrainedFloat, - ConstrainedInt, - ConstrainedList, - ConstrainedStr, - conbytes, - condecimal, - confloat, - conint, - conlist, - constr, -) -from .typing import ( - ForwardRef, - Literal, - is_callable_type, - is_literal_type, - is_new_type, - literal_values, - new_type_supertype, -) -from .utils import get_model, lenient_issubclass, sequence_like - -if TYPE_CHECKING: - from .main import BaseModel # noqa: F401 - from .dataclasses import DataclassType # noqa: F401 - -default_prefix = "#/definitions/" - - -def schema( - models: Sequence[Union[Type["BaseModel"], Type["DataclassType"]]], - *, - by_alias: bool = True, - title: Optional[str] = None, - description: Optional[str] = None, - ref_prefix: Optional[str] = None, -) -> Dict[str, Any]: - """ - Process a list of models and generate a single JSON Schema with all of them defined in the ``definitions`` - top-level JSON key, including their sub-models. - - :param models: a list of models to include in the generated JSON Schema - :param by_alias: generate the schemas using the aliases defined, if any - :param title: title for the generated schema that includes the definitions - :param description: description for the generated schema - :param ref_prefix: the JSON Pointer prefix for schema references with ``$ref``, if None, will be set to the - default of ``#/definitions/``. Update it if you want the schemas to reference the definitions somewhere - else, e.g. for OpenAPI use ``#/components/schemas/``. The resulting generated schemas will still be at the - top-level key ``definitions``, so you can extract them from there. But all the references will have the set - prefix. - :return: dict with the JSON Schema with a ``definitions`` top-level key including the schema definitions for - the models and sub-models passed in ``models``. - """ - clean_models = [get_model(model) for model in models] - ref_prefix = ref_prefix or default_prefix - flat_models = get_flat_models_from_models(clean_models) - model_name_map = get_model_name_map(flat_models) - definitions = {} - output_schema: Dict[str, Any] = {} - if title: - output_schema["title"] = title - if description: - output_schema["description"] = description - for model in clean_models: - m_schema, m_definitions, m_nested_models = model_process_schema( - model, - by_alias=by_alias, - model_name_map=model_name_map, - ref_prefix=ref_prefix, - ) - definitions.update(m_definitions) - model_name = model_name_map[model] - definitions[model_name] = m_schema - if definitions: - output_schema["definitions"] = definitions - return output_schema - - -def model_schema( - model: Union[Type["BaseModel"], Type["DataclassType"]], - by_alias: bool = True, - ref_prefix: Optional[str] = None, -) -> Dict[str, Any]: - """ - Generate a JSON Schema for one model. With all the sub-models defined in the ``definitions`` top-level - JSON key. - - :param model: a Pydantic model (a class that inherits from BaseModel) - :param by_alias: generate the schemas using the aliases defined, if any - :param ref_prefix: the JSON Pointer prefix for schema references with ``$ref``, if None, will be set to the - default of ``#/definitions/``. Update it if you want the schemas to reference the definitions somewhere - else, e.g. for OpenAPI use ``#/components/schemas/``. The resulting generated schemas will still be at the - top-level key ``definitions``, so you can extract them from there. But all the references will have the set - prefix. - :return: dict with the JSON Schema for the passed ``model`` - """ - model = get_model(model) - ref_prefix = ref_prefix or default_prefix - flat_models = get_flat_models_from_model(model) - model_name_map = get_model_name_map(flat_models) - model_name = model_name_map[model] - m_schema, m_definitions, nested_models = model_process_schema( - model, by_alias=by_alias, model_name_map=model_name_map, ref_prefix=ref_prefix - ) - if model_name in nested_models: - # model_name is in Nested models, it has circular references - m_definitions[model_name] = m_schema - m_schema = {"$ref": ref_prefix + model_name} - if m_definitions: - m_schema.update({"definitions": m_definitions}) - return m_schema - - -def field_schema( - field: ModelField, - *, - by_alias: bool = True, - model_name_map: Dict[Type["BaseModel"], str], - ref_prefix: Optional[str] = None, - known_models: Set[Type["BaseModel"]] = None, -) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: - """ - Process a Pydantic field and return a tuple with a JSON Schema for it as the first item. - Also return a dictionary of definitions with models as keys and their schemas as values. If the passed field - is a model and has sub-models, and those sub-models don't have overrides (as ``title``, ``default``, etc), they - will be included in the definitions and referenced in the schema instead of included recursively. - - :param field: a Pydantic ``ModelField`` - :param by_alias: use the defined alias (if any) in the returned schema - :param model_name_map: used to generate the JSON Schema references to other models included in the definitions - :param ref_prefix: the JSON Pointer prefix to use for references to other schemas, if None, the default of - #/definitions/ will be used - :param known_models: used to solve circular references - :return: tuple of the schema for this field and additional definitions - """ - ref_prefix = ref_prefix or default_prefix - schema_overrides = False - s = dict(title=field.field_info.title or field.alias.title().replace("_", " ")) - if field.field_info.title: - schema_overrides = True - - if field.field_info.description: - s["description"] = field.field_info.description - schema_overrides = True - - if not field.required and not field.field_info.const and field.default is not None: - s["default"] = encode_default(field.default) - schema_overrides = True - - validation_schema = get_field_schema_validations(field) - if validation_schema: - s.update(validation_schema) - schema_overrides = True - - f_schema, f_definitions, f_nested_models = field_type_schema( - field, - by_alias=by_alias, - model_name_map=model_name_map, - schema_overrides=schema_overrides, - ref_prefix=ref_prefix, - known_models=known_models or set(), - ) - # $ref will only be returned when there are no schema_overrides - if "$ref" in f_schema: - return f_schema, f_definitions, f_nested_models - else: - s.update(f_schema) - return s, f_definitions, f_nested_models - - -numeric_types = (int, float, Decimal) -_str_types_attrs: Tuple[Tuple[str, Union[type, Tuple[type, ...]], str], ...] = ( - ("max_length", numeric_types, "maxLength"), - ("min_length", numeric_types, "minLength"), - ("regex", str, "pattern"), -) - -_numeric_types_attrs: Tuple[Tuple[str, Union[type, Tuple[type, ...]], str], ...] = ( - ("gt", numeric_types, "exclusiveMinimum"), - ("lt", numeric_types, "exclusiveMaximum"), - ("ge", numeric_types, "minimum"), - ("le", numeric_types, "maximum"), - ("multiple_of", numeric_types, "multipleOf"), -) - - -def get_field_schema_validations(field: ModelField) -> Dict[str, Any]: - """ - Get the JSON Schema validation keywords for a ``field`` with an annotation of - a Pydantic ``FieldInfo`` with validation arguments. - """ - f_schema: Dict[str, Any] = {} - if lenient_issubclass(field.type_, (str, bytes)): - for attr_name, t, keyword in _str_types_attrs: - attr = getattr(field.field_info, attr_name, None) - if isinstance(attr, t): - f_schema[keyword] = attr - if lenient_issubclass(field.type_, numeric_types) and not issubclass( - field.type_, bool - ): - for attr_name, t, keyword in _numeric_types_attrs: - attr = getattr(field.field_info, attr_name, None) - if isinstance(attr, t): - f_schema[keyword] = attr - if field.field_info is not None and field.field_info.const: - f_schema["const"] = field.default - if field.field_info.extra: - f_schema.update(field.field_info.extra) - return f_schema - - -def get_model_name_map( - unique_models: Set[Type["BaseModel"]], -) -> Dict[Type["BaseModel"], str]: - """ - Process a set of models and generate unique names for them to be used as keys in the JSON Schema - definitions. By default the names are the same as the class name. But if two models in different Python - modules have the same name (e.g. "users.Model" and "items.Model"), the generated names will be - based on the Python module path for those conflicting models to prevent name collisions. - - :param unique_models: a Python set of models - :return: dict mapping models to names - """ - name_model_map = {} - conflicting_names: Set[str] = set() - for model in unique_models: - model_name = model.__name__ - model_name = re.sub(r"[^a-zA-Z0-9.\-_]", "_", model_name) - if model_name in conflicting_names: - model_name = get_long_model_name(model) - name_model_map[model_name] = model - elif model_name in name_model_map: - conflicting_names.add(model_name) - conflicting_model = name_model_map.pop(model_name) - name_model_map[get_long_model_name(conflicting_model)] = conflicting_model - name_model_map[get_long_model_name(model)] = model - else: - name_model_map[model_name] = model - return {v: k for k, v in name_model_map.items()} - - -def get_flat_models_from_model( - model: Type["BaseModel"], known_models: Set[Type["BaseModel"]] = None -) -> Set[Type["BaseModel"]]: - """ - Take a single ``model`` and generate a set with itself and all the sub-models in the tree. I.e. if you pass - model ``Foo`` (subclass of Pydantic ``BaseModel``) as ``model``, and it has a field of type ``Bar`` (also - subclass of ``BaseModel``) and that model ``Bar`` has a field of type ``Baz`` (also subclass of ``BaseModel``), - the return value will be ``set([Foo, Bar, Baz])``. - - :param model: a Pydantic ``BaseModel`` subclass - :param known_models: used to solve circular references - :return: a set with the initial model and all its sub-models - """ - known_models = known_models or set() - flat_models: Set[Type["BaseModel"]] = set() - flat_models.add(model) - known_models |= flat_models - fields = cast(Sequence[ModelField], model.__fields__.values()) - flat_models |= get_flat_models_from_fields(fields, known_models=known_models) - return flat_models - - -def get_flat_models_from_field( - field: ModelField, known_models: Set[Type["BaseModel"]] -) -> Set[Type["BaseModel"]]: - """ - Take a single Pydantic ``ModelField`` (from a model) that could have been declared as a sublcass of BaseModel - (so, it could be a submodel), and generate a set with its model and all the sub-models in the tree. - I.e. if you pass a field that was declared to be of type ``Foo`` (subclass of BaseModel) as ``field``, and that - model ``Foo`` has a field of type ``Bar`` (also subclass of ``BaseModel``) and that model ``Bar`` has a field of - type ``Baz`` (also subclass of ``BaseModel``), the return value will be ``set([Foo, Bar, Baz])``. - - :param field: a Pydantic ``ModelField`` - :param known_models: used to solve circular references - :return: a set with the model used in the declaration for this field, if any, and all its sub-models - """ - from .main import BaseModel # noqa: F811 - - flat_models: Set[Type[BaseModel]] = set() - # Handle dataclass-based models - field_type = field.type_ - if lenient_issubclass(getattr(field_type, "__pydantic_model__", None), BaseModel): - field_type = field_type.__pydantic_model__ - if field.sub_fields: - flat_models |= get_flat_models_from_fields( - field.sub_fields, known_models=known_models - ) - elif lenient_issubclass(field_type, BaseModel) and field_type not in known_models: - flat_models |= get_flat_models_from_model(field_type, known_models=known_models) - return flat_models - - -def get_flat_models_from_fields( - fields: Sequence[ModelField], known_models: Set[Type["BaseModel"]] -) -> Set[Type["BaseModel"]]: - """ - Take a list of Pydantic ``ModelField``s (from a model) that could have been declared as sublcasses of ``BaseModel`` - (so, any of them could be a submodel), and generate a set with their models and all the sub-models in the tree. - I.e. if you pass a the fields of a model ``Foo`` (subclass of ``BaseModel``) as ``fields``, and on of them has a - field of type ``Bar`` (also subclass of ``BaseModel``) and that model ``Bar`` has a field of type ``Baz`` (also - subclass of ``BaseModel``), the return value will be ``set([Foo, Bar, Baz])``. - - :param fields: a list of Pydantic ``ModelField``s - :param known_models: used to solve circular references - :return: a set with any model declared in the fields, and all their sub-models - """ - flat_models: Set[Type["BaseModel"]] = set() - for field in fields: - flat_models |= get_flat_models_from_field(field, known_models=known_models) - return flat_models - - -def get_flat_models_from_models( - models: Sequence[Type["BaseModel"]], -) -> Set[Type["BaseModel"]]: - """ - Take a list of ``models`` and generate a set with them and all their sub-models in their trees. I.e. if you pass - a list of two models, ``Foo`` and ``Bar``, both subclasses of Pydantic ``BaseModel`` as models, and ``Bar`` has - a field of type ``Baz`` (also subclass of ``BaseModel``), the return value will be ``set([Foo, Bar, Baz])``. - """ - flat_models: Set[Type["BaseModel"]] = set() - for model in models: - flat_models |= get_flat_models_from_model(model) - return flat_models - - -def get_long_model_name(model: Type["BaseModel"]) -> str: - return f"{model.__module__}__{model.__name__}".replace(".", "__") - - -def field_type_schema( - field: ModelField, - *, - by_alias: bool, - model_name_map: Dict[Type["BaseModel"], str], - schema_overrides: bool = False, - ref_prefix: Optional[str] = None, - known_models: Set[Type["BaseModel"]], -) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: - """ - Used by ``field_schema()``, you probably should be using that function. - - Take a single ``field`` and generate the schema for its type only, not including additional - information as title, etc. Also return additional schema definitions, from sub-models. - """ - definitions = {} - nested_models: Set[str] = set() - f_schema: Dict[str, Any] - ref_prefix = ref_prefix or default_prefix - if field.shape in { - SHAPE_LIST, - SHAPE_TUPLE_ELLIPSIS, - SHAPE_SEQUENCE, - SHAPE_SET, - SHAPE_FROZENSET, - }: - items_schema, f_definitions, f_nested_models = field_singleton_schema( - field, - by_alias=by_alias, - model_name_map=model_name_map, - ref_prefix=ref_prefix, - known_models=known_models, - ) - definitions.update(f_definitions) - nested_models.update(f_nested_models) - f_schema = {"type": "array", "items": items_schema} - if field.shape in {SHAPE_SET, SHAPE_FROZENSET}: - f_schema["uniqueItems"] = True - - elif field.shape == SHAPE_MAPPING: - f_schema = {"type": "object"} - key_field = cast(ModelField, field.key_field) - regex = getattr(key_field.type_, "regex", None) - items_schema, f_definitions, f_nested_models = field_singleton_schema( - field, - by_alias=by_alias, - model_name_map=model_name_map, - ref_prefix=ref_prefix, - known_models=known_models, - ) - definitions.update(f_definitions) - nested_models.update(f_nested_models) - if regex: - # Dict keys have a regex pattern - # items_schema might be a schema or empty dict, add it either way - f_schema["patternProperties"] = {regex.pattern: items_schema} - elif items_schema: - # The dict values are not simply Any, so they need a schema - f_schema["additionalProperties"] = items_schema - elif field.shape == SHAPE_TUPLE: - sub_schema = [] - sub_fields = cast(List[ModelField], field.sub_fields) - for sf in sub_fields: - sf_schema, sf_definitions, sf_nested_models = field_type_schema( - sf, - by_alias=by_alias, - model_name_map=model_name_map, - ref_prefix=ref_prefix, - known_models=known_models, - ) - definitions.update(sf_definitions) - nested_models.update(sf_nested_models) - sub_schema.append(sf_schema) - if len(sub_schema) == 1: - sub_schema = sub_schema[0] # type: ignore - f_schema = {"type": "array", "items": sub_schema} - else: - assert field.shape == SHAPE_SINGLETON, field.shape - f_schema, f_definitions, f_nested_models = field_singleton_schema( - field, - by_alias=by_alias, - model_name_map=model_name_map, - schema_overrides=schema_overrides, - ref_prefix=ref_prefix, - known_models=known_models, - ) - definitions.update(f_definitions) - nested_models.update(f_nested_models) - - # check field type to avoid repeated calls to the same __modify_schema__ method - if field.type_ != field.outer_type_: - modify_schema = getattr(field.outer_type_, "__modify_schema__", None) - if modify_schema: - modify_schema(f_schema) - return f_schema, definitions, nested_models - - -def model_process_schema( - model: Type["BaseModel"], - *, - by_alias: bool = True, - model_name_map: Dict[Type["BaseModel"], str], - ref_prefix: Optional[str] = None, - known_models: Set[Type["BaseModel"]] = None, -) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: - """ - Used by ``model_schema()``, you probably should be using that function. - - Take a single ``model`` and generate its schema. Also return additional schema definitions, from sub-models. The - sub-models of the returned schema will be referenced, but their definitions will not be included in the schema. All - the definitions are returned as the second value. - """ - ref_prefix = ref_prefix or default_prefix - known_models = known_models or set() - s = {"title": model.__config__.title or model.__name__} - doc = inspect.getdoc(model) - if doc: - s["description"] = doc - known_models.add(model) - m_schema, m_definitions, nested_models = model_type_schema( - model, - by_alias=by_alias, - model_name_map=model_name_map, - ref_prefix=ref_prefix, - known_models=known_models, - ) - s.update(m_schema) - schema_extra = model.__config__.schema_extra - if callable(schema_extra): - schema_extra(s) - else: - s.update(schema_extra) - return s, m_definitions, nested_models - - -def model_type_schema( - model: Type["BaseModel"], - *, - by_alias: bool, - model_name_map: Dict[Type["BaseModel"], str], - ref_prefix: Optional[str] = None, - known_models: Set[Type["BaseModel"]], -) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: - """ - You probably should be using ``model_schema()``, this function is indirectly used by that function. - - Take a single ``model`` and generate the schema for its type only, not including additional - information as title, etc. Also return additional schema definitions, from sub-models. - """ - ref_prefix = ref_prefix or default_prefix - properties = {} - required = [] - definitions: Dict[str, Any] = {} - nested_models: Set[str] = set() - for k, f in model.__fields__.items(): - try: - f_schema, f_definitions, f_nested_models = field_schema( - f, - by_alias=by_alias, - model_name_map=model_name_map, - ref_prefix=ref_prefix, - known_models=known_models, - ) - except SkipField as skip: - warnings.warn(skip.message, UserWarning) - continue - definitions.update(f_definitions) - nested_models.update(f_nested_models) - if by_alias: - properties[f.alias] = f_schema - if f.required: - required.append(f.alias) - else: - properties[k] = f_schema - if f.required: - required.append(k) - if ROOT_KEY in properties: - out_schema = properties[ROOT_KEY] - out_schema["title"] = model.__config__.title or model.__name__ - else: - out_schema = {"type": "object", "properties": properties} - if required: - out_schema["required"] = required - if model.__config__.extra == "forbid": - out_schema["additionalProperties"] = False - return out_schema, definitions, nested_models - - -def field_singleton_sub_fields_schema( - sub_fields: Sequence[ModelField], - *, - by_alias: bool, - model_name_map: Dict[Type["BaseModel"], str], - schema_overrides: bool = False, - ref_prefix: Optional[str] = None, - known_models: Set[Type["BaseModel"]], -) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: - """ - This function is indirectly used by ``field_schema()``, you probably should be using that function. - - Take a list of Pydantic ``ModelField`` from the declaration of a type with parameters, and generate their - schema. I.e., fields used as "type parameters", like ``str`` and ``int`` in ``Tuple[str, int]``. - """ - ref_prefix = ref_prefix or default_prefix - definitions = {} - nested_models: Set[str] = set() - sub_fields = [sf for sf in sub_fields if sf.include_in_schema()] - if len(sub_fields) == 1: - return field_type_schema( - sub_fields[0], - by_alias=by_alias, - model_name_map=model_name_map, - schema_overrides=schema_overrides, - ref_prefix=ref_prefix, - known_models=known_models, - ) - else: - sub_field_schemas = [] - for sf in sub_fields: - sub_schema, sub_definitions, sub_nested_models = field_type_schema( - sf, - by_alias=by_alias, - model_name_map=model_name_map, - schema_overrides=schema_overrides, - ref_prefix=ref_prefix, - known_models=known_models, - ) - definitions.update(sub_definitions) - sub_field_schemas.append(sub_schema) - nested_models.update(sub_nested_models) - return {"anyOf": sub_field_schemas}, definitions, nested_models - - -# Order is important, e.g. subclasses of str must go before str -# this is used only for standard library types, custom types should use __modify_schema__ instead -field_class_to_schema: Tuple[Tuple[Any, Dict[str, Any]], ...] = ( - (Path, {"type": "string", "format": "path"}), - (datetime, {"type": "string", "format": "date-time"}), - (date, {"type": "string", "format": "date"}), - (time, {"type": "string", "format": "time"}), - (timedelta, {"type": "number", "format": "time-delta"}), - (IPv4Network, {"type": "string", "format": "ipv4network"}), - (IPv6Network, {"type": "string", "format": "ipv6network"}), - (IPv4Interface, {"type": "string", "format": "ipv4interface"}), - (IPv6Interface, {"type": "string", "format": "ipv6interface"}), - (IPv4Address, {"type": "string", "format": "ipv4"}), - (IPv6Address, {"type": "string", "format": "ipv6"}), - (str, {"type": "string"}), - (bytes, {"type": "string", "format": "binary"}), - (bool, {"type": "boolean"}), - (int, {"type": "integer"}), - (float, {"type": "number"}), - (Decimal, {"type": "number"}), - (UUID, {"type": "string", "format": "uuid"}), - (dict, {"type": "object"}), - (list, {"type": "array", "items": {}}), - (tuple, {"type": "array", "items": {}}), - (set, {"type": "array", "items": {}, "uniqueItems": True}), -) - -json_scheme = {"type": "string", "format": "json-string"} - - -def field_singleton_schema( # noqa: C901 (ignore complexity) - field: ModelField, - *, - by_alias: bool, - model_name_map: Dict[Type["BaseModel"], str], - schema_overrides: bool = False, - ref_prefix: Optional[str] = None, - known_models: Set[Type["BaseModel"]], -) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: - """ - This function is indirectly used by ``field_schema()``, you should probably be using that function. - - Take a single Pydantic ``ModelField``, and return its schema and any additional definitions from sub-models. - """ - from .main import BaseModel # noqa: F811 - - ref_prefix = ref_prefix or default_prefix - definitions: Dict[str, Any] = {} - nested_models: Set[str] = set() - if field.sub_fields: - return field_singleton_sub_fields_schema( - field.sub_fields, - by_alias=by_alias, - model_name_map=model_name_map, - schema_overrides=schema_overrides, - ref_prefix=ref_prefix, - known_models=known_models, - ) - if field.type_ is Any or type(field.type_) == TypeVar: - return {}, definitions, nested_models # no restrictions - if is_callable_type(field.type_): - raise SkipField( - f"Callable {field.name} was excluded from schema since JSON schema has no equivalent type." - ) - f_schema: Dict[str, Any] = {} - if field.field_info is not None and field.field_info.const: - f_schema["const"] = field.default - field_type = field.type_ - if is_new_type(field_type): - field_type = new_type_supertype(field_type) - if is_literal_type(field_type): - values = literal_values(field_type) - if len(values) > 1: - return field_schema( - multivalue_literal_field_for_schema(values, field), - by_alias=by_alias, - model_name_map=model_name_map, - ref_prefix=ref_prefix, - known_models=known_models, - ) - literal_value = values[0] - field_type = type(literal_value) - f_schema["const"] = literal_value - - if issubclass(field_type, Enum): - f_schema.update({"enum": [item.value for item in field_type]}) - # Don't return immediately, to allow adding specific types - - for type_, t_schema in field_class_to_schema: - if issubclass(field_type, type_): - f_schema.update(t_schema) - break - - modify_schema = getattr(field_type, "__modify_schema__", None) - if modify_schema: - modify_schema(f_schema) - - if f_schema: - return f_schema, definitions, nested_models - - # Handle dataclass-based models - if lenient_issubclass(getattr(field_type, "__pydantic_model__", None), BaseModel): - field_type = field_type.__pydantic_model__ - - if issubclass(field_type, BaseModel): - model_name = model_name_map[field_type] - if field_type not in known_models: - sub_schema, sub_definitions, sub_nested_models = model_process_schema( - field_type, - by_alias=by_alias, - model_name_map=model_name_map, - ref_prefix=ref_prefix, - known_models=known_models, - ) - definitions.update(sub_definitions) - definitions[model_name] = sub_schema - nested_models.update(sub_nested_models) - else: - nested_models.add(model_name) - schema_ref = {"$ref": ref_prefix + model_name} - if not schema_overrides: - return schema_ref, definitions, nested_models - else: - return {"allOf": [schema_ref]}, definitions, nested_models - - raise ValueError(f"Value not declarable with JSON Schema, field: {field}") - - -def multivalue_literal_field_for_schema( - values: Tuple[Any, ...], field: ModelField -) -> ModelField: - return ModelField( - name=field.name, - type_=Union[tuple(Literal[value] for value in values)], - class_validators=field.class_validators, - model_config=field.model_config, - default=field.default, - required=field.required, - alias=field.alias, - field_info=field.field_info, - ) - - -def encode_default(dft: Any) -> Any: - if isinstance(dft, (int, float, str)): - return dft - elif sequence_like(dft): - t = type(dft) - return t(encode_default(v) for v in dft) - elif isinstance(dft, dict): - return {encode_default(k): encode_default(v) for k, v in dft.items()} - elif dft is None: - return None - else: - return pydantic_encoder(dft) - - -_map_types_constraint: Dict[Any, Callable[..., type]] = { - int: conint, - float: confloat, - Decimal: condecimal, -} -_field_constraints = { - "min_length", - "max_length", - "regex", - "gt", - "lt", - "ge", - "le", - "multiple_of", - "min_items", - "max_items", -} - - -def get_annotation_from_field_info( - annotation: Any, field_info: FieldInfo, field_name: str -) -> Type[Any]: # noqa: C901 - """ - Get an annotation with validation implemented for numbers and strings based on the field_info. - - :param annotation: an annotation from a field specification, as ``str``, ``ConstrainedStr`` - :param field_info: an instance of FieldInfo, possibly with declarations for validations and JSON Schema - :param field_name: name of the field for use in error messages - :return: the same ``annotation`` if unmodified or a new annotation with validation in place - """ - constraints = {f for f in _field_constraints if getattr(field_info, f) is not None} - if not constraints: - return annotation - used_constraints: Set[str] = set() - - def go(type_: Any) -> Type[Any]: - if ( - is_literal_type(annotation) - or isinstance(type_, ForwardRef) - or lenient_issubclass(type_, ConstrainedList) - ): - return type_ - origin = getattr(type_, "__origin__", None) - if origin is not None: - args: Tuple[Any, ...] = type_.__args__ - if any(isinstance(a, ForwardRef) for a in args): - # forward refs cause infinite recursion below - return type_ - - if origin is Union: - return Union[tuple(go(a) for a in args)] - - if issubclass(origin, List) and ( - field_info.min_items is not None or field_info.max_items is not None - ): - used_constraints.update({"min_items", "max_items"}) - return conlist( - go(args[0]), - min_items=field_info.min_items, - max_items=field_info.max_items, - ) - - for t in (Tuple, List, Set, FrozenSet, Sequence): - if issubclass(origin, t): # type: ignore - return t[tuple(go(a) for a in args)] # type: ignore - - if issubclass(origin, Dict): - return Dict[args[0], go(args[1])] # type: ignore - - attrs: Optional[Tuple[str, ...]] = None - constraint_func: Optional[Callable[..., type]] = None - if isinstance(type_, type): - if issubclass(type_, str) and not issubclass( - type_, (EmailStr, AnyUrl, ConstrainedStr) - ): - attrs = ("max_length", "min_length", "regex") - constraint_func = constr - elif issubclass(type_, bytes): - attrs = ("max_length", "min_length", "regex") - constraint_func = conbytes - elif issubclass(type_, numeric_types) and not issubclass( - type_, - ( - ConstrainedInt, - ConstrainedFloat, - ConstrainedDecimal, - ConstrainedList, - bool, - ), - ): - # Is numeric type - attrs = ("gt", "lt", "ge", "le", "multiple_of") - numeric_type = next( - t for t in numeric_types if issubclass(type_, t) - ) # pragma: no branch - constraint_func = _map_types_constraint[numeric_type] - - if attrs: - used_constraints.update(set(attrs)) - kwargs = { - attr_name: attr - for attr_name, attr in ( - (attr_name, getattr(field_info, attr_name)) for attr_name in attrs - ) - if attr is not None - } - if kwargs: - constraint_func = cast(Callable[..., type], constraint_func) - return constraint_func(**kwargs) - return type_ - - ans = go(annotation) - - unused_constraints = constraints - used_constraints - if unused_constraints: - raise ValueError( - f'On field "{field_name}" the following field constraints are set but not enforced: ' - f'{", ".join(unused_constraints)}. ' - f"\nFor more details see https://pydantic-docs.helpmanual.io/usage/schema/#unenforced-field-constraints" - ) - - return ans - - -class SkipField(Exception): - """ - Utility exception used to exclude fields from schema. - """ - - def __init__(self, message: str) -> None: - self.message = message diff --git a/nornir/_vendor/pydantic/tools.py b/nornir/_vendor/pydantic/tools.py deleted file mode 100644 index 7d2829b0..00000000 --- a/nornir/_vendor/pydantic/tools.py +++ /dev/null @@ -1,59 +0,0 @@ -import json -from functools import lru_cache -from pathlib import Path -from typing import Any, Callable, Optional, Type, TypeVar, Union - -from nornir._vendor.pydantic.parse import Protocol, load_file - -from .typing import display_as_type - -__all__ = ("parse_file_as", "parse_obj_as") - -NameFactory = Union[str, Callable[[Type[Any]], str]] - - -def _generate_parsing_type_name(type_: Any) -> str: - return f"ParsingModel[{display_as_type(type_)}]" - - -@lru_cache(maxsize=2048) -def _get_parsing_type(type_: Any, *, type_name: Optional[NameFactory] = None) -> Any: - from pydantic.main import create_model - - if type_name is None: - type_name = _generate_parsing_type_name - if not isinstance(type_name, str): - type_name = type_name(type_) - return create_model(type_name, __root__=(type_, ...)) - - -T = TypeVar("T") - - -def parse_obj_as( - type_: Type[T], obj: Any, *, type_name: Optional[NameFactory] = None -) -> T: - model_type = _get_parsing_type(type_, type_name=type_name) - return model_type(__root__=obj).__root__ - - -def parse_file_as( - type_: Type[T], - path: Union[str, Path], - *, - content_type: str = None, - encoding: str = "utf8", - proto: Protocol = None, - allow_pickle: bool = False, - json_loads: Callable[[str], Any] = json.loads, - type_name: Optional[NameFactory] = None, -) -> T: - obj = load_file( - path, - proto=proto, - content_type=content_type, - encoding=encoding, - allow_pickle=allow_pickle, - json_loads=json_loads, - ) - return parse_obj_as(type_, obj, type_name=type_name) diff --git a/nornir/_vendor/pydantic/types.py b/nornir/_vendor/pydantic/types.py deleted file mode 100644 index a4d3877b..00000000 --- a/nornir/_vendor/pydantic/types.py +++ /dev/null @@ -1,830 +0,0 @@ -import re -import warnings -from decimal import Decimal -from enum import Enum -from pathlib import Path -from types import new_class -from typing import ( - TYPE_CHECKING, - Any, - Callable, - ClassVar, - Dict, - List, - Optional, - Pattern, - Type, - TypeVar, - Union, - cast, -) -from uuid import UUID - -from . import errors -from .typing import AnyType -from .utils import import_string, update_not_none -from .validators import ( - bytes_validator, - constr_length_validator, - constr_strip_whitespace, - decimal_validator, - float_validator, - int_validator, - list_validator, - number_multiple_validator, - number_size_validator, - path_exists_validator, - path_validator, - str_validator, - strict_float_validator, - strict_int_validator, - strict_str_validator, -) - -__all__ = [ - "NoneStr", - "NoneBytes", - "StrBytes", - "NoneStrBytes", - "StrictStr", - "ConstrainedBytes", - "conbytes", - "ConstrainedList", - "conlist", - "ConstrainedStr", - "constr", - "PyObject", - "ConstrainedInt", - "conint", - "PositiveInt", - "NegativeInt", - "ConstrainedFloat", - "confloat", - "PositiveFloat", - "NegativeFloat", - "ConstrainedDecimal", - "condecimal", - "UUID1", - "UUID3", - "UUID4", - "UUID5", - "FilePath", - "DirectoryPath", - "Json", - "JsonWrapper", - "SecretStr", - "SecretBytes", - "StrictBool", - "StrictInt", - "StrictFloat", - "PaymentCardNumber", - "ByteSize", -] - -NoneStr = Optional[str] -NoneBytes = Optional[bytes] -StrBytes = Union[str, bytes] -NoneStrBytes = Optional[StrBytes] -OptionalInt = Optional[int] -OptionalIntFloat = Union[OptionalInt, float] -OptionalIntFloatDecimal = Union[OptionalIntFloat, Decimal] -StrIntFloat = Union[str, int, float] - -if TYPE_CHECKING: - from .dataclasses import DataclassType # noqa: F401 - from .main import BaseModel, BaseConfig # noqa: F401 - from .typing import CallableGenerator - - ModelOrDc = Type[Union["BaseModel", "DataclassType"]] - - -class ConstrainedBytes(bytes): - strip_whitespace = False - min_length: OptionalInt = None - max_length: OptionalInt = None - - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - update_not_none( - field_schema, minLength=cls.min_length, maxLength=cls.max_length - ) - - @classmethod - def __get_validators__(cls) -> "CallableGenerator": - yield bytes_validator - yield constr_strip_whitespace - yield constr_length_validator - - -def conbytes( - *, strip_whitespace: bool = False, min_length: int = None, max_length: int = None -) -> Type[bytes]: - # use kwargs then define conf in a dict to aid with IDE type hinting - namespace = dict( - strip_whitespace=strip_whitespace, min_length=min_length, max_length=max_length - ) - return type("ConstrainedBytesValue", (ConstrainedBytes,), namespace) - - -T = TypeVar("T") - - -# This types superclass should be List[T], but cython chokes on that... -class ConstrainedList(list): # type: ignore - # Needed for pydantic to detect that this is a list - __origin__ = list - __args__: List[Type[T]] # type: ignore - - min_items: Optional[int] = None - max_items: Optional[int] = None - item_type: Type[T] # type: ignore - - @classmethod - def __get_validators__(cls) -> "CallableGenerator": - yield list_validator - yield cls.list_length_validator - - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - update_not_none(field_schema, minItems=cls.min_items, maxItems=cls.max_items) - - @classmethod - def list_length_validator(cls, v: "List[T]") -> "List[T]": - v_len = len(v) - - if cls.min_items is not None and v_len < cls.min_items: - raise errors.ListMinLengthError(limit_value=cls.min_items) - - if cls.max_items is not None and v_len > cls.max_items: - raise errors.ListMaxLengthError(limit_value=cls.max_items) - - return v - - -def conlist( - item_type: Type[T], *, min_items: int = None, max_items: int = None -) -> Type[List[T]]: - # __args__ is needed to conform to typing generics api - namespace = { - "min_items": min_items, - "max_items": max_items, - "item_type": item_type, - "__args__": [item_type], - } - # We use new_class to be able to deal with Generic types - return new_class( - "ConstrainedListValue", (ConstrainedList,), {}, lambda ns: ns.update(namespace) - ) - - -class ConstrainedStr(str): - strip_whitespace = False - min_length: OptionalInt = None - max_length: OptionalInt = None - curtail_length: OptionalInt = None - regex: Optional[Pattern[str]] = None - strict = False - - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - update_not_none( - field_schema, - minLength=cls.min_length, - maxLength=cls.max_length, - pattern=cls.regex and cls.regex.pattern, - ) - - @classmethod - def __get_validators__(cls) -> "CallableGenerator": - yield strict_str_validator if cls.strict else str_validator - yield constr_strip_whitespace - yield constr_length_validator - yield cls.validate - - @classmethod - def validate(cls, value: Union[str]) -> Union[str]: - if cls.curtail_length and len(value) > cls.curtail_length: - value = value[: cls.curtail_length] - - if cls.regex: - if not cls.regex.match(value): - raise errors.StrRegexError(pattern=cls.regex.pattern) - - return value - - -def constr( - *, - strip_whitespace: bool = False, - strict: bool = False, - min_length: int = None, - max_length: int = None, - curtail_length: int = None, - regex: str = None, -) -> Type[str]: - # use kwargs then define conf in a dict to aid with IDE type hinting - namespace = dict( - strip_whitespace=strip_whitespace, - strict=strict, - min_length=min_length, - max_length=max_length, - curtail_length=curtail_length, - regex=regex and re.compile(regex), - ) - return type("ConstrainedStrValue", (ConstrainedStr,), namespace) - - -class StrictStr(ConstrainedStr): - strict = True - - -if TYPE_CHECKING: - StrictBool = bool -else: - - class StrictBool(int): - """ - StrictBool to allow for bools which are not type-coerced. - """ - - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - field_schema.update(type="boolean") - - @classmethod - def __get_validators__(cls) -> "CallableGenerator": - yield cls.validate - - @classmethod - def validate(cls, value: Any) -> bool: - """ - Ensure that we only allow bools. - """ - if isinstance(value, bool): - return value - - raise errors.StrictBoolError() - - -class PyObject: - validate_always = True - - @classmethod - def __get_validators__(cls) -> "CallableGenerator": - yield cls.validate - - @classmethod - def validate(cls, value: Any) -> Any: - if isinstance(value, Callable): # type: ignore - return value - - try: - value = str_validator(value) - except errors.StrError: - raise errors.PyObjectError( - error_message="value is neither a valid import path not a valid callable" - ) - - try: - return import_string(value) - except ImportError as e: - raise errors.PyObjectError(error_message=str(e)) - - -class ConstrainedNumberMeta(type): - def __new__(cls, name: str, bases: Any, dct: Dict[str, Any]) -> "ConstrainedInt": # type: ignore - new_cls = cast("ConstrainedInt", type.__new__(cls, name, bases, dct)) - - if new_cls.gt is not None and new_cls.ge is not None: - raise errors.ConfigError( - "bounds gt and ge cannot be specified at the same time" - ) - if new_cls.lt is not None and new_cls.le is not None: - raise errors.ConfigError( - "bounds lt and le cannot be specified at the same time" - ) - - return new_cls - - -class ConstrainedInt(int, metaclass=ConstrainedNumberMeta): - strict: bool = False - gt: OptionalInt = None - ge: OptionalInt = None - lt: OptionalInt = None - le: OptionalInt = None - multiple_of: OptionalInt = None - - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - update_not_none( - field_schema, - exclusiveMinimum=cls.gt, - exclusiveMaximum=cls.lt, - minimum=cls.ge, - maximum=cls.le, - multipleOf=cls.multiple_of, - ) - - @classmethod - def __get_validators__(cls) -> "CallableGenerator": - - yield strict_int_validator if cls.strict else int_validator - yield number_size_validator - yield number_multiple_validator - - -def conint( - *, - strict: bool = False, - gt: int = None, - ge: int = None, - lt: int = None, - le: int = None, - multiple_of: int = None, -) -> Type[int]: - # use kwargs then define conf in a dict to aid with IDE type hinting - namespace = dict(strict=strict, gt=gt, ge=ge, lt=lt, le=le, multiple_of=multiple_of) - return type("ConstrainedIntValue", (ConstrainedInt,), namespace) - - -class PositiveInt(ConstrainedInt): - gt = 0 - - -class NegativeInt(ConstrainedInt): - lt = 0 - - -class StrictInt(ConstrainedInt): - strict = True - - -class ConstrainedFloat(float, metaclass=ConstrainedNumberMeta): - strict: bool = False - gt: OptionalIntFloat = None - ge: OptionalIntFloat = None - lt: OptionalIntFloat = None - le: OptionalIntFloat = None - multiple_of: OptionalIntFloat = None - - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - update_not_none( - field_schema, - exclusiveMinimum=cls.gt, - exclusiveMaximum=cls.lt, - minimum=cls.ge, - maximum=cls.le, - multipleOf=cls.multiple_of, - ) - - @classmethod - def __get_validators__(cls) -> "CallableGenerator": - yield strict_float_validator if cls.strict else float_validator - yield number_size_validator - yield number_multiple_validator - - -def confloat( - *, - strict: bool = False, - gt: float = None, - ge: float = None, - lt: float = None, - le: float = None, - multiple_of: float = None, -) -> Type[float]: - # use kwargs then define conf in a dict to aid with IDE type hinting - namespace = dict(strict=strict, gt=gt, ge=ge, lt=lt, le=le, multiple_of=multiple_of) - return type("ConstrainedFloatValue", (ConstrainedFloat,), namespace) - - -class PositiveFloat(ConstrainedFloat): - gt = 0 - - -class NegativeFloat(ConstrainedFloat): - lt = 0 - - -class StrictFloat(ConstrainedFloat): - strict = True - - -class ConstrainedDecimal(Decimal, metaclass=ConstrainedNumberMeta): - gt: OptionalIntFloatDecimal = None - ge: OptionalIntFloatDecimal = None - lt: OptionalIntFloatDecimal = None - le: OptionalIntFloatDecimal = None - max_digits: OptionalInt = None - decimal_places: OptionalInt = None - multiple_of: OptionalIntFloatDecimal = None - - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - update_not_none( - field_schema, - exclusiveMinimum=cls.gt, - exclusiveMaximum=cls.lt, - minimum=cls.ge, - maximum=cls.le, - multipleOf=cls.multiple_of, - ) - - @classmethod - def __get_validators__(cls) -> "CallableGenerator": - yield decimal_validator - yield number_size_validator - yield number_multiple_validator - yield cls.validate - - @classmethod - def validate(cls, value: Decimal) -> Decimal: - digit_tuple, exponent = value.as_tuple()[1:] - if exponent in {"F", "n", "N"}: - raise errors.DecimalIsNotFiniteError() - - if exponent >= 0: - # A positive exponent adds that many trailing zeros. - digits = len(digit_tuple) + exponent - decimals = 0 - else: - # If the absolute value of the negative exponent is larger than the - # number of digits, then it's the same as the number of digits, - # because it'll consume all of the digits in digit_tuple and then - # add abs(exponent) - len(digit_tuple) leading zeros after the - # decimal point. - if abs(exponent) > len(digit_tuple): - digits = decimals = abs(exponent) - else: - digits = len(digit_tuple) - decimals = abs(exponent) - whole_digits = digits - decimals - - if cls.max_digits is not None and digits > cls.max_digits: - raise errors.DecimalMaxDigitsError(max_digits=cls.max_digits) - - if cls.decimal_places is not None and decimals > cls.decimal_places: - raise errors.DecimalMaxPlacesError(decimal_places=cls.decimal_places) - - if cls.max_digits is not None and cls.decimal_places is not None: - expected = cls.max_digits - cls.decimal_places - if whole_digits > expected: - raise errors.DecimalWholeDigitsError(whole_digits=expected) - - return value - - -def condecimal( - *, - gt: Decimal = None, - ge: Decimal = None, - lt: Decimal = None, - le: Decimal = None, - max_digits: int = None, - decimal_places: int = None, - multiple_of: Decimal = None, -) -> Type[Decimal]: - # use kwargs then define conf in a dict to aid with IDE type hinting - namespace = dict( - gt=gt, - ge=ge, - lt=lt, - le=le, - max_digits=max_digits, - decimal_places=decimal_places, - multiple_of=multiple_of, - ) - return type("ConstrainedDecimalValue", (ConstrainedDecimal,), namespace) - - -class UUID1(UUID): - _required_version = 1 - - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - field_schema.update(type="string", format=f"uuid{cls._required_version}") - - -class UUID3(UUID1): - _required_version = 3 - - -class UUID4(UUID1): - _required_version = 4 - - -class UUID5(UUID1): - _required_version = 5 - - -class FilePath(Path): - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - field_schema.update(format="file-path") - - @classmethod - def __get_validators__(cls) -> "CallableGenerator": - yield path_validator - yield path_exists_validator - yield cls.validate - - @classmethod - def validate(cls, value: Path) -> Path: - if not value.is_file(): - raise errors.PathNotAFileError(path=value) - - return value - - -class DirectoryPath(Path): - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - field_schema.update(format="directory-path") - - @classmethod - def __get_validators__(cls) -> "CallableGenerator": - yield path_validator - yield path_exists_validator - yield cls.validate - - @classmethod - def validate(cls, value: Path) -> Path: - if not value.is_dir(): - raise errors.PathNotADirectoryError(path=value) - - return value - - -class JsonWrapper: - pass - - -class JsonMeta(type): - def __getitem__(self, t: AnyType) -> Type[JsonWrapper]: - return type("JsonWrapperValue", (JsonWrapper,), {"inner_type": t}) - - -class Json(metaclass=JsonMeta): - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - field_schema.update(type="string", format="json-string") - - -class SecretStr: - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - field_schema.update(type="string", writeOnly=True) - - @classmethod - def __get_validators__(cls) -> "CallableGenerator": - yield str_validator - yield cls.validate - - @classmethod - def validate(cls, value: str) -> "SecretStr": - return cls(value) - - def __init__(self, value: str): - self._secret_value = value - - def __repr__(self) -> str: - return f"SecretStr('{self}')" - - def __str__(self) -> str: - return "**********" if self._secret_value else "" - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, SecretStr) - and self.get_secret_value() == other.get_secret_value() - ) - - def display(self) -> str: - warnings.warn( - "`secret_str.display()` is deprecated, use `str(secret_str)` instead", - DeprecationWarning, - ) - return str(self) - - def get_secret_value(self) -> str: - return self._secret_value - - -class SecretBytes: - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - field_schema.update(type="string", writeOnly=True) - - @classmethod - def __get_validators__(cls) -> "CallableGenerator": - yield bytes_validator - yield cls.validate - - @classmethod - def validate(cls, value: bytes) -> "SecretBytes": - return cls(value) - - def __init__(self, value: bytes): - self._secret_value = value - - def __repr__(self) -> str: - return f"SecretBytes(b'{self}')" - - def __str__(self) -> str: - return "**********" if self._secret_value else "" - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, SecretBytes) - and self.get_secret_value() == other.get_secret_value() - ) - - def display(self) -> str: - warnings.warn( - "`secret_bytes.display()` is deprecated, use `str(secret_bytes)` instead", - DeprecationWarning, - ) - return str(self) - - def get_secret_value(self) -> bytes: - return self._secret_value - - -class PaymentCardBrand(Enum): - amex = "American Express" - mastercard = "Mastercard" - visa = "Visa" - other = "other" - - def __str__(self) -> str: - return self.value - - -class PaymentCardNumber(str): - """ - Based on: https://en.wikipedia.org/wiki/Payment_card_number - """ - - strip_whitespace: ClassVar[bool] = True - min_length: ClassVar[int] = 12 - max_length: ClassVar[int] = 19 - bin: str - last4: str - brand: PaymentCardBrand - - def __init__(self, card_number: str): - self.bin = card_number[:6] - self.last4 = card_number[-4:] - self.brand = self._get_brand(card_number) - - @classmethod - def __get_validators__(cls) -> "CallableGenerator": - yield str_validator - yield constr_strip_whitespace - yield constr_length_validator - yield cls.validate_digits - yield cls.validate_luhn_check_digit - yield cls - yield cls.validate_length_for_brand - - @property - def masked(self) -> str: - num_masked = len(self) - 10 # len(bin) + len(last4) == 10 - return f'{self.bin}{"*" * num_masked}{self.last4}' - - @classmethod - def validate_digits(cls, card_number: str) -> str: - if not card_number.isdigit(): - raise errors.NotDigitError - return card_number - - @classmethod - def validate_luhn_check_digit(cls, card_number: str) -> str: - """ - Based on: https://en.wikipedia.org/wiki/Luhn_algorithm - """ - sum_ = int(card_number[-1]) - length = len(card_number) - parity = length % 2 - for i in range(length - 1): - digit = int(card_number[i]) - if i % 2 == parity: - digit *= 2 - sum_ += digit - valid = sum_ % 10 == 0 - if not valid: - raise errors.LuhnValidationError - return card_number - - @classmethod - def validate_length_for_brand( - cls, card_number: "PaymentCardNumber" - ) -> "PaymentCardNumber": - """ - Validate length based on BIN for major brands: - https://en.wikipedia.org/wiki/Payment_card_number#Issuer_identification_number_(IIN) - """ - required_length: Optional[int] = None - if card_number.brand is (PaymentCardBrand.visa or PaymentCardBrand.mastercard): - required_length = 16 - valid = len(card_number) == required_length - elif card_number.brand is PaymentCardBrand.amex: - required_length = 15 - valid = len(card_number) == required_length - else: - valid = True - if not valid: - raise errors.InvalidLengthForBrand( - brand=card_number.brand, required_length=required_length - ) - return card_number - - @staticmethod - def _get_brand(card_number: str) -> PaymentCardBrand: - if card_number[0] == "4": - brand = PaymentCardBrand.visa - elif 51 <= int(card_number[:2]) <= 55: - brand = PaymentCardBrand.mastercard - elif card_number[:2] in {"34", "37"}: - brand = PaymentCardBrand.amex - else: - brand = PaymentCardBrand.other - return brand - - -BYTE_SIZES = { - "b": 1, - "kb": 10 ** 3, - "mb": 10 ** 6, - "gb": 10 ** 9, - "tb": 10 ** 12, - "pb": 10 ** 15, - "eb": 10 ** 18, - "kib": 2 ** 10, - "mib": 2 ** 20, - "gib": 2 ** 30, - "tib": 2 ** 40, - "pib": 2 ** 50, - "eib": 2 ** 60, -} -BYTE_SIZES.update({k.lower()[0]: v for k, v in BYTE_SIZES.items() if "i" not in k}) -byte_string_re = re.compile(r"^\s*(\d*\.?\d+)\s*(\w+)?", re.IGNORECASE) - - -class ByteSize(int): - @classmethod - def __get_validators__(cls) -> "CallableGenerator": - yield cls.validate - - @classmethod - def validate(cls, v: StrIntFloat) -> "ByteSize": - - try: - return cls(int(v)) - except ValueError: - pass - - str_match = byte_string_re.match(str(v)) - if str_match is None: - raise errors.InvalidByteSize() - - scalar, unit = str_match.groups() - if unit is None: - unit = "b" - - try: - unit_mult = BYTE_SIZES[unit.lower()] - except KeyError: - raise errors.InvalidByteSizeUnit(unit=unit) - - return cls(int(float(scalar) * unit_mult)) - - def human_readable(self, decimal: bool = False) -> str: - - if decimal: - divisor = 1000 - units = ["B", "KB", "MB", "GB", "TB", "PB"] - final_unit = "EB" - else: - divisor = 1024 - units = ["B", "KiB", "MiB", "GiB", "TiB", "PiB"] - final_unit = "EiB" - - num = float(self) - for unit in units: - if abs(num) < divisor: - return f"{num:0.1f}{unit}" - num /= divisor - - return f"{num:0.1f}{final_unit}" - - def to(self, unit: str) -> float: - - try: - unit_div = BYTE_SIZES[unit.lower()] - except KeyError: - raise errors.InvalidByteSizeUnit(unit=unit) - - return self / unit_div diff --git a/nornir/_vendor/pydantic/typing.py b/nornir/_vendor/pydantic/typing.py deleted file mode 100644 index 71ecae09..00000000 --- a/nornir/_vendor/pydantic/typing.py +++ /dev/null @@ -1,239 +0,0 @@ -import sys -from enum import Enum -from typing import ( # type: ignore - TYPE_CHECKING, - AbstractSet, - Any, - ClassVar, - Dict, - Generator, - List, - NewType, - Optional, - Sequence, - Set, - Tuple, - Type, - Union, - _eval_type, -) - -try: - from typing import _TypingBase as typing_base # type: ignore -except ImportError: - from typing import _Final as typing_base # type: ignore - -try: - from typing import ForwardRef # type: ignore - - def evaluate_forwardref( - type_: ForwardRef, globalns: Any, localns: Any - ) -> Type[Any]: - return type_._evaluate(globalns, localns) - - -except ImportError: - # python 3.6 - from typing import _ForwardRef as ForwardRef # type: ignore - - def evaluate_forwardref( - type_: ForwardRef, globalns: Any, localns: Any - ) -> Type[Any]: - return type_._eval_type(globalns, localns) - - -if sys.version_info < (3, 7): - from typing import Callable as Callable - - AnyCallable = Callable[..., Any] -else: - from collections.abc import Callable as Callable - from typing import Callable as TypingCallable - - AnyCallable = TypingCallable[..., Any] - -if sys.version_info < (3, 8): - if TYPE_CHECKING: - from typing_extensions import Literal - else: # due to different mypy warnings raised during CI for python 3.7 and 3.8 - try: - from typing_extensions import Literal - except ImportError: - Literal = None -else: - from typing import Literal - -if TYPE_CHECKING: - from .fields import ModelField - - TupleGenerator = Generator[Tuple[str, Any], None, None] - DictStrAny = Dict[str, Any] - DictAny = Dict[Any, Any] - SetStr = Set[str] - ListStr = List[str] - IntStr = Union[int, str] - AbstractSetIntStr = AbstractSet[IntStr] - DictIntStrAny = Dict[IntStr, Any] - CallableGenerator = Generator[AnyCallable, None, None] - ReprArgs = Sequence[Tuple[Optional[str], Any]] - -__all__ = ( - "ForwardRef", - "Callable", - "AnyCallable", - "AnyType", - "NoneType", - "display_as_type", - "resolve_annotations", - "is_callable_type", - "is_literal_type", - "literal_values", - "Literal", - "is_new_type", - "new_type_supertype", - "is_classvar", - "update_field_forward_refs", - "TupleGenerator", - "DictStrAny", - "DictAny", - "SetStr", - "ListStr", - "IntStr", - "AbstractSetIntStr", - "DictIntStrAny", - "CallableGenerator", - "ReprArgs", - "CallableGenerator", -) - - -AnyType = Type[Any] -NoneType = type(None) - - -def display_as_type(v: AnyType) -> str: - if not isinstance(v, typing_base) and not isinstance(v, type): - v = type(v) - - if isinstance(v, type) and issubclass(v, Enum): - if issubclass(v, int): - return "int" - elif issubclass(v, str): - return "str" - else: - return "enum" - - try: - return v.__name__ - except AttributeError: - # happens with typing objects - return str(v).replace("typing.", "") - - -def resolve_annotations( - raw_annotations: Dict[str, AnyType], module_name: Optional[str] -) -> Dict[str, AnyType]: - """ - Partially taken from typing.get_type_hints. - - Resolve string or ForwardRef annotations into type objects if possible. - """ - if module_name: - base_globals: Optional[Dict[str, Any]] = sys.modules[module_name].__dict__ - else: - base_globals = None - annotations = {} - for name, value in raw_annotations.items(): - if isinstance(value, str): - if sys.version_info >= (3, 7): - value = ForwardRef(value, is_argument=False) - else: - value = ForwardRef(value) - try: - value = _eval_type(value, base_globals, None) - except NameError: - # this is ok, it can be fixed with update_forward_refs - pass - annotations[name] = value - return annotations - - -def is_callable_type(type_: AnyType) -> bool: - return type_ is Callable or getattr(type_, "__origin__", None) is Callable - - -if sys.version_info >= (3, 7): - - def is_literal_type(type_: AnyType) -> bool: - return Literal is not None and getattr(type_, "__origin__", None) is Literal - - def literal_values(type_: AnyType) -> Tuple[Any, ...]: - return type_.__args__ - - -else: - - def is_literal_type(type_: AnyType) -> bool: - return ( - Literal is not None - and hasattr(type_, "__values__") - and type_ == Literal[type_.__values__] - ) - - def literal_values(type_: AnyType) -> Tuple[Any, ...]: - return type_.__values__ - - -test_type = NewType("test_type", str) - - -def is_new_type(type_: AnyType) -> bool: - return isinstance(type_, type(test_type)) and hasattr(type_, "__supertype__") - - -def new_type_supertype(type_: AnyType) -> AnyType: - while hasattr(type_, "__supertype__"): - type_ = type_.__supertype__ - return type_ - - -def _check_classvar(v: AnyType) -> bool: - return type(v) == type(ClassVar) and ( - sys.version_info < (3, 7) or getattr(v, "_name", None) == "ClassVar" - ) - - -def is_classvar(ann_type: AnyType) -> bool: - return _check_classvar(ann_type) or _check_classvar( - getattr(ann_type, "__origin__", None) - ) - - -def update_field_forward_refs(field: "ModelField", globalns: Any, localns: Any) -> None: - """ - Try to update ForwardRefs on fields based on this ModelField, globalns and localns. - """ - if type(field.type_) == ForwardRef: - field.type_ = evaluate_forwardref(field.type_, globalns, localns or None) - field.prepare() - if field.sub_fields: - for sub_f in field.sub_fields: - update_field_forward_refs(sub_f, globalns=globalns, localns=localns) - - -def get_class(type_: AnyType) -> Union[None, bool, AnyType]: - """ - Tries to get the class of a Type[T] annotation. Returns True if Type is used - without brackets. Otherwise returns None. - """ - try: - origin = getattr(type_, "__origin__") - if origin is None: # Python 3.6 - origin = type_ - if issubclass(origin, Type): # type: ignore - if type_.__args__ is None or not isinstance(type_.__args__[0], type): - return True - return type_.__args__[0] - except AttributeError: - pass - return None diff --git a/nornir/_vendor/pydantic/utils.py b/nornir/_vendor/pydantic/utils.py deleted file mode 100644 index 589a6c2b..00000000 --- a/nornir/_vendor/pydantic/utils.py +++ /dev/null @@ -1,379 +0,0 @@ -import inspect -import platform -import sys -import warnings -from importlib import import_module -from pathlib import Path -from typing import ( - TYPE_CHECKING, - AbstractSet, - Any, - Callable, - Dict, - Generator, - Iterator, - List, - Optional, - Set, - Tuple, - Type, - TypeVar, - Union, - no_type_check, -) - -from .typing import AnyType, display_as_type - -if TYPE_CHECKING: - from .main import BaseModel # noqa: F401 - from .typing import AbstractSetIntStr, DictIntStrAny, IntStr, ReprArgs # noqa: F401 - from .dataclasses import DataclassType # noqa: F401 - -KeyType = TypeVar("KeyType") - - -def import_string(dotted_path: str) -> Any: - """ - Stolen approximately from django. Import a dotted module path and return the attribute/class designated by the - last name in the path. Raise ImportError if the import fails. - """ - try: - module_path, class_name = dotted_path.strip(" ").rsplit(".", 1) - except ValueError as e: - raise ImportError(f'"{dotted_path}" doesn\'t look like a module path') from e - - module = import_module(module_path) - try: - return getattr(module, class_name) - except AttributeError as e: - raise ImportError( - f'Module "{module_path}" does not define a "{class_name}" attribute' - ) from e - - -def truncate(v: Union[str], *, max_len: int = 80) -> str: - """ - Truncate a value and add a unicode ellipsis (three dots) to the end if it was too long - """ - warnings.warn( - "`truncate` is no-longer used by pydantic and is deprecated", DeprecationWarning - ) - if isinstance(v, str) and len(v) > (max_len - 2): - # -3 so quote + string + … + quote has correct length - return (v[: (max_len - 3)] + "…").__repr__() - try: - v = v.__repr__() - except TypeError: - v = type(v).__repr__(v) # in case v is a type - if len(v) > max_len: - v = v[: max_len - 1] + "…" - return v - - -ExcType = Type[Exception] - - -def sequence_like(v: AnyType) -> bool: - return isinstance(v, (list, tuple, set, frozenset)) or inspect.isgenerator(v) - - -def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None: - """ - Ensure that the field's name does not shadow an existing attribute of the model. - """ - for base in bases: - if getattr(base, field_name, None): - raise NameError( - f'Field name "{field_name}" shadows a BaseModel attribute; ' - f"use a different field name with \"alias='{field_name}'\"." - ) - - -def lenient_issubclass( - cls: Any, class_or_tuple: Union[AnyType, Tuple[AnyType, ...]] -) -> bool: - return isinstance(cls, type) and issubclass(cls, class_or_tuple) - - -def in_ipython() -> bool: - """ - Check whether we're in an ipython environment, including jupyter notebooks. - """ - try: - eval("__IPYTHON__") - except NameError: - return False - else: # pragma: no cover - return True - - -def deep_update( - mapping: Dict[KeyType, Any], updating_mapping: Dict[KeyType, Any] -) -> Dict[KeyType, Any]: - updated_mapping = mapping.copy() - for k, v in updating_mapping.items(): - if k in mapping and isinstance(mapping[k], dict) and isinstance(v, dict): - updated_mapping[k] = deep_update(mapping[k], v) - else: - updated_mapping[k] = v - return updated_mapping - - -def update_not_none(mapping: Dict[Any, Any], **update: Any) -> None: - mapping.update({k: v for k, v in update.items() if v is not None}) - - -def almost_equal_floats(value_1: float, value_2: float, *, delta: float = 1e-8) -> bool: - """ - Return True if two floats are almost equal - """ - return abs(value_1 - value_2) <= delta - - -def get_model( - obj: Union[Type["BaseModel"], Type["DataclassType"]] -) -> Type["BaseModel"]: - from .main import BaseModel # noqa: F811 - - try: - model_cls = obj.__pydantic_model__ # type: ignore - except AttributeError: - model_cls = obj - - if not issubclass(model_cls, BaseModel): - raise TypeError("Unsupported type, must be either BaseModel or dataclass") - return model_cls - - -class PyObjectStr(str): - """ - String class where repr doesn't include quotes. Useful with Representation when you want to return a string - representation of something that valid (or pseudo-valid) python. - """ - - def __repr__(self) -> str: - return str(self) - - -class Representation: - """ - Mixin to provide __str__, __repr__, and __pretty__ methods. See #884 for more details. - - __pretty__ is used by [devtools](https://python-devtools.helpmanual.io/) to provide human readable representations - of objects. - """ - - def __repr_args__(self) -> "ReprArgs": - """ - Returns the attributes to show in __str__, __repr__, and __pretty__ this is generally overridden. - - Can either return: - * name - value pairs, e.g.: `[('foo_name', 'foo'), ('bar_name', ['b', 'a', 'r'])]` - * or, just values, e.g.: `[(None, 'foo'), (None, ['b', 'a', 'r'])]` - """ - attrs = ((s, getattr(self, s)) for s in self.__slots__) - return [(a, v) for a, v in attrs if v is not None] - - def __repr_name__(self) -> str: - """ - Name of the instance's class, used in __repr__. - """ - return self.__class__.__name__ - - def __repr_str__(self, join_str: str) -> str: - return join_str.join( - repr(v) if a is None else f"{a}={v!r}" for a, v in self.__repr_args__() - ) - - def __pretty__( - self, fmt: Callable[[Any], Any], **kwargs: Any - ) -> Generator[Any, None, None]: - """ - Used by devtools (https://python-devtools.helpmanual.io/) to provide a human readable representations of objects - """ - yield self.__repr_name__() + "(" - yield 1 - for name, value in self.__repr_args__(): - if name is not None: - yield name + "=" - yield fmt(value) - yield "," - yield 0 - yield -1 - yield ")" - - def __str__(self) -> str: - return self.__repr_str__(" ") - - def __repr__(self) -> str: - return f'{self.__repr_name__()}({self.__repr_str__(", ")})' - - -class GetterDict(Representation): - """ - Hack to make object's smell just enough like dicts for validate_model. - - We can't inherit from Mapping[str, Any] because it upsets cython so we have to implement all methods ourselves. - """ - - __slots__ = ("_obj",) - - def __init__(self, obj: Any): - self._obj = obj - - def __getitem__(self, key: str) -> Any: - try: - return getattr(self._obj, key) - except AttributeError as e: - raise KeyError(key) from e - - def get(self, key: Any, default: Any = None) -> Any: - return getattr(self._obj, key, default) - - def extra_keys(self) -> Set[Any]: - """ - We don't want to get any other attributes of obj if the model didn't explicitly ask for them - """ - return set() - - def keys(self) -> List[Any]: - """ - Keys of the pseudo dictionary, uses a list not set so order information can be maintained like python - dictionaries. - """ - return list(self) - - def values(self) -> List[Any]: - return [self[k] for k in self] - - def items(self) -> Iterator[Tuple[str, Any]]: - for k in self: - yield k, self.get(k) - - def __iter__(self) -> Iterator[str]: - for name in dir(self._obj): - if not name.startswith("_"): - yield name - - def __len__(self) -> int: - return sum(1 for _ in self) - - def __contains__(self, item: Any) -> bool: - return item in self.keys() - - def __eq__(self, other: Any) -> bool: - return dict(self) == dict(other.items()) # type: ignore - - def __repr_args__(self) -> "ReprArgs": - return [(None, dict(self))] # type: ignore - - def __repr_name__(self) -> str: - return f"GetterDict[{display_as_type(self._obj)}]" - - -class ValueItems(Representation): - """ - Class for more convenient calculation of excluded or included fields on values. - """ - - __slots__ = ("_items", "_type") - - def __init__( - self, value: Any, items: Union["AbstractSetIntStr", "DictIntStrAny"] - ) -> None: - if TYPE_CHECKING: - self._items: Union["AbstractSetIntStr", "DictIntStrAny"] - self._type: Type[Union[set, dict]] # type: ignore - - # For further type checks speed-up - if isinstance(items, dict): - self._type = dict - elif isinstance(items, AbstractSet): - self._type = set - else: - raise TypeError(f"Unexpected type of exclude value {type(items)}") - - if isinstance(value, (list, tuple)): - items = self._normalize_indexes(items, len(value)) - - self._items = items - - @no_type_check - def is_excluded(self, item: Any) -> bool: - """ - Check if item is fully excluded - (value considered excluded if self._type is set and item contained in self._items - or self._type is dict and self._items.get(item) is ... - - :param item: key or index of a value - """ - if self._type is set: - return item in self._items - return self._items.get(item) is ... - - @no_type_check - def is_included(self, item: Any) -> bool: - """ - Check if value is contained in self._items - - :param item: key or index of value - """ - return item in self._items - - @no_type_check - def for_element( - self, e: "IntStr" - ) -> Optional[Union["AbstractSetIntStr", "DictIntStrAny"]]: - """ - :param e: key or index of element on value - :return: raw values for elemet if self._items is dict and contain needed element - """ - - if self._type is dict: - item = self._items.get(e) - return item if item is not ... else None - return None - - @no_type_check - def _normalize_indexes( - self, items: Union["AbstractSetIntStr", "DictIntStrAny"], v_length: int - ) -> Union["AbstractSetIntStr", "DictIntStrAny"]: - """ - :param items: dict or set of indexes which will be normalized - :param v_length: length of sequence indexes of which will be - - >>> self._normalize_indexes({0, -2, -1}, 4) - {0, 2, 3} - """ - if self._type is set: - return {v_length + i if i < 0 else i for i in items} - else: - return {v_length + i if i < 0 else i: v for i, v in items.items()} - - def __repr_args__(self) -> "ReprArgs": - return [(None, self._items)] - - -def version_info() -> str: - from .main import compiled - from .version import VERSION - - optional_deps = [] - for p in ("typing-extensions", "email-validator", "devtools"): - try: - import_module(p.replace("-", "_")) - except ImportError: - continue - optional_deps.append(p) - - info = { - "pydantic version": VERSION, - "pydantic compiled": compiled, - "install path": Path(__file__).resolve().parent, - "python version": sys.version, - "platform": platform.platform(), - "optional deps. installed": optional_deps, - } - return "\n".join( - "{:>30} {}".format(k + ":", str(v).replace("\n", " ")) for k, v in info.items() - ) diff --git a/nornir/_vendor/pydantic/validators.py b/nornir/_vendor/pydantic/validators.py deleted file mode 100644 index 454ced8a..00000000 --- a/nornir/_vendor/pydantic/validators.py +++ /dev/null @@ -1,615 +0,0 @@ -import re -import sys -from collections import OrderedDict -from datetime import date, datetime, time, timedelta -from decimal import Decimal, DecimalException -from enum import Enum, IntEnum -from ipaddress import ( - IPv4Address, - IPv4Interface, - IPv4Network, - IPv6Address, - IPv6Interface, - IPv6Network, -) -from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - FrozenSet, - Generator, - List, - Optional, - Pattern, - Set, - Tuple, - Type, - TypeVar, - Union, -) -from uuid import UUID - -from . import errors -from .datetime_parse import parse_date, parse_datetime, parse_duration, parse_time -from .typing import ( - AnyCallable, - AnyType, - ForwardRef, - display_as_type, - get_class, - is_callable_type, - is_literal_type, -) -from .utils import almost_equal_floats, lenient_issubclass, sequence_like - -if TYPE_CHECKING: - from .fields import ModelField - from .main import BaseConfig - from .types import ConstrainedDecimal, ConstrainedFloat, ConstrainedInt - - ConstrainedNumber = Union[ConstrainedDecimal, ConstrainedFloat, ConstrainedInt] - AnyOrderedDict = OrderedDict[Any, Any] - Number = Union[int, float, Decimal] - StrBytes = Union[str, bytes] - - -def str_validator(v: Any) -> Optional[str]: - if isinstance(v, str): - if isinstance(v, Enum): - return v.value - else: - return v - elif isinstance(v, (float, int, Decimal)): - # is there anything else we want to add here? If you think so, create an issue. - return str(v) - elif isinstance(v, (bytes, bytearray)): - return v.decode() - else: - raise errors.StrError() - - -def strict_str_validator(v: Any) -> Union[str]: - if isinstance(v, str): - return v - raise errors.StrError() - - -def bytes_validator(v: Any) -> bytes: - if isinstance(v, bytes): - return v - elif isinstance(v, bytearray): - return bytes(v) - elif isinstance(v, str): - return v.encode() - elif isinstance(v, (float, int, Decimal)): - return str(v).encode() - else: - raise errors.BytesError() - - -BOOL_FALSE = {0, "0", "off", "f", "false", "n", "no"} -BOOL_TRUE = {1, "1", "on", "t", "true", "y", "yes"} - - -def bool_validator(v: Any) -> bool: - if v is True or v is False: - return v - if isinstance(v, bytes): - v = v.decode() - if isinstance(v, str): - v = v.lower() - try: - if v in BOOL_TRUE: - return True - if v in BOOL_FALSE: - return False - except TypeError: - raise errors.BoolError() - raise errors.BoolError() - - -def int_validator(v: Any) -> int: - if isinstance(v, int) and not (v is True or v is False): - return v - - try: - return int(v) - except (TypeError, ValueError): - raise errors.IntegerError() - - -def strict_int_validator(v: Any) -> int: - if isinstance(v, int) and not (v is True or v is False): - return v - raise errors.IntegerError() - - -def float_validator(v: Any) -> float: - if isinstance(v, float): - return v - - try: - return float(v) - except (TypeError, ValueError): - raise errors.FloatError() - - -def strict_float_validator(v: Any) -> float: - if isinstance(v, float): - return v - raise errors.FloatError() - - -def number_multiple_validator(v: "Number", field: "ModelField") -> "Number": - field_type: ConstrainedNumber = field.type_ - if field_type.multiple_of is not None: - mod = float(v) / float(field_type.multiple_of) % 1 - if not almost_equal_floats(mod, 0.0) and not almost_equal_floats(mod, 1.0): - raise errors.NumberNotMultipleError(multiple_of=field_type.multiple_of) - return v - - -def number_size_validator(v: "Number", field: "ModelField") -> "Number": - field_type: ConstrainedNumber = field.type_ - if field_type.gt is not None and not v > field_type.gt: - raise errors.NumberNotGtError(limit_value=field_type.gt) - elif field_type.ge is not None and not v >= field_type.ge: - raise errors.NumberNotGeError(limit_value=field_type.ge) - - if field_type.lt is not None and not v < field_type.lt: - raise errors.NumberNotLtError(limit_value=field_type.lt) - if field_type.le is not None and not v <= field_type.le: - raise errors.NumberNotLeError(limit_value=field_type.le) - - return v - - -def constant_validator(v: "Any", field: "ModelField") -> "Any": - """Validate ``const`` fields. - - The value provided for a ``const`` field must be equal to the default value - of the field. This is to support the keyword of the same name in JSON - Schema. - """ - if v != field.default: - raise errors.WrongConstantError(given=v, permitted=[field.default]) - - return v - - -def anystr_length_validator(v: "StrBytes", config: "BaseConfig") -> "StrBytes": - v_len = len(v) - - min_length = config.min_anystr_length - if min_length is not None and v_len < min_length: - raise errors.AnyStrMinLengthError(limit_value=min_length) - - max_length = config.max_anystr_length - if max_length is not None and v_len > max_length: - raise errors.AnyStrMaxLengthError(limit_value=max_length) - - return v - - -def anystr_strip_whitespace(v: "StrBytes") -> "StrBytes": - return v.strip() - - -def ordered_dict_validator(v: Any) -> "AnyOrderedDict": - if isinstance(v, OrderedDict): - return v - - try: - return OrderedDict(v) - except (TypeError, ValueError): - raise errors.DictError() - - -def dict_validator(v: Any) -> Dict[Any, Any]: - if isinstance(v, dict): - return v - - try: - return dict(v) - except (TypeError, ValueError): - raise errors.DictError() - - -def list_validator(v: Any) -> List[Any]: - if isinstance(v, list): - return v - elif sequence_like(v): - return list(v) - else: - raise errors.ListError() - - -def tuple_validator(v: Any) -> Tuple[Any, ...]: - if isinstance(v, tuple): - return v - elif sequence_like(v): - return tuple(v) - else: - raise errors.TupleError() - - -def set_validator(v: Any) -> Set[Any]: - if isinstance(v, set): - return v - elif sequence_like(v): - return set(v) - else: - raise errors.SetError() - - -def frozenset_validator(v: Any) -> FrozenSet[Any]: - if isinstance(v, frozenset): - return v - elif sequence_like(v): - return frozenset(v) - else: - raise errors.FrozenSetError() - - -def enum_validator(v: Any, field: "ModelField", config: "BaseConfig") -> Enum: - try: - enum_v = field.type_(v) - except ValueError: - # field.type_ should be an enum, so will be iterable - raise errors.EnumError(enum_values=list(field.type_)) - return enum_v.value if config.use_enum_values else enum_v - - -def uuid_validator(v: Any, field: "ModelField") -> UUID: - try: - if isinstance(v, str): - v = UUID(v) - elif isinstance(v, (bytes, bytearray)): - v = UUID(v.decode()) - except ValueError: - raise errors.UUIDError() - - if not isinstance(v, UUID): - raise errors.UUIDError() - - required_version = getattr(field.type_, "_required_version", None) - if required_version and v.version != required_version: - raise errors.UUIDVersionError(required_version=required_version) - - return v - - -def decimal_validator(v: Any) -> Decimal: - if isinstance(v, Decimal): - return v - elif isinstance(v, (bytes, bytearray)): - v = v.decode() - - v = str(v).strip() - - try: - v = Decimal(v) - except DecimalException: - raise errors.DecimalError() - - if not v.is_finite(): - raise errors.DecimalIsNotFiniteError() - - return v - - -def ip_v4_address_validator(v: Any) -> IPv4Address: - if isinstance(v, IPv4Address): - return v - - try: - return IPv4Address(v) - except ValueError: - raise errors.IPv4AddressError() - - -def ip_v6_address_validator(v: Any) -> IPv6Address: - if isinstance(v, IPv6Address): - return v - - try: - return IPv6Address(v) - except ValueError: - raise errors.IPv6AddressError() - - -def ip_v4_network_validator(v: Any) -> IPv4Network: - """ - Assume IPv4Network initialised with a default ``strict`` argument - - See more: - https://docs.python.org/library/ipaddress.html#ipaddress.IPv4Network - """ - if isinstance(v, IPv4Network): - return v - - try: - return IPv4Network(v) - except ValueError: - raise errors.IPv4NetworkError() - - -def ip_v6_network_validator(v: Any) -> IPv6Network: - """ - Assume IPv6Network initialised with a default ``strict`` argument - - See more: - https://docs.python.org/library/ipaddress.html#ipaddress.IPv6Network - """ - if isinstance(v, IPv6Network): - return v - - try: - return IPv6Network(v) - except ValueError: - raise errors.IPv6NetworkError() - - -def ip_v4_interface_validator(v: Any) -> IPv4Interface: - if isinstance(v, IPv4Interface): - return v - - try: - return IPv4Interface(v) - except ValueError: - raise errors.IPv4InterfaceError() - - -def ip_v6_interface_validator(v: Any) -> IPv6Interface: - if isinstance(v, IPv6Interface): - return v - - try: - return IPv6Interface(v) - except ValueError: - raise errors.IPv6InterfaceError() - - -def path_validator(v: Any) -> Path: - if isinstance(v, Path): - return v - - try: - return Path(v) - except TypeError: - raise errors.PathError() - - -def path_exists_validator(v: Any) -> Path: - if not v.exists(): - raise errors.PathNotExistsError(path=v) - - return v - - -def callable_validator(v: Any) -> AnyCallable: - """ - Perform a simple check if the value is callable. - - Note: complete matching of argument type hints and return types is not performed - """ - if callable(v): - return v - - raise errors.CallableError(value=v) - - -def make_literal_validator(type_: Any) -> Callable[[Any], Any]: - if sys.version_info >= (3, 7): - permitted_choices = type_.__args__ - else: - permitted_choices = type_.__values__ - allowed_choices_set = set(permitted_choices) - - def literal_validator(v: Any) -> Any: - if v not in allowed_choices_set: - raise errors.WrongConstantError(given=v, permitted=permitted_choices) - return v - - return literal_validator - - -def constr_length_validator( - v: "StrBytes", field: "ModelField", config: "BaseConfig" -) -> "StrBytes": - v_len = len(v) - - min_length = field.type_.min_length or config.min_anystr_length - if min_length is not None and v_len < min_length: - raise errors.AnyStrMinLengthError(limit_value=min_length) - - max_length = field.type_.max_length or config.max_anystr_length - if max_length is not None and v_len > max_length: - raise errors.AnyStrMaxLengthError(limit_value=max_length) - - return v - - -def constr_strip_whitespace( - v: "StrBytes", field: "ModelField", config: "BaseConfig" -) -> "StrBytes": - strip_whitespace = field.type_.strip_whitespace or config.anystr_strip_whitespace - if strip_whitespace: - v = v.strip() - - return v - - -def validate_json(v: Any, config: "BaseConfig") -> Any: - if v is None: - # pass None through to other validators - return v - try: - return config.json_loads(v) # type: ignore - except ValueError: - raise errors.JsonError() - except TypeError: - raise errors.JsonTypeError() - - -T = TypeVar("T") - - -def make_arbitrary_type_validator(type_: Type[T]) -> Callable[[T], T]: - def arbitrary_type_validator(v: Any) -> T: - if isinstance(v, type_): - return v - raise errors.ArbitraryTypeError(expected_arbitrary_type=type_) - - return arbitrary_type_validator - - -def make_class_validator(type_: Type[T]) -> Callable[[Any], Type[T]]: - def class_validator(v: Any) -> Type[T]: - if lenient_issubclass(v, type_): - return v - raise errors.SubclassError(expected_class=type_) - - return class_validator - - -def any_class_validator(v: Any) -> Type[T]: - if isinstance(v, type): - return v - raise errors.ClassError() - - -def pattern_validator(v: Any) -> Pattern[str]: - try: - return re.compile(v) - except re.error: - raise errors.PatternError() - - -class IfConfig: - def __init__(self, validator: AnyCallable, *config_attr_names: str) -> None: - self.validator = validator - self.config_attr_names = config_attr_names - - def check(self, config: Type["BaseConfig"]) -> bool: - return any( - getattr(config, name) not in {None, False} - for name in self.config_attr_names - ) - - -pattern_validators = [str_validator, pattern_validator] -# order is important here, for example: bool is a subclass of int so has to come first, datetime before date same, -# IPv4Interface before IPv4Address, etc -_VALIDATORS: List[Tuple[AnyType, List[Any]]] = [ - (IntEnum, [int_validator, enum_validator]), - (Enum, [enum_validator]), - ( - str, - [ - str_validator, - IfConfig(anystr_strip_whitespace, "anystr_strip_whitespace"), - IfConfig(anystr_length_validator, "min_anystr_length", "max_anystr_length"), - ], - ), - ( - bytes, - [ - bytes_validator, - IfConfig(anystr_strip_whitespace, "anystr_strip_whitespace"), - IfConfig(anystr_length_validator, "min_anystr_length", "max_anystr_length"), - ], - ), - (bool, [bool_validator]), - (int, [int_validator]), - (float, [float_validator]), - (Path, [path_validator]), - (datetime, [parse_datetime]), - (date, [parse_date]), - (time, [parse_time]), - (timedelta, [parse_duration]), - (OrderedDict, [ordered_dict_validator]), - (dict, [dict_validator]), - (list, [list_validator]), - (tuple, [tuple_validator]), - (set, [set_validator]), - (frozenset, [frozenset_validator]), - (UUID, [uuid_validator]), - (Decimal, [decimal_validator]), - (IPv4Interface, [ip_v4_interface_validator]), - (IPv6Interface, [ip_v6_interface_validator]), - (IPv4Address, [ip_v4_address_validator]), - (IPv6Address, [ip_v6_address_validator]), - (IPv4Network, [ip_v4_network_validator]), - (IPv6Network, [ip_v6_network_validator]), -] - - -def find_validators( # noqa: C901 (ignore complexity) - type_: AnyType, config: Type["BaseConfig"] -) -> Generator[AnyCallable, None, None]: - if type_ is Any: - return - type_type = type(type_) - if type_type == ForwardRef or type_type == TypeVar: - return - if type_ is Pattern: - yield from pattern_validators - return - if is_callable_type(type_): - yield callable_validator - return - if is_literal_type(type_): - yield make_literal_validator(type_) - return - - class_ = get_class(type_) - if class_ is not None: - if isinstance(class_, type): - yield make_class_validator(class_) - else: - yield any_class_validator - return - - supertype = _find_supertype(type_) - if supertype is not None: - type_ = supertype - - for val_type, validators in _VALIDATORS: - try: - if issubclass(type_, val_type): - for v in validators: - if isinstance(v, IfConfig): - if v.check(config): - yield v.validator - else: - yield v - return - except TypeError: - raise RuntimeError( - f"error checking inheritance of {type_!r} (type: {display_as_type(type_)})" - ) - - if config.arbitrary_types_allowed: - yield make_arbitrary_type_validator(type_) - else: - raise RuntimeError( - f"no validator found for {type_}, see `arbitrary_types_allowed` in Config" - ) - - -def _find_supertype(type_: AnyType) -> Optional[AnyType]: - if not _is_new_type(type_): - return None - - supertype = type_.__supertype__ - if _is_new_type(supertype): - supertype = _find_supertype(supertype) - - return supertype - - -def _is_new_type(type_: AnyType) -> bool: - return hasattr(type_, "__name__") and hasattr(type_, "__supertype__") diff --git a/nornir/_vendor/pydantic/version.py b/nornir/_vendor/pydantic/version.py deleted file mode 100644 index 624787ab..00000000 --- a/nornir/_vendor/pydantic/version.py +++ /dev/null @@ -1,5 +0,0 @@ -from distutils.version import StrictVersion - -__all__ = ["VERSION"] - -VERSION = StrictVersion("1.3") diff --git a/nornir/core/configuration.py b/nornir/core/configuration.py index c2d40e91..10773500 100644 --- a/nornir/core/configuration.py +++ b/nornir/core/configuration.py @@ -1,57 +1,151 @@ +import ast import logging import logging.handlers +import os import sys import warnings from pathlib import Path -from typing import Any, Callable, Dict, Optional, TYPE_CHECKING, Type, List +from typing import Any, Dict, Optional, Type, TYPE_CHECKING, List, TypeVar from nornir.core.exceptions import ConflictingConfigurationWarning +import ruamel.yaml + if TYPE_CHECKING: from nornir.core.deserializer.inventory import Inventory # noqa +DEFAULT_SSH_CONFIG = str(Path("~/.ssh/config").expanduser()) + +T = TypeVar("T") + + +class Parameter: + def __init__( + self, + envvar: str, + typ: Optional[Type[T]] = None, + help: str = "", + default: Optional[T] = None, + ) -> None: + if typ is not None: + self.type: Type[T] = typ + elif default is not None: + self.type = default.__class__ + else: + raise TypeError("either typ or default needs to be specified") + self.envvar = envvar + self.help = help + self.default = default or self.type() + + def resolve(self, value: Optional[T]) -> T: + v: Optional[Any] = value + if value is None: + t = os.environ.get(self.envvar) + if self.type is bool and t: + v = t in ["true", "True", "1", "yes"] + elif self.type is str and t: + v = t + elif t: + v = ast.literal_eval(t) if t is not None else None + + if v is None: + v = self.default + return v + + class SSHConfig(object): - __slots__ = "config_file" + __slots__ = ("config_file",) - def __init__(self, config_file: str) -> None: - self.config_file = config_file + class Parameters: + config_file = Parameter( + default=DEFAULT_SSH_CONFIG, envvar="NORNIR_SSH_CONFIG_FILE" + ) + + def __init__(self, config_file: Optional[str] = None) -> None: + self.config_file = self.Parameters.config_file.resolve(config_file) + + def dict(self) -> Dict[str, Any]: + return {"config_file": self.config_file} class InventoryConfig(object): __slots__ = "plugin", "options", "transform_function", "transform_function_options" + class Parameters: + plugin = Parameter(typ=str, envvar="NORNIR_INVENTORY_PLUGIN") + options = Parameter(default={}, envvar="NORNIR_INVENTORY_OPTIONS") + transform_function = Parameter( + typ=str, envvar="NORNIR_INVENTORY_TRANSFORM_FUNCTION" + ) + transform_function_options = Parameter( + default={}, envvar="NORNIR_INVENTORY_TRANSFORM_FUNCTION_OPTIONS" + ) + def __init__( self, - plugin: Type["Inventory"], - options: Dict[str, Any], - transform_function: Optional[Callable[..., Any]], - transform_function_options: Optional[Dict[str, Any]], + plugin: Optional[str] = None, + options: Optional[Dict[str, Any]] = None, + transform_function: Optional[str] = None, + transform_function_options: Optional[Dict[str, Any]] = None, ) -> None: - self.plugin = plugin - self.options = options - self.transform_function = transform_function - self.transform_function_options = transform_function_options + self.plugin = self.Parameters.plugin.resolve(plugin) + self.options = self.Parameters.options.resolve(options) or {} + self.transform_function = self.Parameters.transform_function.resolve( + transform_function + ) + self.transform_function_options = self.Parameters.transform_function_options.resolve( + transform_function_options + ) + + def dict(self) -> Dict[str, Any]: + return { + "plugin": self.plugin, + "options": self.options, + "transform_function": self.transform_function, + "transform_function_options": self.transform_function_options, + } class LoggingConfig(object): - __slots__ = "enabled", "level", "file", "format", "to_console", "loggers" + __slots__ = "enabled", "level", "log_file", "format", "to_console", "loggers" + + class Parameters: + enabled = Parameter(default=True, envvar="NORNIR_LOGGING_ENABLED") + level = Parameter(default="INFO", envvar="NORNIR_LOGGING_LEVEL") + log_file = Parameter(default="nornir.log", envvar="NORNIR_LOGGING_LOG_FILE") + format = Parameter( + default="%(asctime)s - %(name)12s - %(levelname)8s - %(funcName)10s() - %(message)s", + envvar="NORNIR_LOGGING_FORMAT", + ) + to_console = Parameter(default=False, envvar="NORNIR_LOGGING_TO_CONSOLE") + loggers = Parameter(default=["nornir"], envvar="NORNIR_LOGGING_LOGGERS") def __init__( self, - enabled: Optional[bool], - level: str, - file_: str, - format_: str, - to_console: bool, - loggers: List[str], + enabled: Optional[bool] = None, + level: Optional[str] = None, + log_file: Optional[str] = None, + format: Optional[str] = None, + to_console: Optional[bool] = None, + loggers: Optional[List[str]] = None, ) -> None: - self.enabled = enabled - self.level = level - self.file = file_ - self.format = format_ - self.to_console = to_console - self.loggers = loggers + self.enabled = self.Parameters.enabled.resolve(enabled) + self.level = self.Parameters.level.resolve(level) + self.log_file = self.Parameters.log_file.resolve(log_file) + self.format = self.Parameters.format.resolve(format) + self.to_console = self.Parameters.to_console.resolve(to_console) + self.loggers = self.Parameters.loggers.resolve(loggers) + + def dict(self) -> Dict[str, Any]: + return { + "enabled": self.enabled, + "level": self.level, + "log_file": self.log_file, + "format": self.format, + "to_console": self.to_console, + "loggers": self.loggers, + } def configure(self) -> None: if not self.enabled: @@ -94,9 +188,9 @@ def configure(self) -> None: # logging.config.dictConfig configuring 'nornir' logger, etc. # The warning is not emitted in this scenario continue - if self.file: + if self.log_file: handler = logging.handlers.RotatingFileHandler( - str(Path(self.file)), maxBytes=1024 * 1024 * 10, backupCount=20 + str(Path(self.log_file)), maxBytes=1024 * 1024 * 10, backupCount=20 ) handler.setFormatter(formatter) logger_.addHandler(handler) @@ -106,19 +200,24 @@ def configure(self) -> None: logger_.addHandler(stderr_handler) -class Jinja2Config(object): - __slots__ = "filters" - - def __init__(self, filters: Optional[Dict[str, Callable[..., Any]]]) -> None: - self.filters = filters or {} - - class CoreConfig(object): __slots__ = ("num_workers", "raise_on_error") - def __init__(self, num_workers: int, raise_on_error: bool) -> None: - self.num_workers = num_workers - self.raise_on_error = raise_on_error + class Parameters: + num_workers = Parameter(default=20, envvar="NORNIR_CORE_NUM_WORKERS") + raise_on_error = Parameter(default=False, envvar="NORNIR_CORE_RAISE_ON_ERROR") + + def __init__( + self, num_workers: Optional[int] = None, raise_on_error: Optional[bool] = None + ) -> None: + self.num_workers = self.Parameters.num_workers.resolve(num_workers) + self.raise_on_error = self.Parameters.raise_on_error.resolve(raise_on_error) + + def dict(self) -> Dict[str, Any]: + return { + "num_workers": self.num_workers, + "raise_on_error": self.raise_on_error, + } class Config(object): @@ -126,16 +225,66 @@ class Config(object): def __init__( self, - inventory: InventoryConfig, - ssh: SSHConfig, - logging: LoggingConfig, - jinja2: Jinja2Config, - core: CoreConfig, - user_defined: Dict[str, Any], + inventory: Optional[InventoryConfig] = None, + ssh: Optional[SSHConfig] = None, + logging: Optional[LoggingConfig] = None, + core: Optional[CoreConfig] = None, + user_defined: Optional[Dict[str, Any]] = None, ) -> None: - self.inventory = inventory - self.ssh = ssh - self.logging = logging - self.jinja2 = jinja2 - self.core = core - self.user_defined = user_defined + self.inventory = inventory or InventoryConfig() + self.ssh = ssh or SSHConfig() + self.logging = logging or LoggingConfig() + self.core = core or CoreConfig() + self.user_defined = user_defined or {} + + @classmethod + def from_dict( + cls, + inventory: Dict[str, Any] = None, + ssh: Optional[Dict[str, Any]] = None, + logging: Optional[Dict[str, Any]] = None, + core: Optional[Dict[str, Any]] = None, + user_defined: Optional[Dict[str, Any]] = None, + ) -> "Config": + return cls( + inventory=InventoryConfig(**inventory or {}), + ssh=SSHConfig(**ssh or {}), + logging=LoggingConfig(**logging or {}), + core=CoreConfig(**core or {}), + user_defined=user_defined or {}, + ) + + @classmethod + def from_file( + cls, + config_file: str, + inventory: Optional[Dict[str, Any]] = None, + ssh: Optional[Dict[str, Any]] = None, + logging: Optional[Dict[str, Any]] = None, + core: Optional[Dict[str, Any]] = None, + user_defined: Optional[Dict[str, Any]] = None, + ) -> "Config": + inventory = inventory or {} + ssh = ssh or {} + logging = logging or {} + core = core or {} + user_defined = user_defined or {} + with open(config_file, "r") as f: + yml = ruamel.yaml.YAML(typ="safe") + data = yml.load(f) + return cls( + inventory=InventoryConfig(**{**data.get("inventory", {}), **inventory}), + ssh=SSHConfig(**{**data.get("ssh", {}), **ssh}), + logging=LoggingConfig(**{**data.get("loggin", {}), **logging}), + core=CoreConfig(**{**data.get("core", {}), **core}), + user_defined={**data.get("user_defined", {}), **user_defined}, + ) + + def dict(self) -> Dict[str, Any]: + return { + "inventory": self.inventory.dict(), + "ssh": self.ssh.dict(), + "logging": self.logging.dict(), + "core": self.core.dict(), + "user_defined": self.user_defined, + } diff --git a/nornir/core/connections.py b/nornir/core/connections.py index a554914e..7b7f7825 100644 --- a/nornir/core/connections.py +++ b/nornir/core/connections.py @@ -4,8 +4,8 @@ from nornir.core.configuration import Config from nornir.core.exceptions import ( - ConnectionPluginAlreadyRegistered, - ConnectionPluginNotRegistered, + PluginAlreadyRegistered, + PluginNotRegistered, ) @@ -68,7 +68,7 @@ def register(cls, name: str, plugin: Type[ConnectionPlugin]) -> None: if existing_plugin is None: cls.available[name] = plugin elif existing_plugin != plugin: - raise ConnectionPluginAlreadyRegistered( + raise PluginAlreadyRegistered( f"Connection plugin {plugin.__name__} can't be registered as " f"{name!r} because plugin {existing_plugin.__name__} " f"was already registered under this name" @@ -85,9 +85,7 @@ def deregister(cls, name: str) -> None: :obj:`nornir.core.exceptions.ConnectionPluginNotRegistered` """ if name not in cls.available: - raise ConnectionPluginNotRegistered( - f"Connection {name!r} is not registered" - ) + raise PluginNotRegistered(f"Connection {name!r} is not registered") cls.available.pop(name) @classmethod @@ -106,7 +104,5 @@ def get_plugin(cls, name: str) -> Type[ConnectionPlugin]: :obj:`nornir.core.exceptions.ConnectionPluginNotRegistered` """ if name not in cls.available: - raise ConnectionPluginNotRegistered( - f"Connection {name!r} is not registered" - ) + raise PluginNotRegistered(f"Connection {name!r} is not registered") return cls.available[name] diff --git a/nornir/core/deserializer/__init__.py b/nornir/core/deserializer/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/nornir/core/deserializer/configuration.py b/nornir/core/deserializer/configuration.py deleted file mode 100644 index d7b5fc8a..00000000 --- a/nornir/core/deserializer/configuration.py +++ /dev/null @@ -1,221 +0,0 @@ -import importlib -import logging -from pathlib import Path -from typing import Any, Callable, Dict, Optional, Type, Union, List, cast - -from nornir.core import configuration -from nornir.core.deserializer.inventory import Inventory - -from nornir._vendor.pydantic import BaseSettings, Field - -import ruamel.yaml - - -logger = logging.getLogger(__name__) - - -class BaseNornirSettings(BaseSettings): - def _build_values(self, init_kwargs: Dict[str, Any]) -> Dict[str, Any]: - config_settings = init_kwargs.pop("__config_settings__", {}) - return {**config_settings, **self._build_environ(), **init_kwargs} - - -class SSHConfig(BaseNornirSettings): - config_file: str = Field( - default="~/.ssh/config", description="Path to ssh configuration file" - ) - - class Config: - env_prefix = "NORNIR_SSH_" - ignore_extra = False - - @classmethod - def deserialize(cls, **kwargs: Any) -> configuration.SSHConfig: - s = SSHConfig(**kwargs) - s.config_file = str(Path(s.config_file).expanduser()) - return configuration.SSHConfig(**s.dict()) - - -class InventoryConfig(BaseNornirSettings): - plugin: str = Field( - default="nornir.plugins.inventory.simple.SimpleInventory", - description="Import path to inventory plugin", - ) - options: Dict[str, Any] = Field( - default={}, description="kwargs to pass to the inventory plugin" - ) - transform_function: str = Field( - default="", - description=( - "Path to transform function. The transform_function " - "you provide will run against each host in the inventory" - ), - ) - transform_function_options: Dict[str, Any] = Field( - default={}, description="kwargs to pass to the transform_function" - ) - - class Config: - env_prefix = "NORNIR_INVENTORY_" - ignore_extra = False - - @classmethod - def deserialize(cls, **kwargs: Any) -> configuration.InventoryConfig: - inv = InventoryConfig(**kwargs) - return configuration.InventoryConfig( - plugin=cast(Type[Inventory], _resolve_import_from_string(inv.plugin)), - options=inv.options, - transform_function=_resolve_import_from_string(inv.transform_function), - transform_function_options=inv.transform_function_options, - ) - - -class LoggingConfig(BaseNornirSettings): - enabled: Optional[bool] = Field( - default=None, description="Whether to configure logging or not" - ) - level: str = Field(default="INFO", description="Logging level") - file: str = Field(default="nornir.log", description="Logging file") - format: str = Field( - default="%(asctime)s - %(name)12s - %(levelname)8s - %(funcName)10s() - %(message)s", - description="Logging format", - ) - to_console: bool = Field( - default=False, description="Whether to log to console or not" - ) - loggers: List[str] = Field(default=["nornir"], description="Loggers to configure") - - class Config: - env_prefix = "NORNIR_LOGGING_" - ignore_extra = False - - @classmethod - def deserialize(cls, **kwargs) -> configuration.LoggingConfig: - conf = cls(**kwargs) - return configuration.LoggingConfig( - enabled=conf.enabled, - level=conf.level.upper(), - file_=conf.file, - format_=conf.format, - to_console=conf.to_console, - loggers=conf.loggers, - ) - - -class Jinja2Config(BaseNornirSettings): - filters: str = Field( - default="", description="Path to callable returning jinja filters to be used" - ) - - class Config: - env_prefix = "NORNIR_JINJA2_" - ignore_extra = False - - @classmethod - def deserialize(cls, **kwargs: Any) -> configuration.Jinja2Config: - c = Jinja2Config(**kwargs) - jinja_filter_func = _resolve_import_from_string(c.filters) - jinja_filters = jinja_filter_func() if jinja_filter_func else {} - return configuration.Jinja2Config(filters=jinja_filters) - - -class CoreConfig(BaseNornirSettings): - num_workers: int = Field( - default=20, - description="Number of Nornir worker threads that are run at the same time by default", - ) - raise_on_error: bool = Field( - default=False, - description=( - "If set to ``True``, (:obj:`nornir.core.Nornir.run`) method of " - "will raise exception :obj:`nornir.core.exceptions.NornirExecutionError` " - "if at least a host failed" - ), - ) - - class Config: - env_prefix = "NORNIR_CORE_" - ignore_extra = False - - @classmethod - def deserialize(cls, **kwargs: Any) -> configuration.CoreConfig: - c = CoreConfig(**kwargs) - return configuration.CoreConfig(**c.dict()) - - -class Config(BaseNornirSettings): - core: CoreConfig = CoreConfig() - inventory: InventoryConfig = InventoryConfig() - ssh: SSHConfig = SSHConfig() - logging: LoggingConfig = LoggingConfig() - jinja2: Jinja2Config = Jinja2Config() - user_defined: Dict[str, Any] = Field( - default={}, description="User-defined pairs" - ) - - class Config: - env_prefix = "NORNIR_" - ignore_extra = False - - @classmethod - def deserialize( - cls, __config_settings__: Optional[Dict[str, Any]] = None, **kwargs: Any - ) -> configuration.Config: - __config_settings__ = __config_settings__ or {} - c = Config( - core=CoreConfig( - __config_settings__=__config_settings__.pop("core", {}), - **kwargs.pop("core", {}), - ), - ssh=SSHConfig( - __config_settings__=__config_settings__.pop("ssh", {}), - **kwargs.pop("ssh", {}), - ), - inventory=InventoryConfig( - __config_settings__=__config_settings__.pop("inventory", {}), - **kwargs.pop("inventory", {}), - ), - logging=LoggingConfig( - __config_settings__=__config_settings__.pop("logging", {}), - **kwargs.pop("logging", {}), - ), - jinja2=Jinja2Config( - __config_settings__=__config_settings__.pop("jinja2", {}), - **kwargs.pop("jinja2", {}), - ), - __config_settings__=__config_settings__, - **kwargs, - ) - return configuration.Config( - core=CoreConfig.deserialize(**c.core.dict()), - inventory=InventoryConfig.deserialize(**c.inventory.dict()), - ssh=SSHConfig.deserialize(**c.ssh.dict()), - logging=LoggingConfig.deserialize(**c.logging.dict()), - jinja2=Jinja2Config.deserialize(**c.jinja2.dict()), - user_defined=c.user_defined, - ) - - @classmethod - def load_from_file(cls, config_file: str, **kwargs: Any) -> configuration.Config: - config_dict: Dict[str, Any] = {} - if config_file: - yml = ruamel.yaml.YAML(typ="safe") - with open(config_file, "r") as f: - config_dict = yml.load(f) or {} - return Config.deserialize(__config_settings__=config_dict, **kwargs) - - -def _resolve_import_from_string( - import_path: Union[Callable[..., Any], str] -) -> Optional[Callable[..., Any]]: - try: - if not import_path: - return None - elif callable(import_path): - return import_path - module_name, obj_name = import_path.rsplit(".", 1) - module = importlib.import_module(module_name) - return getattr(module, obj_name) - except Exception: - logger.error("Failed to import %r", import_path, exc_info=True) - raise diff --git a/nornir/core/deserializer/inventory.py b/nornir/core/deserializer/inventory.py deleted file mode 100644 index 1ae28d6a..00000000 --- a/nornir/core/deserializer/inventory.py +++ /dev/null @@ -1,163 +0,0 @@ -from typing import Any, Callable, Dict, List, Optional, Union - -from nornir.core import inventory - -from nornir._vendor.pydantic import BaseModel - - -VarsDict = Dict[str, Any] -HostsDict = Dict[str, VarsDict] -GroupsDict = Dict[str, VarsDict] -DefaultsDict = VarsDict - - -class BaseAttributes(BaseModel): - hostname: Optional[str] = None - port: Optional[int] = None - username: Optional[str] = None - password: Optional[str] = None - platform: Optional[str] = None - - class Config: - ignore_extra = False - - -class ConnectionOptions(BaseAttributes): - extras: Optional[Dict[str, Any]] - - @classmethod - def serialize(cls, i: inventory.ConnectionOptions) -> "ConnectionOptions": - return ConnectionOptions( - hostname=i.hostname, - port=i.port, - username=i.username, - password=i.password, - platform=i.platform, - extras=i.extras, - ) - - -class InventoryElement(BaseAttributes): - groups: List[str] = [] - data: Dict[str, Any] = {} - connection_options: Dict[str, ConnectionOptions] = {} - - @classmethod - def deserialize( - cls, - name: str, - hostname: Optional[str] = None, - port: Optional[int] = None, - username: Optional[str] = None, - password: Optional[str] = None, - platform: Optional[str] = None, - groups: Optional[List[str]] = None, - data: Optional[Dict[str, Any]] = None, - connection_options: Optional[Dict[str, Dict[str, Any]]] = None, - defaults: inventory.Defaults = None, - ) -> Dict[str, Any]: - parent_groups = inventory.ParentGroups(groups) - connection_options = connection_options or {} - conn_opts = { - k: inventory.ConnectionOptions(**v) for k, v in connection_options.items() - } - return { - "name": name, - "hostname": hostname, - "port": port, - "username": username, - "password": password, - "platform": platform, - "groups": parent_groups, - "data": data, - "connection_options": conn_opts, - "defaults": defaults, - } - - @classmethod - def deserialize_host(cls, **kwargs: Any) -> inventory.Host: - return inventory.Host(**cls.deserialize(**kwargs)) - - @classmethod - def deserialize_group(cls, **kwargs: Any) -> inventory.Group: - return inventory.Group(**cls.deserialize(**kwargs)) - - @classmethod - def serialize(cls, e: Union[inventory.Host, inventory.Group]) -> "InventoryElement": - d = {} - for f in cls.__fields__: - d[f] = object.__getattribute__(e, f) - d["groups"] = list(d["groups"]) - d["connection_options"] = { - k: ConnectionOptions.serialize(v) - for k, v in d["connection_options"].items() - } - return InventoryElement(**d) - - -class Defaults(BaseAttributes): - data: Dict[str, Any] = {} - connection_options: Dict[str, ConnectionOptions] = {} - - @classmethod - def serialize(cls, defaults: inventory.Defaults) -> "Defaults": - d = {} - for f in cls.__fields__: - d[f] = getattr(defaults, f) - - d["connection_options"] = { - k: ConnectionOptions.serialize(v) - for k, v in d["connection_options"].items() - } - return Defaults(**d) - - -class Inventory(BaseModel): - hosts: Dict[str, InventoryElement] - groups: Dict[str, InventoryElement] - defaults: Defaults - - @classmethod - def deserialize( - cls, - transform_function: Optional[Callable[..., Any]] = None, - transform_function_options: Optional[Dict[str, Any]] = None, - *args: Any, - **kwargs: Any - ) -> inventory.Inventory: - transform_function_options = transform_function_options or {} - deserialized = cls(*args, **kwargs) - - defaults_dict = deserialized.defaults.dict() - for k, v in defaults_dict["connection_options"].items(): - defaults_dict["connection_options"][k] = inventory.ConnectionOptions(**v) - defaults = inventory.Defaults(**defaults_dict) - - hosts = inventory.Hosts() - for n, h in deserialized.hosts.items(): - hosts[n] = InventoryElement.deserialize_host( - defaults=defaults, name=n, **h.dict() - ) - - groups = inventory.Groups() - for n, g in deserialized.groups.items(): - groups[n] = InventoryElement.deserialize_group(name=n, **g.dict()) - - return inventory.Inventory( - hosts=hosts, - groups=groups, - defaults=defaults, - transform_function=transform_function, - transform_function_options=transform_function_options, - ) - - @classmethod - def serialize(cls, inv: inventory.Inventory) -> "Inventory": - hosts = {} - for n, h in inv.hosts.items(): - hosts[n] = InventoryElement.serialize(h) - groups = {} - for n, g in inv.groups.items(): - groups[n] = InventoryElement.serialize(g) - defaults = Defaults.serialize(inv.defaults) - return Inventory(hosts=hosts, groups=groups, defaults=defaults) diff --git a/nornir/core/exceptions.py b/nornir/core/exceptions.py index 768032e8..5388a61f 100644 --- a/nornir/core/exceptions.py +++ b/nornir/core/exceptions.py @@ -31,7 +31,7 @@ class ConnectionNotOpen(ConnectionException): pass -class ConnectionPluginAlreadyRegistered(ConnectionException): +class PluginAlreadyRegistered(Exception): """ Raised when trying to register an already registered plugin """ @@ -39,7 +39,7 @@ class ConnectionPluginAlreadyRegistered(ConnectionException): pass -class ConnectionPluginNotRegistered(ConnectionException): +class PluginNotRegistered(Exception): """ Raised when trying to access a plugin that is not registered """ diff --git a/nornir/core/inventory.py b/nornir/core/inventory.py index d1c6cdd5..7a89160b 100644 --- a/nornir/core/inventory.py +++ b/nornir/core/inventory.py @@ -1,8 +1,18 @@ -import warnings -from collections import UserList -from typing import Any, Dict, List, Optional, Set, Union +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Set, + Union, + KeysView, + ValuesView, + ItemsView, + Iterator, + TypeVar, +) -from nornir.core import deserializer from nornir.core.configuration import Config from nornir.core.connections import ( ConnectionPlugin, @@ -10,6 +20,11 @@ ) from nornir.core.exceptions import ConnectionAlreadyOpen, ConnectionNotOpen +from mypy_extensions import Arg, KwArg + + +HostOrGroup = TypeVar("HostOrGroup", "Host", "Group") + class BaseAttributes(object): __slots__ = ("hostname", "port", "username", "password", "platform") @@ -28,33 +43,50 @@ def __init__( self.password = password self.platform = platform - def dict(self): - w = f"{self.dict.__qualname__} is deprecated, use nornir.core.deserializer instead" - warnings.warn(w) - return ( - getattr(deserializer.inventory, self.__class__.__name__) - .serialize(self) - .dict() - ) + def dict(self) -> Dict[str, Any]: + return { + "hostname": self.hostname, + "port": self.port, + "username": self.username, + "password": self.password, + "platform": self.platform, + } class ConnectionOptions(BaseAttributes): __slots__ = ("extras",) - def __init__(self, extras: Optional[Dict[str, Any]] = None, **kwargs) -> None: + def __init__( + self, + hostname: Optional[str] = None, + port: Optional[int] = None, + username: Optional[str] = None, + password: Optional[str] = None, + platform: Optional[str] = None, + extras: Optional[Dict[str, Any]] = None, + ) -> None: self.extras = extras - super().__init__(**kwargs) - + super().__init__( + hostname=hostname, + port=port, + username=username, + password=password, + platform=platform, + ) -class ParentGroups(UserList): - __slots__ = "refs" + def dict(self) -> Dict[str, Any]: + return { + "extras": self.extras, + **super().dict(), + } - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.refs: List["Group"] = kwargs.get("refs", []) - def __contains__(self, value) -> bool: - return value in self.data or value in self.refs +class ParentGroups(List["Group"]): + def __contains__(self, value: object) -> bool: + if isinstance(value, str): + return any([value == g.name for g in self]) + else: + return any([value == g for g in self]) class InventoryElement(BaseAttributes): @@ -62,15 +94,35 @@ class InventoryElement(BaseAttributes): def __init__( self, + hostname: Optional[str] = None, + port: Optional[int] = None, + username: Optional[str] = None, + password: Optional[str] = None, + platform: Optional[str] = None, groups: Optional[ParentGroups] = None, data: Optional[Dict[str, Any]] = None, connection_options: Optional[Dict[str, ConnectionOptions]] = None, - **kwargs, ) -> None: self.groups = groups or ParentGroups() self.data = data or {} self.connection_options = connection_options or {} - super().__init__(**kwargs) + super().__init__( + hostname=hostname, + port=port, + username=username, + password=password, + platform=platform, + ) + + def dict(self) -> Dict[str, Any]: + return { + "groups": [g.name for g in self.groups], + "data": self.data, + "connection_options": { + k: v.dict() for k, v in self.connection_options.items() + }, + **super().dict(), + } class Defaults(BaseAttributes): @@ -78,33 +130,71 @@ class Defaults(BaseAttributes): def __init__( self, + hostname: Optional[str] = None, + port: Optional[int] = None, + username: Optional[str] = None, + password: Optional[str] = None, + platform: Optional[str] = None, data: Optional[Dict[str, Any]] = None, connection_options: Optional[Dict[str, ConnectionOptions]] = None, - **kwargs, ) -> None: self.data = data or {} self.connection_options = connection_options or {} - super().__init__(**kwargs) + super().__init__( + hostname=hostname, + port=port, + username=username, + password=password, + platform=platform, + ) + + def dict(self) -> Dict[str, Any]: + return { + "data": self.data, + "connection_options": { + k: v.dict() for k, v in self.connection_options.items() + }, + **super().dict(), + } class Host(InventoryElement): __slots__ = ("name", "connections", "defaults") def __init__( - self, name: str, defaults: Optional[Defaults] = None, **kwargs + self, + name: str, + hostname: Optional[str] = None, + port: Optional[int] = None, + username: Optional[str] = None, + password: Optional[str] = None, + platform: Optional[str] = None, + groups: Optional[ParentGroups] = None, + data: Optional[Dict[str, Any]] = None, + connection_options: Optional[Dict[str, ConnectionOptions]] = None, + defaults: Optional[Defaults] = None, ) -> None: self.name = name - self.defaults = defaults or Defaults() + self.defaults = defaults or Defaults(None, None, None, None, None, None, None) self.connections: Connections = Connections() - super().__init__(**kwargs) + super().__init__( + hostname=hostname, + port=port, + username=username, + password=password, + platform=platform, + groups=groups, + data=data, + connection_options=connection_options, + ) - def _resolve_data(self): + def _resolve_data(self) -> Dict[str, Any]: processed = [] result = {} for k, v in self.data.items(): processed.append(k) result[k] = v - for g in self.groups.refs: + for g in self.groups: for k, v in g.items(): if k not in processed: processed.append(k) @@ -115,22 +205,31 @@ def _resolve_data(self): result[k] = v return result - def keys(self): + def dict(self) -> Dict[str, Any]: + return { + "name": self.name, + "connection_options": { + k: v.dict() for k, v in self.connection_options.items() + }, + **super().dict(), + } + + def keys(self) -> KeysView[str]: """Returns the keys of the attribute ``data`` and of the parent(s) groups.""" return self._resolve_data().keys() - def values(self): + def values(self) -> ValuesView[Any]: """Returns the values of the attribute ``data`` and of the parent(s) groups.""" return self._resolve_data().values() - def items(self): + def items(self) -> ItemsView[str, Any]: """ Returns all the data accessible from a device, including the one inherited from parent groups """ return self._resolve_data().items() - def has_parent_group(self, group): + def has_parent_group(self, group: Union[str, "Group"]) -> bool: """Returns whether the object is a child of the :obj:`Group` ``group``""" if isinstance(group, str): return self._has_parent_group_by_name(group) @@ -138,22 +237,24 @@ def has_parent_group(self, group): else: return self._has_parent_group_by_object(group) - def _has_parent_group_by_name(self, group): - for g in self.groups.refs: + def _has_parent_group_by_name(self, group: str) -> bool: + for g in self.groups: if g.name == group or g.has_parent_group(group): return True + return False - def _has_parent_group_by_object(self, group): - for g in self.groups.refs: + def _has_parent_group_by_object(self, group: "Group") -> bool: + for g in self.groups: if g is group or g.has_parent_group(group): return True + return False - def __getitem__(self, item): + def __getitem__(self, item: str) -> Any: try: return self.data[item] except KeyError: - for g in self.groups.refs: + for g in self.groups: try: r = g[item] return r @@ -166,12 +267,12 @@ def __getitem__(self, item): raise - def __getattribute__(self, name): + def __getattribute__(self, name: str) -> Any: if name not in ("hostname", "port", "username", "password", "platform"): return object.__getattribute__(self, name) v = object.__getattribute__(self, name) if v is None: - for g in self.groups.refs: + for g in self.groups: r = getattr(g, name) if r is not None: return r @@ -180,25 +281,25 @@ def __getattribute__(self, name): else: return v - def __bool__(self): + def __bool__(self) -> bool: return bool(self.name) - def __setitem__(self, item, value): + def __setitem__(self, item: str, value: Any) -> None: self.data[item] = value - def __len__(self): + def __len__(self) -> int: return len(self._resolve_data().keys()) - def __iter__(self): + def __iter__(self) -> Iterator[str]: return self.data.__iter__() - def __str__(self): + def __str__(self) -> str: return self.name - def __repr__(self): + def __repr__(self) -> str: return "{}: {}".format(self.__class__.__name__, self.name or "") - def get(self, item, default=None): + def get(self, item: str, default: Any = None) -> Any: """ Returns the value ``item`` from the host or hosts group variables. @@ -253,9 +354,9 @@ def _get_connection_options_recursively( ) -> Optional[ConnectionOptions]: p = self.connection_options.get(connection) if p is None: - p = ConnectionOptions() + p = ConnectionOptions(None, None, None, None, None, None) - for g in self.groups.refs: + for g in self.groups: sp = g._get_connection_options_recursively(connection) if sp is not None: p.hostname = p.hostname if p.hostname is not None else sp.hostname @@ -365,7 +466,7 @@ def open_connection( configuration=configuration, ) self.connections[conn_name] = conn_obj - return connection + return conn_obj def close_connection(self, connection: str) -> None: """ Close the connection""" @@ -396,6 +497,10 @@ class Groups(Dict[str, Group]): pass +TransformFunction = Callable[[Arg(Host), KwArg(Any)], None] +FilterObj = Callable[[Arg(Host), KwArg(Any)], bool] + + class Inventory(object): __slots__ = ("hosts", "groups", "defaults") @@ -404,111 +509,51 @@ def __init__( hosts: Hosts, groups: Optional[Groups] = None, defaults: Optional[Defaults] = None, - transform_function=None, - transform_function_options=None, + transform_function: TransformFunction = None, + transform_function_options: Optional[Dict[str, Any]] = None, ) -> None: self.hosts = hosts self.groups = groups or Groups() - self.defaults = defaults or Defaults() - - for host in self.hosts.values(): - host.groups.refs = [self.groups[p] for p in host.groups] - for group in self.groups.values(): - group.groups.refs = [self.groups[p] for p in group.groups] + self.defaults = defaults or Defaults(None, None, None, None, None, None, None) - if transform_function: - for h in self.hosts.values(): - transform_function(h, **transform_function_options) - - def filter(self, filter_obj=None, filter_func=None, *args, **kwargs): + def filter( + self, filter_obj: FilterObj = None, filter_func: FilterObj = None, **kwargs: Any + ) -> "Inventory": filter_func = filter_obj or filter_func if filter_func: - filtered = {n: h for n, h in self.hosts.items() if filter_func(h, **kwargs)} + filtered = Hosts( + {n: h for n, h in self.hosts.items() if filter_func(h, **kwargs)} + ) else: - filtered = { - n: h - for n, h in self.hosts.items() - if all(h.get(k) == v for k, v in kwargs.items()) - } + filtered = Hosts( + { + n: h + for n, h in self.hosts.items() + if all(h.get(k) == v for k, v in kwargs.items()) + } + ) return Inventory(hosts=filtered, groups=self.groups, defaults=self.defaults) - def __len__(self): + def __len__(self) -> int: return self.hosts.__len__() - def _update_group_refs(self, inventory_element: InventoryElement) -> None: - """ - Returns inventory_element with updated group references for the supplied - inventory element - """ - if hasattr(inventory_element, "groups"): - inventory_element.groups.refs = [ - self.groups[p] for p in inventory_element.groups - ] - return inventory_element - def children_of_group(self, group: Union[str, Group]) -> Set[Host]: """ Returns set of hosts that belongs to a group including those that belong indirectly via inheritance """ - hosts: List[Host] = set() + hosts: Set[Host] = set() for host in self.hosts.values(): if host.has_parent_group(group): hosts.add(host) return hosts - def add_host(self, name: str, **kwargs) -> None: - """ - Add a host to the inventory after initialization - """ - host_element = deserializer.inventory.InventoryElement.deserialize_host( - name=name, defaults=self.defaults, **kwargs - ) - host = {name: self._update_group_refs(host_element)} - self.hosts.update(host) - - def add_group(self, name: str, **kwargs) -> None: - """ - Add a group to the inventory after initialization - """ - group_element = deserializer.inventory.InventoryElement.deserialize_group( - name=name, defaults=self.defaults, **kwargs - ) - group = {name: self._update_group_refs(group_element)} - self.groups.update(group) - - def dict(self) -> Dict: - """ - Return serialized dictionary of inventory - """ - return deserializer.inventory.Inventory.serialize(self).dict() - - def get_inventory_dict(self) -> Dict: + def dict(self) -> Dict[str, Any]: """ Return serialized dictionary of inventory """ - return self.dict() - - def get_defaults_dict(self) -> Dict: - """ - Returns serialized dictionary of defaults from inventory - """ - return deserializer.inventory.Defaults.serialize(self.defaults).dict() - - def get_groups_dict(self) -> Dict: - """ - Returns serialized dictionary of groups from inventory - """ - return { - k: deserializer.inventory.InventoryElement.serialize(v).dict() - for k, v in self.groups.items() - } - - def get_hosts_dict(self) -> Dict: - """ - Returns serialized dictionary of hosts from inventory - """ return { - k: deserializer.inventory.InventoryElement.serialize(v).dict() - for k, v in self.hosts.items() + "hosts": {n: h.dict() for n, h in self.hosts.items()}, + "groups": {n: g.dict() for n, g in self.groups.items()}, + "defaults": self.defaults.dict(), } diff --git a/nornir/_vendor/__init__.py b/nornir/core/plugins/__init__.py similarity index 100% rename from nornir/_vendor/__init__.py rename to nornir/core/plugins/__init__.py diff --git a/nornir/core/plugins/inventory.py b/nornir/core/plugins/inventory.py new file mode 100644 index 00000000..44f63f53 --- /dev/null +++ b/nornir/core/plugins/inventory.py @@ -0,0 +1,34 @@ +from typing import Any, Type + +from nornir.core.inventory import Inventory, TransformFunction +from nornir.core.plugins.register import PluginRegister + +from typing_extensions import Protocol + + +INVENTORY_PLUGIN_PATH = "nornir.plugins.inventory" +TRANSFORM_FUNCTION_PLUGIN_PATH = "nornir.plugins.transform_function" + + +class InventoryPlugin(Protocol): + def __init__(self, *args: Any, **kwargs: Any) -> None: + """ + This method configures the plugin + """ + raise NotImplementedError("needs to be implemented by the plugin") + + def load(self) -> Inventory: + """ + This method implements the plugin's business logic + """ + raise NotImplementedError("needs to be implemented by the plugin") + + +InventoryPluginRegister: PluginRegister[Type[InventoryPlugin]] = PluginRegister( + INVENTORY_PLUGIN_PATH +) + + +TransformFunctionRegister: PluginRegister[TransformFunction] = PluginRegister( + TRANSFORM_FUNCTION_PLUGIN_PATH +) diff --git a/nornir/core/plugins/register.py b/nornir/core/plugins/register.py new file mode 100644 index 00000000..51f94484 --- /dev/null +++ b/nornir/core/plugins/register.py @@ -0,0 +1,76 @@ +import pkg_resources +from typing import Dict, TypeVar, Generic + + +from nornir.core.exceptions import ( + PluginAlreadyRegistered, + PluginNotRegistered, +) + +T = TypeVar("T") + + +class PluginRegister(Generic[T]): + available: Dict[str, T] = {} + + def __init__(self, entry_point: str) -> None: + self._entry_point = entry_point + + def auto_register(self) -> None: + discovered_plugins: Dict[str, T] = { + entry_point.name: entry_point.load() + for entry_point in pkg_resources.iter_entry_points(self._entry_point) + } + for k, v in discovered_plugins.items(): + self.register(k, v) + + def register(self, name: str, plugin: T) -> None: + """Registers a plugin with a specified name + + Args: + name: name of the connection plugin to register + plugin: plugin class + + Raises: + :obj:`nornir.core.exceptions.PluginAlreadyRegistered` if + another plugin with the specified name was already registered + """ + existing_plugin = self.available.get(name) + if existing_plugin is None: + self.available[name] = plugin + elif existing_plugin != plugin: + raise PluginAlreadyRegistered( + f"plugin {plugin} can't be registered as " + f"{name!r} because plugin {existing_plugin} " + f"was already registered under this name" + ) + + def deregister(self, name: str) -> None: + """Deregisters a registered plugin by its name + + Args: + name: name of the plugin to deregister + + Raises: + :obj:`nornir.core.exceptions.PluginNotRegistered` + """ + if name not in self.available: + raise PluginNotRegistered(f"plugin {name!r} is not registered") + self.available.pop(name) + + def deregister_all(self) -> None: + """Deregisters all registered plugins""" + self.available = {} + + def get_plugin(self, name: str) -> T: + """Fetches the plugin by name if already registered + + Args: + name: name of the plugin + + Raises: + :obj:`nornir.core.exceptions.PluginNotRegistered` + """ + if name not in self.available: + raise PluginNotRegistered(f"plugin {name!r} is not registered") + return self.available[name] diff --git a/nornir/init_nornir.py b/nornir/init_nornir.py index 5e93c4a2..ca9788bd 100644 --- a/nornir/init_nornir.py +++ b/nornir/init_nornir.py @@ -1,10 +1,14 @@ import pkg_resources -import warnings -from typing import Any, Callable, Dict, Optional, Type +from typing import Any, Dict, Type from nornir.core import Nornir +from nornir.core.configuration import Config from nornir.core.connections import Connections, ConnectionPlugin -from nornir.core.deserializer.configuration import Config +from nornir.core.inventory import Inventory +from nornir.core.plugins.inventory import ( + InventoryPluginRegister, + TransformFunctionRegister, +) from nornir.core.state import GlobalState @@ -17,16 +21,22 @@ def register_default_connection_plugins() -> None: Connections.register(k, v) -def cls_to_string(cls: Callable[..., Any]) -> str: - return f"{cls.__module__}.{cls.__name__}" +def load_inventory(config: Config,) -> Inventory: + InventoryPluginRegister.auto_register() + inventory_plugin = InventoryPluginRegister.get_plugin(config.inventory.plugin or "") + inv = inventory_plugin(**config.inventory.options).load() + + if config.inventory.transform_function: + transform_function = TransformFunctionRegister.get_plugin( + config.inventory.transform_function + ) + for h in inv.hosts.values(): + transform_function(h, **(config.inventory.transform_function_options or {})) + + return inv -def InitNornir( - config_file: str = "", - dry_run: bool = False, - configure_logging: Optional[bool] = None, - **kwargs: Dict[str, Any], -) -> Nornir: +def InitNornir(config_file: str = "", dry_run: bool = False, **kwargs: Any,) -> Nornir: """ Arguments: config_file(str): Path to the configuration file (optional) @@ -42,39 +52,13 @@ def InitNornir( """ register_default_connection_plugins() - if callable(kwargs.get("inventory", {}).get("plugin", "")): - kwargs["inventory"]["plugin"] = cls_to_string(kwargs["inventory"]["plugin"]) - - if callable(kwargs.get("inventory", {}).get("transform_function", "")): - kwargs["inventory"]["transform_function"] = cls_to_string( - kwargs["inventory"]["transform_function"] - ) - - conf = Config.load_from_file(config_file, **kwargs) + if config_file: + config = Config.from_file(config_file, **kwargs) + else: + config = Config.from_dict(**kwargs) data = GlobalState(dry_run=dry_run) - if configure_logging is not None: - msg = ( - "'configure_logging' argument is deprecated, please use " - "'logging.enabled' parameter in the configuration instead: " - "https://nornir.readthedocs.io/en/stable/configuration/index.html" - ) - warnings.warn(msg, DeprecationWarning) - - if conf.logging.enabled is None: - if configure_logging is not None: - conf.logging.enabled = configure_logging - else: - conf.logging.enabled = True - - conf.logging.configure() - - inv = conf.inventory.plugin.deserialize( - transform_function=conf.inventory.transform_function, - transform_function_options=conf.inventory.transform_function_options, - config=conf, - **conf.inventory.options, - ) + config.logging.configure() - return Nornir(inventory=inv, config=conf, data=data) + return Nornir(inventory=load_inventory(config), config=config, data=data,) diff --git a/nornir/plugins/inventory/simple.py b/nornir/plugins/inventory/simple.py deleted file mode 100644 index 780348ca..00000000 --- a/nornir/plugins/inventory/simple.py +++ /dev/null @@ -1,55 +0,0 @@ -import logging -import os -from typing import Any, Optional - -from nornir.core.deserializer.inventory import ( - HostsDict, - GroupsDict, - Inventory, - VarsDict, -) - -import ruamel.yaml - -logger = logging.getLogger(__name__) - - -class SimpleInventory(Inventory): - def __init__( - self, - host_file: str = "hosts.yaml", - group_file: str = "groups.yaml", - defaults_file: str = "defaults.yaml", - hosts: Optional[HostsDict] = None, - groups: Optional[GroupsDict] = None, - defaults: Optional[VarsDict] = None, - *args: Any, - **kwargs: Any - ) -> None: - yml = ruamel.yaml.YAML(typ="safe") - if hosts is None: - with open(os.path.expanduser(host_file), "r") as f: - hosts = yml.load(f) - - if groups is None: - groups = {} - if group_file: - group_file = os.path.expanduser(group_file) - if os.path.exists(group_file): - with open(group_file, "r") as f: - groups = yml.load(f) or {} - else: - logger.debug("File %r was not found", group_file) - groups = {} - - if defaults is None: - defaults = {} - if defaults_file: - defaults_file = os.path.expanduser(defaults_file) - if os.path.exists(defaults_file): - with open(defaults_file, "r") as f: - defaults = yml.load(f) or {} - else: - logger.debug("File %r was not found", defaults_file) - defaults = {} - super().__init__(hosts=hosts, groups=groups, defaults=defaults, *args, **kwargs) diff --git a/pyproject.toml b/pyproject.toml index 0c241f3b..bcadfb6e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.masonry.api" [tool.poetry] name = "nornir" -version = "3.0.0a0" +version = "3.0.0a2" description = "Pluggable multi-threaded framework with inventory management to help operate collections of devices" authors = ["David Barroso "] readme = "README.md" diff --git a/setup.cfg b/setup.cfg index 0c85fc19..9473b87a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -52,9 +52,6 @@ ignore_errors = True [mypy-nornir.core.deserializer.configuration] ignore_errors = True -[mypy-nornir.core.inventory] -ignore_errors = True - [mypy-tests.*] ignore_errors = True diff --git a/tests/conftest.py b/tests/conftest.py index 7c820e24..b4831935 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,31 +1,116 @@ import os - -from nornir import InitNornir +from nornir.core import Nornir +from nornir.core.inventory import ( + Inventory, + Host, + Hosts, + Group, + Groups, + Defaults, + ParentGroups, + ConnectionOptions, +) from nornir.core.state import GlobalState +import ruamel.yaml import pytest global_data = GlobalState(dry_run=True) +def inventory_from_yaml(): + dir_path = os.path.dirname(os.path.realpath(__file__)) + yml = ruamel.yaml.YAML(typ="safe") + + def get_connection_options(data): + cp = {} + for cn, c in data.items(): + cp[cn] = ConnectionOptions( + hostname=c.get("hostname"), + port=c.get("port"), + username=c.get("username"), + password=c.get("password"), + platform=c.get("platform"), + extras=c.get("extras"), + ) + return cp + + def get_defaults(): + defaults_file = f"{dir_path}/inventory_data/defaults.yaml" + with open(defaults_file, "r") as f: + defaults_dict = yml.load(f) + + defaults = Defaults( + hostname=defaults_dict.get("hostname"), + port=defaults_dict.get("port"), + username=defaults_dict.get("username"), + password=defaults_dict.get("password"), + platform=defaults_dict.get("platform"), + data=defaults_dict.get("data"), + connection_options=get_connection_options( + defaults_dict.get("connection_options", {}) + ), + ) + + return defaults + + def get_inventory_element(typ, data, name, defaults): + return typ( + name=name, + hostname=data.get("hostname"), + port=data.get("port"), + username=data.get("username"), + password=data.get("password"), + platform=data.get("platform"), + data=data.get("data"), + groups=data.get( + "groups" + ), # this is a hack, we will convert it later to the correct type + defaults=defaults, + connection_options=get_connection_options( + data.get("connection_options", {}) + ), + ) + + host_file = f"{dir_path}/inventory_data/hosts.yaml" + group_file = f"{dir_path}/inventory_data/groups.yaml" + + defaults = get_defaults() + + hosts = Hosts() + with open(host_file, "r") as f: + hosts_dict = yml.load(f) + + for n, h in hosts_dict.items(): + hosts[n] = get_inventory_element(Host, h, n, defaults) + + groups = Groups() + with open(group_file, "r") as f: + groups_dict = yml.load(f) + + for n, g in groups_dict.items(): + groups[n] = get_inventory_element(Group, g, n, defaults) + + for h in hosts.values(): + h.groups = ParentGroups([groups[g] for g in h.groups]) + + for g in groups.values(): + g.groups = ParentGroups([groups[g] for g in g.groups]) + + return Inventory(hosts=hosts, groups=groups, defaults=defaults) + + +@pytest.fixture(scope="session", autouse=True) +def inv(request): + return inventory_from_yaml() + + @pytest.fixture(scope="session", autouse=True) def nornir(request): """Initializes nornir""" - dir_path = os.path.dirname(os.path.realpath(__file__)) - - nornir = InitNornir( - inventory={ - "options": { - "host_file": "{}/inventory_data/hosts.yaml".format(dir_path), - "group_file": "{}/inventory_data/groups.yaml".format(dir_path), - "defaults_file": "{}/inventory_data/defaults.yaml".format(dir_path), - } - }, - dry_run=True, - ) - nornir.data = global_data - return nornir + nr = Nornir(inventory=inventory_from_yaml(), data=global_data) + return nr @pytest.fixture(scope="function", autouse=True) diff --git a/tests/core/deserializer/__init__.py b/tests/core/deserializer/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/core/deserializer/my_jinja_filters.py b/tests/core/deserializer/my_jinja_filters.py deleted file mode 100644 index 554582ae..00000000 --- a/tests/core/deserializer/my_jinja_filters.py +++ /dev/null @@ -1,10 +0,0 @@ -def upper(blah: str) -> str: - return blah.upper() - - -def lower(blah: str) -> str: - return blah.lower() - - -def jinja_filters(): - return {"upper": upper, "lower": lower} diff --git a/tests/core/deserializer/test_configuration.py b/tests/core/deserializer/test_configuration.py deleted file mode 100644 index cdf41a34..00000000 --- a/tests/core/deserializer/test_configuration.py +++ /dev/null @@ -1,206 +0,0 @@ -import os -from pathlib import Path - -from nornir.core.configuration import Config -from nornir.plugins.inventory.simple import SimpleInventory -from nornir.core.deserializer.configuration import Config as ConfigDeserializer - -from tests.core.deserializer import my_jinja_filters - -import pytest - -dir_path = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "test_configuration" -) - - -DEFAULT_LOG_FORMAT = ( - "%(asctime)s - %(name)12s - %(levelname)8s - %(funcName)10s() - %(message)s" -) - - -class DummyInventory(SimpleInventory): - pass - - -class Test(object): - def test_config_defaults(self): - c = ConfigDeserializer() - assert c.dict() == { - "core": {"num_workers": 20, "raise_on_error": False}, - "inventory": { - "plugin": "nornir.plugins.inventory.simple.SimpleInventory", - "options": {}, - "transform_function": "", - "transform_function_options": {}, - }, - "ssh": {"config_file": "~/.ssh/config"}, - "logging": { - "enabled": None, - "level": "INFO", - "file": "nornir.log", - "format": DEFAULT_LOG_FORMAT, - "to_console": False, - "loggers": ["nornir"], - }, - "jinja2": {"filters": ""}, - "user_defined": {}, - } - - def test_config_basic(self): - c = ConfigDeserializer( - core={"num_workers": 30}, - logging={"file": ""}, - user_defined={"my_opt": True}, - ) - assert c.dict() == { - "inventory": { - "plugin": "nornir.plugins.inventory.simple.SimpleInventory", - "options": {}, - "transform_function": "", - "transform_function_options": {}, - }, - "ssh": {"config_file": "~/.ssh/config"}, - "logging": { - "enabled": None, - "level": "INFO", - "file": "", - "format": DEFAULT_LOG_FORMAT, - "to_console": False, - "loggers": ["nornir"], - }, - "jinja2": {"filters": ""}, - "core": {"num_workers": 30, "raise_on_error": False}, - "user_defined": {"my_opt": True}, - } - - def test_deserialize_defaults(self): - c = ConfigDeserializer.deserialize() - assert isinstance(c, Config) - - assert c.core.num_workers == 20 - assert not c.core.raise_on_error - assert c.user_defined == {} - - assert c.logging.enabled is None - assert c.logging.level == "INFO" - assert c.logging.file == "nornir.log" - assert c.logging.format == DEFAULT_LOG_FORMAT - assert not c.logging.to_console - - assert c.ssh.config_file == str(Path("~/.ssh/config").expanduser()) - - assert c.inventory.plugin == SimpleInventory - assert c.inventory.options == {} - assert c.inventory.transform_function is None - assert c.inventory.transform_function_options == {} - - def test_deserialize_basic(self): - c = ConfigDeserializer.deserialize( - core={"num_workers": 30}, - user_defined={"my_opt": True}, - logging={"file": "", "level": "DEBUG"}, - ssh={"config_file": "~/.ssh/alt_config"}, - inventory={ - "plugin": "tests.core.deserializer.test_configuration.DummyInventory" - }, - ) - assert isinstance(c, Config) - - assert c.core.num_workers == 30 - assert not c.core.raise_on_error - assert c.user_defined == {"my_opt": True} - - assert c.logging.enabled is None - assert c.logging.level == "DEBUG" - assert c.logging.file == "" - assert c.logging.format == DEFAULT_LOG_FORMAT - assert not c.logging.to_console - - assert c.ssh.config_file == str(Path("~/.ssh/alt_config").expanduser()) - - assert c.inventory.plugin == DummyInventory - assert c.inventory.options == {} - assert c.inventory.transform_function is None - assert c.inventory.transform_function_options == {} - - def test_jinja_filters(self): - c = ConfigDeserializer.deserialize( - jinja2={"filters": "tests.core.deserializer.my_jinja_filters.jinja_filters"} - ) - assert c.jinja2.filters == my_jinja_filters.jinja_filters() - - def test_jinja_filters_error(self): - with pytest.raises(ModuleNotFoundError): - ConfigDeserializer.deserialize(jinja2={"filters": "asdasd.asdasd"}) - - def test_configuration_file_empty(self): - config = ConfigDeserializer.load_from_file( - os.path.join(dir_path, "empty.yaml"), user_defined={"asd": "qwe"} - ) - assert config.user_defined["asd"] == "qwe" - assert config.core.num_workers == 20 - assert not config.core.raise_on_error - assert config.inventory.plugin == SimpleInventory - - def test_configuration_file_normal(self): - config = ConfigDeserializer.load_from_file( - os.path.join(dir_path, "config.yaml") - ) - assert config.core.num_workers == 10 - assert not config.core.raise_on_error - assert config.inventory.plugin == DummyInventory - - def test_configuration_file_override_argument(self): - config = ConfigDeserializer.load_from_file( - os.path.join(dir_path, "config.yaml"), - core={"num_workers": 20, "raise_on_error": True}, - ) - assert config.core.num_workers == 20 - assert config.core.raise_on_error - - def test_configuration_file_override_env(self): - os.environ["NORNIR_CORE_NUM_WORKERS"] = "30" - os.environ["NORNIR_CORE_RAISE_ON_ERROR"] = "1" - os.environ["NORNIR_SSH_CONFIG_FILE"] = "/user/ssh_config" - config = ConfigDeserializer.deserialize() - assert config.core.num_workers == 30 - assert config.core.raise_on_error - assert config.ssh.config_file == "/user/ssh_config" - os.environ.pop("NORNIR_CORE_NUM_WORKERS") - os.environ.pop("NORNIR_CORE_RAISE_ON_ERROR") - os.environ.pop("NORNIR_SSH_CONFIG_FILE") - - def test_configuration_bool_env(self): - os.environ["NORNIR_CORE_RAISE_ON_ERROR"] = "0" - config = ConfigDeserializer.deserialize() - assert config.core.num_workers == 20 - assert not config.core.raise_on_error - - def test_get_user_defined_from_file(self): - config = ConfigDeserializer.load_from_file( - os.path.join(dir_path, "config.yaml") - ) - assert config.user_defined["asd"] == "qwe" - - def test_order_of_resolution_config_is_lowest(self): - config = ConfigDeserializer.load_from_file( - os.path.join(dir_path, "config.yaml") - ) - assert config.core.num_workers == 10 - - def test_order_of_resolution_env_is_higher_than_config(self): - os.environ["NORNIR_CORE_NUM_WORKERS"] = "20" - config = ConfigDeserializer.load_from_file( - os.path.join(dir_path, "config.yaml") - ) - os.environ.pop("NORNIR_CORE_NUM_WORKERS") - assert config.core.num_workers == 20 - - def test_order_of_resolution_code_is_higher_than_env(self): - os.environ["NORNIR_CORE_NUM_WORKERS"] = "20" - config = ConfigDeserializer.load_from_file( - os.path.join(dir_path, "config.yaml"), core={"num_workers": 30} - ) - os.environ.pop("NORNIR_CORE_NUM_WORKERS") - assert config.core.num_workers == 30 diff --git a/tests/core/test_InitNornir.py b/tests/core/test_InitNornir.py index 3ffee200..41441215 100644 --- a/tests/core/test_InitNornir.py +++ b/tests/core/test_InitNornir.py @@ -4,8 +4,12 @@ import pytest from nornir import InitNornir -from nornir.core.deserializer.inventory import Inventory from nornir.core.exceptions import ConflictingConfigurationWarning +from nornir.core.inventory import Inventory, Host, Hosts, Groups, Group, Defaults +from nornir.core.plugins.inventory import ( + InventoryPluginRegister, + TransformFunctionRegister, +) dir_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test_InitNornir") @@ -42,16 +46,29 @@ def transform_func_with_options(host, a): host["a"] = a -class StringInventory(Inventory): - def __init__(self, **kwargs): - inv_dict = {"hosts": {"host1": {}, "host2": {}}, "groups": {}, "defaults": {}} - super().__init__(**inv_dict, **kwargs) +class InventoryTest: + def __init__(self, *args, **kwargs): + pass + + def load(self): + return Inventory( + hosts=Hosts({"h1": Host("h1"), "h2": Host("h2"), "h3": Host("h3")}), + groups=Groups({"g1": Group("g1")}), + defaults=Defaults(), + ) + + +InventoryPluginRegister.register("inventory-test", InventoryTest) +TransformFunctionRegister.register("transform_func", transform_func) +TransformFunctionRegister.register( + "transform_func_with_options", transform_func_with_options +) class Test(object): def test_InitNornir_defaults(self): os.chdir("tests/inventory_data/") - nr = InitNornir() + nr = InitNornir(inventory={"plugin": "inventory-test"}) os.chdir("../../") assert not nr.data.dry_run assert nr.config.core.num_workers == 20 @@ -69,7 +86,7 @@ def test_InitNornir_programmatically(self): nr = InitNornir( core={"num_workers": 100}, inventory={ - "plugin": "nornir.plugins.inventory.simple.SimpleInventory", + "plugin": "inventory-test", "options": { "host_file": "tests/inventory_data/hosts.yaml", "group_file": "tests/inventory_data/groups.yaml", @@ -99,41 +116,12 @@ def test_InitNornir_combined(self): assert len(nr.inventory.hosts) assert len(nr.inventory.groups) - def test_InitNornir_different_inventory_by_string(self): - nr = InitNornir( - config_file=os.path.join(dir_path, "a_config.yaml"), - inventory={"plugin": "tests.core.test_InitNornir.StringInventory"}, - ) - assert "host1" in nr.inventory.hosts - - def test_InitNornir_different_inventory_imported(self): - nr = InitNornir( - config_file=os.path.join(dir_path, "a_config.yaml"), - inventory={"plugin": StringInventory}, - ) - assert "host1" in nr.inventory.hosts - def test_InitNornir_different_transform_function_by_string(self): nr = InitNornir( config_file=os.path.join(dir_path, "a_config.yaml"), inventory={ - "plugin": "nornir.plugins.inventory.simple.SimpleInventory", - "transform_function": "tests.core.test_InitNornir.transform_func", - "options": { - "host_file": "tests/inventory_data/hosts.yaml", - "group_file": "tests/inventory_data/groups.yaml", - }, - }, - ) - for host in nr.inventory.hosts.values(): - assert host["processed_by_transform_function"] - - def test_InitNornir_different_transform_function_imported(self): - nr = InitNornir( - config_file=os.path.join(dir_path, "a_config.yaml"), - inventory={ - "plugin": "nornir.plugins.inventory.simple.SimpleInventory", - "transform_function": transform_func, + "plugin": "inventory-test", + "transform_function": "transform_func", "options": { "host_file": "tests/inventory_data/hosts.yaml", "group_file": "tests/inventory_data/groups.yaml", @@ -147,8 +135,8 @@ def test_InitNornir_different_transform_function_by_string_with_options(self): nr = InitNornir( config_file=os.path.join(dir_path, "a_config.yaml"), inventory={ - "plugin": "nornir.plugins.inventory.simple.SimpleInventory", - "transform_function": "tests.core.test_InitNornir.transform_func_with_options", + "plugin": "inventory-test", + "transform_function": "transform_func_with_options", "transform_function_options": {"a": 1}, "options": { "host_file": "tests/inventory_data/hosts.yaml", @@ -164,8 +152,8 @@ def test_InitNornir_different_transform_function_by_string_with_bad_options(self nr = InitNornir( config_file=os.path.join(dir_path, "a_config.yaml"), inventory={ - "plugin": "nornir.plugins.inventory.simple.SimpleInventory", - "transform_function": "tests.core.test_InitNornir.transform_func_with_options", + "plugin": "inventory-test", + "transform_function": "transform_func_with_options", "transform_function_options": {"a": 1, "b": 0}, "options": { "host_file": "tests/inventory_data/hosts.yaml", @@ -238,16 +226,6 @@ def test_InitNornir_logging_disabled(self): assert nornir_logger.level == logging.NOTSET - def test_InitNornir_logging_disabled_alt(self): - self.cleanup() - with pytest.warns(DeprecationWarning): - InitNornir( - config_file=os.path.join(dir_path, "a_config.yaml"), - configure_logging=False, - ) - nornir_logger = logging.getLogger("nornir") - assert nornir_logger.level == logging.NOTSET - def test_InitNornir_logging_basicConfig(self): self.cleanup() logging.basicConfig() diff --git a/tests/core/test_InitNornir/a_config.yaml b/tests/core/test_InitNornir/a_config.yaml index ccf8712b..396be013 100644 --- a/tests/core/test_InitNornir/a_config.yaml +++ b/tests/core/test_InitNornir/a_config.yaml @@ -2,7 +2,7 @@ core: num_workers: 100 inventory: - plugin: nornir.plugins.inventory.simple.SimpleInventory + plugin: "inventory-test" options: host_file: "tests/inventory_data/hosts.yaml" group_file: "tests/inventory_data/groups.yaml" diff --git a/tests/core/test_configuration.py b/tests/core/test_configuration.py new file mode 100644 index 00000000..a3f46532 --- /dev/null +++ b/tests/core/test_configuration.py @@ -0,0 +1,129 @@ +import os +from pathlib import Path + +from nornir.core.configuration import Config + + +dir_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "test_configuration" +) + + +DEFAULT_LOG_FORMAT = ( + "%(asctime)s - %(name)12s - %(levelname)8s - %(funcName)10s() - %(message)s" +) + + +class Test(object): + def test_config_defaults(self): + c = Config() + assert c.dict() == { + "core": {"num_workers": 20, "raise_on_error": False}, + "inventory": { + "plugin": "", + "options": {}, + "transform_function": "", + "transform_function_options": {}, + }, + "ssh": {"config_file": str(Path("~/.ssh/config").expanduser())}, + "logging": { + "enabled": True, + "level": "INFO", + "log_file": "nornir.log", + "format": DEFAULT_LOG_FORMAT, + "to_console": False, + "loggers": ["nornir"], + }, + "user_defined": {}, + } + + def test_config_from_dict_defaults(self): + c = Config.from_dict() + assert c.dict() == { + "core": {"num_workers": 20, "raise_on_error": False}, + "inventory": { + "plugin": "", + "options": {}, + "transform_function": "", + "transform_function_options": {}, + }, + "ssh": {"config_file": str(Path("~/.ssh/config").expanduser())}, + "logging": { + "enabled": True, + "level": "INFO", + "log_file": "nornir.log", + "format": DEFAULT_LOG_FORMAT, + "to_console": False, + "loggers": ["nornir"], + }, + "user_defined": {}, + } + + def test_config_basic(self): + c = Config.from_dict( + inventory={"plugin": "an-inventory"}, + core={"num_workers": 30}, + logging={"log_file": ""}, + user_defined={"my_opt": True}, + ) + assert c.dict() == { + "inventory": { + "plugin": "an-inventory", + "options": {}, + "transform_function": "", + "transform_function_options": {}, + }, + "ssh": {"config_file": str(Path("~/.ssh/config").expanduser())}, + "logging": { + "enabled": True, + "level": "INFO", + "log_file": "", + "format": DEFAULT_LOG_FORMAT, + "to_console": False, + "loggers": ["nornir"], + }, + "core": {"num_workers": 30, "raise_on_error": False}, + "user_defined": {"my_opt": True}, + } + + def test_configuration_file_override_argument(self): + config = Config.from_file( + os.path.join(dir_path, "config.yaml"), + core={"num_workers": 20, "raise_on_error": True}, + ) + assert config.core.num_workers == 20 + assert config.core.raise_on_error + + def test_configuration_file_override_env(self): + os.environ["NORNIR_CORE_NUM_WORKERS"] = "30" + os.environ["NORNIR_CORE_RAISE_ON_ERROR"] = "1" + os.environ["NORNIR_SSH_CONFIG_FILE"] = "/user/ssh_config" + config = Config.from_dict(inventory={"plugin": "an-inventory"}) + assert config.core.num_workers == 30 + assert config.core.raise_on_error + assert config.ssh.config_file == "/user/ssh_config" + os.environ.pop("NORNIR_CORE_NUM_WORKERS") + os.environ.pop("NORNIR_CORE_RAISE_ON_ERROR") + os.environ.pop("NORNIR_SSH_CONFIG_FILE") + + def test_configuration_bool_env(self): + os.environ["NORNIR_CORE_RAISE_ON_ERROR"] = "0" + config = Config.from_dict(inventory={"plugin": "an-inventory"}) + assert config.core.num_workers == 20 + assert not config.core.raise_on_error + + def test_get_user_defined_from_file(self): + config = Config.from_file(os.path.join(dir_path, "config.yaml")) + assert config.user_defined["asd"] == "qwe" + + def test_order_of_resolution_config_is_lowest(self): + config = Config.from_file(os.path.join(dir_path, "config.yaml")) + assert config.core.num_workers == 10 + + def test_order_of_resolution_code_is_higher_than_env(self): + os.environ["NORNIR_CORE_NUM_WORKERS"] = "20" + config = Config.from_file( + os.path.join(dir_path, "config.yaml"), core={"num_workers": 30} + ) + os.environ.pop("NORNIR_CORE_NUM_WORKERS") + assert config.core.num_workers == 30 diff --git a/tests/core/deserializer/test_configuration/config.yaml b/tests/core/test_configuration/config.yaml similarity index 100% rename from tests/core/deserializer/test_configuration/config.yaml rename to tests/core/test_configuration/config.yaml diff --git a/tests/core/deserializer/test_configuration/empty.yaml b/tests/core/test_configuration/empty.yaml similarity index 100% rename from tests/core/deserializer/test_configuration/empty.yaml rename to tests/core/test_configuration/empty.yaml diff --git a/tests/core/test_connections.py b/tests/core/test_connections.py index a4ac1d7b..3615bde2 100644 --- a/tests/core/test_connections.py +++ b/tests/core/test_connections.py @@ -5,8 +5,8 @@ from nornir.core.exceptions import ( ConnectionAlreadyOpen, ConnectionNotOpen, - ConnectionPluginAlreadyRegistered, - ConnectionPluginNotRegistered, + PluginAlreadyRegistered, + PluginNotRegistered, ) from nornir.init_nornir import register_default_connection_plugins @@ -219,7 +219,7 @@ def test_register_already_registered_same(self): assert Connections.available["dummy"] == DummyConnectionPlugin def test_register_already_registered_new(self): - with pytest.raises(ConnectionPluginAlreadyRegistered): + with pytest.raises(PluginAlreadyRegistered): Connections.register("dummy", AnotherDummyConnectionPlugin) def test_deregister_existing(self): @@ -228,7 +228,7 @@ def test_deregister_existing(self): assert "dummy" not in Connections.available def test_deregister_nonexistent(self): - with pytest.raises(ConnectionPluginNotRegistered): + with pytest.raises(PluginNotRegistered): Connections.deregister("nonexistent_dummy") def test_deregister_all(self): @@ -241,5 +241,5 @@ def test_get_plugin(self): assert len(Connections.available) == 2 def test_nonexistent_plugin(self): - with pytest.raises(ConnectionPluginNotRegistered): + with pytest.raises(PluginNotRegistered): Connections.get_plugin("nonexistent_dummy") diff --git a/tests/core/test_inventory.py b/tests/core/test_inventory.py index 2cf7eedb..a66c24b1 100644 --- a/tests/core/test_inventory.py +++ b/tests/core/test_inventory.py @@ -1,9 +1,6 @@ import os from nornir.core import inventory -from nornir.core.deserializer import inventory as deserializer - -from nornir._vendor.pydantic import ValidationError import pytest @@ -50,8 +47,8 @@ def test_host(self): def test_inventory(self): g1 = inventory.Group(name="g1") - g2 = inventory.Group(name="g2", groups=inventory.ParentGroups(["g1"])) - h1 = inventory.Host(name="h1", groups=inventory.ParentGroups(["g1", "g2"])) + g2 = inventory.Group(name="g2", groups=inventory.ParentGroups([g1])) + h1 = inventory.Host(name="h1", groups=inventory.ParentGroups([g1, g2])) h2 = inventory.Host(name="h2") hosts = {"h1": h1, "h2": h2} groups = {"g1": g1, "g2": g2} @@ -63,19 +60,8 @@ def test_inventory(self): assert inv.groups["g1"] in inv.hosts["h1"].groups assert inv.groups["g1"] in inv.groups["g2"].groups - def test_inventory_deserializer_wrong(self): - with pytest.raises(ValidationError): - deserializer.Inventory.deserialize( - **{"hosts": {"wrong": {"host": "should_be_hostname"}}} - ) - - def test_inventory_deserializer(self): - inv = deserializer.Inventory.deserialize(**inv_dict) - assert inv.groups["group_1"] in inv.hosts["dev1.group_1"].groups - - def test_inventory_data(self): + def test_inventory_data(self, inv): """Test Host values()/keys()/items()""" - inv = deserializer.Inventory.deserialize(**inv_dict) h = inv.hosts["dev1.group_1"] assert "comes_from_dev1.group_1" in h.values() assert "blah" in h.values() @@ -83,8 +69,228 @@ def test_inventory_data(self): assert "only_default" in h.keys() assert "comes_from_dev1.group_1" == dict(h.items())["my_var"] - def test_filtering(self): - inv = deserializer.Inventory.deserialize(**inv_dict) + def test_inventory_dict(self, inv): + assert inv.dict() == { + "defaults": { + "connection_options": { + "dummy": { + "extras": {"blah": "from_defaults"}, + "hostname": "dummy_from_defaults", + "password": None, + "platform": None, + "port": None, + "username": None, + } + }, + "data": { + "my_var": "comes_from_defaults", + "only_default": "only_defined_in_default", + }, + "hostname": None, + "password": "docker", + "platform": "linux", + "port": None, + "username": "root", + }, + "groups": { + "group_1": { + "connection_options": {}, + "data": {"my_var": "comes_from_group_1", "site": "site1"}, + "groups": ["parent_group"], + "hostname": None, + "name": "group_1", + "password": "from_group1", + "platform": "linux", + "port": None, + "username": "root", + }, + "group_2": { + "connection_options": {}, + "data": {"site": "site2"}, + "groups": [], + "hostname": None, + "name": "group_2", + "password": "docker", + "platform": "linux", + "port": None, + "username": "root", + }, + "group_3": { + "connection_options": {}, + "data": {"site": "site2"}, + "groups": [], + "hostname": None, + "name": "group_3", + "password": "docker", + "platform": "linux", + "port": None, + "username": "root", + }, + "parent_group": { + "connection_options": { + "dummy": { + "extras": {"blah": "from_group"}, + "hostname": "dummy_from_parent_group", + "password": None, + "platform": None, + "port": None, + "username": None, + }, + "dummy2": { + "extras": {"blah": "from_group"}, + "hostname": "dummy2_from_parent_group", + "password": None, + "platform": None, + "port": None, + "username": None, + }, + }, + "data": {"a_false_var": False, "a_var": "blah"}, + "groups": [], + "hostname": None, + "name": "parent_group", + "password": "from_parent_group", + "platform": "linux", + "port": None, + "username": "root", + }, + }, + "hosts": { + "dev1.group_1": { + "connection_options": { + "dummy": { + "extras": {"blah": "from_host"}, + "hostname": "dummy_from_host", + "password": None, + "platform": None, + "port": None, + "username": None, + }, + "paramiko": { + "extras": {}, + "hostname": None, + "password": "docker", + "platform": "linux", + "port": 65020, + "username": "root", + }, + }, + "data": { + "my_var": "comes_from_dev1.group_1", + "nested_data": { + "a_dict": {"a": 1, "b": 2}, + "a_list": [1, 2], + "a_string": "asdasd", + }, + "role": "www", + "www_server": "nginx", + }, + "groups": ["group_1"], + "hostname": "localhost", + "name": "dev1.group_1", + "password": "a_password", + "platform": "eos", + "port": 65020, + "username": "root", + }, + "dev2.group_1": { + "connection_options": { + "dummy2": { + "extras": None, + "hostname": None, + "password": None, + "platform": None, + "port": None, + "username": "dummy2_from_host", + }, + "paramiko": { + "extras": {}, + "hostname": None, + "password": "docker", + "platform": "linux", + "port": None, + "username": "root", + }, + }, + "data": { + "nested_data": { + "a_dict": {"b": 2, "c": 3}, + "a_list": [2, 3], + "a_string": "qwe", + }, + "role": "db", + }, + "groups": ["group_1"], + "hostname": "localhost", + "name": "dev2.group_1", + "password": "from_group1", + "platform": "junos", + "port": 65021, + "username": "root", + }, + "dev3.group_2": { + "connection_options": { + "nornir_napalm.napalm": { + "extras": {}, + "hostname": None, + "password": None, + "platform": "mock", + "port": None, + "username": None, + } + }, + "data": {"role": "www", "www_server": "apache"}, + "groups": ["group_2"], + "hostname": "localhost", + "name": "dev3.group_2", + "password": "docker", + "platform": "linux", + "port": 65022, + "username": "root", + }, + "dev4.group_2": { + "connection_options": { + "netmiko": { + "extras": {}, + "hostname": "localhost", + "password": "docker", + "platform": "linux", + "port": None, + "username": "root", + }, + "paramiko": { + "extras": {}, + "hostname": "localhost", + "password": "docker", + "platform": "linux", + "port": None, + "username": "root", + }, + }, + "data": {"my_var": "comes_from_dev4.group_2", "role": "db"}, + "groups": ["parent_group", "group_2"], + "hostname": "localhost", + "name": "dev4.group_2", + "password": "from_parent_group", + "platform": "linux", + "port": 65023, + "username": "root", + }, + "dev5.no_group": { + "connection_options": {}, + "data": {}, + "groups": [], + "hostname": "localhost", + "name": "dev5.no_group", + "password": "docker", + "platform": "linux", + "port": 65024, + "username": "root", + }, + }, + } + + def test_filtering(self, inv): unfiltered = sorted(list(inv.hosts.keys())) assert unfiltered == [ "dev1.group_1", @@ -105,8 +311,7 @@ def test_filtering(self): ) assert www_site1 == ["dev1.group_1"] - def test_filtering_func(self): - inv = deserializer.Inventory.deserialize(**inv_dict) + def test_filtering_func(self, inv): long_names = sorted( list(inv.filter(filter_func=lambda x: len(x["my_var"]) > 20).hosts.keys()) ) @@ -120,13 +325,11 @@ def longer_than(dev, length): ) assert long_names == ["dev1.group_1", "dev4.group_2"] - def test_filter_unique_keys(self): - inv = deserializer.Inventory.deserialize(**inv_dict) + def test_filter_unique_keys(self, inv): filtered = sorted(list(inv.filter(www_server="nginx").hosts.keys())) assert filtered == ["dev1.group_1"] - def test_var_resolution(self): - inv = deserializer.Inventory.deserialize(**inv_dict) + def test_var_resolution(self, inv): assert inv.hosts["dev1.group_1"]["my_var"] == "comes_from_dev1.group_1" assert inv.hosts["dev2.group_1"]["my_var"] == "comes_from_group_1" assert inv.hosts["dev3.group_2"]["my_var"] == "comes_from_defaults" @@ -140,68 +343,50 @@ def test_var_resolution(self): inv.hosts["dev3.group_2"].data["my_var"] assert inv.hosts["dev4.group_2"].data["my_var"] == "comes_from_dev4.group_2" - def test_attributes_resolution(self): - inv = deserializer.Inventory.deserialize(**inv_dict) + def test_attributes_resolution(self, inv): assert inv.hosts["dev1.group_1"].password == "a_password" assert inv.hosts["dev2.group_1"].password == "from_group1" assert inv.hosts["dev3.group_2"].password == "docker" assert inv.hosts["dev4.group_2"].password == "from_parent_group" assert inv.hosts["dev5.no_group"].password == "docker" - def test_has_parents(self): - inv = deserializer.Inventory.deserialize(**inv_dict) + def test_has_parents(self, inv): assert inv.hosts["dev1.group_1"].has_parent_group(inv.groups["group_1"]) assert not inv.hosts["dev1.group_1"].has_parent_group(inv.groups["group_2"]) assert inv.hosts["dev1.group_1"].has_parent_group("group_1") assert not inv.hosts["dev1.group_1"].has_parent_group("group_2") - def test_to_dict(self): - inv = deserializer.Inventory.deserialize(**inv_dict) - inv_serialized = deserializer.Inventory.serialize(inv).dict() - for k, v in inv_dict.items(): - assert v == inv_serialized[k] - - def test_get_connection_parameters(self): - inv = deserializer.Inventory.deserialize(**inv_dict) + def test_get_connection_parameters(self, inv): p1 = inv.hosts["dev1.group_1"].get_connection_parameters("dummy") - assert deserializer.ConnectionOptions.serialize(p1).dict() == { - "port": 65020, - "hostname": "dummy_from_host", - "username": "root", - "password": "a_password", - "platform": "eos", - "extras": {"blah": "from_host"}, - } + assert p1.port == 65020 + assert p1.hostname == "dummy_from_host" + assert p1.username == "root" + assert p1.password == "a_password" + assert p1.platform == "eos" + assert p1.extras == {"blah": "from_host"} p2 = inv.hosts["dev1.group_1"].get_connection_parameters("asd") - assert deserializer.ConnectionOptions.serialize(p2).dict() == { - "port": 65020, - "hostname": "localhost", - "username": "root", - "password": "a_password", - "platform": "eos", - "extras": {}, - } + assert p2.port == 65020 + assert p2.hostname == "localhost" + assert p2.username == "root" + assert p2.password == "a_password" + assert p2.platform == "eos" + assert p2.extras == {} p3 = inv.hosts["dev2.group_1"].get_connection_parameters("dummy") - assert deserializer.ConnectionOptions.serialize(p3).dict() == { - "port": 65021, - "hostname": "dummy_from_parent_group", - "username": "root", - "password": "from_group1", - "platform": "junos", - "extras": {"blah": "from_group"}, - } + assert p3.port == 65021 + assert p3.hostname == "dummy_from_parent_group" + assert p3.username == "root" + assert p3.password == "from_group1" + assert p3.platform == "junos" + assert p3.extras == {"blah": "from_group"} p4 = inv.hosts["dev3.group_2"].get_connection_parameters("dummy") - assert deserializer.ConnectionOptions.serialize(p4).dict() == { - "port": 65022, - "hostname": "dummy_from_defaults", - "username": "root", - "password": "docker", - "platform": "linux", - "extras": {"blah": "from_defaults"}, - } - - def test_defaults(self): - inv = deserializer.Inventory.deserialize(**inv_dict) + assert p4.port == 65022 + assert p4.hostname == "dummy_from_defaults" + assert p4.username == "root" + assert p4.password == "docker" + assert p4.platform == "linux" + assert p4.extras == {"blah": "from_defaults"} + + def test_defaults(self, inv): inv.defaults.password = "asd" assert inv.defaults.password == "asd" assert inv.hosts["dev2.group_1"].password == "from_group1" @@ -209,8 +394,7 @@ def test_defaults(self): assert inv.hosts["dev4.group_2"].password == "from_parent_group" assert inv.hosts["dev5.no_group"].password == "asd" - def test_children_of_str(self): - inv = deserializer.Inventory.deserialize(**inv_dict) + def test_children_of_str(self, inv): assert inv.children_of_group("parent_group") == { inv.hosts["dev1.group_1"], inv.hosts["dev2.group_1"], @@ -229,8 +413,7 @@ def test_children_of_str(self): assert inv.children_of_group("blah") == set() - def test_children_of_obj(self): - inv = deserializer.Inventory.deserialize(**inv_dict) + def test_children_of_obj(self, inv): assert inv.children_of_group(inv.groups["parent_group"]) == { inv.hosts["dev1.group_1"], inv.hosts["dev2.group_1"], @@ -251,21 +434,24 @@ def test_add_host(self): data = {"test_var": "test_value"} defaults = inventory.Defaults(data=data) g1 = inventory.Group(name="g1") - g2 = inventory.Group(name="g2", groups=inventory.ParentGroups(["g1"])) - h1 = inventory.Host(name="h1", groups=inventory.ParentGroups(["g1", "g2"])) + g2 = inventory.Group(name="g2", groups=inventory.ParentGroups([g1])) + h1 = inventory.Host(name="h1", groups=inventory.ParentGroups([g1, g2])) h2 = inventory.Host(name="h2") hosts = {"h1": h1, "h2": h2} groups = {"g1": g1, "g2": g2} inv = inventory.Inventory(hosts=hosts, groups=groups, defaults=defaults) - h3_connection_options = {"netmiko": {"extras": {"device_type": "cisco_ios"}}} - inv.add_host( + h3_connection_options = inventory.ConnectionOptions( + extras={"device_type": "cisco_ios"} + ) + inv.hosts["h3"] = inventory.Host( name="h3", - groups=["g1"], + groups=[g1], platform="TestPlatform", - connection_options=h3_connection_options, + connection_options={"netmiko": h3_connection_options}, + defaults=defaults, ) assert "h3" in inv.hosts - assert "g1" in [i.name for i in inv.hosts["h3"].groups.refs] + assert "g1" in [i.name for i in inv.hosts["h3"].groups] assert "test_var" in inv.hosts["h3"].defaults.data.keys() assert inv.hosts["h3"].defaults.data.get("test_var") == "test_value" assert inv.hosts["h3"].platform == "TestPlatform" @@ -273,28 +459,28 @@ def test_add_host(self): inv.hosts["h3"].connection_options["netmiko"].extras["device_type"] == "cisco_ios" ) - with pytest.raises(KeyError): - inv.add_host(name="h4", groups=["not_defined"]) - # Test with one good and one undefined group - with pytest.raises(KeyError): - inv.add_host(name="h5", groups=["g1", "not_defined"]) def test_add_group(self): connection_options = {"username": "test_user", "password": "test_pass"} data = {"test_var": "test_value"} defaults = inventory.Defaults(data=data, connection_options=connection_options) g1 = inventory.Group(name="g1") - g2 = inventory.Group(name="g2", groups=inventory.ParentGroups(["g1"])) - h1 = inventory.Host(name="h1", groups=inventory.ParentGroups(["g1", "g2"])) + g2 = inventory.Group(name="g2", groups=inventory.ParentGroups([g1])) + h1 = inventory.Host(name="h1", groups=inventory.ParentGroups([g1, g2])) h2 = inventory.Host(name="h2") hosts = {"h1": h1, "h2": h2} groups = {"g1": g1, "g2": g2} inv = inventory.Inventory(hosts=hosts, groups=groups, defaults=defaults) - g3_connection_options = {"netmiko": {"extras": {"device_type": "cisco_ios"}}} - inv.add_group( - name="g3", username="test_user", connection_options=g3_connection_options + g3_connection_options = inventory.ConnectionOptions( + extras={"device_type": "cisco_ios"} + ) + inv.groups["g3"] = inventory.Group( + name="g3", + username="test_user", + connection_options={"netmiko": g3_connection_options}, + defaults=defaults, ) - assert "g1" in [i.name for i in inv.groups["g2"].groups.refs] + assert "g1" in [i.name for i in inv.groups["g2"].groups] assert "g3" in inv.groups assert ( inv.groups["g3"].defaults.connection_options.get("username") == "test_user" @@ -308,15 +494,8 @@ def test_add_group(self): inv.groups["g3"].connection_options["netmiko"].extras["device_type"] == "cisco_ios" ) - # Test with one undefined parent group - with pytest.raises(KeyError): - inv.add_group(name="g4", groups=["undefined"]) - # Test with one defined and one undefined parent group - with pytest.raises(KeyError): - inv.add_group(name="g4", groups=["g1", "undefined"]) - def test_dict(self): - inv = deserializer.Inventory.deserialize(**inv_dict) + def test_dict(self, inv): inventory_dict = inv.dict() def_extras = inventory_dict["defaults"]["connection_options"]["dummy"]["extras"] grp_data = inventory_dict["groups"]["group_1"]["data"] @@ -327,37 +506,22 @@ def test_dict(self): assert "my_var" and "site" in grp_data assert "www_server" and "role" in host_data - def test_get_inventory_dict(self): - inv = deserializer.Inventory.deserialize(**inv_dict) - inventory_dict = inv.get_inventory_dict() - def_extras = inventory_dict["defaults"]["connection_options"]["dummy"]["extras"] - grp_data = inventory_dict["groups"]["group_1"]["data"] - host_data = inventory_dict["hosts"]["dev1.group_1"]["data"] - assert type(inventory_dict) == dict - assert inventory_dict["defaults"]["username"] == "root" - assert def_extras["blah"] == "from_defaults" - assert "my_var" and "site" in grp_data - assert "www_server" and "role" in host_data - - def test_get_defaults_dict(self): - inv = deserializer.Inventory.deserialize(**inv_dict) - defaults_dict = inv.get_defaults_dict() + def test_get_defaults_dict(self, inv): + defaults_dict = inv.defaults.dict() con_options = defaults_dict["connection_options"]["dummy"] assert type(defaults_dict) == dict assert defaults_dict["username"] == "root" assert con_options["hostname"] == "dummy_from_defaults" assert "blah" in con_options["extras"] - def test_get_groups_dict(self): - inv = deserializer.Inventory.deserialize(**inv_dict) - groups_dict = inv.get_groups_dict() + def test_get_groups_dict(self, inv): + groups_dict = {n: g.dict() for n, g in inv.groups.items()} assert type(groups_dict) == dict assert groups_dict["group_1"]["password"] == "from_group1" assert groups_dict["group_2"]["data"]["site"] == "site2" - def test_get_hosts_dict(self): - inv = deserializer.Inventory.deserialize(**inv_dict) - hosts_dict = inv.get_hosts_dict() + def test_get_hosts_dict(self, inv): + hosts_dict = {n: h.dict() for n, h in inv.hosts.items()} dev1_groups = hosts_dict["dev1.group_1"]["groups"] dev2_paramiko_opts = hosts_dict["dev2.group_1"]["connection_options"][ "paramiko" diff --git a/tests/inventory_data/hosts.yaml b/tests/inventory_data/hosts.yaml index 179f0aa6..8f11e479 100644 --- a/tests/inventory_data/hosts.yaml +++ b/tests/inventory_data/hosts.yaml @@ -93,8 +93,8 @@ dev4.group_2: my_var: comes_from_dev4.group_2 role: db groups: - - group_2 - parent_group + - group_2 connection_options: paramiko: port: diff --git a/tests/plugins/__init__.py b/tests/plugins/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/plugins/inventory/__init__.py b/tests/plugins/inventory/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/plugins/inventory/test_simple.py b/tests/plugins/inventory/test_simple.py deleted file mode 100644 index a862f737..00000000 --- a/tests/plugins/inventory/test_simple.py +++ /dev/null @@ -1,68 +0,0 @@ -import os - -from nornir.plugins.inventory import simple -from nornir.core.deserializer.inventory import Inventory - -BASE_PATH = os.path.join(os.path.dirname(__file__), "nsot") - - -class Test(object): - def test_inventory(self): - hosts = { - "host1": { - "username": "user", - "groups": ["group_a"], - "data": {"a": 1, "b": 2}, - }, - "host2": {"username": "user2", "data": {"a": 1, "b": 2}}, - } - groups = {"group_a": {"platform": "linux"}} - defaults = {"data": {"a_default": "asd"}} - inv = simple.SimpleInventory.deserialize( - hosts=hosts, groups=groups, defaults=defaults - ) - assert Inventory.serialize(inv).dict() == { - "hosts": { - "host1": { - "hostname": None, - "port": None, - "username": "user", - "password": None, - "platform": None, - "groups": ["group_a"], - "data": {"a": 1, "b": 2}, - "connection_options": {}, - }, - "host2": { - "hostname": None, - "port": None, - "username": "user2", - "password": None, - "platform": None, - "groups": [], - "data": {"a": 1, "b": 2}, - "connection_options": {}, - }, - }, - "groups": { - "group_a": { - "hostname": None, - "port": None, - "username": None, - "password": None, - "platform": "linux", - "groups": [], - "data": {}, - "connection_options": {}, - } - }, - "defaults": { - "hostname": None, - "port": None, - "username": None, - "password": None, - "platform": None, - "data": {"a_default": "asd"}, - "connection_options": {}, - }, - }