From b6044955bd2b24676575ed8bec4c215745276c67 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Sun, 29 Sep 2024 23:08:33 +0000 Subject: [PATCH] [NVSHMEM] Enable nvshmem memory allocation This PR add the support of nvshmem memory allocation, and integrates it into disco. --- .../contrib/nvshmem/{nvshmem.cc => init.cc} | 2 + .../contrib/nvshmem/memory_allocator.cc | 97 +++++++++++++++++++ tests/python/disco/test_nvshmem.py | 28 +++++- 3 files changed, 122 insertions(+), 5 deletions(-) rename src/runtime/contrib/nvshmem/{nvshmem.cc => init.cc} (96%) create mode 100644 src/runtime/contrib/nvshmem/memory_allocator.cc diff --git a/src/runtime/contrib/nvshmem/nvshmem.cc b/src/runtime/contrib/nvshmem/init.cc similarity index 96% rename from src/runtime/contrib/nvshmem/nvshmem.cc rename to src/runtime/contrib/nvshmem/init.cc index 985ba5510762b..50fdde4c49d8b 100644 --- a/src/runtime/contrib/nvshmem/nvshmem.cc +++ b/src/runtime/contrib/nvshmem/init.cc @@ -54,6 +54,8 @@ void InitNVSHMEM(ShapeTuple uid_64, int num_workers) { } nvshmemx_set_attr_uniqueid_args(worker->worker_id, num_workers, &uid, &attr); nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr); + int mype_node = nvshmem_team_my_pe(NVSHMEMX_TEAM_NODE); + CUDA_CALL(cudaSetDevice(mype_node)); LOG_INFO << "NVSHMEM init finished: mype=" << nvshmem_my_pe() << " " << ", npes=" << nvshmem_n_pes(); } diff --git a/src/runtime/contrib/nvshmem/memory_allocator.cc b/src/runtime/contrib/nvshmem/memory_allocator.cc new file mode 100644 index 0000000000000..f9ca61ecaaaab --- /dev/null +++ b/src/runtime/contrib/nvshmem/memory_allocator.cc @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include + +#include + +#include "../../cuda/cuda_common.h" +#include "../../memory/pooled_allocator.h" + +namespace tvm { +namespace runtime { + +using tvm::runtime::memory::Buffer; +using tvm::runtime::memory::PooledAllocator; + +/*! + * \brief The memory allocator of CUDAIPCMemory. + * Overriding PooledAllocator for efficient memory management. + */ +class NVSHMEMAllocator final : public PooledAllocator { + public: + explicit NVSHMEMAllocator() : PooledAllocator() {} + + ~NVSHMEMAllocator() { PooledAllocator::ReleaseAll(); } + + void Clear() final { PooledAllocator::ReleaseAll(); } + + bool AllowMemoryScope(const std::string& mem_scope) const final { + // The allowed memory scope of CUDAIPCMemory is "ipc_memory"; + return mem_scope == "nvshmem"; + } + + /*! \brief Return the global CUDAIPCMemory singleton allocator. */ + static NVSHMEMAllocator* Global() { + static NVSHMEMAllocator* allocator = new NVSHMEMAllocator(); + return allocator; + } + + NDArray Empty(ShapeTuple shape, DataType dtype, Device device) { + NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, device); + container->SetDeleter([](Object* obj) { + auto* ptr = static_cast(obj); + ICHECK(ptr->manager_ctx != nullptr); + Buffer* buffer = reinterpret_cast(ptr->manager_ctx); + NVSHMEMAllocator::Global()->Free(*(buffer)); + delete buffer; + delete ptr; + }); + Buffer* buffer = new Buffer; + *buffer = PooledAllocator::Alloc(device, shape, dtype, String("nvshmem")); + container->manager_ctx = reinterpret_cast(buffer); + container->dl_tensor.data = buffer->data; + return NDArray(GetObjectPtr(container)); + } + + private: + void* DeviceAllocDataSpace(Device dev, size_t size, size_t alignment, + DLDataType type_hint) final { + ICHECK_EQ(dev.device_type, DLDeviceType::kDLCUDA) + << "nvshmem can only allocate cuda device memory space."; + ICHECK(type_hint.code == DLDataTypeCode::kDLInt || type_hint.code == DLDataTypeCode::kDLUInt || + type_hint.code == DLDataTypeCode::kDLFloat) + << "nvshmem can only allocate tensor with int, usingned int or float data types."; + return nvshmem_align(alignment, size); + } + + void DeviceFreeDataSpace(Device dev, void* ptr) final { nvshmem_free(ptr); } +}; + +NDArray NVSHMEMEmpty(ShapeTuple shape, DataType dtype, Device device) { + return NVSHMEMAllocator::Global()->Empty(shape, dtype, device); +} + +TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.empty").set_body_typed(NVSHMEMEmpty); + +} // namespace runtime +} // namespace tvm diff --git a/tests/python/disco/test_nvshmem.py b/tests/python/disco/test_nvshmem.py index 0b16fe93612f2..b67fd8266e4b3 100644 --- a/tests/python/disco/test_nvshmem.py +++ b/tests/python/disco/test_nvshmem.py @@ -23,9 +23,10 @@ import subprocess import threading import sys +from multiprocessing import Process + import tvm -import tvm.testing from tvm.runtime import ShapeTuple from tvm.runtime import disco as di from tvm.exec import disco_worker as _ # pylint: disable=unused-import @@ -82,8 +83,6 @@ def start_server(): thread.join() def __del__(self): - for node in self.remote_nodes: - node.kill() if self.sess is not None: self.sess.shutdown() del self.sess @@ -98,7 +97,6 @@ def create_socket_session(num_workers): return _SOCKET_SESSION_TESTER.sess -@pytest.mark.parametrize("num_workers", [2, 4]) def test_nvshmem_init(num_workers): if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is None: return @@ -110,5 +108,25 @@ def test_nvshmem_init(num_workers): sess.sync_worker_0() +def test_nvshmem_empty(num_workers): + if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is None: + return + device = tvm.cuda() + sess = create_socket_session(num_workers=num_workers) + f_init_nvshmem_uid = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid") + uid = f_init_nvshmem_uid() + init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem") + init_dfunc(uid, num_workers) + sess.sync_worker_0() + empty_dfunc = sess.get_global_func("runtime.disco.nvshmem.empty") + a = empty_dfunc(ShapeTuple((32, 64)), "float32", device) + b = empty_dfunc(ShapeTuple((64, 32)), "float32", device) + sess.sync_worker_0() + + if __name__ == "__main__": - tvm.testing.main() + for num_worker in [2, 4]: + for test_func in [test_nvshmem_init, test_nvshmem_empty]: + p = Process(target=test_func, args=[num_worker]) + p.start() + p.join()