diff --git a/pint/_typing.py b/pint/_typing.py new file mode 100644 index 000000000..2c8ebd069 --- /dev/null +++ b/pint/_typing.py @@ -0,0 +1,13 @@ +from typing import TYPE_CHECKING, Any, Callable, Tuple, TypeVar, Union + +if TYPE_CHECKING: + from .unit import Unit + from .util import UnitsContainer + +UnitLike = Union["Unit", "UnitsContainer", str] + +FuncType = Callable[..., Any] +F = TypeVar("F", bound=FuncType) + +ScalarLike = TypeVar("ScalarLike", float, int, complex) +Shape = Tuple[int, ...] diff --git a/pint/context.py b/pint/context.py index 6cd440e8a..5c0a203d5 100644 --- a/pint/context.py +++ b/pint/context.py @@ -11,10 +11,11 @@ import re import weakref from collections import ChainMap, defaultdict +from typing import Dict, Hashable, Iterable, List, Optional, Tuple from .definitions import Definition, UnitDefinition from .errors import DefinitionSyntaxError -from .util import ParserHelper, SourceIterator, to_units_container +from .util import ParserHelper, SourceIterator, UnitsContainer, to_units_container #: Regex to match the header parts of a context. _header_re = re.compile( @@ -84,7 +85,12 @@ class Context: >>> c.redefine("pound = 0.5 kg") """ - def __init__(self, name=None, aliases=(), defaults=None): + def __init__( + self, + name: Optional[str] = None, + aliases: Iterable[str] = (), + defaults: Dict[str, float] = None, + ) -> None: self.name = name self.aliases = aliases @@ -96,7 +102,7 @@ def __init__(self, name=None, aliases=(), defaults=None): self.defaults = defaults or {} # Store Definition objects that are context-specific - self.redefinitions = [] + self.redefinitions: List[Definition] = [] # Flag set to True by the Registry the first time the context is enabled self.checked = False @@ -106,7 +112,7 @@ def __init__(self, name=None, aliases=(), defaults=None): self.relation_to_context = weakref.WeakValueDictionary() @classmethod - def from_context(cls, context, **defaults): + def from_context(cls, context: "Context", **defaults) -> "Context": """Creates a new context that shares the funcs dictionary with the original context. The default values are copied from the original context and updated with the new defaults. @@ -223,14 +229,14 @@ def to_num(val): return ctx - def add_transformation(self, src, dst, func): + def add_transformation(self, src, dst, func) -> None: """Add a transformation function to the context.""" _key = self.__keytransform__(src, dst) self.funcs[_key] = func self.relation_to_context[_key] = self - def remove_transformation(self, src, dst): + def remove_transformation(self, src, dst) -> None: """Add a transformation function to the context.""" _key = self.__keytransform__(src, dst) @@ -238,7 +244,7 @@ def remove_transformation(self, src, dst): del self.relation_to_context[_key] @staticmethod - def __keytransform__(src, dst): + def __keytransform__(src, dst) -> Tuple[UnitsContainer, UnitsContainer]: return to_units_container(src), to_units_container(dst) def transform(self, src, dst, registry, value): @@ -270,7 +276,7 @@ def redefine(self, definition: str) -> None: raise DefinitionSyntaxError("Can't define base units within a context") self.redefinitions.append(d) - def hashable(self): + def hashable(self) -> Tuple[Hashable, ...]: """Generate a unique hashable and comparable representation of self, which can be used as a key in a dict. This class cannot define ``__hash__`` because it is mutable, and the Python interpreter does cache the output of ``__hash__``. @@ -293,13 +299,13 @@ class ContextChain(ChainMap): to transform from one dimension to another. """ - def __init__(self): + def __init__(self) -> None: super().__init__() - self.contexts = [] + self.contexts: List[Context] = [] self.maps.clear() # Remove default empty map self._graph = None - def insert_contexts(self, *contexts): + def insert_contexts(self, *contexts: Context) -> None: """Insert one or more contexts in reversed order the chained map. (A rule in last context will take precedence) @@ -311,7 +317,7 @@ def insert_contexts(self, *contexts): self.maps = [ctx.relation_to_context for ctx in reversed(contexts)] + self.maps self._graph = None - def remove_contexts(self, n: int = None): + def remove_contexts(self, n: Optional[int] = None) -> None: """Remove the last n inserted contexts from the chain. Parameters @@ -345,7 +351,7 @@ def transform(self, src, dst, registry, value): """ return self[(src, dst)].transform(src, dst, registry, value) - def hashable(self): + def hashable(self) -> Tuple[Hashable, ...]: """Generate a unique hashable and comparable representation of self, which can be used as a key in a dict. This class cannot define ``__hash__`` because it is mutable, and the Python interpreter does cache the output of ``__hash__``. diff --git a/pint/converters.py b/pint/converters.py index eae71ad59..9981646da 100644 --- a/pint/converters.py +++ b/pint/converters.py @@ -16,11 +16,11 @@ class Converter: """Base class for value converters.""" @property - def is_multiplicative(self): + def is_multiplicative(self) -> bool: return True @property - def is_logarithmic(self): + def is_logarithmic(self) -> bool: return False def to_reference(self, value, inplace=False): @@ -116,11 +116,11 @@ def __init__(self, scale, logbase, logfactor): self.logfactor = logfactor @property - def is_multiplicative(self): + def is_multiplicative(self) -> bool: return False @property - def is_logarithmic(self): + def is_logarithmic(self) -> bool: return True def from_reference(self, value, inplace=False): diff --git a/pint/definitions.py b/pint/definitions.py index 7e30c8942..63d61b934 100644 --- a/pint/definitions.py +++ b/pint/definitions.py @@ -9,8 +9,9 @@ """ from collections import namedtuple +from typing import Callable, Iterable, Optional, Union -from .converters import LogarithmicConverter, OffsetConverter, ScaleConverter +from .converters import Converter, LogarithmicConverter, OffsetConverter, ScaleConverter from .errors import DefinitionSyntaxError from .util import ParserHelper, UnitsContainer, _is_dim @@ -42,7 +43,7 @@ class PreprocessedDefinition( """ @classmethod - def from_string(cls, definition): + def from_string(cls, definition: str) -> "PreprocessedDefinition": name, definition = definition.split("=", 1) name = name.strip() @@ -64,7 +65,7 @@ def __init__(self, value): self.value = value -def numeric_parse(s, non_int_type=float): +def numeric_parse(s: str, non_int_type: type = float): """Try parse a string into a number (without using eval). Parameters @@ -103,7 +104,13 @@ class Definition: converter : callable or Converter or None """ - def __init__(self, name, symbol, aliases, converter): + def __init__( + self, + name: str, + symbol: Optional[str], + aliases: Iterable[str], + converter: Optional[Union[Callable, Converter]], + ): if isinstance(converter, str): raise TypeError( @@ -112,19 +119,21 @@ def __init__(self, name, symbol, aliases, converter): self._name = name self._symbol = symbol - self._aliases = aliases + self._aliases = tuple(aliases) self._converter = converter @property - def is_multiplicative(self): + def is_multiplicative(self) -> bool: return self._converter.is_multiplicative @property - def is_logarithmic(self): + def is_logarithmic(self) -> bool: return self._converter.is_logarithmic @classmethod - def from_string(cls, definition, non_int_type=float): + def from_string( + cls, definition: Union[str, PreprocessedDefinition], non_int_type: type = float + ) -> "Definition": """Parse a definition. Parameters @@ -150,30 +159,30 @@ def from_string(cls, definition, non_int_type=float): return UnitDefinition.from_string(definition, non_int_type) @property - def name(self): + def name(self) -> str: return self._name @property - def symbol(self): + def symbol(self) -> str: return self._symbol or self._name @property - def has_symbol(self): + def has_symbol(self) -> bool: return bool(self._symbol) @property - def aliases(self): + def aliases(self) -> Iterable[str]: return self._aliases - def add_aliases(self, *alias): + def add_aliases(self, *alias: str) -> None: alias = tuple(a for a in alias if a not in self._aliases) self._aliases = self._aliases + alias @property - def converter(self): + def converter(self) -> Converter: return self._converter - def __str__(self): + def __str__(self) -> str: return self.name @@ -188,7 +197,9 @@ class PrefixDefinition(Definition): """ @classmethod - def from_string(cls, definition, non_int_type=float): + def from_string( + cls, definition: Union[str, PreprocessedDefinition], non_int_type: type = float + ) -> "PrefixDefinition": if isinstance(definition, str): definition = PreprocessedDefinition.from_string(definition) @@ -226,14 +237,24 @@ class UnitDefinition(Definition): """ - def __init__(self, name, symbol, aliases, converter, reference=None, is_base=False): + def __init__( + self, + name: str, + symbol: Optional[str], + aliases: Iterable[str], + converter: Converter, + reference: Optional[UnitsContainer] = None, + is_base: bool = False, + ) -> None: self.reference = reference self.is_base = is_base super().__init__(name, symbol, aliases, converter) @classmethod - def from_string(cls, definition, non_int_type=float): + def from_string( + cls, definition: Union[str, PreprocessedDefinition], non_int_type: type = float + ) -> "UnitDefinition": if isinstance(definition, str): definition = PreprocessedDefinition.from_string(definition) @@ -305,14 +326,24 @@ class DimensionDefinition(Definition): [density] = [mass] / [volume] """ - def __init__(self, name, symbol, aliases, converter, reference=None, is_base=False): + def __init__( + self, + name: str, + symbol: Optional[str], + aliases: Iterable[str], + converter: Optional[Union[Callable, Converter]], + reference: Optional[UnitsContainer] = None, + is_base: bool = False, + ) -> None: self.reference = reference self.is_base = is_base super().__init__(name, symbol, aliases, converter=None) @classmethod - def from_string(cls, definition, non_int_type=float): + def from_string( + cls, definition: Union[str, PreprocessedDefinition], non_int_type: type = float + ) -> "DimensionDefinition": if isinstance(definition, str): definition = PreprocessedDefinition.from_string(definition) @@ -350,11 +381,13 @@ class AliasDefinition(Definition): @alias meter = my_meter """ - def __init__(self, name, aliases): + def __init__(self, name: str, aliases: Iterable[str]) -> None: super().__init__(name=name, symbol=None, aliases=aliases, converter=None) @classmethod - def from_string(cls, definition, non_int_type=float): + def from_string( + cls, definition: Union[str, PreprocessedDefinition], non_int_type: type = float + ) -> "AliasDefinition": if isinstance(definition, str): definition = PreprocessedDefinition.from_string(definition) diff --git a/pint/formatting.py b/pint/formatting.py index afc51fe8e..57de8834c 100644 --- a/pint/formatting.py +++ b/pint/formatting.py @@ -118,9 +118,9 @@ def _pretty_fmt_exponent(num): def formatter( - items, - as_ratio=True, - single_denominator=False, + items: list, + as_ratio: bool = True, + single_denominator: bool = False, product_fmt=" * ", division_fmt=" / ", power_fmt="{} ** {}", diff --git a/pint/quantity.py b/pint/quantity.py index 82cc1af39..0665337c7 100644 --- a/pint/quantity.py +++ b/pint/quantity.py @@ -17,7 +17,24 @@ import operator import re import warnings -from typing import List +from decimal import Decimal +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generic, + Iterable, + Iterator, + List, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + overload, +) from packaging import version @@ -70,6 +87,15 @@ to_units_container, ) +if TYPE_CHECKING: + import numpy as np # noqa: F811 + + from pint.context import Context + from pint.registry import BaseRegistry + + from .measurement import Measurement + from .unit import Unit + class _Exception(Exception): # pragma: no cover def __init__(self, internal): @@ -157,7 +183,13 @@ def printoptions(*args, **kwargs): np.set_printoptions(**opts) -class Quantity(PrettyIPython, SharedRegistryObject): +from ._typing import Shape, UnitLike + +Magnitude = TypeVar("Magnitude", bound=Union[float, int, Decimal, "np.ndarray"]) +TQuantity = TypeVar("TQuantity", bound="Quantity") + + +class Quantity(PrettyIPython, SharedRegistryObject, Generic[Magnitude]): """Implements a class to describe a physical quantity: the product of a numerical value and a unit of measurement. @@ -177,18 +209,27 @@ class Quantity(PrettyIPython, SharedRegistryObject): default_format = "" @property - def force_ndarray(self): + def force_ndarray(self) -> bool: return self._REGISTRY.force_ndarray @property - def force_ndarray_like(self): + def force_ndarray_like(self) -> bool: return self._REGISTRY.force_ndarray_like @property - def UnitsContainer(self): + def UnitsContainer(self) -> Type[UnitsContainer]: return self._REGISTRY.UnitsContainer - def __reduce__(self): + @classmethod + def build_quantity_class( + cls: Type["Quantity"], registry: "BaseRegistry" + ) -> Type["Quantity[Magnitude]"]: + class Quantity(cls): + _REGISTRY = registry + + return Quantity + + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: """Allow pickling quantities. Since UnitRegistries are not pickled, upon unpickling the new object is always attached to the application registry. """ @@ -198,6 +239,22 @@ def __reduce__(self): # build_quantity_class can't be pickled return _unpickle_quantity, (Quantity, self.magnitude, self._units) + @overload + def __new__(cls, value: str, units: Optional[UnitLike] = None) -> "Quantity[Any]": + ... + + @overload + def __new__( + cls, value: Union[list, tuple], units: Optional[UnitLike] = None + ) -> "Quantity[np.ndarray]": + ... + + @overload + def __new__( + cls, value: Magnitude, units: Optional[UnitLike] = None + ) -> "Quantity[Magnitude]": + ... + def __new__(cls, value, units=None): if is_upcast_type(type(value)): raise TypeError(f"Quantity cannot wrap upcast type {type(value)}") @@ -254,10 +311,10 @@ def __new__(cls, value, units=None): return inst @property - def debug_used(self): + def debug_used(self) -> bool: return self.__used - def __iter__(self): + def __iter__(self: "Quantity[Iterable]") -> Iterator: # Make sure that, if self.magnitude is not iterable, we raise TypeError as soon # as one calls iter(self) without waiting for the first element to be drawn from # the iterator @@ -269,34 +326,34 @@ def it_outer(): return it_outer() - def __copy__(self): + def __copy__(self: TQuantity) -> TQuantity: ret = self.__class__(copy.copy(self._magnitude), self._units) ret.__used = self.__used return ret - def __deepcopy__(self, memo): + def __deepcopy__(self: TQuantity, memo: Optional[dict]) -> TQuantity: ret = self.__class__( copy.deepcopy(self._magnitude, memo), copy.deepcopy(self._units, memo) ) ret.__used = self.__used return ret - def __str__(self): + def __str__(self) -> str: if self._REGISTRY.fmt_locale is not None: return self.format_babel() return format(self) - def __bytes__(self): + def __bytes__(self) -> bytes: return str(self).encode(locale.getpreferredencoding()) - def __repr__(self): + def __repr__(self) -> str: if isinstance(self._magnitude, float): return f"" else: return f"" - def __hash__(self): + def __hash__(self) -> int: self_base = self.to_base_units() if self_base.dimensionless: return hash(self_base.magnitude) @@ -305,7 +362,7 @@ def __hash__(self): _exp_pattern = re.compile(r"([0-9]\.?[0-9]*)e(-?)\+?0*([0-9]+)") - def __format__(self, spec): + def __format__(self, spec: str) -> str: if self._REGISTRY.fmt_locale is not None: return self.format_babel(spec) @@ -407,7 +464,7 @@ def _repr_pretty_(self, p, cycle): p.text(" ") p.pretty(self.units) - def format_babel(self, spec="", **kwspec): + def format_babel(self, spec: str = "", **kwspec: Any) -> str: spec = spec or self.default_format # standard cases @@ -432,16 +489,16 @@ def format_babel(self, spec="", **kwspec): ).replace("\n", "") @property - def magnitude(self): + def magnitude(self) -> Magnitude: """Quantity's magnitude. Long form for `m`""" return self._magnitude @property - def m(self): + def m(self) -> Magnitude: """Quantity's magnitude. Short form for `magnitude`""" return self._magnitude - def m_as(self, units): + def m_as(self, units) -> Magnitude: """Quantity's magnitude expressed in particular units. Parameters @@ -456,22 +513,22 @@ def m_as(self, units): return self.to(units).magnitude @property - def units(self): + def units(self) -> "Unit": """Quantity's units. Long form for `u`""" return self._REGISTRY.Unit(self._units) @property - def u(self): + def u(self) -> "Unit": """Quantity's units. Short form for `units`""" return self._REGISTRY.Unit(self._units) @property - def unitless(self): + def unitless(self) -> bool: """ """ return not bool(self.to_root_units()._units) @property - def dimensionless(self): + def dimensionless(self) -> bool: """ """ tmp = self.to_root_units() @@ -480,7 +537,7 @@ def dimensionless(self): _dimensionality = None @property - def dimensionality(self): + def dimensionality(self) -> UnitsContainer: """ Returns ------- @@ -492,12 +549,14 @@ def dimensionality(self): return self._dimensionality - def check(self, dimension): + def check(self, dimension) -> bool: """Return true if the quantity's dimension matches passed dimension.""" return self.dimensionality == self._REGISTRY.get_dimensionality(dimension) @classmethod - def from_list(cls, quant_list, units=None): + def from_list( + cls, quant_list: "List[Quantity]", units: Optional[UnitLike] = None + ) -> "Quantity[np.ndarray]": """Transforms a list of Quantities into an numpy.array quantity. If no units are specified, the unit of the first element will be used. Same as from_sequence. @@ -519,7 +578,11 @@ def from_list(cls, quant_list, units=None): return cls.from_sequence(quant_list, units=units) @classmethod - def from_sequence(cls, seq, units=None): + def from_sequence( + cls, + seq: "Sequence[Quantity]", + units: Optional[UnitLike] = None, + ) -> "Quantity[np.ndarray]": """Transforms a sequence of Quantities into an numpy.array quantity. If no units are specified, the unit of the first element will be used. @@ -554,20 +617,25 @@ def from_sequence(cls, seq, units=None): return cls(a, units) @classmethod - def from_tuple(cls, tup): + def from_tuple(cls, tup: Tuple[Magnitude, Tuple]) -> "Quantity[Magnitude]": return cls(tup[0], cls._REGISTRY.UnitsContainer(tup[1])) - def to_tuple(self): + def to_tuple(self) -> Tuple[Magnitude, Tuple]: return self.m, tuple(self._units.items()) - def compatible_units(self, *contexts): + def compatible_units(self, *contexts: Union[str, "Context"]): if contexts: with self._REGISTRY.context(*contexts): return self._REGISTRY.get_compatible_units(self._units) return self._REGISTRY.get_compatible_units(self._units) - def is_compatible_with(self, other, *contexts, **ctx_kwargs): + def is_compatible_with( + self, + other: Union["Unit", TQuantity, str], + *contexts: Union[str, "Context"], + **ctx_kwargs: Any, + ) -> bool: """check if the other object is compatible Parameters @@ -608,7 +676,9 @@ def _convert_magnitude_not_inplace(self, other, *contexts, **ctx_kwargs): return self._REGISTRY.convert(self._magnitude, self._units, other) - def _convert_magnitude(self, other, *contexts, **ctx_kwargs): + def _convert_magnitude( + self, other, *contexts: Union[str, "Context"], **ctx_kwargs + ) -> Magnitude: if contexts: with self._REGISTRY.context(*contexts, **ctx_kwargs): return self._REGISTRY.convert(self._magnitude, self._units, other) @@ -620,7 +690,7 @@ def _convert_magnitude(self, other, *contexts, **ctx_kwargs): inplace=is_duck_array_type(type(self._magnitude)), ) - def ito(self, other=None, *contexts, **ctx_kwargs): + def ito(self, other=None, *contexts: Union[str, "Context"], **ctx_kwargs) -> None: """Inplace rescale to different units. Parameters @@ -639,7 +709,9 @@ def ito(self, other=None, *contexts, **ctx_kwargs): return None - def to(self, other=None, *contexts, **ctx_kwargs): + def to( + self: TQuantity, other=None, *contexts: Union[str, "Context"], **ctx_kwargs + ) -> TQuantity: """Return Quantity rescaled to different units. Parameters @@ -661,7 +733,7 @@ def to(self, other=None, *contexts, **ctx_kwargs): return self.__class__(magnitude, other) - def ito_root_units(self): + def ito_root_units(self) -> None: """Return Quantity rescaled to root units.""" _, other = self._REGISTRY._get_root_units(self._units) @@ -671,7 +743,7 @@ def ito_root_units(self): return None - def to_root_units(self): + def to_root_units(self: TQuantity) -> TQuantity: """Return Quantity rescaled to root units.""" _, other = self._REGISTRY._get_root_units(self._units) @@ -680,7 +752,7 @@ def to_root_units(self): return self.__class__(magnitude, other) - def ito_base_units(self): + def ito_base_units(self) -> None: """Return Quantity rescaled to base units.""" _, other = self._REGISTRY._get_base_units(self._units) @@ -690,7 +762,7 @@ def ito_base_units(self): return None - def to_base_units(self): + def to_base_units(self: TQuantity) -> TQuantity: """Return Quantity rescaled to base units.""" _, other = self._REGISTRY._get_base_units(self._units) @@ -699,7 +771,7 @@ def to_base_units(self): return self.__class__(magnitude, other) - def ito_reduced_units(self): + def ito_reduced_units(self) -> None: """Return Quantity scaled in place to reduced units, i.e. one unit per dimension. This will not reduce compound units (e.g., 'J/kg' will not be reduced to m**2/s**2), nor can it make use of contexts at this time. @@ -727,7 +799,7 @@ def ito_reduced_units(self): return self.ito(newunits) - def to_reduced_units(self): + def to_reduced_units(self: TQuantity) -> TQuantity: """Return Quantity scaled in place to reduced units, i.e. one unit per dimension. This will not reduce compound units (intentionally), nor can it make use of contexts at this time. @@ -738,7 +810,7 @@ def to_reduced_units(self): newq.ito_reduced_units() return newq - def to_compact(self, unit=None): + def to_compact(self: TQuantity, unit: Optional[UnitLike] = None) -> TQuantity: """ "Return Quantity rescaled to compact, human-readable units. To get output in terms of a different unit, use the unit parameter. @@ -772,18 +844,18 @@ def to_compact(self, unit=None): ): return self - SI_prefixes = {} + SI_prefixes_map: Dict[int, str] = {} for prefix in self._REGISTRY._prefixes.values(): try: scale = prefix.converter.scale # Kludgy way to check if this is an SI prefix log10_scale = int(math.log10(scale)) if log10_scale == math.log10(scale): - SI_prefixes[log10_scale] = prefix.name + SI_prefixes_map[log10_scale] = prefix.name except Exception: - SI_prefixes[0] = "" + SI_prefixes_map[0] = "" - SI_prefixes = sorted(SI_prefixes.items()) + SI_prefixes: List[Tuple[int, str]] = sorted(SI_prefixes_map.items()) SI_powers = [item[0] for item in SI_prefixes] SI_bases = [item[1] for item in SI_prefixes] @@ -822,23 +894,23 @@ def to_compact(self, unit=None): return self.to(new_unit_container) # Mathematical operations - def __int__(self): + def __int__(self) -> int: if self.dimensionless: return int(self._convert_magnitude_not_inplace(UnitsContainer())) raise DimensionalityError(self._units, "dimensionless") - def __float__(self): + def __float__(self) -> float: if self.dimensionless: return float(self._convert_magnitude_not_inplace(UnitsContainer())) raise DimensionalityError(self._units, "dimensionless") - def __complex__(self): + def __complex__(self) -> complex: if self.dimensionless: return complex(self._convert_magnitude_not_inplace(UnitsContainer())) raise DimensionalityError(self._units, "dimensionless") @check_implemented - def _iadd_sub(self, other, op): + def _iadd_sub(self: TQuantity, other, op: Callable) -> TQuantity: """Perform addition or subtraction operation in-place and return the result. Parameters @@ -951,7 +1023,7 @@ def _iadd_sub(self, other, op): return self @check_implemented - def _add_sub(self, other, op): + def _add_sub(self: TQuantity, other, op: Callable) -> TQuantity: """Perform addition or subtraction operation and return the result. Parameters @@ -1063,6 +1135,14 @@ def _add_sub(self, other, op): return self.__class__(magnitude, units) + @overload + def __iadd__(self, other: datetime.datetime) -> datetime.timedelta: + ... + + @overload + def __iadd__(self: TQuantity, other) -> TQuantity: + ... + def __iadd__(self, other): if isinstance(other, datetime.datetime): return self.to_timedelta() + other @@ -1071,6 +1151,14 @@ def __iadd__(self, other): else: return self._add_sub(other, operator.add) + @overload + def __add__(self, other: datetime.datetime) -> datetime.timedelta: + ... + + @overload + def __add__(self: TQuantity, other) -> TQuantity: + ... + def __add__(self, other): if isinstance(other, datetime.datetime): return self.to_timedelta() + other @@ -1079,13 +1167,13 @@ def __add__(self, other): __radd__ = __add__ - def __isub__(self, other): + def __isub__(self: TQuantity, other) -> TQuantity: if is_duck_array_type(type(self._magnitude)): return self._iadd_sub(other, operator.isub) else: return self._add_sub(other, operator.sub) - def __sub__(self, other): + def __sub__(self: TQuantity, other) -> TQuantity: return self._add_sub(other, operator.sub) def __rsub__(self, other): @@ -1096,7 +1184,7 @@ def __rsub__(self, other): @check_implemented @ireduce_dimensions - def _imul_div(self, other, magnitude_op, units_op=None): + def _imul_div(self, other, magnitude_op: Callable, units_op=None): """Perform multiplication or division operation in-place and return the result. @@ -1167,7 +1255,7 @@ def _imul_div(self, other, magnitude_op, units_op=None): @check_implemented @ireduce_dimensions - def _mul_div(self, other, magnitude_op, units_op=None): + def _mul_div(self, other, magnitude_op: Callable, units_op=None): """Perform multiplication or division operation and return the result. Parameters @@ -1239,13 +1327,13 @@ def _mul_div(self, other, magnitude_op, units_op=None): return self.__class__(magnitude, units) - def __imul__(self, other): + def __imul__(self: TQuantity, other) -> TQuantity: if is_duck_array_type(type(self._magnitude)): return self._imul_div(other, operator.imul) else: return self._mul_div(other, operator.mul) - def __mul__(self, other): + def __mul__(self: TQuantity, other) -> TQuantity: return self._mul_div(other, operator.mul) __rmul__ = __mul__ @@ -1259,16 +1347,16 @@ def __matmul__(self, other): __rmatmul__ = __matmul__ - def __itruediv__(self, other): + def __itruediv__(self: TQuantity, other) -> TQuantity: if is_duck_array_type(type(self._magnitude)): return self._imul_div(other, operator.itruediv) else: return self._mul_div(other, operator.truediv) - def __truediv__(self, other): + def __truediv__(self: TQuantity, other) -> TQuantity: return self._mul_div(other, operator.truediv) - def __rtruediv__(self, other): + def __rtruediv__(self: TQuantity, other) -> TQuantity: try: other_magnitude = _to_magnitude( other, self.force_ndarray, self.force_ndarray_like @@ -1290,7 +1378,7 @@ def __rtruediv__(self, other): __rdiv__ = __rtruediv__ __idiv__ = __itruediv__ - def __ifloordiv__(self, other): + def __ifloordiv__(self: TQuantity, other) -> TQuantity: if self._check(other): self._magnitude //= other.to(self._units)._magnitude elif self.dimensionless: @@ -1301,7 +1389,7 @@ def __ifloordiv__(self, other): return self @check_implemented - def __floordiv__(self, other): + def __floordiv__(self: TQuantity, other) -> TQuantity: if self._check(other): magnitude = self._magnitude // other.to(self._units)._magnitude elif self.dimensionless: @@ -1311,7 +1399,7 @@ def __floordiv__(self, other): return self.__class__(magnitude, self.UnitsContainer({})) @check_implemented - def __rfloordiv__(self, other): + def __rfloordiv__(self: TQuantity, other) -> TQuantity: if self._check(other): magnitude = other._magnitude // self.to(other._units)._magnitude elif self.dimensionless: @@ -1321,21 +1409,21 @@ def __rfloordiv__(self, other): return self.__class__(magnitude, self.UnitsContainer({})) @check_implemented - def __imod__(self, other): + def __imod__(self: TQuantity, other) -> TQuantity: if not self._check(other): other = self.__class__(other, self.UnitsContainer({})) self._magnitude %= other.to(self._units)._magnitude return self @check_implemented - def __mod__(self, other): + def __mod__(self: TQuantity, other) -> TQuantity: if not self._check(other): other = self.__class__(other, self.UnitsContainer({})) magnitude = self._magnitude % other.to(self._units)._magnitude return self.__class__(magnitude, self._units) @check_implemented - def __rmod__(self, other): + def __rmod__(self: TQuantity, other) -> TQuantity: if self._check(other): magnitude = other._magnitude % self.to(other._units)._magnitude return self.__class__(magnitude, other._units) @@ -1346,7 +1434,7 @@ def __rmod__(self, other): raise DimensionalityError(self._units, "dimensionless") @check_implemented - def __divmod__(self, other): + def __divmod__(self: TQuantity, other) -> Tuple[TQuantity, TQuantity]: if not self._check(other): other = self.__class__(other, self.UnitsContainer({})) q, r = divmod(self._magnitude, other.to(self._units)._magnitude) @@ -1356,7 +1444,7 @@ def __divmod__(self, other): ) @check_implemented - def __rdivmod__(self, other): + def __rdivmod__(self: TQuantity, other) -> Tuple[TQuantity, TQuantity]: if self._check(other): q, r = divmod(other._magnitude, self.to(other._units)._magnitude) unit = other._units @@ -1432,7 +1520,7 @@ def __ipow__(self, other): return self @check_implemented - def __pow__(self, other): + def __pow__(self: TQuantity, other: Any) -> TQuantity: try: _to_magnitude(other, self.force_ndarray, self.force_ndarray_like) except PintTypeError: @@ -1497,7 +1585,7 @@ def __pow__(self, other): return self.__class__(magnitude, units) @check_implemented - def __rpow__(self, other): + def __rpow__(self: TQuantity, other) -> TQuantity: try: _to_magnitude(other, self.force_ndarray, self.force_ndarray_like) except PintTypeError: @@ -1510,16 +1598,16 @@ def __rpow__(self, other): new_self = self.to_root_units() return other ** new_self._magnitude - def __abs__(self): + def __abs__(self: TQuantity) -> TQuantity: return self.__class__(abs(self._magnitude), self._units) - def __round__(self, ndigits=0): + def __round__(self: TQuantity, ndigits: Optional[int] = 0) -> TQuantity: return self.__class__(round(self._magnitude, ndigits=ndigits), self._units) - def __pos__(self): + def __pos__(self: TQuantity) -> TQuantity: return self.__class__(operator.pos(self._magnitude), self._units) - def __neg__(self): + def __neg__(self: TQuantity) -> TQuantity: return self.__class__(operator.neg(self._magnitude), self._units) @check_implemented @@ -1628,7 +1716,7 @@ def compare(self, other, op): __ge__ = lambda self, other: self.compare(other, op=operator.ge) __gt__ = lambda self, other: self.compare(other, op=operator.gt) - def __bool__(self): + def __bool__(self) -> bool: # Only cast when non-ambiguous (when multiplicative unit) if self._is_multiplicative: return bool(self._magnitude) @@ -1696,7 +1784,7 @@ def _numpy_method_wrap(self, func, *args, **kwargs): else: return value - def __array__(self, t=None): + def __array__(self, t=None) -> "np.ndarray": warnings.warn( "The unit of the quantity is stripped when downcasting to ndarray.", UnitStrippedWarning, @@ -1748,15 +1836,15 @@ def put(self, indices, values, mode="raise"): self.magnitude.put(indices, values, mode) @property - def real(self): + def real(self: TQuantity) -> TQuantity: return self.__class__(self._magnitude.real, self._units) @property - def imag(self): + def imag(self: TQuantity) -> TQuantity: return self.__class__(self._magnitude.imag, self._units) @property - def T(self): + def T(self: "Quantity[np.ndarray]") -> "Quantity[np.ndarray]": return self.__class__(self._magnitude.T, self._units) @property @@ -1765,11 +1853,11 @@ def flat(self): yield self.__class__(v, self._units) @property - def shape(self): + def shape(self: "Quantity[np.ndarray]") -> Shape: return self._magnitude.shape @shape.setter - def shape(self, value): + def shape(self, value: Shape) -> None: self._magnitude.shape = value def searchsorted(self, v, side="left", sorter=None): @@ -1812,10 +1900,10 @@ def __ito_if_needed(self, to_units): self.ito(to_units) - def __len__(self): + def __len__(self: "Quantity[np.ndarray]") -> int: return len(self._magnitude) - def __getattr__(self, item): + def __getattr__(self, item) -> Any: if item.startswith("__array_"): # Handle array protocol attributes other than `__array__` raise AttributeError(f"Array protocol attribute {item} not available.") @@ -1857,7 +1945,7 @@ def __getitem__(self, key): "supports indexing".format(self._magnitude) ) - def __setitem__(self, key, value): + def __setitem__(self, key, value) -> None: try: if np.ma.is_masked(value) or math.isnan(value): self._magnitude[key] = value @@ -1894,7 +1982,7 @@ def __setitem__(self, key, value): "supports indexing" ) from exc - def tolist(self): + def tolist(self) -> Any: units = self._units try: @@ -1914,7 +2002,9 @@ def tolist(self): ) # Measurement support - def plus_minus(self, error, relative=False): + def plus_minus( + self, error: Union["Quantity[float]", float], relative: bool = False + ) -> "Measurement": if isinstance(error, self.__class__): if relative: raise ValueError("{} is not a valid relative error.".format(error)) @@ -1965,7 +2055,7 @@ def _has_compatible_delta(self, unit: str) -> bool: self._get_unit_definition(d).reference == offset_unit_dim for d in deltas ) - def _ok_for_muldiv(self, no_offset_units=None): + def _ok_for_muldiv(self, no_offset_units=None) -> bool: """Checks if Quantity object can be multiplied or divided""" is_ok = True @@ -1985,7 +2075,7 @@ def _ok_for_muldiv(self, no_offset_units=None): is_ok = False return is_ok - def to_timedelta(self): + def to_timedelta(self: "Quantity[float]") -> datetime.timedelta: return datetime.timedelta(microseconds=self.to("microseconds").magnitude) # Dask.array.Array ducking @@ -2074,11 +2164,11 @@ def visualize(self, **kwargs): visualize(self, **kwargs) -_Quantity = Quantity +# _Quantity = Quantity -def build_quantity_class(registry): - class Quantity(_Quantity): - _REGISTRY = registry +# def build_quantity_class(registry: "BaseRegistry") -> Type[TQuantity]: +# class Quantity(_Quantity): +# _REGISTRY = registry - return Quantity +# return Quantity diff --git a/pint/registry.py b/pint/registry.py index 6b2697f88..d49d66b16 100644 --- a/pint/registry.py +++ b/pint/registry.py @@ -45,6 +45,21 @@ from fractions import Fraction from io import StringIO from tokenize import NAME, NUMBER +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Mapping, + Optional, + Set, + Tuple, + Type, + Union, +) from . import registry_helpers, systems from .compat import babel_parse, tokenizer @@ -64,6 +79,7 @@ UndefinedUnitError, ) from .pint_eval import build_eval_tree +from .systems import Group, System from .util import ( ParserHelper, SourceIterator, @@ -85,6 +101,9 @@ # Backport for Python < 3.7 import importlib_resources +if TYPE_CHECKING: + from pint.unit import Unit + _BLOCK_RE = re.compile(r"[ (]") @@ -116,15 +135,15 @@ def __call__(self, *args, **kwargs): class RegistryCache: """Cache to speed up unit registries""" - def __init__(self): + def __init__(self) -> None: #: Maps dimensionality (UnitsContainer) to Units (str) - self.dimensional_equivalents = {} + self.dimensional_equivalents: Dict[UnitsContainer, Set[str]] = {} #: Maps dimensionality (UnitsContainer) to Dimensionality (UnitsContainer) - self.root_units = {} + self.root_units: Dict[UnitsContainer, UnitsContainer] = {} #: Maps dimensionality (UnitsContainer) to Units (UnitsContainer) - self.dimensionality = {} + self.dimensionality: Dict[UnitsContainer, UnitsContainer] = {} #: Cache the unit name associated to user input. ('mV' -> 'millivolt') - self.parse_unit = {} + self.parse_unit: Dict[str, str] = {} class ContextCacheOverlay: @@ -132,13 +151,17 @@ class ContextCacheOverlay: active contexts which contain unit redefinitions. """ - def __init__(self, registry_cache: RegistryCache): + def __init__(self, registry_cache: RegistryCache) -> None: self.dimensional_equivalents = registry_cache.dimensional_equivalents - self.root_units = {} + self.root_units: Dict[UnitsContainer, UnitsContainer] = {} self.dimensionality = registry_cache.dimensionality self.parse_unit = registry_cache.parse_unit +NON_INT_TYPE = Type[Union[float, Decimal, Fraction]] +PreprocessorType = Callable[[str], str] + + class BaseRegistry(metaclass=RegistryMeta): """Base class for all registries. @@ -179,23 +202,23 @@ class BaseRegistry(metaclass=RegistryMeta): #: Map context prefix to function #: type: Dict[str, (SourceIterator -> None)] - _parsers = None + _parsers: Dict[str, Callable[[SourceIterator], None]] = None #: Babel.Locale instance or None - fmt_locale = None + fmt_locale: Optional[str] = None def __init__( self, filename="", - force_ndarray=False, - force_ndarray_like=False, - on_redefinition="warn", - auto_reduce_dimensions=False, - preprocessors=None, - fmt_locale=None, - non_int_type=float, - case_sensitive=True, - ): + force_ndarray: bool = False, + force_ndarray_like: bool = False, + on_redefinition: str = "warn", + auto_reduce_dimensions: bool = False, + preprocessors: Optional[List[PreprocessorType]] = None, + fmt_locale: Optional[str] = None, + non_int_type: NON_INT_TYPE = float, + case_sensitive: bool = True, + ) -> None: self._register_parsers() self._init_dynamic_classes() @@ -221,47 +244,49 @@ def __init__( #: Map between name (string) and value (string) of defaults stored in the #: definitions file. - self._defaults = {} + self._defaults: Dict[str, str] = {} #: Map dimension name (string) to its definition (DimensionDefinition). - self._dimensions = {} + self._dimensions: Dict[str, DimensionDefinition] = {} #: Map unit name (string) to its definition (UnitDefinition). #: Might contain prefixed units. - self._units = {} + self._units: Mapping[str, UnitDefinition] = {} #: Map unit name in lower case (string) to a set of unit names with the right #: case. #: Does not contain prefixed units. #: e.g: 'hz' - > set('Hz', ) - self._units_casei = defaultdict(set) + self._units_casei: Dict[str, Set[str]] = defaultdict(set) #: Map prefix name (string) to its definition (PrefixDefinition). - self._prefixes = {"": PrefixDefinition("", "", (), 1)} + self._prefixes: Dict[str, PrefixDefinition] = { + "": PrefixDefinition("", "", (), 1) + } #: Map suffix name (string) to canonical , and unit alias to canonical unit name - self._suffixes = {"": "", "s": ""} + self._suffixes: Dict[str, str] = {"": "", "s": ""} #: Map contexts to RegistryCache self._cache = RegistryCache() self._initialized = False - def _init_dynamic_classes(self): + def _init_dynamic_classes(self) -> None: """Generate subclasses on the fly and attach them to self""" from .unit import build_unit_class self.Unit = build_unit_class(self) - from .quantity import build_quantity_class + from .quantity import Quantity - self.Quantity = build_quantity_class(self) + self.Quantity = Quantity.build_quantity_class(self) from .measurement import build_measurement_class self.Measurement = build_measurement_class(self) - def _after_init(self): + def _after_init(self) -> None: """This should be called after all __init__""" if self._filename == "": @@ -272,7 +297,7 @@ def _after_init(self): self._build_cache() self._initialized = True - def _register_parsers(self): + def _register_parsers(self) -> None: self._register_parser("@defaults", self._parse_defaults) def _parse_defaults(self, ifile): @@ -282,7 +307,7 @@ def _parse_defaults(self, ifile): k, v = part.split("=") self._defaults[k.strip()] = v.strip() - def __deepcopy__(self, memo): + def __deepcopy__(self, memo) -> "BaseRegistry": new = object.__new__(type(self)) new.__dict__ = copy.deepcopy(self.__dict__, memo) new._init_dynamic_classes() @@ -299,7 +324,7 @@ def __getitem__(self, item): ) return self.parse_expression(item) - def __contains__(self, item): + def __contains__(self, item) -> bool: """Support checking prefixed units with the `in` operator""" try: self.__getattr__(item) @@ -307,12 +332,12 @@ def __contains__(self, item): except UndefinedUnitError: return False - def __dir__(self): + def __dir__(self) -> List[str]: #: Calling dir(registry) gives all units, methods, and attributes. #: Also used for autocompletion in IPython. return list(self._units.keys()) + list(object.__dir__(self)) - def __iter__(self): + def __iter__(self) -> Iterator[str]: """Allows for listing all units in registry with `list(ureg)`. Returns @@ -321,7 +346,7 @@ def __iter__(self): """ return iter(sorted(self._units.keys())) - def set_fmt_locale(self, loc): + def set_fmt_locale(self, loc: Optional[str]) -> None: """Change the locale used by default by `format_babel`. Parameters @@ -338,11 +363,11 @@ def set_fmt_locale(self, loc): self.fmt_locale = loc - def UnitsContainer(self, *args, **kwargs): + def UnitsContainer(self, *args, **kwargs) -> UnitsContainer: return UnitsContainer(*args, non_int_type=self.non_int_type, **kwargs) @property - def default_format(self): + def default_format(self) -> str: """Default formatting string for quantities.""" return self.Quantity.default_format @@ -351,7 +376,7 @@ def default_format(self, value): self.Unit.default_format = value self.Quantity.default_format = value - def define(self, definition): + def define(self, definition: Union[str, Definition]) -> None: """Add unit to the registry. Parameters @@ -366,7 +391,9 @@ def define(self, definition): else: self._define(definition) - def _define(self, definition): + def _define( + self, definition: Union[str, Definition] + ) -> Tuple[Definition, dict, dict]: """Add unit to the registry. This method defines only multiplicative units, converting any other type @@ -493,7 +520,7 @@ def _define_alias(self, definition, unit_dict, casei_unit_dict): unit_dict[alias] = unit casei_unit_dict[alias.lower()].add(alias) - def _register_parser(self, prefix, parserfunc): + def _register_parser(self, prefix: str, parserfunc): """Register a loader for a given @ directive.. Parameters @@ -515,7 +542,7 @@ def _register_parser(self, prefix, parserfunc): else: raise ValueError("Prefix directives must start with '@'") - def load_definitions(self, file, is_resource=False): + def load_definitions(self, file, is_resource=False) -> None: """Add units and prefixes defined in a definition text file. Parameters @@ -590,7 +617,7 @@ def load_definitions(self, file, is_resource=False): except Exception as ex: logger.error("In line {}, cannot add '{}' {}".format(no, line, ex)) - def _build_cache(self): + def _build_cache(self) -> None: """Build a cache of dimensionality and base units.""" self._cache = RegistryCache() @@ -665,7 +692,7 @@ def get_name(self, name_or_alias, case_sensitive=None): return unit_name - def get_symbol(self, name_or_alias, case_sensitive=None): + def get_symbol(self, name_or_alias: str, case_sensitive=None) -> str: """Return the preferred alias for a unit.""" candidates = self.parse_unit_name(name_or_alias, case_sensitive) if not candidates: @@ -695,7 +722,7 @@ def get_dimensionality(self, input_units): return self._get_dimensionality(input_units) - def _get_dimensionality(self, input_units): + def _get_dimensionality(self, input_units) -> UnitsContainer: """Convert a UnitsContainer to base dimensions.""" if not input_units: return self.UnitsContainer() @@ -997,7 +1024,9 @@ def _convert(self, value, src, dst, inplace=False, check_dimensionality=True): return value - def parse_unit_name(self, unit_name, case_sensitive=None): + def parse_unit_name( + self, unit_name: str, case_sensitive: Optional[bool] = None + ) -> Tuple[Tuple[str, str, str], ...]: """Parse a unit to identify prefix, unit name and suffix by walking the list of prefix and suffix. In case of equivalent combinations (e.g. ('kilo', 'gram', '') and @@ -1020,7 +1049,9 @@ def parse_unit_name(self, unit_name, case_sensitive=None): self._parse_unit_name(unit_name, case_sensitive=case_sensitive) ) - def _parse_unit_name(self, unit_name, case_sensitive=None): + def _parse_unit_name( + self, unit_name: str, case_sensitive: Optional[bool] = None + ) -> Iterator[Tuple[str, str, str]]: """Helper of parse_unit_name.""" case_sensitive = ( self.case_sensitive if case_sensitive is None else case_sensitive @@ -1050,7 +1081,9 @@ def _parse_unit_name(self, unit_name, case_sensitive=None): ) @staticmethod - def _dedup_candidates(candidates): + def _dedup_candidates( + candidates: Iterable[Tuple[str, str, str]] + ) -> Tuple[Tuple[str, str, str], ...]: """Helper of parse_unit_name. Given an iterable of unit triplets (prefix, name, suffix), remove those with @@ -1068,7 +1101,12 @@ def _dedup_candidates(candidates): candidates.pop(("", cp + cu, ""), None) return tuple(candidates) - def parse_units(self, input_string, as_delta=None, case_sensitive=None): + def parse_units( + self, + input_string: str, + as_delta: Optional[bool] = None, + case_sensitive: Optional[bool] = None, + ) -> "Unit": """Parse a units expression and returns a UnitContainer with the canonical names. @@ -1093,7 +1131,9 @@ def parse_units(self, input_string, as_delta=None, case_sensitive=None): units = self._parse_units(input_string, as_delta, case_sensitive) return self.Unit(units) - def _parse_units(self, input_string, as_delta=True, case_sensitive=None): + def _parse_units( + self, input_string, as_delta=True, case_sensitive=None + ) -> "UnitsContainer": """Parse a units expression and returns a UnitContainer with the canonical names. """ @@ -1165,7 +1205,12 @@ def _eval_token(self, token, case_sensitive=None, use_decimal=False, **values): raise Exception("unknown token type") def parse_pattern( - self, input_string, pattern, case_sensitive=None, use_decimal=False, many=False + self, + input_string: str, + pattern: str, + case_sensitive: Optional[None] = None, + use_decimal: bool = False, + many: bool = False, ): """Parse a string with a given regex pattern and returns result. @@ -1221,7 +1266,11 @@ def parse_pattern( return results def parse_expression( - self, input_string, case_sensitive=None, use_decimal=False, **values + self, + input_string: str, + case_sensitive: Optional[bool] = None, + use_decimal: bool = False, + **values, ): """Parse a mathematical expression including units and return a quantity object. @@ -1286,8 +1335,11 @@ class NonMultiplicativeRegistry(BaseRegistry): """ def __init__( - self, default_as_delta=True, autoconvert_offset_to_baseunit=False, **kwargs - ): + self, + default_as_delta: bool = True, + autoconvert_offset_to_baseunit: bool = False, + **kwargs: Any, + ) -> None: super().__init__(**kwargs) #: When performing a multiplication of units, interpret @@ -1298,14 +1350,19 @@ def __init__( # base units on multiplication and division. self.autoconvert_offset_to_baseunit = autoconvert_offset_to_baseunit - def _parse_units(self, input_string, as_delta=None, case_sensitive=None): + def _parse_units( + self, + input_string: str, + as_delta: Optional[bool] = None, + case_sensitive: Optional[bool] = None, + ): """""" if as_delta is None: as_delta = self.default_as_delta return super()._parse_units(input_string, as_delta, case_sensitive) - def _define(self, definition): + def _define(self, definition: Union[str, Definition]): """Add unit to the registry. In addition to what is done by the BaseRegistry, @@ -1331,7 +1388,7 @@ def _define(self, definition): return definition, d, di - def _is_multiplicative(self, u): + def _is_multiplicative(self, u) -> bool: if u in self._units: return self._units[u].is_multiplicative @@ -1479,9 +1536,9 @@ class ContextRegistry(BaseRegistry): - Parse @context directive. """ - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: # Map context name (string) or abbreviation to context. - self._contexts = {} + self._contexts: Dict[str, Context] = {} # Stores active contexts. self._active_ctx = ContextChain() # Map context chain to cache @@ -1492,13 +1549,13 @@ def __init__(self, **kwargs): super().__init__(**kwargs) # Allow contexts to add override layers to the units - self._units = ChainMap(self._units) + self._units: "ChainMap[str, UnitDefinition]" = ChainMap(self._units) - def _register_parsers(self): + def _register_parsers(self) -> None: super()._register_parsers() self._register_parser("@context", self._parse_context) - def _parse_context(self, ifile): + def _parse_context(self, ifile) -> None: try: self.add_context( Context.from_lines( @@ -1630,7 +1687,9 @@ def _redefine(self, definition: UnitDefinition) -> None: # Write into the context-specific self._units.maps[0] and self._cache.root_units self.define(definition) - def enable_contexts(self, *names_or_contexts, **kwargs) -> None: + def enable_contexts( + self, *names_or_contexts: Union[str, Context], **kwargs + ) -> None: """Enable contexts provided by name or by object. Parameters @@ -1670,10 +1729,10 @@ def enable_contexts(self, *names_or_contexts, **kwargs) -> None: ctx.checked = True # and create a new one with the new defaults. - ctxs = tuple(Context.from_context(ctx, **kwargs) for ctx in ctxs) + contexts = tuple(Context.from_context(ctx, **kwargs) for ctx in ctxs) # Finally we add them to the active context. - self._active_ctx.insert_contexts(*ctxs) + self._active_ctx.insert_contexts(*contexts) self._switch_context_cache_and_units() def disable_contexts(self, n: int = None) -> None: @@ -1745,7 +1804,7 @@ def context(self, *names, **kwargs): # the added contexts are removed from the active one. self.disable_contexts(len(names)) - def with_context(self, name, **kwargs): + def with_context(self, name, **kwargs) -> Callable: """Decorator to wrap a function call in a Pint context. Use it to ensure that a certain context is active when @@ -1847,7 +1906,7 @@ def _get_compatible_units(self, input_units, group_or_system): class SystemRegistry(BaseRegistry): """Handle of Systems and Groups. - Conversion between units with different dimenstions according + Conversion between units with different dimensions according to previously established relations (contexts). (e.g. in the spectroscopy, conversion between frequency and energy is possible) @@ -1864,14 +1923,14 @@ def __init__(self, system=None, **kwargs): #: Map system name to system. #: :type: dict[ str | System] - self._systems = {} + self._systems: Dict[str, System] = {} #: Maps dimensionality (UnitsContainer) to Dimensionality (UnitsContainer) self._base_units_cache = dict() #: Map group name to group. #: :type: dict[ str | Group] - self._groups = {} + self._groups: Dict[str, Group] = {} self._groups["root"] = self.Group("root") self._default_system = system @@ -1907,20 +1966,20 @@ def _after_init(self): "system", None ) - def _register_parsers(self): + def _register_parsers(self) -> None: super()._register_parsers() self._register_parser("@group", self._parse_group) self._register_parser("@system", self._parse_system) - def _parse_group(self, ifile): + def _parse_group(self, ifile) -> None: self.Group.from_lines(ifile.block_iter(), self.define, self.non_int_type) - def _parse_system(self, ifile): + def _parse_system(self, ifile) -> None: self.System.from_lines( ifile.block_iter(), self.get_root_units, self.non_int_type ) - def get_group(self, name, create_if_needed=True): + def get_group(self, name, create_if_needed=True) -> Group: """Return a Group. Parameters @@ -1962,7 +2021,7 @@ def default_system(self, name): self._default_system = name - def get_system(self, name, create_if_needed=True): + def get_system(self, name: str, create_if_needed: bool = True) -> System: """Return a Group. Parameters @@ -2176,7 +2235,7 @@ def pi_theorem(self, quantities): """ return pi_theorem(quantities, self) - def setup_matplotlib(self, enable=True): + def setup_matplotlib(self, enable: bool = True) -> None: """Set up handlers for matplotlib's unit support. Parameters diff --git a/pint/systems.py b/pint/systems.py index 881b83e44..16ec00f5d 100644 --- a/pint/systems.py +++ b/pint/systems.py @@ -99,7 +99,7 @@ def members(self): return self._computed_members - def invalidate_members(self): + def invalidate_members(self) -> None: """Invalidate computed members in this Group and all parent nodes.""" self._computed_members = None d = self._REGISTRY._groups diff --git a/pint/unit.py b/pint/unit.py index f104c83cd..cb28e36dd 100644 --- a/pint/unit.py +++ b/pint/unit.py @@ -12,6 +12,9 @@ import locale import operator from numbers import Number +from typing import TYPE_CHECKING, Any, Callable, Tuple, Type, Union + +from pint.quantity import Magnitude from .compat import NUMERIC_TYPES, is_upcast_type from .definitions import UnitDefinition @@ -19,6 +22,11 @@ from .formatting import siunitx_format_unit from .util import PrettyIPython, SharedRegistryObject, UnitsContainer +if TYPE_CHECKING: + from .context import Context + from .quantity import Quantity + from .registry import BaseRegistry + class Unit(PrettyIPython, SharedRegistryObject): """Implements a class to describe a unit supporting math operations.""" @@ -26,13 +34,13 @@ class Unit(PrettyIPython, SharedRegistryObject): #: Default formatting string. default_format = "" - def __reduce__(self): + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: # See notes in Quantity.__reduce__ from . import _unpickle_unit return _unpickle_unit, (Unit, self._units) - def __init__(self, units): + def __init__(self, units) -> None: super().__init__() if isinstance(units, (UnitsContainer, UnitDefinition)): self._units = units @@ -50,29 +58,29 @@ def __init__(self, units): self.__handling = None @property - def debug_used(self): + def debug_used(self) -> bool: return self.__used - def __copy__(self): + def __copy__(self) -> "Unit": ret = self.__class__(self._units) ret.__used = self.__used return ret - def __deepcopy__(self, memo): + def __deepcopy__(self, memo) -> "Unit": ret = self.__class__(copy.deepcopy(self._units, memo)) ret.__used = self.__used return ret - def __str__(self): + def __str__(self) -> str: return format(self) - def __bytes__(self): + def __bytes__(self) -> bytes: return str(self).encode(locale.getpreferredencoding()) - def __repr__(self): + def __repr__(self) -> str: return "".format(self._units) - def __format__(self, spec): + def __format__(self, spec: str) -> str: spec = spec or self.default_format # special cases if "Lx" in spec: # the LaTeX siunitx code @@ -93,7 +101,7 @@ def __format__(self, spec): return format(units, spec) - def format_babel(self, spec="", **kwspec): + def format_babel(self, spec: str = "", **kwspec: Any) -> str: spec = spec or self.default_format if "~" in spec: @@ -112,12 +120,12 @@ def format_babel(self, spec="", **kwspec): return "%s" % (units.format_babel(spec, **kwspec)) @property - def dimensionless(self): + def dimensionless(self) -> bool: """Return True if the Unit is dimensionless; False otherwise.""" return not bool(self.dimensionality) @property - def dimensionality(self): + def dimensionality(self) -> UnitsContainer: """ Returns ------- @@ -139,7 +147,9 @@ def compatible_units(self, *contexts): return self._REGISTRY.get_compatible_units(self) - def is_compatible_with(self, other, *contexts, **ctx_kwargs): + def is_compatible_with( + self, other: Any, *contexts: Union[str, "Context"], **ctx_kwargs: Any + ) -> bool: """check if the other object is compatible Parameters @@ -173,7 +183,7 @@ def is_compatible_with(self, other, *contexts, **ctx_kwargs): return self.dimensionless - def __mul__(self, other): + def __mul__(self, other) -> Union["Quantity", "Unit"]: if self._check(other): if isinstance(other, self.__class__): return self.__class__(self._units * other._units) @@ -188,7 +198,7 @@ def __mul__(self, other): __rmul__ = __mul__ - def __truediv__(self, other): + def __truediv__(self, other) -> Union["Quantity", "Unit"]: if self._check(other): if isinstance(other, self.__class__): return self.__class__(self._units / other._units) @@ -198,7 +208,7 @@ def __truediv__(self, other): return self._REGISTRY.Quantity(1 / other, self._units) - def __rtruediv__(self, other): + def __rtruediv__(self, other) -> Union["Quantity", "Unit"]: # As Unit and Quantity both handle truediv with each other rtruediv can # only be called for something different. if isinstance(other, NUMERIC_TYPES): @@ -211,7 +221,7 @@ def __rtruediv__(self, other): __div__ = __truediv__ __rdiv__ = __rtruediv__ - def __pow__(self, other): + def __pow__(self, other) -> "Unit": if isinstance(other, NUMERIC_TYPES): return self.__class__(self._units ** other) @@ -219,10 +229,10 @@ def __pow__(self, other): mess = "Cannot power Unit by {}".format(type(other)) raise TypeError(mess) - def __hash__(self): + def __hash__(self) -> int: return self._units.__hash__() - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: # We compare to the base class of Unit because each Unit class is # unique. if self._check(other): @@ -237,10 +247,10 @@ def __eq__(self, other): else: return self._units == other - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not (self == other) - def compare(self, other, op): + def compare(self, other: Any, op: Callable[[Any], bool]) -> bool: self_q = self._REGISTRY.Quantity(1, self) if isinstance(other, NUMERIC_TYPES): @@ -255,13 +265,13 @@ def compare(self, other, op): __ge__ = lambda self, other: self.compare(other, op=operator.ge) __gt__ = lambda self, other: self.compare(other, op=operator.gt) - def __int__(self): + def __int__(self) -> int: return int(self._REGISTRY.Quantity(1, self._units)) - def __float__(self): + def __float__(self) -> float: return float(self._REGISTRY.Quantity(1, self._units)) - def __complex__(self): + def __complex__(self) -> complex: return complex(self._REGISTRY.Quantity(1, self._units)) __array_priority__ = 17 @@ -302,7 +312,7 @@ def systems(self): out.add(sname) return frozenset(out) - def from_(self, value, strict=True, name="value"): + def from_(self, value, strict: bool = True, name: str = "value") -> "Quantity": """Converts a numerical value or quantity to this unit Parameters @@ -329,7 +339,9 @@ def from_(self, value, strict=True, name="value"): else: return value * self - def m_from(self, value, strict=True, name="value"): + def m_from( + self, value: "Quantity[Magnitude]", strict: bool = True, name: str = "value" + ) -> Magnitude: """Converts a numerical value or quantity to this unit, then returns the magnitude of the converted value @@ -354,7 +366,7 @@ def m_from(self, value, strict=True, name="value"): _Unit = Unit -def build_unit_class(registry): +def build_unit_class(registry: "BaseRegistry") -> Type[Unit]: class Unit(_Unit): _REGISTRY = registry diff --git a/pint/util.py b/pint/util.py index f2162f456..39736b0f9 100644 --- a/pint/util.py +++ b/pint/util.py @@ -18,12 +18,16 @@ from logging import NullHandler from numbers import Number from token import NAME, NUMBER +from typing import TYPE_CHECKING, Any, ClassVar, Optional from .compat import NUMERIC_TYPES, tokenizer from .errors import DefinitionSyntaxError from .formatting import format_unit from .pint_eval import build_eval_tree +if TYPE_CHECKING: + from .registry import BaseRegistry + logger = logging.getLogger(__name__) logger.addHandler(NullHandler()) @@ -345,10 +349,10 @@ def __init__(self, *args, **kwargs): d[key] = self._non_int_type(value) self._hash = None - def copy(self): + def copy(self) -> "UnitsContainer": return self.__copy__() - def add(self, key, value): + def add(self, key, value) -> "UnitsContainer": newval = self._d[key] + value new = self.copy() if newval: @@ -376,7 +380,7 @@ def remove(self, keys): new._hash = None return new - def rename(self, oldkey, newkey): + def rename(self, oldkey: str, newkey: str) -> "UnitsContainer": """Create a new UnitsContainer in which an entry has been renamed. Parameters @@ -398,16 +402,16 @@ def rename(self, oldkey, newkey): def __iter__(self): return iter(self._d) - def __len__(self): + def __len__(self) -> int: return len(self._d) def __getitem__(self, key): return self._d[key] - def __contains__(self, key): + def __contains__(self, key) -> bool: return key in self._d - def __hash__(self): + def __hash__(self) -> int: if self._hash is None: self._hash = hash(frozenset(self._d.items())) return self._hash @@ -416,10 +420,10 @@ def __hash__(self): def __getstate__(self): return self._d, self._hash, self._one, self._non_int_type - def __setstate__(self, state): + def __setstate__(self, state) -> None: self._d, self._hash, self._one, self._non_int_type = state - def __eq__(self, other): + def __eq__(self, other) -> bool: if isinstance(other, UnitsContainer): # UnitsContainer.__hash__(self) is not the same as hash(self); see # ParserHelper.__hash__ and __eq__. @@ -440,22 +444,22 @@ def __eq__(self, other): return dict.__eq__(self._d, other) - def __str__(self): + def __str__(self) -> str: return self.__format__("") - def __repr__(self): + def __repr__(self) -> str: tmp = "{%s}" % ", ".join( ["'{}': {}".format(key, value) for key, value in sorted(self._d.items())] ) return "".format(tmp) - def __format__(self, spec): + def __format__(self, spec) -> str: return format_unit(self, spec) - def format_babel(self, spec, **kwspec): + def format_babel(self, spec, **kwspec) -> str: return format_unit(self, spec, **kwspec) - def __copy__(self): + def __copy__(self) -> "UnitsContainer": # Skip expensive health checks performed by __init__ out = object.__new__(self.__class__) out._d = self._d.copy() @@ -464,7 +468,7 @@ def __copy__(self): out._one = self._one return out - def __mul__(self, other): + def __mul__(self, other) -> "UnitsContainer": if not isinstance(other, self.__class__): err = "Cannot multiply UnitsContainer by {}" raise TypeError(err.format(type(other))) @@ -480,7 +484,7 @@ def __mul__(self, other): __rmul__ = __mul__ - def __pow__(self, other): + def __pow__(self, other) -> "UnitsContainer": if not isinstance(other, NUMERIC_TYPES): err = "Cannot power UnitsContainer by {}" raise TypeError(err.format(type(other))) @@ -491,7 +495,7 @@ def __pow__(self, other): new._hash = None return new - def __truediv__(self, other): + def __truediv__(self, other) -> "UnitsContainer": if not isinstance(other, self.__class__): err = "Cannot divide UnitsContainer by {}" raise TypeError(err.format(type(other))) @@ -505,7 +509,7 @@ def __truediv__(self, other): new._hash = None return new - def __rtruediv__(self, other): + def __rtruediv__(self, other) -> "UnitsContainer": if not isinstance(other, self.__class__) and other != 1: err = "Cannot divide {} by UnitsContainer" raise TypeError(err.format(type(other))) @@ -533,7 +537,7 @@ class ParserHelper(UnitsContainer): __slots__ = ("scale",) - def __init__(self, scale=1, *args, **kwargs): + def __init__(self, scale=1, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.scale = scale @@ -763,7 +767,7 @@ def __rtruediv__(self, other): _pretty_exp_re = re.compile(r"⁻?[⁰¹²³⁴⁵⁶⁷⁸⁹]+(?:\.[⁰¹²³⁴⁵⁶⁷⁸⁹]*)?") -def string_preprocessor(input_string): +def string_preprocessor(input_string: str) -> str: input_string = input_string.replace(",", "") input_string = input_string.replace(" per ", "/") @@ -781,7 +785,7 @@ def string_preprocessor(input_string): return input_string -def _is_dim(name): +def _is_dim(name: str) -> bool: return name[0] == "[" and name[-1] == "]" @@ -799,7 +803,10 @@ class SharedRegistryObject: """ - def __new__(cls, *args, **kwargs): + _REGISTRY: ClassVar["BaseRegistry"] + _units: UnitsContainer + + def __new__(cls, *args, **kwargs) -> "SharedRegistryObject": inst = object.__new__(cls) if not hasattr(cls, "_REGISTRY"): # Base class, not subclasses dynamically by @@ -809,7 +816,7 @@ def __new__(cls, *args, **kwargs): inst._REGISTRY = _APP_REGISTRY return inst - def _check(self, other): + def _check(self, other: Any) -> bool: """Check if the other object use a registry and if so that it is the same registry. @@ -840,13 +847,13 @@ def _check(self, other): class PrettyIPython: """Mixin to add pretty-printers for IPython""" - def _repr_html_(self): + def _repr_html_(self) -> str: if "~" in self.default_format: return "{:~H}".format(self) else: return "{:H}".format(self) - def _repr_latex_(self): + def _repr_latex_(self) -> str: if "~" in self.default_format: return "${:~L}$".format(self) else: @@ -859,7 +866,9 @@ def _repr_pretty_(self, p, cycle): p.text("{:P}".format(self)) -def to_units_container(unit_like, registry=None): +def to_units_container( + unit_like, registry: Optional["BaseRegistry"] = None +) -> UnitsContainer: """Convert a unit compatible type to a UnitsContainer. Parameters @@ -871,6 +880,7 @@ def to_units_container(unit_like, registry=None): Returns ------- + units_container """ mro = type(unit_like).mro() @@ -890,7 +900,7 @@ def to_units_container(unit_like, registry=None): return UnitsContainer(unit_like) -def infer_base_unit(q): +def infer_base_unit(q: SharedRegistryObject): """ Parameters @@ -915,7 +925,7 @@ def infer_base_unit(q): return UnitsContainer({k: v for k, v in d.items() if v != 0}) -def getattr_maybe_raise(self, item): +def getattr_maybe_raise(self, item: str) -> Any: """Helper function invoked at start of all overridden ``__getattr__``. Raise AttributeError if the user tries to ask for a _ or __ attribute, @@ -1014,7 +1024,7 @@ def __next__(self): next = __next__ -def iterable(y): +def iterable(y) -> bool: """Check whether or not an object can be iterated over. Vendored from numpy under the terms of the BSD 3-Clause License. (Copyright @@ -1036,7 +1046,7 @@ def iterable(y): return True -def sized(y): +def sized(y) -> bool: """Check whether or not an object has a defined length. Parameters