Skip to content

Commit 785e939

Browse files
authored
refactor(store): use dedicated thread for signal handling (#840)
* refactor(store): use dedicated thread for signal handling and atomic cleanup * refactor(store): switch PyClient to shared_ptr and improve resource management * refactor(ResourceTracker): use leaked heap object to avoid static destruction order issues * test: increase unittest verbosity to show running tests * try fix ci * fix setup after close * fix comments
1 parent 51687e9 commit 785e939

File tree

6 files changed

+181
-87
lines changed

6 files changed

+181
-87
lines changed

mooncake-integration/integration_utils.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ namespace py = pybind11;
1111

1212
namespace mooncake {
1313

14-
auto torch = py::module_::import("torch");
14+
// Avoid global py::module_ objects
15+
inline py::module_ torch_module() { return py::module_::import("torch"); }
1516

1617
enum class TensorDtype : int32_t {
1718
FLOAT32 = 0,
@@ -68,6 +69,8 @@ inline TensorDtype get_tensor_dtype(py::object dtype_obj) {
6869
return TensorDtype::UNKNOWN;
6970
}
7071

72+
auto torch = torch_module();
73+
7174
if (dtype_obj.equal(torch.attr("float32"))) return TensorDtype::FLOAT32;
7275
if (dtype_obj.equal(torch.attr("float64"))) return TensorDtype::FLOAT64;
7376
if (dtype_obj.equal(torch.attr("int8"))) return TensorDtype::INT8;

mooncake-integration/store/store_py.cpp

Lines changed: 46 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@ namespace mooncake {
1717
// Python-specific wrapper functions that handle GIL and return pybind11 types
1818
class MooncakeStorePyWrapper {
1919
public:
20-
PyClient store_;
20+
std::shared_ptr<PyClient> store_{nullptr};
21+
22+
MooncakeStorePyWrapper() : store_(PyClient::create()) {}
2123

2224
pybind11::bytes get(const std::string &key) {
23-
if (!store_.client_) {
25+
if (!store_ || !store_->client_) {
2426
LOG(ERROR) << "Client is not initialized";
2527
return pybind11::bytes("\\0", 0);
2628
}
@@ -29,7 +31,7 @@ class MooncakeStorePyWrapper {
2931

3032
{
3133
py::gil_scoped_release release_gil;
32-
auto buffer_handle = store_.get_buffer(key);
34+
auto buffer_handle = store_->get_buffer(key);
3335
if (!buffer_handle) {
3436
py::gil_scoped_acquire acquire_gil;
3537
return kNullString;
@@ -44,15 +46,15 @@ class MooncakeStorePyWrapper {
4446
std::vector<pybind11::bytes> get_batch(
4547
const std::vector<std::string> &keys) {
4648
const auto kNullString = pybind11::bytes("\\0", 0);
47-
if (!store_.client_) {
49+
if (!store_ || !store_->client_) {
4850
LOG(ERROR) << "Client is not initialized";
4951
py::gil_scoped_acquire acquire_gil;
5052
return {kNullString};
5153
}
5254

5355
{
5456
py::gil_scoped_release release_gil;
55-
auto batch_data = store_.batch_get_buffer(keys);
57+
auto batch_data = store_->batch_get_buffer(keys);
5658
if (batch_data.empty()) {
5759
py::gil_scoped_acquire acquire_gil;
5860
return {kNullString};
@@ -73,15 +75,15 @@ class MooncakeStorePyWrapper {
7375
}
7476

7577
pybind11::object get_tensor(const std::string &key) {
76-
if (!store_.client_) {
78+
if (!store_ || !store_->client_) {
7779
LOG(ERROR) << "Client is not initialized";
7880
return pybind11::none();
7981
}
8082

8183
try {
8284
// Section with GIL released
8385
py::gil_scoped_release release_gil;
84-
auto buffer_handle = store_.get_buffer(key);
86+
auto buffer_handle = store_->get_buffer(key);
8587
if (!buffer_handle) {
8688
py::gil_scoped_acquire acquire_gil;
8789
return pybind11::none();
@@ -144,7 +146,8 @@ class MooncakeStorePyWrapper {
144146
py::tuple shape_tuple = py::cast(shape_vec);
145147
np_array = np_array.attr("reshape")(shape_tuple);
146148
}
147-
pybind11::object tensor = torch.attr("from_numpy")(np_array);
149+
pybind11::object tensor =
150+
torch_module().attr("from_numpy")(np_array);
148151
return tensor;
149152

150153
} catch (const pybind11::error_already_set &e) {
@@ -154,7 +157,7 @@ class MooncakeStorePyWrapper {
154157
}
155158

156159
int put_tensor(const std::string &key, pybind11::object tensor) {
157-
if (!store_.client_) {
160+
if (!store_ || !store_->client_) {
158161
LOG(ERROR) << "Client is not initialized";
159162
return -static_cast<int>(ErrorCode::INVALID_PARAMS);
160163
}
@@ -211,7 +214,7 @@ class MooncakeStorePyWrapper {
211214
values.emplace_back(std::span<const char>(buffer, tensor_size));
212215

213216
// Use put_parts to put metadata and tensor together
214-
auto put_result = store_.put_parts_internal(key, values);
217+
auto put_result = store_->put_parts_internal(key, values);
215218
if (!put_result) {
216219
return -static_cast<int>(put_result.error());
217220
}
@@ -287,77 +290,83 @@ PYBIND11_MODULE(store, m) {
287290
const std::string &protocol = "tcp",
288291
const std::string &rdma_devices = "",
289292
const std::string &master_server_addr = "127.0.0.1:50051") {
290-
return self.store_.setup(local_hostname, metadata_server,
291-
global_segment_size,
292-
local_buffer_size, protocol,
293-
rdma_devices, master_server_addr);
293+
if (!self.store_) {
294+
self.store_ = PyClient::create();
295+
}
296+
return self.store_->setup(local_hostname, metadata_server,
297+
global_segment_size,
298+
local_buffer_size, protocol,
299+
rdma_devices, master_server_addr);
294300
})
295301
.def("init_all",
296302
[](MooncakeStorePyWrapper &self, const std::string &protocol,
297303
const std::string &device_name,
298304
size_t mount_segment_size = 1024 * 1024 * 16) {
299-
return self.store_.initAll(protocol, device_name,
300-
mount_segment_size);
305+
return self.store_->initAll(protocol, device_name,
306+
mount_segment_size);
301307
})
302308
.def("get", &MooncakeStorePyWrapper::get)
303309
.def("get_batch", &MooncakeStorePyWrapper::get_batch)
304310
.def(
305311
"get_buffer",
306312
[](MooncakeStorePyWrapper &self, const std::string &key) {
307313
py::gil_scoped_release release;
308-
return self.store_.get_buffer(key);
314+
return self.store_->get_buffer(key);
309315
},
310316
py::return_value_policy::take_ownership)
311317
.def(
312318
"batch_get_buffer",
313319
[](MooncakeStorePyWrapper &self,
314320
const std::vector<std::string> &keys) {
315321
py::gil_scoped_release release;
316-
return self.store_.batch_get_buffer(keys);
322+
return self.store_->batch_get_buffer(keys);
317323
},
318324
py::return_value_policy::take_ownership)
319325
.def("remove",
320326
[](MooncakeStorePyWrapper &self, const std::string &key) {
321327
py::gil_scoped_release release;
322-
return self.store_.remove(key);
328+
return self.store_->remove(key);
323329
})
324330
.def(
325331
"remove_by_regex",
326332
[](MooncakeStorePyWrapper &self, const std::string &str) {
327333
py::gil_scoped_release release;
328-
return self.store_.removeByRegex(str);
334+
return self.store_->removeByRegex(str);
329335
},
330336
py::arg("regex_pattern"),
331337
"Removes objects from the store whose keys match the given "
332338
"regular expression.")
333339
.def("remove_all",
334340
[](MooncakeStorePyWrapper &self) {
335341
py::gil_scoped_release release;
336-
return self.store_.removeAll();
342+
return self.store_->removeAll();
337343
})
338344
.def("is_exist",
339345
[](MooncakeStorePyWrapper &self, const std::string &key) {
340346
py::gil_scoped_release release;
341-
return self.store_.isExist(key);
347+
return self.store_->isExist(key);
342348
})
343349
.def(
344350
"batch_is_exist",
345351
[](MooncakeStorePyWrapper &self,
346352
const std::vector<std::string> &keys) {
347353
py::gil_scoped_release release;
348-
return self.store_.batchIsExist(keys);
354+
return self.store_->batchIsExist(keys);
349355
},
350356
py::arg("keys"),
351357
"Check if multiple objects exist. Returns list of results: 1 if "
352358
"exists, 0 if not exists, -1 if error")
353359
.def("close",
354360
[](MooncakeStorePyWrapper &self) {
355-
return self.store_.tearDownAll();
361+
if (!self.store_) return 0;
362+
int rc = self.store_->tearDownAll();
363+
self.store_.reset();
364+
return rc;
356365
})
357366
.def("get_size",
358367
[](MooncakeStorePyWrapper &self, const std::string &key) {
359368
py::gil_scoped_release release;
360-
return self.store_.getSize(key);
369+
return self.store_->getSize(key);
361370
})
362371
.def("get_tensor", &MooncakeStorePyWrapper::get_tensor, py::arg("key"),
363372
"Get a PyTorch tensor from the store")
@@ -370,7 +379,7 @@ PYBIND11_MODULE(store, m) {
370379
// Register memory buffer for RDMA operations
371380
void *buffer = reinterpret_cast<void *>(buffer_ptr);
372381
py::gil_scoped_release release;
373-
return self.store_.register_buffer(buffer, size);
382+
return self.store_->register_buffer(buffer, size);
374383
},
375384
py::arg("buffer_ptr"), py::arg("size"),
376385
"Register a memory buffer for direct access operations")
@@ -380,7 +389,7 @@ PYBIND11_MODULE(store, m) {
380389
// Unregister memory buffer
381390
void *buffer = reinterpret_cast<void *>(buffer_ptr);
382391
py::gil_scoped_release release;
383-
return self.store_.unregister_buffer(buffer);
392+
return self.store_->unregister_buffer(buffer);
384393
},
385394
py::arg("buffer_ptr"),
386395
"Unregister a previously registered memory "
@@ -392,7 +401,7 @@ PYBIND11_MODULE(store, m) {
392401
// Get data directly into user-provided buffer
393402
void *buffer = reinterpret_cast<void *>(buffer_ptr);
394403
py::gil_scoped_release release;
395-
return self.store_.get_into(key, buffer, size);
404+
return self.store_->get_into(key, buffer, size);
396405
},
397406
py::arg("key"), py::arg("buffer_ptr"), py::arg("size"),
398407
"Get object data directly into a pre-allocated buffer")
@@ -408,7 +417,7 @@ PYBIND11_MODULE(store, m) {
408417
buffers.push_back(reinterpret_cast<void *>(ptr));
409418
}
410419
py::gil_scoped_release release;
411-
return self.store_.batch_get_into(keys, buffers, sizes);
420+
return self.store_->batch_get_into(keys, buffers, sizes);
412421
},
413422
py::arg("keys"), py::arg("buffer_ptrs"), py::arg("sizes"),
414423
"Get object data directly into pre-allocated buffers for "
@@ -422,7 +431,7 @@ PYBIND11_MODULE(store, m) {
422431
// Put data directly from user-provided buffer
423432
void *buffer = reinterpret_cast<void *>(buffer_ptr);
424433
py::gil_scoped_release release;
425-
return self.store_.put_from(key, buffer, size, config);
434+
return self.store_->put_from(key, buffer, size, config);
426435
},
427436
py::arg("key"), py::arg("buffer_ptr"), py::arg("size"),
428437
py::arg("config") = ReplicateConfig{},
@@ -439,7 +448,7 @@ PYBIND11_MODULE(store, m) {
439448
void *metadata_buffer =
440449
reinterpret_cast<void *>(metadata_buffer_ptr);
441450
py::gil_scoped_release release;
442-
return self.store_.put_from_with_metadata(
451+
return self.store_->put_from_with_metadata(
443452
key, buffer, metadata_buffer, size, metadata_size, config);
444453
},
445454
py::arg("key"), py::arg("buffer_ptr"),
@@ -460,7 +469,8 @@ PYBIND11_MODULE(store, m) {
460469
buffers.push_back(reinterpret_cast<void *>(ptr));
461470
}
462471
py::gil_scoped_release release;
463-
return self.store_.batch_put_from(keys, buffers, sizes, config);
472+
return self.store_->batch_put_from(keys, buffers, sizes,
473+
config);
464474
},
465475
py::arg("keys"), py::arg("buffer_ptrs"), py::arg("sizes"),
466476
py::arg("config") = ReplicateConfig{},
@@ -474,7 +484,7 @@ PYBIND11_MODULE(store, m) {
474484
const ReplicateConfig &config = ReplicateConfig{}) {
475485
py::buffer_info info = buf.request(/*writable=*/false);
476486
py::gil_scoped_release release;
477-
return self.store_.put(
487+
return self.store_->put(
478488
key,
479489
std::span<const char>(static_cast<char *>(info.ptr),
480490
static_cast<size_t>(info.size)),
@@ -507,7 +517,7 @@ PYBIND11_MODULE(store, m) {
507517

508518
// 2) Call C++ function
509519
py::gil_scoped_release unlock;
510-
return self.store_.put_parts(key, spans, config);
520+
return self.store_->put_parts(key, spans, config);
511521
},
512522
py::arg("key"), py::arg("config") = ReplicateConfig{})
513523
.def(
@@ -530,12 +540,12 @@ PYBIND11_MODULE(store, m) {
530540
}
531541

532542
py::gil_scoped_release release;
533-
return self.store_.put_batch(keys, spans, config);
543+
return self.store_->put_batch(keys, spans, config);
534544
},
535545
py::arg("keys"), py::arg("values"),
536546
py::arg("config") = ReplicateConfig{})
537547
.def("get_hostname", [](MooncakeStorePyWrapper &self) {
538-
return self.store_.get_hostname();
548+
return self.store_->get_hostname();
539549
});
540550

541551
// Expose NUMA binding as a module-level function (no self required)

mooncake-store/include/pybind_client.h

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
#pragma once
22

33
#include <csignal>
4-
#include <mutex>
4+
#include <atomic>
5+
#include <thread>
56
#include <string>
6-
#include <unordered_set>
7+
#include <memory>
8+
#include <vector>
79

810
#include "client.h"
911
#include "client_buffer.hpp"
12+
#include "mutex.h"
1013
#include "utils.h"
1114

1215
namespace mooncake {
@@ -40,10 +43,7 @@ class ResourceTracker {
4043
static ResourceTracker &getInstance();
4144

4245
// Register a DistributedObjectStore instance for cleanup
43-
void registerInstance(PyClient *instance);
44-
45-
// Unregister a DistributedObjectStore instance
46-
void unregisterInstance(PyClient *instance);
46+
void registerInstance(const std::shared_ptr<PyClient> &instance);
4747

4848
private:
4949
ResourceTracker();
@@ -62,15 +62,26 @@ class ResourceTracker {
6262
// Exit handler function
6363
static void exitHandler();
6464

65-
std::mutex mutex_;
66-
std::unordered_set<PyClient *> instances_;
65+
Mutex mutex_;
66+
std::vector<std::weak_ptr<PyClient>> instances_ GUARDED_BY(mutex_);
67+
68+
// Ensure cleanup runs at most once
69+
std::atomic<bool> cleaned_{false};
70+
71+
// Dedicated signal handling thread
72+
void startSignalThread();
73+
std::once_flag signal_once_{};
74+
std::jthread signal_thread_{}; // joins on destruction
6775
};
6876

6977
class PyClient {
7078
public:
7179
PyClient();
7280
~PyClient();
7381

82+
// Factory to create shared instances and auto-register to ResourceTracker
83+
static std::shared_ptr<PyClient> create();
84+
7485
int setup(const std::string &local_hostname,
7586
const std::string &metadata_server,
7687
size_t global_segment_size = 1024 * 1024 * 16,
@@ -308,6 +319,9 @@ class PyClient {
308319
std::string protocol;
309320
std::string device_name;
310321
std::string local_hostname;
322+
323+
// Ensure cleanup executes at most once across multiple entry points
324+
std::atomic<bool> closed_{false};
311325
};
312326

313327
} // namespace mooncake

0 commit comments

Comments
 (0)