Skip to content

Commit

Permalink
socket session
Browse files Browse the repository at this point in the history
  • Loading branch information
cyx-6 committed Aug 29, 2024
1 parent 2688307 commit 0dc7e8e
Showing 1 changed file with 20 additions and 18 deletions.
38 changes: 20 additions & 18 deletions src/runtime/contrib/nvshmem/nvshmem.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,38 +27,40 @@
namespace tvm {
namespace runtime {

void InitNVSHMEM(Session sess, IntTuple device_ids) {
DRef func = sess->GetGlobalFunc("runtime.disco.nvshmem.init_nvshmem_per_worker");
DLOG(INFO) << "Initializing NVSHMEM with devices: " << device_ids;
ShapeTuple InitNVSHMEMUID() {
nvshmemx_uniqueid_t uid;
nvshmemx_get_uniqueid(&uid);
TVMByteArray array;
array.data = uid.internal;
array.size = UNIQUEID_PADDING;
sess->CallPacked(func, device_ids, uid.version, array);
std::vector<int64_t> uid_64;
uid_64.push_back(static_cast<int64_t>(uid.version));
for (int i = 0; i < UNIQUEID_PADDING; ++i) {
uid_64.push_back(static_cast<int64_t>(uid.internal[i]));
}
return ShapeTuple(uid_64);
}

void InitNVSHMEMPerWorker(IntTuple device_ids, int version, std::string unique_id_bytes) {
void InitNVSHMEM(ShapeTuple uid_64, IntTuple device_ids) {
DiscoWorker* worker = DiscoWorker::ThreadLocal();
ICHECK(worker != nullptr);

CHECK_EQ(unique_id_bytes.size(), UNIQUEID_PADDING)
CHECK_EQ(uid_64.size(), UNIQUEID_PADDING + 1)
<< "ValueError: The length of unique_id must be " << UNIQUEID_PADDING << ", but got "
<< unique_id_bytes.size() << ".";
<< uid_64.size() << ".";

nvshmemx_init_attr_t attr = NVSHMEMX_INIT_ATTR_INITIALIZER;

nvshmemx_uniqueid_t uid;
uid.version = version;
std::memcpy(uid.internal, unique_id_bytes.data(), UNIQUEID_PADDING);
nvshmemx_set_attr_uniqueid_args(worker->local_worker_id, device_ids.size(), &uid, &attr);
uid.version = static_cast<int>(uid_64[0]);
for (int i = 0; i < UNIQUEID_PADDING; ++i) {
uid.internal[i] = static_cast<char>(uid_64[i + 1]);
}
nvshmemx_set_attr_uniqueid_args(worker->worker_id, device_ids.size(), &uid, &attr);
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr);
LOG_INFO << "mype=" << nvshmem_my_pe() << " " << ", npes=" << nvshmem_n_pes();
LOG_INFO << "NVSHMEM init finished: mype=" << nvshmem_my_pe() << " "
<< ", npes=" << nvshmem_n_pes();
}

TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem").set_body_typed(InitNVSHMEM);
TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_uid").set_body_typed(InitNVSHMEMUID);

TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_per_worker")
.set_body_typed(InitNVSHMEMPerWorker);
TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem").set_body_typed(InitNVSHMEM);

} // namespace runtime
} // namespace tvm

0 comments on commit 0dc7e8e

Please sign in to comment.