Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

avoid using return t.cast which can prevent attribute access during process teardown #913

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 59 additions & 53 deletions traitlets/traitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ class TraitError(Exception):
# -----------------------------------------------------------------------------


def isidentifier(s: t.Any) -> bool:
return t.cast(bool, s.isidentifier())
def isidentifier(s: str) -> bool:
return s.isidentifier()


def _safe_literal_eval(s: str) -> t.Any:
Expand Down Expand Up @@ -293,13 +293,21 @@ class link:

updating = False

def __init__(self, source: t.Any, target: t.Any, transform: t.Any = None) -> None:
def __init__(
self, source: t.Any, target: t.Any, transform: t.Iterable[FuncT] | None = None
) -> None:
_validate_link(source, target)
self.source, self.target = source, target
self._transform, self._transform_inv = transform if transform else (lambda x: x,) * 2

if transform:
self._transform, self._transform_inv = transform # type:ignore[method-assign]
self.link()

def _transform(self, x: T) -> T:
"""default transform: no-op"""
return x

_transform_inv = _transform

def link(self) -> None:
try:
setattr(
Expand Down Expand Up @@ -597,12 +605,12 @@ def default(self, obj: t.Any = None) -> G | None:
in the same way that dynamic defaults defined by ``@default`` are.
"""
if self.default_value is not Undefined:
return t.cast(G, self.default_value)
return self.default_value # type:ignore[no-any-return]
elif hasattr(self, "make_dynamic_default"):
return t.cast(G, self.make_dynamic_default())
return self.make_dynamic_default() # type:ignore[no-any-return]
else:
# Undefined will raise in TraitType.get
return t.cast(G, self.default_value)
return self.default_value # type:ignore[no-any-return]

def get_default_value(self) -> G | None:
"""DEPRECATED: Retrieve the static default value for this trait.
Expand All @@ -613,7 +621,7 @@ def get_default_value(self) -> G | None:
DeprecationWarning,
stacklevel=2,
)
return t.cast(G, self.default_value)
return self.default_value # type:ignore[no-any-return]

def init_default_value(self, obj: t.Any) -> G | None:
"""DEPRECATED: Set the static default value for the trait type."""
Expand Down Expand Up @@ -658,12 +666,12 @@ def get(self, obj: HasTraits, cls: type[t.Any] | None = None) -> G | None:
type="default",
)
)
return t.cast(G, value)
return value # type:ignore[no-any-return]
except Exception as e:
# This should never be reached.
raise TraitError("Unexpected error in TraitType: default value not set properly") from e
else:
return t.cast(G, value)
return value # type:ignore[no-any-return]

@t.overload
def __get__(self, obj: None, cls: type[t.Any]) -> Self:
Expand All @@ -684,7 +692,7 @@ def __get__(self, obj: HasTraits | None, cls: type[t.Any]) -> Self | G:
if obj is None:
return self
else:
return t.cast(G, self.get(obj, cls)) # the G should encode the Optional
return self.get(obj, cls) # type:ignore[return-value]

def set(self, obj: HasTraits, value: S) -> None:
new_value = self._validate(obj, value)
Expand Down Expand Up @@ -722,7 +730,7 @@ def _validate(self, obj: t.Any, value: t.Any) -> G | None:
value = self.validate(obj, value)
if obj._cross_validation_lock is False:
value = self._cross_validate(obj, value)
return t.cast(G, value)
return value # type:ignore[no-any-return]

def _cross_validate(self, obj: t.Any, value: t.Any) -> G | None:
if self.name in obj._trait_validators:
Expand All @@ -738,7 +746,7 @@ def _cross_validate(self, obj: t.Any, value: t.Any) -> G | None:
"use @validate decorator instead.",
)
value = cross_validate(value, self)
return t.cast(G, value)
return value # type:ignore[no-any-return]

def __or__(self, other: TraitType[t.Any, t.Any]) -> Union:
if isinstance(other, Union):
Expand Down Expand Up @@ -1142,7 +1150,7 @@ def compatible_observer(
)
return func(self, change)

return t.cast(FuncT, compatible_observer)
return compatible_observer # type:ignore[return-value]


def validate(*names: Sentinel | str) -> ValidateHandler:
Expand Down Expand Up @@ -1894,7 +1902,7 @@ def trait_defaults(self, *names: str, **metadata: t.Any) -> dict[str, t.Any] | S
raise TraitError(f"'{n}' is not a trait of '{type(self).__name__}' instances")

if len(names) == 1 and len(metadata) == 0:
return t.cast(Sentinel, self._get_trait_default_generator(names[0])(self))
return self._get_trait_default_generator(names[0])(self) # type:ignore[no-any-return]

trait_names = self.trait_names(**metadata)
trait_names.extend(names)
Expand Down Expand Up @@ -2144,7 +2152,7 @@ def validate(self, obj: t.Any, value: t.Any) -> G:
) from e
try:
if issubclass(value, self.klass): # type:ignore[arg-type]
return t.cast(G, value)
return value # type:ignore[no-any-return]
except Exception:
pass

Expand Down Expand Up @@ -2306,7 +2314,7 @@ def validate(self, obj: t.Any, value: t.Any) -> T | None:
if self.allow_none and value is None:
return value
if isinstance(value, self.klass): # type:ignore[arg-type]
return t.cast(T, value)
return value # type:ignore[no-any-return]
else:
self.error(obj, value)

Expand Down Expand Up @@ -2338,7 +2346,7 @@ def default_value_repr(self) -> str:
return repr(self.make_dynamic_default())

def from_string(self, s: str) -> T | None:
return t.cast(T, _safe_literal_eval(s))
return _safe_literal_eval(s) # type:ignore[no-any-return]


class ForwardDeclaredMixin:
Expand Down Expand Up @@ -2635,12 +2643,12 @@ def __init__(
def validate(self, obj: t.Any, value: t.Any) -> G:
if not isinstance(value, int):
self.error(obj, value)
return t.cast(G, _validate_bounds(self, obj, value))
return _validate_bounds(self, obj, value) # type:ignore[no-any-return]

def from_string(self, s: str) -> G:
if self.allow_none and s == "None":
return t.cast(G, None)
return t.cast(G, int(s))
return None # type:ignore[return-value]
return int(s) # type:ignore[return-value]

def subclass_init(self, cls: type[t.Any]) -> None:
pass # fully opt out of instance_init
Expand Down Expand Up @@ -2691,7 +2699,7 @@ def validate(self, obj: t.Any, value: t.Any) -> G:
value = int(value)
except Exception:
self.error(obj, value)
return t.cast(G, _validate_bounds(self, obj, value))
return _validate_bounds(self, obj, value) # type:ignore[no-any-return]


Long, CLong = Int, CInt
Expand Down Expand Up @@ -2753,12 +2761,12 @@ def validate(self, obj: t.Any, value: t.Any) -> G:
value = float(value)
if not isinstance(value, float):
self.error(obj, value)
return t.cast(G, _validate_bounds(self, obj, value))
return _validate_bounds(self, obj, value) # type:ignore[no-any-return]

def from_string(self, s: str) -> G:
if self.allow_none and s == "None":
return t.cast(G, None)
return t.cast(G, float(s))
return None # type:ignore[return-value]
return float(s) # type:ignore[return-value]

def subclass_init(self, cls: type[t.Any]) -> None:
pass # fully opt out of instance_init
Expand Down Expand Up @@ -2809,7 +2817,7 @@ def validate(self, obj: t.Any, value: t.Any) -> G:
value = float(value)
except Exception:
self.error(obj, value)
return t.cast(G, _validate_bounds(self, obj, value))
return _validate_bounds(self, obj, value) # type:ignore[no-any-return]


class Complex(TraitType[complex, t.Union[complex, float, int]]):
Expand Down Expand Up @@ -2935,18 +2943,18 @@ def __init__(

def validate(self, obj: t.Any, value: t.Any) -> G:
if isinstance(value, str):
return t.cast(G, value)
return value # type:ignore[return-value]
if isinstance(value, bytes):
try:
return t.cast(G, value.decode("ascii", "strict"))
return value.decode("ascii", "strict") # type:ignore[return-value]
except UnicodeDecodeError as e:
msg = "Could not decode {!r} for unicode trait '{}' of {} instance."
raise TraitError(msg.format(value, self.name, class_of(obj))) from e
self.error(obj, value)

def from_string(self, s: str) -> G:
if self.allow_none and s == "None":
return t.cast(G, None)
return None # type:ignore[return-value]
s = os.path.expanduser(s)
if len(s) >= 2:
# handle deprecated "1"
Expand All @@ -2960,7 +2968,7 @@ def from_string(self, s: str) -> G:
DeprecationWarning,
stacklevel=2,
)
return t.cast(G, s)
return s # type:ignore[return-value]

def subclass_init(self, cls: type[t.Any]) -> None:
pass # fully opt out of instance_init
Expand Down Expand Up @@ -3008,7 +3016,7 @@ def __init__(

def validate(self, obj: t.Any, value: t.Any) -> G:
try:
return t.cast(G, str(value))
return str(value) # type:ignore[return-value]
except Exception:
self.error(obj, value)

Expand Down Expand Up @@ -3091,22 +3099,22 @@ def __init__(

def validate(self, obj: t.Any, value: t.Any) -> G:
if isinstance(value, bool):
return t.cast(G, value)
return value # type:ignore[return-value]
elif isinstance(value, int):
if value == 1:
return t.cast(G, True)
return True # type:ignore[return-value]
elif value == 0:
return t.cast(G, False)
return False # type:ignore[return-value]
self.error(obj, value)

def from_string(self, s: str) -> G:
if self.allow_none and s == "None":
return t.cast(G, None)
return None # type:ignore[return-value]
s = s.lower()
if s in {"true", "1"}:
return t.cast(G, True)
return True # type:ignore[return-value]
elif s in {"false", "0"}:
return t.cast(G, False)
return False # type:ignore[return-value]
else:
raise ValueError("%r is not 1, 0, true, or false")

Expand Down Expand Up @@ -3163,7 +3171,7 @@ def __init__(

def validate(self, obj: t.Any, value: t.Any) -> G:
try:
return t.cast(G, bool(value))
return bool(value) # type:ignore[return-value]
except Exception:
self.error(obj, value)

Expand Down Expand Up @@ -3220,7 +3228,7 @@ def __init__(

def validate(self, obj: t.Any, value: t.Any) -> G:
if self.values and value in self.values:
return t.cast(G, value)
return value # type:ignore[no-any-return]
self.error(obj, value)

def _choices_str(self, as_rst: bool = False) -> str:
Expand All @@ -3247,7 +3255,7 @@ def from_string(self, s: str) -> G:
try:
return self.validate(None, s)
except TraitError:
return t.cast(G, _safe_literal_eval(s))
return _safe_literal_eval(s) # type:ignore[no-any-return]

def subclass_init(self, cls: type[t.Any]) -> None:
pass # fully opt out of instance_init
Expand Down Expand Up @@ -3275,7 +3283,7 @@ def validate(self, obj: t.Any, value: t.Any) -> G:
for v in self.values or []:
assert isinstance(v, str)
if v.lower() == value.lower():
return t.cast(G, v)
return v # type:ignore[return-value]
self.error(obj, value)

def _info(self, as_rst: bool = False) -> str:
Expand Down Expand Up @@ -3479,14 +3487,12 @@ def validate(self, obj: t.Any, value: t.Any) -> T | None:
if value is None:
return value

value = self.validate_elements(obj, value)

return t.cast(T, value)
return self.validate_elements(obj, value)

def validate_elements(self, obj: t.Any, value: t.Any) -> T | None:
validated = []
if self._trait is None or isinstance(self._trait, Any):
return t.cast(T, value)
return value # type:ignore[no-any-return]
for v in value:
try:
v = self._trait._validate(obj, v)
Expand Down Expand Up @@ -3553,7 +3559,7 @@ def from_string_list(self, s_list: list[str]) -> T | None:
else:
# backward-compat: allow item_from_string to ignore index arg
def item_from_string(s: str, index: int | None = None) -> T | str:
return t.cast(T, self.item_from_string(s))
return self.item_from_string(s)

return self.klass( # type:ignore[call-arg]
[item_from_string(s, index=idx) for idx, s in enumerate(s_list)]
Expand All @@ -3565,7 +3571,7 @@ def item_from_string(self, s: str, index: int | None = None) -> T | str:
Evaluated when parsing CLI configuration from a string
"""
if self._trait:
return t.cast(T, self._trait.from_string(s))
return self._trait.from_string(s) # type:ignore[no-any-return]
else:
return s

Expand Down Expand Up @@ -4051,7 +4057,7 @@ def from_string(self, s: str) -> dict[K, V] | None:
if not isinstance(s, str):
raise TypeError(f"from_string expects a string, got {s!r} of type {type(s)}")
try:
return t.cast("dict[K, V]", self.from_string_list([s]))
return self.from_string_list([s]) # type:ignore[no-any-return]
except Exception:
test = _safe_literal_eval(s)
if isinstance(test, dict):
Expand Down Expand Up @@ -4109,7 +4115,7 @@ def item_from_string(self, s: str) -> dict[K, V]:
value_trait = (self._per_key_traits or {}).get(key, self._value_trait)
if value_trait:
value = value_trait.from_string(value)
return t.cast("dict[K, V]", {key: value})
return {key: value} # type:ignore[dict-item]


class TCPAddress(TraitType[G, S]):
Expand Down Expand Up @@ -4165,17 +4171,17 @@ def validate(self, obj: t.Any, value: t.Any) -> G:
if isinstance(value[0], str) and isinstance(value[1], int):
port = value[1]
if port >= 0 and port <= 65535:
return t.cast(G, value)
return value # type:ignore[return-value]
self.error(obj, value)

def from_string(self, s: str) -> G:
if self.allow_none and s == "None":
return t.cast(G, None)
return None # type:ignore[return-value]
if ":" not in s:
raise ValueError("Require `ip:port`, got %r" % s)
ip, port_str = s.split(":", 1)
port = int(port_str)
return t.cast(G, (ip, port))
return (ip, port) # type:ignore[return-value]


class CRegExp(TraitType["re.Pattern[t.Any]", t.Union["re.Pattern[t.Any]", str]]):
Expand Down
Loading