@@ -17,10 +17,12 @@ namespace mooncake {
1717// Python-specific wrapper functions that handle GIL and return pybind11 types
1818class MooncakeStorePyWrapper {
1919 public:
20- PyClient store_;
20+ std::shared_ptr<PyClient> store_{nullptr };
21+
22+ MooncakeStorePyWrapper () : store_(PyClient::create()) {}
2123
2224 pybind11::bytes get (const std::string &key) {
23- if (!store_. client_ ) {
25+ if (!store_ || !store_-> client_ ) {
2426 LOG (ERROR) << " Client is not initialized" ;
2527 return pybind11::bytes (" \\ 0" , 0 );
2628 }
@@ -29,7 +31,7 @@ class MooncakeStorePyWrapper {
2931
3032 {
3133 py::gil_scoped_release release_gil;
32- auto buffer_handle = store_. get_buffer (key);
34+ auto buffer_handle = store_-> get_buffer (key);
3335 if (!buffer_handle) {
3436 py::gil_scoped_acquire acquire_gil;
3537 return kNullString ;
@@ -44,15 +46,15 @@ class MooncakeStorePyWrapper {
4446 std::vector<pybind11::bytes> get_batch (
4547 const std::vector<std::string> &keys) {
4648 const auto kNullString = pybind11::bytes (" \\ 0" , 0 );
47- if (!store_. client_ ) {
49+ if (!store_ || !store_-> client_ ) {
4850 LOG (ERROR) << " Client is not initialized" ;
4951 py::gil_scoped_acquire acquire_gil;
5052 return {kNullString };
5153 }
5254
5355 {
5456 py::gil_scoped_release release_gil;
55- auto batch_data = store_. batch_get_buffer (keys);
57+ auto batch_data = store_-> batch_get_buffer (keys);
5658 if (batch_data.empty ()) {
5759 py::gil_scoped_acquire acquire_gil;
5860 return {kNullString };
@@ -73,15 +75,15 @@ class MooncakeStorePyWrapper {
7375 }
7476
7577 pybind11::object get_tensor (const std::string &key) {
76- if (!store_. client_ ) {
78+ if (!store_ || !store_-> client_ ) {
7779 LOG (ERROR) << " Client is not initialized" ;
7880 return pybind11::none ();
7981 }
8082
8183 try {
8284 // Section with GIL released
8385 py::gil_scoped_release release_gil;
84- auto buffer_handle = store_. get_buffer (key);
86+ auto buffer_handle = store_-> get_buffer (key);
8587 if (!buffer_handle) {
8688 py::gil_scoped_acquire acquire_gil;
8789 return pybind11::none ();
@@ -144,7 +146,8 @@ class MooncakeStorePyWrapper {
144146 py::tuple shape_tuple = py::cast (shape_vec);
145147 np_array = np_array.attr (" reshape" )(shape_tuple);
146148 }
147- pybind11::object tensor = torch.attr (" from_numpy" )(np_array);
149+ pybind11::object tensor =
150+ torch_module ().attr (" from_numpy" )(np_array);
148151 return tensor;
149152
150153 } catch (const pybind11::error_already_set &e) {
@@ -154,7 +157,7 @@ class MooncakeStorePyWrapper {
154157 }
155158
156159 int put_tensor (const std::string &key, pybind11::object tensor) {
157- if (!store_. client_ ) {
160+ if (!store_ || !store_-> client_ ) {
158161 LOG (ERROR) << " Client is not initialized" ;
159162 return -static_cast <int >(ErrorCode::INVALID_PARAMS);
160163 }
@@ -211,7 +214,7 @@ class MooncakeStorePyWrapper {
211214 values.emplace_back (std::span<const char >(buffer, tensor_size));
212215
213216 // Use put_parts to put metadata and tensor together
214- auto put_result = store_. put_parts_internal (key, values);
217+ auto put_result = store_-> put_parts_internal (key, values);
215218 if (!put_result) {
216219 return -static_cast <int >(put_result.error ());
217220 }
@@ -287,77 +290,83 @@ PYBIND11_MODULE(store, m) {
287290 const std::string &protocol = " tcp" ,
288291 const std::string &rdma_devices = " " ,
289292 const std::string &master_server_addr = " 127.0.0.1:50051" ) {
290- return self.store_ .setup (local_hostname, metadata_server,
291- global_segment_size,
292- local_buffer_size, protocol,
293- rdma_devices, master_server_addr);
293+ if (!self.store_ ) {
294+ self.store_ = PyClient::create ();
295+ }
296+ return self.store_ ->setup (local_hostname, metadata_server,
297+ global_segment_size,
298+ local_buffer_size, protocol,
299+ rdma_devices, master_server_addr);
294300 })
295301 .def (" init_all" ,
296302 [](MooncakeStorePyWrapper &self, const std::string &protocol,
297303 const std::string &device_name,
298304 size_t mount_segment_size = 1024 * 1024 * 16 ) {
299- return self.store_ . initAll (protocol, device_name,
300- mount_segment_size);
305+ return self.store_ -> initAll (protocol, device_name,
306+ mount_segment_size);
301307 })
302308 .def (" get" , &MooncakeStorePyWrapper::get)
303309 .def (" get_batch" , &MooncakeStorePyWrapper::get_batch)
304310 .def (
305311 " get_buffer" ,
306312 [](MooncakeStorePyWrapper &self, const std::string &key) {
307313 py::gil_scoped_release release;
308- return self.store_ . get_buffer (key);
314+ return self.store_ -> get_buffer (key);
309315 },
310316 py::return_value_policy::take_ownership)
311317 .def (
312318 " batch_get_buffer" ,
313319 [](MooncakeStorePyWrapper &self,
314320 const std::vector<std::string> &keys) {
315321 py::gil_scoped_release release;
316- return self.store_ . batch_get_buffer (keys);
322+ return self.store_ -> batch_get_buffer (keys);
317323 },
318324 py::return_value_policy::take_ownership)
319325 .def (" remove" ,
320326 [](MooncakeStorePyWrapper &self, const std::string &key) {
321327 py::gil_scoped_release release;
322- return self.store_ . remove (key);
328+ return self.store_ -> remove (key);
323329 })
324330 .def (
325331 " remove_by_regex" ,
326332 [](MooncakeStorePyWrapper &self, const std::string &str) {
327333 py::gil_scoped_release release;
328- return self.store_ . removeByRegex (str);
334+ return self.store_ -> removeByRegex (str);
329335 },
330336 py::arg (" regex_pattern" ),
331337 " Removes objects from the store whose keys match the given "
332338 " regular expression." )
333339 .def (" remove_all" ,
334340 [](MooncakeStorePyWrapper &self) {
335341 py::gil_scoped_release release;
336- return self.store_ . removeAll ();
342+ return self.store_ -> removeAll ();
337343 })
338344 .def (" is_exist" ,
339345 [](MooncakeStorePyWrapper &self, const std::string &key) {
340346 py::gil_scoped_release release;
341- return self.store_ . isExist (key);
347+ return self.store_ -> isExist (key);
342348 })
343349 .def (
344350 " batch_is_exist" ,
345351 [](MooncakeStorePyWrapper &self,
346352 const std::vector<std::string> &keys) {
347353 py::gil_scoped_release release;
348- return self.store_ . batchIsExist (keys);
354+ return self.store_ -> batchIsExist (keys);
349355 },
350356 py::arg (" keys" ),
351357 " Check if multiple objects exist. Returns list of results: 1 if "
352358 " exists, 0 if not exists, -1 if error" )
353359 .def (" close" ,
354360 [](MooncakeStorePyWrapper &self) {
355- return self.store_ .tearDownAll ();
361+ if (!self.store_ ) return 0 ;
362+ int rc = self.store_ ->tearDownAll ();
363+ self.store_ .reset ();
364+ return rc;
356365 })
357366 .def (" get_size" ,
358367 [](MooncakeStorePyWrapper &self, const std::string &key) {
359368 py::gil_scoped_release release;
360- return self.store_ . getSize (key);
369+ return self.store_ -> getSize (key);
361370 })
362371 .def (" get_tensor" , &MooncakeStorePyWrapper::get_tensor, py::arg (" key" ),
363372 " Get a PyTorch tensor from the store" )
@@ -370,7 +379,7 @@ PYBIND11_MODULE(store, m) {
370379 // Register memory buffer for RDMA operations
371380 void *buffer = reinterpret_cast <void *>(buffer_ptr);
372381 py::gil_scoped_release release;
373- return self.store_ . register_buffer (buffer, size);
382+ return self.store_ -> register_buffer (buffer, size);
374383 },
375384 py::arg (" buffer_ptr" ), py::arg (" size" ),
376385 " Register a memory buffer for direct access operations" )
@@ -380,7 +389,7 @@ PYBIND11_MODULE(store, m) {
380389 // Unregister memory buffer
381390 void *buffer = reinterpret_cast <void *>(buffer_ptr);
382391 py::gil_scoped_release release;
383- return self.store_ . unregister_buffer (buffer);
392+ return self.store_ -> unregister_buffer (buffer);
384393 },
385394 py::arg (" buffer_ptr" ),
386395 " Unregister a previously registered memory "
@@ -392,7 +401,7 @@ PYBIND11_MODULE(store, m) {
392401 // Get data directly into user-provided buffer
393402 void *buffer = reinterpret_cast <void *>(buffer_ptr);
394403 py::gil_scoped_release release;
395- return self.store_ . get_into (key, buffer, size);
404+ return self.store_ -> get_into (key, buffer, size);
396405 },
397406 py::arg (" key" ), py::arg (" buffer_ptr" ), py::arg (" size" ),
398407 " Get object data directly into a pre-allocated buffer" )
@@ -408,7 +417,7 @@ PYBIND11_MODULE(store, m) {
408417 buffers.push_back (reinterpret_cast <void *>(ptr));
409418 }
410419 py::gil_scoped_release release;
411- return self.store_ . batch_get_into (keys, buffers, sizes);
420+ return self.store_ -> batch_get_into (keys, buffers, sizes);
412421 },
413422 py::arg (" keys" ), py::arg (" buffer_ptrs" ), py::arg (" sizes" ),
414423 " Get object data directly into pre-allocated buffers for "
@@ -422,7 +431,7 @@ PYBIND11_MODULE(store, m) {
422431 // Put data directly from user-provided buffer
423432 void *buffer = reinterpret_cast <void *>(buffer_ptr);
424433 py::gil_scoped_release release;
425- return self.store_ . put_from (key, buffer, size, config);
434+ return self.store_ -> put_from (key, buffer, size, config);
426435 },
427436 py::arg (" key" ), py::arg (" buffer_ptr" ), py::arg (" size" ),
428437 py::arg (" config" ) = ReplicateConfig{},
@@ -439,7 +448,7 @@ PYBIND11_MODULE(store, m) {
439448 void *metadata_buffer =
440449 reinterpret_cast <void *>(metadata_buffer_ptr);
441450 py::gil_scoped_release release;
442- return self.store_ . put_from_with_metadata (
451+ return self.store_ -> put_from_with_metadata (
443452 key, buffer, metadata_buffer, size, metadata_size, config);
444453 },
445454 py::arg (" key" ), py::arg (" buffer_ptr" ),
@@ -460,7 +469,8 @@ PYBIND11_MODULE(store, m) {
460469 buffers.push_back (reinterpret_cast <void *>(ptr));
461470 }
462471 py::gil_scoped_release release;
463- return self.store_ .batch_put_from (keys, buffers, sizes, config);
472+ return self.store_ ->batch_put_from (keys, buffers, sizes,
473+ config);
464474 },
465475 py::arg (" keys" ), py::arg (" buffer_ptrs" ), py::arg (" sizes" ),
466476 py::arg (" config" ) = ReplicateConfig{},
@@ -474,7 +484,7 @@ PYBIND11_MODULE(store, m) {
474484 const ReplicateConfig &config = ReplicateConfig{}) {
475485 py::buffer_info info = buf.request (/* writable=*/ false );
476486 py::gil_scoped_release release;
477- return self.store_ . put (
487+ return self.store_ -> put (
478488 key,
479489 std::span<const char >(static_cast <char *>(info.ptr ),
480490 static_cast <size_t >(info.size )),
@@ -507,7 +517,7 @@ PYBIND11_MODULE(store, m) {
507517
508518 // 2) Call C++ function
509519 py::gil_scoped_release unlock;
510- return self.store_ . put_parts (key, spans, config);
520+ return self.store_ -> put_parts (key, spans, config);
511521 },
512522 py::arg (" key" ), py::arg (" config" ) = ReplicateConfig{})
513523 .def (
@@ -530,12 +540,12 @@ PYBIND11_MODULE(store, m) {
530540 }
531541
532542 py::gil_scoped_release release;
533- return self.store_ . put_batch (keys, spans, config);
543+ return self.store_ -> put_batch (keys, spans, config);
534544 },
535545 py::arg (" keys" ), py::arg (" values" ),
536546 py::arg (" config" ) = ReplicateConfig{})
537547 .def (" get_hostname" , [](MooncakeStorePyWrapper &self) {
538- return self.store_ . get_hostname ();
548+ return self.store_ -> get_hostname ();
539549 });
540550
541551 // Expose NUMA binding as a module-level function (no self required)
0 commit comments