Skip to content

Commit 05abefe

Browse files
feat: Add BearType enforcement to critical utility functions (#188)
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: William Easton <strawgate@users.noreply.github.com>
1 parent 97e58b7 commit 05abefe

File tree

6 files changed

+24
-5
lines changed

6 files changed

+24
-5
lines changed

key-value/key-value-shared/src/key_value/shared/type_checking/bear_spray.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
enforce_bear_type_conf = BeartypeConf(strategy=BeartypeStrategy.O1, violation_type=TypeError)
1111
enforce_bear_type = beartype(conf=enforce_bear_type_conf)
1212

13-
P = ParamSpec("P")
14-
R = TypeVar("R")
13+
P = ParamSpec(name="P")
14+
R = TypeVar(name="R")
1515

1616

1717
def no_bear_type_check(func: Callable[P, R]) -> Callable[P, R]:

key-value/key-value-shared/src/key_value/shared/utils/compound.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from collections.abc import Sequence
22

3+
from key_value.shared.type_checking.bear_spray import bear_enforce
4+
35
DEFAULT_COMPOUND_SEPARATOR = "::"
46
DEFAULT_PREFIX_SEPARATOR = "__"
57

@@ -29,11 +31,13 @@ def uncompound_strings(strings: Sequence[str], separator: str | None = None) ->
2931
return [uncompound_string(string=string, separator=separator) for string in strings]
3032

3133

34+
@bear_enforce
3235
def compound_key(collection: str, key: str, separator: str | None = None) -> str:
3336
separator = separator or DEFAULT_COMPOUND_SEPARATOR
3437
return compound_string(first=collection, second=key, separator=separator)
3538

3639

40+
@bear_enforce
3741
def uncompound_key(key: str, separator: str | None = None) -> tuple[str, str]:
3842
separator = separator or DEFAULT_COMPOUND_SEPARATOR
3943
return uncompound_string(string=key, separator=separator)

key-value/key-value-shared/src/key_value/shared/utils/sanitize.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import hashlib
22
from enum import Enum
33

4+
from key_value.shared.type_checking.bear_spray import bear_enforce
5+
46
MINIMUM_MAX_LENGTH = 16
57

68
DEFAULT_HASH_FRAGMENT_SIZE = 8
@@ -59,6 +61,7 @@ def sanitize_characters_in_string(value: str, allowed_characters: str, replace_w
5961
return new_value
6062

6163

64+
@bear_enforce
6265
def sanitize_string(
6366
value: str,
6467
max_length: int,
@@ -133,6 +136,7 @@ def sanitize_string(
133136
return sanitized_value
134137

135138

139+
@bear_enforce
136140
def hash_excess_length(value: str, max_length: int) -> str:
137141
"""Hash part of the value if it exceeds the maximum length. This operation
138142
will truncate the value to the maximum length minus 8 characters and will swap

key-value/key-value-shared/src/key_value/shared/utils/serialization.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
from typing import Any, Literal, TypeVar
1111

1212
from key_value.shared.errors import DeserializationError, SerializationError
13+
from key_value.shared.type_checking.bear_spray import bear_enforce
1314
from key_value.shared.utils.managed_entry import ManagedEntry, dump_to_json, load_from_json, verify_dict
1415

1516
T = TypeVar("T")
1617

1718

19+
@bear_enforce
1820
def key_must_be(dictionary: dict[str, Any], /, key: str, expected_type: type[T]) -> T | None:
1921
if key not in dictionary:
2022
return None
@@ -24,6 +26,7 @@ def key_must_be(dictionary: dict[str, Any], /, key: str, expected_type: type[T])
2426
return dictionary[key]
2527

2628

29+
@bear_enforce
2730
def parse_datetime_str(value: str) -> datetime:
2831
try:
2932
return datetime.fromisoformat(value)

key-value/key-value-shared/src/key_value/shared/utils/time_to_live.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import time
22
from datetime import datetime, timedelta, timezone
3-
from numbers import Real
43
from typing import Any, SupportsFloat, overload
54

65
from key_value.shared.errors import InvalidTTLError
6+
from key_value.shared.type_checking.bear_spray import bear_enforce
77

88

99
def epoch_to_datetime(epoch: float) -> datetime:
@@ -59,11 +59,18 @@ def prepare_ttl(t: SupportsFloat | None) -> float | None:
5959
if a bool is provided, an InvalidTTLError will be raised. If the user passes TTL=True, true becomes `1` and the
6060
entry immediately expires which is likely not what the user intended.
6161
"""
62+
try:
63+
return _validate_ttl(t=t)
64+
except TypeError as e:
65+
raise InvalidTTLError(ttl=t, extra_info={"type": type(t).__name__}) from e
66+
67+
68+
@bear_enforce
69+
def _validate_ttl(t: SupportsFloat | None) -> float | None:
6270
if t is None:
6371
return None
6472

65-
# This is not needed by the static type checker but is needed by the runtime type checker
66-
if not isinstance(t, Real | SupportsFloat) or isinstance(t, bool): # pyright: ignore[reportUnnecessaryIsInstance]
73+
if isinstance(t, bool):
6774
raise InvalidTTLError(ttl=t, extra_info={"type": type(t).__name__})
6875

6976
ttl = float(t)

key-value/key-value-shared/tests/utils/test_time_to_live.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def test_prepare_ttl(t: Any, expected: int | float | None):
5454
"bool-false",
5555
],
5656
)
57+
@pytest.mark.filterwarnings("ignore:Function key_value.shared.utils") # Ignore BearType warnings here
5758
def test_prepare_ttl_invalid(t: Any):
5859
with pytest.raises(InvalidTTLError):
5960
prepare_ttl(t)

0 commit comments

Comments
 (0)