diff --git a/src/runtime/contrib/nvshmem/memory_allocator.cc b/src/runtime/contrib/nvshmem/memory_allocator.cc index facfc9521741..770db2f90227 100644 --- a/src/runtime/contrib/nvshmem/memory_allocator.cc +++ b/src/runtime/contrib/nvshmem/memory_allocator.cc @@ -57,20 +57,18 @@ class NVSHMEMAllocator final : public PooledAllocator { } NDArray Empty(ffi::Shape shape, DataType dtype, Device device) { - NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, device); - container->SetDeleter([](Object* obj) { - auto* ptr = static_cast(obj); - ICHECK(ptr->manager_ctx != nullptr); - Buffer* buffer = reinterpret_cast(ptr->manager_ctx); - NVSHMEMAllocator::Global()->Free(*(buffer)); - delete buffer; - delete ptr; - }); - Buffer* buffer = new Buffer; - *buffer = PooledAllocator::Alloc(device, shape, dtype, String("nvshmem")); - container->manager_ctx = reinterpret_cast(buffer); - container->dl_tensor.data = buffer->data; - return NDArray(GetObjectPtr(container)); + class NVSHMEMAlloc { + public: + explicit NVSHMEMAlloc(Buffer buffer) : buffer_(buffer) {} + void AllocData(DLTensor* tensor) { tensor->data = buffer_.data; } + void FreeData(DLTensor* tensor) { NVSHMEMAllocator::Global()->Free(buffer_); } + + private: + Buffer buffer_; + }; + + Buffer buffer = PooledAllocator::Alloc(device, shape, dtype, String("nvshmem")); + return NDArray::FromNDAlloc(NVSHMEMAlloc(buffer), shape, dtype, device); } private: