diff --git a/traitlets/traitlets.py b/traitlets/traitlets.py index ecd0d7cc..863bb9bd 100644 --- a/traitlets/traitlets.py +++ b/traitlets/traitlets.py @@ -169,8 +169,8 @@ class TraitError(Exception): # ----------------------------------------------------------------------------- -def isidentifier(s: t.Any) -> bool: - return t.cast(bool, s.isidentifier()) +def isidentifier(s: str) -> bool: + return s.isidentifier() def _safe_literal_eval(s: str) -> t.Any: @@ -293,13 +293,21 @@ class link: updating = False - def __init__(self, source: t.Any, target: t.Any, transform: t.Any = None) -> None: + def __init__( + self, source: t.Any, target: t.Any, transform: t.Iterable[FuncT] | None = None + ) -> None: _validate_link(source, target) self.source, self.target = source, target - self._transform, self._transform_inv = transform if transform else (lambda x: x,) * 2 - + if transform: + self._transform, self._transform_inv = transform # type:ignore[method-assign] self.link() + def _transform(self, x: T) -> T: + """default transform: no-op""" + return x + + _transform_inv = _transform + def link(self) -> None: try: setattr( @@ -597,12 +605,12 @@ def default(self, obj: t.Any = None) -> G | None: in the same way that dynamic defaults defined by ``@default`` are. """ if self.default_value is not Undefined: - return t.cast(G, self.default_value) + return self.default_value # type:ignore[no-any-return] elif hasattr(self, "make_dynamic_default"): - return t.cast(G, self.make_dynamic_default()) + return self.make_dynamic_default() # type:ignore[no-any-return] else: # Undefined will raise in TraitType.get - return t.cast(G, self.default_value) + return self.default_value # type:ignore[no-any-return] def get_default_value(self) -> G | None: """DEPRECATED: Retrieve the static default value for this trait. @@ -613,7 +621,7 @@ def get_default_value(self) -> G | None: DeprecationWarning, stacklevel=2, ) - return t.cast(G, self.default_value) + return self.default_value # type:ignore[no-any-return] def init_default_value(self, obj: t.Any) -> G | None: """DEPRECATED: Set the static default value for the trait type.""" @@ -658,12 +666,12 @@ def get(self, obj: HasTraits, cls: type[t.Any] | None = None) -> G | None: type="default", ) ) - return t.cast(G, value) + return value # type:ignore[no-any-return] except Exception as e: # This should never be reached. raise TraitError("Unexpected error in TraitType: default value not set properly") from e else: - return t.cast(G, value) + return value # type:ignore[no-any-return] @t.overload def __get__(self, obj: None, cls: type[t.Any]) -> Self: @@ -684,7 +692,7 @@ def __get__(self, obj: HasTraits | None, cls: type[t.Any]) -> Self | G: if obj is None: return self else: - return t.cast(G, self.get(obj, cls)) # the G should encode the Optional + return self.get(obj, cls) # type:ignore[return-value] def set(self, obj: HasTraits, value: S) -> None: new_value = self._validate(obj, value) @@ -722,7 +730,7 @@ def _validate(self, obj: t.Any, value: t.Any) -> G | None: value = self.validate(obj, value) if obj._cross_validation_lock is False: value = self._cross_validate(obj, value) - return t.cast(G, value) + return value # type:ignore[no-any-return] def _cross_validate(self, obj: t.Any, value: t.Any) -> G | None: if self.name in obj._trait_validators: @@ -738,7 +746,7 @@ def _cross_validate(self, obj: t.Any, value: t.Any) -> G | None: "use @validate decorator instead.", ) value = cross_validate(value, self) - return t.cast(G, value) + return value # type:ignore[no-any-return] def __or__(self, other: TraitType[t.Any, t.Any]) -> Union: if isinstance(other, Union): @@ -1142,7 +1150,7 @@ def compatible_observer( ) return func(self, change) - return t.cast(FuncT, compatible_observer) + return compatible_observer # type:ignore[return-value] def validate(*names: Sentinel | str) -> ValidateHandler: @@ -1894,7 +1902,7 @@ def trait_defaults(self, *names: str, **metadata: t.Any) -> dict[str, t.Any] | S raise TraitError(f"'{n}' is not a trait of '{type(self).__name__}' instances") if len(names) == 1 and len(metadata) == 0: - return t.cast(Sentinel, self._get_trait_default_generator(names[0])(self)) + return self._get_trait_default_generator(names[0])(self) # type:ignore[no-any-return] trait_names = self.trait_names(**metadata) trait_names.extend(names) @@ -2144,7 +2152,7 @@ def validate(self, obj: t.Any, value: t.Any) -> G: ) from e try: if issubclass(value, self.klass): # type:ignore[arg-type] - return t.cast(G, value) + return value # type:ignore[no-any-return] except Exception: pass @@ -2306,7 +2314,7 @@ def validate(self, obj: t.Any, value: t.Any) -> T | None: if self.allow_none and value is None: return value if isinstance(value, self.klass): # type:ignore[arg-type] - return t.cast(T, value) + return value # type:ignore[no-any-return] else: self.error(obj, value) @@ -2338,7 +2346,7 @@ def default_value_repr(self) -> str: return repr(self.make_dynamic_default()) def from_string(self, s: str) -> T | None: - return t.cast(T, _safe_literal_eval(s)) + return _safe_literal_eval(s) # type:ignore[no-any-return] class ForwardDeclaredMixin: @@ -2635,12 +2643,12 @@ def __init__( def validate(self, obj: t.Any, value: t.Any) -> G: if not isinstance(value, int): self.error(obj, value) - return t.cast(G, _validate_bounds(self, obj, value)) + return _validate_bounds(self, obj, value) # type:ignore[no-any-return] def from_string(self, s: str) -> G: if self.allow_none and s == "None": - return t.cast(G, None) - return t.cast(G, int(s)) + return None # type:ignore[return-value] + return int(s) # type:ignore[return-value] def subclass_init(self, cls: type[t.Any]) -> None: pass # fully opt out of instance_init @@ -2691,7 +2699,7 @@ def validate(self, obj: t.Any, value: t.Any) -> G: value = int(value) except Exception: self.error(obj, value) - return t.cast(G, _validate_bounds(self, obj, value)) + return _validate_bounds(self, obj, value) # type:ignore[no-any-return] Long, CLong = Int, CInt @@ -2753,12 +2761,12 @@ def validate(self, obj: t.Any, value: t.Any) -> G: value = float(value) if not isinstance(value, float): self.error(obj, value) - return t.cast(G, _validate_bounds(self, obj, value)) + return _validate_bounds(self, obj, value) # type:ignore[no-any-return] def from_string(self, s: str) -> G: if self.allow_none and s == "None": - return t.cast(G, None) - return t.cast(G, float(s)) + return None # type:ignore[return-value] + return float(s) # type:ignore[return-value] def subclass_init(self, cls: type[t.Any]) -> None: pass # fully opt out of instance_init @@ -2809,7 +2817,7 @@ def validate(self, obj: t.Any, value: t.Any) -> G: value = float(value) except Exception: self.error(obj, value) - return t.cast(G, _validate_bounds(self, obj, value)) + return _validate_bounds(self, obj, value) # type:ignore[no-any-return] class Complex(TraitType[complex, t.Union[complex, float, int]]): @@ -2935,10 +2943,10 @@ def __init__( def validate(self, obj: t.Any, value: t.Any) -> G: if isinstance(value, str): - return t.cast(G, value) + return value # type:ignore[return-value] if isinstance(value, bytes): try: - return t.cast(G, value.decode("ascii", "strict")) + return value.decode("ascii", "strict") # type:ignore[return-value] except UnicodeDecodeError as e: msg = "Could not decode {!r} for unicode trait '{}' of {} instance." raise TraitError(msg.format(value, self.name, class_of(obj))) from e @@ -2946,7 +2954,7 @@ def validate(self, obj: t.Any, value: t.Any) -> G: def from_string(self, s: str) -> G: if self.allow_none and s == "None": - return t.cast(G, None) + return None # type:ignore[return-value] s = os.path.expanduser(s) if len(s) >= 2: # handle deprecated "1" @@ -2960,7 +2968,7 @@ def from_string(self, s: str) -> G: DeprecationWarning, stacklevel=2, ) - return t.cast(G, s) + return s # type:ignore[return-value] def subclass_init(self, cls: type[t.Any]) -> None: pass # fully opt out of instance_init @@ -3008,7 +3016,7 @@ def __init__( def validate(self, obj: t.Any, value: t.Any) -> G: try: - return t.cast(G, str(value)) + return str(value) # type:ignore[return-value] except Exception: self.error(obj, value) @@ -3091,22 +3099,22 @@ def __init__( def validate(self, obj: t.Any, value: t.Any) -> G: if isinstance(value, bool): - return t.cast(G, value) + return value # type:ignore[return-value] elif isinstance(value, int): if value == 1: - return t.cast(G, True) + return True # type:ignore[return-value] elif value == 0: - return t.cast(G, False) + return False # type:ignore[return-value] self.error(obj, value) def from_string(self, s: str) -> G: if self.allow_none and s == "None": - return t.cast(G, None) + return None # type:ignore[return-value] s = s.lower() if s in {"true", "1"}: - return t.cast(G, True) + return True # type:ignore[return-value] elif s in {"false", "0"}: - return t.cast(G, False) + return False # type:ignore[return-value] else: raise ValueError("%r is not 1, 0, true, or false") @@ -3163,7 +3171,7 @@ def __init__( def validate(self, obj: t.Any, value: t.Any) -> G: try: - return t.cast(G, bool(value)) + return bool(value) # type:ignore[return-value] except Exception: self.error(obj, value) @@ -3220,7 +3228,7 @@ def __init__( def validate(self, obj: t.Any, value: t.Any) -> G: if self.values and value in self.values: - return t.cast(G, value) + return value # type:ignore[no-any-return] self.error(obj, value) def _choices_str(self, as_rst: bool = False) -> str: @@ -3247,7 +3255,7 @@ def from_string(self, s: str) -> G: try: return self.validate(None, s) except TraitError: - return t.cast(G, _safe_literal_eval(s)) + return _safe_literal_eval(s) # type:ignore[no-any-return] def subclass_init(self, cls: type[t.Any]) -> None: pass # fully opt out of instance_init @@ -3275,7 +3283,7 @@ def validate(self, obj: t.Any, value: t.Any) -> G: for v in self.values or []: assert isinstance(v, str) if v.lower() == value.lower(): - return t.cast(G, v) + return v # type:ignore[return-value] self.error(obj, value) def _info(self, as_rst: bool = False) -> str: @@ -3479,14 +3487,12 @@ def validate(self, obj: t.Any, value: t.Any) -> T | None: if value is None: return value - value = self.validate_elements(obj, value) - - return t.cast(T, value) + return self.validate_elements(obj, value) def validate_elements(self, obj: t.Any, value: t.Any) -> T | None: validated = [] if self._trait is None or isinstance(self._trait, Any): - return t.cast(T, value) + return value # type:ignore[no-any-return] for v in value: try: v = self._trait._validate(obj, v) @@ -3553,7 +3559,7 @@ def from_string_list(self, s_list: list[str]) -> T | None: else: # backward-compat: allow item_from_string to ignore index arg def item_from_string(s: str, index: int | None = None) -> T | str: - return t.cast(T, self.item_from_string(s)) + return self.item_from_string(s) return self.klass( # type:ignore[call-arg] [item_from_string(s, index=idx) for idx, s in enumerate(s_list)] @@ -3565,7 +3571,7 @@ def item_from_string(self, s: str, index: int | None = None) -> T | str: Evaluated when parsing CLI configuration from a string """ if self._trait: - return t.cast(T, self._trait.from_string(s)) + return self._trait.from_string(s) # type:ignore[no-any-return] else: return s @@ -4051,7 +4057,7 @@ def from_string(self, s: str) -> dict[K, V] | None: if not isinstance(s, str): raise TypeError(f"from_string expects a string, got {s!r} of type {type(s)}") try: - return t.cast("dict[K, V]", self.from_string_list([s])) + return self.from_string_list([s]) # type:ignore[no-any-return] except Exception: test = _safe_literal_eval(s) if isinstance(test, dict): @@ -4109,7 +4115,7 @@ def item_from_string(self, s: str) -> dict[K, V]: value_trait = (self._per_key_traits or {}).get(key, self._value_trait) if value_trait: value = value_trait.from_string(value) - return t.cast("dict[K, V]", {key: value}) + return {key: value} # type:ignore[dict-item] class TCPAddress(TraitType[G, S]): @@ -4165,17 +4171,17 @@ def validate(self, obj: t.Any, value: t.Any) -> G: if isinstance(value[0], str) and isinstance(value[1], int): port = value[1] if port >= 0 and port <= 65535: - return t.cast(G, value) + return value # type:ignore[return-value] self.error(obj, value) def from_string(self, s: str) -> G: if self.allow_none and s == "None": - return t.cast(G, None) + return None # type:ignore[return-value] if ":" not in s: raise ValueError("Require `ip:port`, got %r" % s) ip, port_str = s.split(":", 1) port = int(port_str) - return t.cast(G, (ip, port)) + return (ip, port) # type:ignore[return-value] class CRegExp(TraitType["re.Pattern[t.Any]", t.Union["re.Pattern[t.Any]", str]]):