Skip to content

Commit eb6f9da

Browse files
strawgategithub-actions[bot]claude
authored
fix: resolve critical bugs in Memory Store TTL and Windows Registry operations (#163)
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: William Easton <strawgate@users.noreply.github.com>
1 parent 681aa5e commit eb6f9da

File tree

4 files changed

+87
-90
lines changed
  • key-value
    • key-value-aio/src/key_value/aio/stores
    • key-value-sync/src/key_value/sync/code_gen/stores

4 files changed

+87
-90
lines changed

key-value/key-value-aio/src/key_value/aio/stores/memory/store.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import sys
2-
from dataclasses import dataclass, field
2+
from dataclasses import dataclass
33
from datetime import datetime
4-
from typing import Any, SupportsFloat
4+
from typing import Any
55

66
from key_value.shared.utils.managed_entry import ManagedEntry
7-
from key_value.shared.utils.time_to_live import epoch_to_datetime
87
from typing_extensions import Self, override
98

109
from key_value.aio.stores.base import (
@@ -30,30 +29,23 @@ class MemoryCacheEntry:
3029

3130
expires_at: datetime | None
3231

33-
ttl_at_insert: SupportsFloat | None = field(default=None)
34-
3532
@classmethod
36-
def from_managed_entry(cls, managed_entry: ManagedEntry, ttl: SupportsFloat | None = None) -> Self:
33+
def from_managed_entry(cls, managed_entry: ManagedEntry) -> Self:
3734
return cls(
3835
json_str=managed_entry.to_json(),
3936
expires_at=managed_entry.expires_at,
40-
ttl_at_insert=ttl,
4137
)
4238

4339
def to_managed_entry(self) -> ManagedEntry:
4440
return ManagedEntry.from_json(json_str=self.json_str)
4541

4642

47-
def _memory_cache_ttu(_key: Any, value: MemoryCacheEntry, now: float) -> float:
48-
"""Calculate time-to-use for cache entries based on their TTL."""
49-
if value.ttl_at_insert is None:
43+
def _memory_cache_ttu(_key: Any, value: MemoryCacheEntry, _now: float) -> float:
44+
"""Calculate time-to-use for cache entries based on their expiration time."""
45+
if value.expires_at is None:
5046
return float(sys.maxsize)
5147

52-
expiration_epoch: float = now + float(value.ttl_at_insert)
53-
54-
value.expires_at = epoch_to_datetime(epoch=expiration_epoch)
55-
56-
return float(expiration_epoch)
48+
return value.expires_at.timestamp()
5749

5850

5951
def _memory_cache_getsizeof(value: MemoryCacheEntry) -> int: # noqa: ARG001
@@ -93,7 +85,7 @@ def get(self, key: str) -> ManagedEntry | None:
9385
return managed_entry
9486

9587
def put(self, key: str, value: ManagedEntry) -> None:
96-
self._cache[key] = MemoryCacheEntry.from_managed_entry(managed_entry=value, ttl=value.ttl)
88+
self._cache[key] = MemoryCacheEntry.from_managed_entry(managed_entry=value)
9789

9890
def delete(self, key: str) -> bool:
9991
return self._cache.pop(key, None) is not None
Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,9 @@
11
import contextlib
22
import winreg
3-
from collections.abc import Generator
4-
from contextlib import contextmanager
5-
6-
HiveType = int
73

4+
from key_value.shared.errors.store import StoreSetupError
85

9-
@contextmanager
10-
def handle_winreg_error() -> Generator[None, None, None]:
11-
try:
12-
yield
13-
except (FileNotFoundError, OSError):
14-
return None
6+
HiveType = int
157

168

179
def get_reg_sz_value(hive: HiveType, sub_key: str, value_name: str) -> str | None:
@@ -24,8 +16,15 @@ def get_reg_sz_value(hive: HiveType, sub_key: str, value_name: str) -> str | Non
2416

2517

2618
def set_reg_sz_value(hive: HiveType, sub_key: str, value_name: str, value: str) -> None:
27-
with winreg.OpenKey(key=hive, sub_key=sub_key, access=winreg.KEY_WRITE) as reg_key:
28-
winreg.SetValueEx(reg_key, value_name, 0, winreg.REG_SZ, value)
19+
try:
20+
with winreg.OpenKey(key=hive, sub_key=sub_key, access=winreg.KEY_WRITE) as reg_key:
21+
winreg.SetValueEx(reg_key, value_name, 0, winreg.REG_SZ, value)
22+
except FileNotFoundError as e:
23+
msg = f"Registry key '{sub_key}' does not exist"
24+
raise StoreSetupError(msg) from e
25+
except OSError as e:
26+
msg = f"Failed to set registry value '{value_name}' at '{sub_key}'"
27+
raise StoreSetupError(msg) from e
2928

3029

3130
def delete_reg_sz_value(hive: HiveType, sub_key: str, value_name: str) -> bool:
@@ -46,29 +45,36 @@ def has_key(hive: HiveType, sub_key: str) -> bool:
4645

4746

4847
def create_key(hive: HiveType, sub_key: str) -> None:
49-
winreg.CreateKey(hive, sub_key)
48+
try:
49+
key = winreg.CreateKey(hive, sub_key)
50+
key.Close()
51+
except OSError as e:
52+
msg = f"Failed to create registry key '{sub_key}'"
53+
raise StoreSetupError(msg) from e
5054

5155

5256
def delete_key(hive: HiveType, sub_key: str) -> bool:
5357
try:
54-
with winreg.OpenKey(key=hive, sub_key=sub_key, access=winreg.KEY_WRITE) as reg_key:
55-
winreg.DeleteKey(reg_key, sub_key)
56-
return True
58+
winreg.DeleteKey(hive, sub_key)
5759
except (FileNotFoundError, OSError):
5860
return False
61+
else:
62+
return True
5963

6064

6165
def delete_sub_keys(hive: HiveType, sub_key: str) -> None:
62-
with (
63-
handle_winreg_error(),
64-
winreg.OpenKey(key=hive, sub_key=sub_key, access=winreg.KEY_WRITE | winreg.KEY_ENUMERATE_SUB_KEYS) as reg_key,
65-
):
66-
index = 0
67-
while True:
68-
if not (next_child_key := winreg.EnumKey(reg_key, index)):
69-
break
70-
71-
with contextlib.suppress(Exception):
72-
winreg.DeleteKey(reg_key, next_child_key)
73-
74-
index += 1
66+
try:
67+
with winreg.OpenKey(key=hive, sub_key=sub_key, access=winreg.KEY_WRITE | winreg.KEY_ENUMERATE_SUB_KEYS) as reg_key:
68+
while True:
69+
try:
70+
# Always enumerate at index 0 since keys shift after deletion
71+
next_child_key = winreg.EnumKey(reg_key, 0)
72+
except OSError:
73+
# No more subkeys
74+
break
75+
76+
# Key already deleted or can't be deleted, skip it
77+
with contextlib.suppress(FileNotFoundError, OSError):
78+
winreg.DeleteKey(reg_key, next_child_key)
79+
except (FileNotFoundError, OSError):
80+
return

key-value/key-value-sync/src/key_value/sync/code_gen/stores/memory/store.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22
# from the original file 'store.py'
33
# DO NOT CHANGE! Change the original file instead.
44
import sys
5-
from dataclasses import dataclass, field
5+
from dataclasses import dataclass
66
from datetime import datetime
7-
from typing import Any, SupportsFloat
7+
from typing import Any
88

99
from key_value.shared.utils.managed_entry import ManagedEntry
10-
from key_value.shared.utils.time_to_live import epoch_to_datetime
1110
from typing_extensions import Self, override
1211

1312
from key_value.sync.code_gen.stores.base import (
@@ -33,26 +32,20 @@ class MemoryCacheEntry:
3332

3433
expires_at: datetime | None
3534

36-
ttl_at_insert: SupportsFloat | None = field(default=None)
37-
3835
@classmethod
39-
def from_managed_entry(cls, managed_entry: ManagedEntry, ttl: SupportsFloat | None = None) -> Self:
40-
return cls(json_str=managed_entry.to_json(), expires_at=managed_entry.expires_at, ttl_at_insert=ttl)
36+
def from_managed_entry(cls, managed_entry: ManagedEntry) -> Self:
37+
return cls(json_str=managed_entry.to_json(), expires_at=managed_entry.expires_at)
4138

4239
def to_managed_entry(self) -> ManagedEntry:
4340
return ManagedEntry.from_json(json_str=self.json_str)
4441

4542

46-
def _memory_cache_ttu(_key: Any, value: MemoryCacheEntry, now: float) -> float:
47-
"""Calculate time-to-use for cache entries based on their TTL."""
48-
if value.ttl_at_insert is None:
43+
def _memory_cache_ttu(_key: Any, value: MemoryCacheEntry, _now: float) -> float:
44+
"""Calculate time-to-use for cache entries based on their expiration time."""
45+
if value.expires_at is None:
4946
return float(sys.maxsize)
5047

51-
expiration_epoch: float = now + float(value.ttl_at_insert)
52-
53-
value.expires_at = epoch_to_datetime(epoch=expiration_epoch)
54-
55-
return float(expiration_epoch)
48+
return value.expires_at.timestamp()
5649

5750

5851
def _memory_cache_getsizeof(value: MemoryCacheEntry) -> int:
@@ -88,7 +81,7 @@ def get(self, key: str) -> ManagedEntry | None:
8881
return managed_entry
8982

9083
def put(self, key: str, value: ManagedEntry) -> None:
91-
self._cache[key] = MemoryCacheEntry.from_managed_entry(managed_entry=value, ttl=value.ttl)
84+
self._cache[key] = MemoryCacheEntry.from_managed_entry(managed_entry=value)
9285

9386
def delete(self, key: str) -> bool:
9487
return self._cache.pop(key, None) is not None

key-value/key-value-sync/src/key_value/sync/code_gen/stores/windows_registry/utils.py

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,10 @@
33
# DO NOT CHANGE! Change the original file instead.
44
import contextlib
55
import winreg
6-
from collections.abc import Generator
7-
from contextlib import contextmanager
8-
9-
HiveType = int
106

7+
from key_value.shared.errors.store import StoreSetupError
118

12-
@contextmanager
13-
def handle_winreg_error() -> Generator[None, None, None]:
14-
try:
15-
yield
16-
except (FileNotFoundError, OSError):
17-
return None
9+
HiveType = int
1810

1911

2012
def get_reg_sz_value(hive: HiveType, sub_key: str, value_name: str) -> str | None:
@@ -27,8 +19,15 @@ def get_reg_sz_value(hive: HiveType, sub_key: str, value_name: str) -> str | Non
2719

2820

2921
def set_reg_sz_value(hive: HiveType, sub_key: str, value_name: str, value: str) -> None:
30-
with winreg.OpenKey(key=hive, sub_key=sub_key, access=winreg.KEY_WRITE) as reg_key:
31-
winreg.SetValueEx(reg_key, value_name, 0, winreg.REG_SZ, value)
22+
try:
23+
with winreg.OpenKey(key=hive, sub_key=sub_key, access=winreg.KEY_WRITE) as reg_key:
24+
winreg.SetValueEx(reg_key, value_name, 0, winreg.REG_SZ, value)
25+
except FileNotFoundError as e:
26+
msg = f"Registry key '{sub_key}' does not exist"
27+
raise StoreSetupError(msg) from e
28+
except OSError as e:
29+
msg = f"Failed to set registry value '{value_name}' at '{sub_key}'"
30+
raise StoreSetupError(msg) from e
3231

3332

3433
def delete_reg_sz_value(hive: HiveType, sub_key: str, value_name: str) -> bool:
@@ -49,29 +48,36 @@ def has_key(hive: HiveType, sub_key: str) -> bool:
4948

5049

5150
def create_key(hive: HiveType, sub_key: str) -> None:
52-
winreg.CreateKey(hive, sub_key)
51+
try:
52+
key = winreg.CreateKey(hive, sub_key)
53+
key.Close()
54+
except OSError as e:
55+
msg = f"Failed to create registry key '{sub_key}'"
56+
raise StoreSetupError(msg) from e
5357

5458

5559
def delete_key(hive: HiveType, sub_key: str) -> bool:
5660
try:
57-
with winreg.OpenKey(key=hive, sub_key=sub_key, access=winreg.KEY_WRITE) as reg_key:
58-
winreg.DeleteKey(reg_key, sub_key)
59-
return True
61+
winreg.DeleteKey(hive, sub_key)
6062
except (FileNotFoundError, OSError):
6163
return False
64+
else:
65+
return True
6266

6367

6468
def delete_sub_keys(hive: HiveType, sub_key: str) -> None:
65-
with (
66-
handle_winreg_error(),
67-
winreg.OpenKey(key=hive, sub_key=sub_key, access=winreg.KEY_WRITE | winreg.KEY_ENUMERATE_SUB_KEYS) as reg_key,
68-
):
69-
index = 0
70-
while True:
71-
if not (next_child_key := winreg.EnumKey(reg_key, index)):
72-
break
73-
74-
with contextlib.suppress(Exception):
75-
winreg.DeleteKey(reg_key, next_child_key)
76-
77-
index += 1
69+
try:
70+
with winreg.OpenKey(key=hive, sub_key=sub_key, access=winreg.KEY_WRITE | winreg.KEY_ENUMERATE_SUB_KEYS) as reg_key:
71+
while True:
72+
try:
73+
# Always enumerate at index 0 since keys shift after deletion
74+
next_child_key = winreg.EnumKey(reg_key, 0)
75+
except OSError:
76+
# No more subkeys
77+
break
78+
79+
# Key already deleted or can't be deleted, skip it
80+
with contextlib.suppress(FileNotFoundError, OSError):
81+
winreg.DeleteKey(reg_key, next_child_key)
82+
except (FileNotFoundError, OSError):
83+
return

0 commit comments

Comments
 (0)