Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 97 additions & 7 deletions tests/distributed/local_cpu_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,10 @@ def test_pin_on_hit(self, clean_backend_instance):
key = CacheKey(model_name="test_model", chunk_hash="A")
backend.add(key, create_mock_value(10))

assert key not in backend.pinned_keys
assert key not in backend.pin_counts
backend.contains(key, pin_on_hit=True)
assert key in backend.pinned_keys
assert key in backend.pin_counts
assert backend.pin_counts[key] == 1

def test_pinned_item_is_not_evicted(self, clean_backend_instance):
"""Tests that a pinned item is protected from eviction."""
Expand All @@ -147,7 +148,7 @@ def test_pinned_item_is_not_evicted(self, clean_backend_instance):
backend.add(key_a, value)
backend.add(key_b, value)
backend.contains(key_a, pin_on_hit=True)
assert key_a in backend.pinned_keys
assert key_a in backend.pin_counts

# This should evict key_b, because key_a is pinned
backend.add(key_c, value)
Expand All @@ -172,8 +173,8 @@ def test_unpin_makes_item_evictable(self, clean_backend_instance):
assert list(backend.cache.keys()) == [key_b, key_a]

# Unpin A, making it the LRU evictable item
backend.unpin_keys([key_a])
assert key_a not in backend.pinned_keys
backend.maybe_unpin_keys([key_a])
assert key_a not in backend.pin_counts

# This should now evict B
backend.add(key_c, value)
Expand All @@ -182,8 +183,7 @@ def test_unpin_makes_item_evictable(self, clean_backend_instance):
assert key_a in backend.cache
assert key_c in backend.cache

def test_cache_full_of_pinned_items_prevents_add(self,
clean_backend_instance):
def test_cache_full_of_pinned_items_prevents_add(clean_backend_instance):
"""
Tests that no new items can be added if the cache is full of
pinned items.
Expand All @@ -208,3 +208,93 @@ def test_cache_full_of_pinned_items_prevents_add(self,
assert key_a in backend.cache
assert key_b in backend.cache
assert backend.current_size_bytes == 100
assert key_a in backend.pin_counts
assert key_b in backend.pin_counts

def test_pinning_same_key_multiple_times_increments_count(
self, clean_backend_instance):
"""Verifies that pinning an already-pinned key increments its count."""
backend = LocalCPUBackend(max_cpu_cache_size_bytes=100)
key = CacheKey(model_name="test_model", chunk_hash="A")
backend.add(key, create_mock_value(10))

backend.contains(key, pin_on_hit=True)
assert backend.pin_counts[key] == 1

backend.contains(key, pin_on_hit=True)
assert backend.pin_counts[key] == 2

def test_unpin_decrements_count_and_removes_at_zero(
self, clean_backend_instance):
"""Tests the core reference counting logic of the unpin_keys method."""
backend = LocalCPUBackend(max_cpu_cache_size_bytes=100)
key = CacheKey(model_name="test_model", chunk_hash="A")
backend.add(key, create_mock_value(10))

# Pin twice
backend.contains(key, pin_on_hit=True)
backend.contains(key, pin_on_hit=True)
assert backend.pin_counts[key] == 2

# Unpin once
backend.maybe_unpin_keys([key])
assert key in backend.pin_counts
assert backend.pin_counts[key] == 1

# Unpin again
backend.maybe_unpin_keys([key])
assert key not in backend.pin_counts

def test_item_with_positive_pin_count_is_not_evicted(
self, clean_backend_instance):
"""
Tests that an item with a pin count > 0 is not evicted, confirming
the race condition fix.
"""
backend = LocalCPUBackend(max_cpu_cache_size_bytes=100)
key_a = CacheKey(model_name="test_model", chunk_hash="A")
key_b = CacheKey(model_name="test_model", chunk_hash="B")
key_c = CacheKey(model_name="test_model", chunk_hash="C")
value = create_mock_value(50)

backend.add(key_a, value) # Will be LRU
backend.add(key_b, value)

# Pin key_a twice (simulating two requests)
backend.contains(key_a, pin_on_hit=True)
backend.contains(key_a, pin_on_hit=True)

# Unpin key_a once (simulating one request finishing)
backend.maybe_unpin_keys([key_a])
assert backend.pin_counts[key_a] == 1

# This add should trigger eviction of key_b, as key_a is still pinned.
backend.add(key_c, value)

assert key_a in backend.cache
assert key_b not in backend.cache
assert key_c in backend.cache
assert key_a in backend.pin_counts

def test_unpin_keys_returns_correct_counts(self, clean_backend_instance):
"""Validates the meaningful return values of unpin_keys."""
backend = LocalCPUBackend(max_cpu_cache_size_bytes=100)
key_a = CacheKey(model_name="test_model", chunk_hash="A")
key_b = CacheKey(model_name="test_model", chunk_hash="B")
value = create_mock_value(10)

backend.add(key_a, value)
backend.add(key_b, value)

# Pin A twice, B once
backend.contains(key_a, pin_on_hit=True)
backend.contains(key_a, pin_on_hit=True)
backend.contains(key_b, pin_on_hit=True)

# Unpin both. A should be decremented, B should be fully unpinned.
unpinned_count, found_count = backend.maybe_unpin_keys([key_a, key_b])

assert found_count == 2 # Both keys were found in pin_counts
assert unpinned_count == 1 # Only key_b's count went to 0
assert backend.pin_counts[key_a] == 1
assert key_b not in backend.pin_counts
33 changes: 23 additions & 10 deletions tpu_inference/distributed/local_cpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def __init__(self,
# The cache is an OrderedDict for LRU behavior.
self.cache: OrderedDict[CacheKey, Any] = OrderedDict()
self.current_size_bytes = 0
self.pinned_keys: set[CacheKey] = set()
# Use a dictionary for reference counting of pinned keys.
self.pin_counts: dict[CacheKey, int] = {}
self._initialized = True
logger.info("Singleton LocalCPUBackend initialized."
f"CPU cache size: {self.max_cpu_cache_size_bytes} bytes")
Expand Down Expand Up @@ -93,7 +94,7 @@ def add(self, key: CacheKey, value: Any):
evicted_key = None
# Find the first unpinned key from the LRU end of the cache.
for k in self.cache:
if k not in self.pinned_keys:
if k not in self.pin_counts:
evicted_key = k
break

Expand Down Expand Up @@ -139,19 +140,31 @@ def contains(self, key: CacheKey, pin_on_hit: bool = False) -> bool:
# Mark as most recently used, since this is an access.
self.cache.move_to_end(key)
if pin_on_hit:
self.pinned_keys.add(key)
logger.info(f"Pinned key on hit. Hash: {key.chunk_hash}")
self.pin_counts[key] = self.pin_counts.get(key, 0) + 1
logger.info(f"Pinned key on hit. Hash: {key.chunk_hash}, "
f"New count: {self.pin_counts[key]}")
return True
return False

def unpin_keys(self, keys: List[CacheKey]) -> Tuple[int, int]:
"""Unpins a list of keys, making them eligible for eviction again."""
def maybe_unpin_keys(self, keys: List[CacheKey]) -> Tuple[int, int]:
"""
Unpins a list of keys.

Decrements the pin count for each key. If a key's count reaches zero,
it is fully unpinned and becomes eligible for eviction.
"""
unpinned_count = 0
found_count = 0
for key in keys:
if key in self.pinned_keys:
if key in self.pin_counts:
found_count += 1
self.pinned_keys.remove(key)
unpinned_count += 1
logger.info(f"Unpinned key. Hash: {key.chunk_hash}")
self.pin_counts[key] -= 1
logger.info(
f"Decremented pin count for key. Hash: {key.chunk_hash}, "
f"New count: {self.pin_counts[key]}")
if self.pin_counts[key] == 0:
del self.pin_counts[key]
unpinned_count += 1
logger.info(
f"Unpinned key completely. Hash: {key.chunk_hash}")
return unpinned_count, found_count
2 changes: 1 addition & 1 deletion tpu_inference/distributed/tpu_connector_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,7 +1379,7 @@ def get_finished(self) -> tuple[set[str], set[str]]:
)

if keys_to_unpin:
unpinned_count, found_count = self.cpu_backend.unpin_keys(
unpinned_count, found_count = self.cpu_backend.maybe_unpin_keys(
keys_to_unpin)
logger.info(
f"Unpinned {unpinned_count} out of {found_count} existing keys (Request to unpin {len(keys_to_unpin)} keys)."
Expand Down