Skip to content

Commit

Permalink
[NVSHMEM] Enable nvshmem memory allocation
Browse files Browse the repository at this point in the history
This PR add the support of nvshmem memory allocation, and integrates it into disco.
  • Loading branch information
cyx-6 committed Sep 29, 2024
1 parent d9ee637 commit b604495
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ void InitNVSHMEM(ShapeTuple uid_64, int num_workers) {
}
nvshmemx_set_attr_uniqueid_args(worker->worker_id, num_workers, &uid, &attr);
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr);
int mype_node = nvshmem_team_my_pe(NVSHMEMX_TEAM_NODE);
CUDA_CALL(cudaSetDevice(mype_node));
LOG_INFO << "NVSHMEM init finished: mype=" << nvshmem_my_pe() << " "
<< ", npes=" << nvshmem_n_pes();
}
Expand Down
97 changes: 97 additions & 0 deletions src/runtime/contrib/nvshmem/memory_allocator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <nvshmem.h>
#include <nvshmemx.h>
#include <tvm/runtime/memory/memory_manager.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>

#include <thread>

#include "../../cuda/cuda_common.h"
#include "../../memory/pooled_allocator.h"

namespace tvm {
namespace runtime {

using tvm::runtime::memory::Buffer;
using tvm::runtime::memory::PooledAllocator;

/*!
* \brief The memory allocator of CUDAIPCMemory.
* Overriding PooledAllocator for efficient memory management.
*/
class NVSHMEMAllocator final : public PooledAllocator {
public:
explicit NVSHMEMAllocator() : PooledAllocator() {}

~NVSHMEMAllocator() { PooledAllocator::ReleaseAll(); }

void Clear() final { PooledAllocator::ReleaseAll(); }

bool AllowMemoryScope(const std::string& mem_scope) const final {
// The allowed memory scope of CUDAIPCMemory is "ipc_memory";
return mem_scope == "nvshmem";
}

/*! \brief Return the global CUDAIPCMemory singleton allocator. */
static NVSHMEMAllocator* Global() {
static NVSHMEMAllocator* allocator = new NVSHMEMAllocator();
return allocator;
}

NDArray Empty(ShapeTuple shape, DataType dtype, Device device) {
NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, device);
container->SetDeleter([](Object* obj) {
auto* ptr = static_cast<NDArray::Container*>(obj);
ICHECK(ptr->manager_ctx != nullptr);
Buffer* buffer = reinterpret_cast<Buffer*>(ptr->manager_ctx);
NVSHMEMAllocator::Global()->Free(*(buffer));
delete buffer;
delete ptr;
});
Buffer* buffer = new Buffer;
*buffer = PooledAllocator::Alloc(device, shape, dtype, String("nvshmem"));
container->manager_ctx = reinterpret_cast<void*>(buffer);
container->dl_tensor.data = buffer->data;
return NDArray(GetObjectPtr<Object>(container));
}

private:
void* DeviceAllocDataSpace(Device dev, size_t size, size_t alignment,
DLDataType type_hint) final {
ICHECK_EQ(dev.device_type, DLDeviceType::kDLCUDA)
<< "nvshmem can only allocate cuda device memory space.";
ICHECK(type_hint.code == DLDataTypeCode::kDLInt || type_hint.code == DLDataTypeCode::kDLUInt ||
type_hint.code == DLDataTypeCode::kDLFloat)
<< "nvshmem can only allocate tensor with int, usingned int or float data types.";
return nvshmem_align(alignment, size);
}

void DeviceFreeDataSpace(Device dev, void* ptr) final { nvshmem_free(ptr); }
};

NDArray NVSHMEMEmpty(ShapeTuple shape, DataType dtype, Device device) {
return NVSHMEMAllocator::Global()->Empty(shape, dtype, device);
}

TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.empty").set_body_typed(NVSHMEMEmpty);

} // namespace runtime
} // namespace tvm
28 changes: 23 additions & 5 deletions tests/python/disco/test_nvshmem.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
import subprocess
import threading
import sys
from multiprocessing import Process


import tvm
import tvm.testing
from tvm.runtime import ShapeTuple
from tvm.runtime import disco as di
from tvm.exec import disco_worker as _ # pylint: disable=unused-import
Expand Down Expand Up @@ -82,8 +83,6 @@ def start_server():
thread.join()

def __del__(self):
for node in self.remote_nodes:
node.kill()
if self.sess is not None:
self.sess.shutdown()
del self.sess
Expand All @@ -98,7 +97,6 @@ def create_socket_session(num_workers):
return _SOCKET_SESSION_TESTER.sess


@pytest.mark.parametrize("num_workers", [2, 4])
def test_nvshmem_init(num_workers):
if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is None:
return
Expand All @@ -110,5 +108,25 @@ def test_nvshmem_init(num_workers):
sess.sync_worker_0()


def test_nvshmem_empty(num_workers):
if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is None:
return
device = tvm.cuda()
sess = create_socket_session(num_workers=num_workers)
f_init_nvshmem_uid = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
uid = f_init_nvshmem_uid()
init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem")
init_dfunc(uid, num_workers)
sess.sync_worker_0()
empty_dfunc = sess.get_global_func("runtime.disco.nvshmem.empty")
a = empty_dfunc(ShapeTuple((32, 64)), "float32", device)
b = empty_dfunc(ShapeTuple((64, 32)), "float32", device)
sess.sync_worker_0()


if __name__ == "__main__":
tvm.testing.main()
for num_worker in [2, 4]:
for test_func in [test_nvshmem_init, test_nvshmem_empty]:
p = Process(target=test_func, args=[num_worker])
p.start()
p.join()

0 comments on commit b604495

Please sign in to comment.