Skip to content

Commit 6233b53

Browse files
hawkinsptensorflower-gardener
authored andcommitted
Fix crash on reentrant call to LRUCache::Clear().
Fixes jax-ml/jax#30517 Example traceback of crash: ``` * thread #1, queue = 'com.apple.main-thread', stop reason = EXC_BAD_ACCESS (code=1, address=0x10) * frame #0: 0x0000000150679fe4 libjax_common.dylib`xla::LRUCache<jax::CallSignature, std::__1::shared_ptr<jax::(anonymous namespace)::PjitCacheEntry>, absl::lts_20250127::hash_internal::Hash<jax::CallSignature>, std::__1::equal_to<jax::CallSignature>>::Clear() + 164 frame #1: 0x000000015067dc20 libjax_common.dylib`std::__1::__shared_ptr_emplace<xla::LRUCache<jax::CallSignature, std::__1::shared_ptr<jax::(anonymous namespace)::PjitCacheEntry>, absl::lts_20250127::hash_internal::Hash<jax::CallSignature>, std::__1::equal_to<jax::CallSignature>>, std::__1::allocator<xla::LRUCache<jax::CallSignature, std::__1::shared_ptr<jax::(anonymous namespace)::PjitCacheEntry>, absl::lts_20250127::hash_internal::Hash<jax::CallSignature>, std::__1::equal_to<jax::CallSignature>>>>::__on_zero_shared() + 32 frame #2: 0x000000015067a0e4 libjax_common.dylib`std::__1::unique_ptr<jax::(anonymous namespace)::PjitFunctionCache::Value, std::__1::default_delete<jax::(anonymous namespace)::PjitFunctionCache::Value>>::reset[abi:ne180100](jax::(anonymous namespace)::PjitFunctionCache::Value*) + 104 frame #3: 0x000000015067e230 libjax_common.dylib`_object* nanobind::detail::func_create<true, true, jax::(anonymous namespace)::PjitFunctionCache::Lookup(xla::nb_class_ptr<jax::(anonymous namespace)::PjitFunctionCache>, nanobind::handle, nanobind::object)::$_1, void, nanobind::handle, 0ul>(jax::(anonymous namespace)::PjitFunctionCache::Lookup(xla::nb_class_ptr<jax::(anonymous namespace)::PjitFunctionCache>, nanobind::handle, nanobind::object)::$_1&&, void (*)(nanobind::handle), std::__1::integer_sequence<unsigned long, 0ul>)::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::__invoke(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*) + 336 frame #4: 0x000000015394ba58 libjax_common.dylib`nanobind::detail::nb_func_vectorcall_simple_1(_object*, _object* const*, unsigned long, _object*) + 156 frame #5: 0x0000000106356620 libpython3.12.dylib`PyObject_CallOneArg + 116 frame #6: 0x0000000106421144 libpython3.12.dylib`PyObject_ClearWeakRefs + 340 frame #7: 0x0000000106377f78 libpython3.12.dylib`func_dealloc + 352 frame #8: 0x00000001506714c8 libjax_common.dylib`PjitFunction_tp_dealloc + 504 frame #9: 0x0000000106420dd8 libpython3.12.dylib`PyDict_DelItem + 668 frame #10: 0x00000001063cbc38 libpython3.12.dylib`_PyEval_EvalFrameDefault + 26328 frame #11: 0x0000000106356620 libpython3.12.dylib`PyObject_CallOneArg + 116 frame #12: 0x0000000106421144 libpython3.12.dylib`PyObject_ClearWeakRefs + 340 frame #13: 0x0000000106377f78 libpython3.12.dylib`func_dealloc + 352 frame #14: 0x0000000150676a4c libjax_common.dylib`jax::ArgumentSignature::~ArgumentSignature() + 172 frame #15: 0x0000000150679c88 libjax_common.dylib`jax::CallSignature::~CallSignature() + 456 frame #16: 0x0000000150679fb0 libjax_common.dylib`xla::LRUCache<jax::CallSignature, std::__1::shared_ptr<jax::(anonymous namespace)::PjitCacheEntry>, absl::lts_20250127::hash_internal::Hash<jax::CallSignature>, std::__1::equal_to<jax::CallSignature>>::Clear() + 112 frame #17: 0x0000000150672400 libjax_common.dylib`jax::(anonymous namespace)::PjitFunctionCache::Clear() + 44 frame #18: 0x000000015067a340 libjax_common.dylib`_object* nanobind::detail::func_create<false, true, void nanobind::cpp_function_def<jax::(anonymous namespace)::PjitFunctionCache, void, jax::(anonymous namespace)::PjitFunctionCache, nanobind::scope, nanobind::name, nanobind::is_method, nanobind::lock_self>(void (jax::(anonymous namespace)::PjitFunctionCache::*)(), nanobind::scope const&, nanobind::name const&, nanobind::is_method const&, nanobind::lock_self const&)::'lambda'(jax::(anonymous namespace)::PjitFunctionCache*), void, jax::(anonymous namespace)::PjitFunctionCache*, 0ul, nanobind::scope, nanobind::name, nanobind::is_method, nanobind::lock_self>(jax::(anonymous namespace)::PjitFunctionCache&&, void (*)(nanobind::scope, nanobind::name, nanobind::is_method, nanobind::lock_self), std::__1::integer_sequence<unsigned long, 0ul>, nanobind::scope const&, nanobind::name const&, nanobind::is_method const&, nanobind::lock_self const&)::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::__invoke(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*) + 80 frame #19: 0x000000015394ba58 libjax_common.dylib`nanobind::detail::nb_func_vectorcall_simple_1(_object*, _object* const*, unsigned long, _object*) + 156 frame #20: 0x00000001063ed3a8 libpython3.12.dylib`_PyEval_EvalFrameDefault + 163400 frame #21: 0x00000001064e3450 libpython3.12.dylib`atexit_callfuncs.llvm.13196908868581062239 + 96 frame #22: 0x00000001064ece28 libpython3.12.dylib`Py_FinalizeEx + 96 frame #23: 0x000000010650366c libpython3.12.dylib`Py_Exit + 20 frame #24: 0x000000010650364c libpython3.12.dylib`handle_system_exit + 32 frame #25: 0x0000000106503330 libpython3.12.dylib`_PyErr_PrintEx.llvm.12194046240795210664 + 52 frame #26: 0x000000010650de00 libpython3.12.dylib`_PyRun_SimpleFileObject + 464 frame #27: 0x00000001065051e4 libpython3.12.dylib`_PyRun_AnyFileObject + 80 frame #28: 0x00000001065045a0 libpython3.12.dylib`pymain_run_file_obj + 164 frame #29: 0x0000000106503c00 libpython3.12.dylib`pymain_run_file + 72 frame #30: 0x0000000106501e04 libpython3.12.dylib`Py_RunMain + 1120 frame #31: 0x0000000106501808 libpython3.12.dylib`pymain_main + 456 frame #32: 0x0000000106501634 libpython3.12.dylib`Py_BytesMain + 36 frame #33: 0x00000001951fab98 dyld`start + 6076 ``` PiperOrigin-RevId: 788532094
1 parent e337221 commit 6233b53

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

third_party/xla/xla/pjrt/lru_cache.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,10 @@ void LRUCache<Key, Value, Hash, Eq>::Clear() {
133133
l->prev->next = l->next;
134134
--lru_list_->size_;
135135
}
136-
entries_.clear();
136+
// Deleting a cache entry may reentrantly trigger other calls into, say,
137+
// Clear().
138+
std::unordered_map<Key, Entry, Hash, Eq> entries;
139+
std::swap(entries, entries_);
137140
}
138141

139142
template <typename Key, typename Value, typename Hash, typename Eq>

third_party/xla/xla/pjrt/lru_cache_test.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515

1616
#include "xla/pjrt/lru_cache.h"
1717

18+
#include <memory>
1819
#include <random>
1920

2021
#include "xla/hlo/testlib/test.h"
@@ -115,5 +116,22 @@ TEST(LRUCache, RandomInsertions) {
115116
}
116117
}
117118

119+
TEST(LRUCache, ReentrantClear) {
120+
struct Value {
121+
explicit Value(LRUCache<int, std::shared_ptr<Value>>* cache)
122+
: cache(cache) {}
123+
~Value() { cache->Clear(); }
124+
125+
LRUCache<int, std::shared_ptr<Value>>* cache;
126+
};
127+
128+
LRUCache<int, std::shared_ptr<Value>>::LRUList list(3);
129+
LRUCache<int, std::shared_ptr<Value>> cache(&list);
130+
131+
cache.GetOrCreateIfAbsent(
132+
0, [&](int) { return std::make_shared<Value>(&cache); });
133+
cache.Clear();
134+
}
135+
118136
} // namespace
119137
} // namespace xla

0 commit comments

Comments
 (0)