Skip to content

Commit

Permalink
Added free-threading support to WeakrefLRUCache
Browse files Browse the repository at this point in the history
+ another multi-threaded test
  • Loading branch information
vfdev-5 committed Dec 19, 2024
1 parent 37e7fb0 commit 5bd17e2
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 4 deletions.
14 changes: 10 additions & 4 deletions xla/python/weakref_lru_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,12 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
if (cache == nullptr) {
return;
}
// Set up PyCriticalSection for cache python associated object;
auto py_cache = nb::find(cache);
// This should never happen as python cache should always be found
CHECK(py_cache.ptr() != NULL);
nb::ft_object_guard lock(py_cache);

// The object the reference referred to is now in the process of being
// destroyed, so we cannot refer to its contents. Python weakref
// objects compare based on identity if the object they refer to is
Expand Down Expand Up @@ -367,10 +373,10 @@ void BuildWeakrefLRUCacheAPI(nb::module_& m) {
nb::class_<WeakrefLRUCache>(m, "WeakrefLRUCache",
nb::is_weak_referenceable(),
nb::type_slots(WeakrefLRUCache::slots_))
.def("__call__", &WeakrefLRUCache::Call)
.def("cache_keys", &WeakrefLRUCache::GetKeys)
.def("cache_info", &WeakrefLRUCache::GetCacheInfo)
.def("cache_clear", &WeakrefLRUCache::Clear);
.def("__call__", &WeakrefLRUCache::Call, nb::lock_self())
.def("cache_keys", &WeakrefLRUCache::GetKeys, nb::lock_self())
.def("cache_info", &WeakrefLRUCache::GetCacheInfo, nb::lock_self())
.def("cache_clear", &WeakrefLRUCache::Clear, nb::lock_self());
nb::class_<WeakrefLRUCache::CacheInfo>(weakref_lru_cache,
"WeakrefLRUCacheInfo")
.def_ro("hits", &WeakrefLRUCache::CacheInfo::hits)
Expand Down
31 changes: 31 additions & 0 deletions xla/python/weakref_lru_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,37 @@ def Body():
cache(wrkey, GilReleasingCacheKey())
t.join()

def testAnotherMultiThreaded(self):
num_workers = 5
barrier = threading.Barrier(num_workers)
cache = xla_client.weakref_lru_cache(lambda: None, lambda x, y: y, 2048)

class WRKey:
pass

def WorkerAddToCache():
barrier.wait()
wrkey = WRKey()
for i in range(10):
cache(wrkey, i)

def WorkerCleanCache():
barrier.wait()
for _ in range(10):
cache.cache_clear()

workers = [
threading.Thread(target=WorkerAddToCache) for _ in range(num_workers - 1)
] + [
threading.Thread(target=WorkerCleanCache)
]

for t in workers:
t.start()

for t in workers:
t.join()

def testKwargsDictOrder(self):
miss_id = 0

Expand Down

0 comments on commit 5bd17e2

Please sign in to comment.