diff --git a/static_tests/test_not_none.py b/static_tests/test_not_none.py new file mode 100644 index 0000000..6d26854 --- /dev/null +++ b/static_tests/test_not_none.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +from typing import Any +from typing_extensions import assert_type + +from useful_types import not_none + + +def test(x: Any, y: str | None, z: str) -> None: + assert_type(not_none(x), Any) + assert_type(not_none(y), str) + assert_type(not_none(z), str) diff --git a/useful_types/__init__.py b/useful_types/__init__.py index 064c4b9..6791c05 100644 --- a/useful_types/__init__.py +++ b/useful_types/__init__.py @@ -3,7 +3,7 @@ from collections.abc import Awaitable, Iterable, Sequence, Set as AbstractSet, Sized from os import PathLike from typing import Any, TypeVar, Union, overload -from typing_extensions import Buffer, Literal, Protocol, TypeAlias +from typing_extensions import Buffer, Literal, Never, Protocol, TypeAlias _KT = TypeVar("_KT") _KT_co = TypeVar("_KT_co", covariant=True) @@ -321,3 +321,27 @@ def __getitem__(self, __i: int) -> int: class SizedBuffer(Sized, Buffer, Protocol): ... + + +@overload +def not_none(obj: _T, /, message: str | None = ...) -> _T: + ... + + +@overload +def not_none(obj: None, /, message: str | None = ...) -> Never: + ... + + +def not_none(obj: _T | None, /, message: str | None = None) -> _T: + """Raise TypeError if obj is None, otherwise return obj. + + Useful for safely casting away optional types. + + """ + if obj is None: + if message is not None: + raise TypeError(message) + else: + raise TypeError("Object is unexpectedly None") + return obj