diff --git a/CMakeLists.txt b/CMakeLists.txt index aa2a385683d7..38dd59b9c906 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,6 +13,7 @@ include(cmake/utils/FindLLVM.cmake) include(cmake/utils/FindROCM.cmake) include(cmake/utils/FindRCCL.cmake) include(cmake/utils/FindEthosN.cmake) +include(cmake/utils/FindNVSHMEM.cmake) if(EXISTS ${CMAKE_BINARY_DIR}/config.cmake) include(${CMAKE_BINARY_DIR}/config.cmake) @@ -133,6 +134,7 @@ tvm_option(USE_UMA "Build with UMA support" OFF) tvm_option(USE_VERILATOR "Build with Verilator support" OFF) tvm_option(USE_MSC "Enable Multi-System Compiler" OFF) tvm_option(USE_MRVL "Build with MRVL TVM support" OFF) +tvm_option(USE_NVSHMEM "Build with NVSHMEM support" OFF) # include directories include_directories(${CMAKE_INCLUDE_PATH}) @@ -472,6 +474,16 @@ if(USE_CUDA AND USE_NCCL) list(APPEND RUNTIME_SRCS ${RUNTIME_NCCL_SRC}) endif() +if (USE_CUDA AND USE_NVSHMEM) + message(STATUS "Build with NVSHMEM...") + find_nvshmem(${USE_NVSHMEM}) + if (NOT NVSHMEM_FOUND) + message(FATAL_ERROR "Cannot find NVSHMEM, USE_NVSHMEM=" ${USE_NVSHMEM}) + endif() + tvm_file_glob(GLOB RUNTIME_NVSHMEM_SRCS src/runtime/contrib/nvshmem/*.cc) + list(APPEND RUNTIME_SRCS ${RUNTIME_NVSHMEM_SRCS}) +endif() + if(USE_ROCM AND USE_RCCL) message(STATUS "Build with RCCL...") find_rccl(${USE_RCCL}) @@ -957,6 +969,17 @@ if(USE_CUDA AND USE_NCCL) target_link_libraries(tvm_runtime PRIVATE nccl ${LIBRT}) endif() + +if (USE_CUDA AND USE_NVSHMEM) + include_directories(SYSTEM ${USE_NVSHMEM}/include) + find_library(NVSHMEM_HOST nvshmem_host ${NVSHMEM_LIB_DIR}) + find_library(NVSHMEM_DEVICE nvshmem_device ${NVSHMEM_LIB_DIR}) + target_link_libraries(tvm PRIVATE ${NVSHMEM_HOST} ${NVSHMEM_DEVICE}) + target_link_libraries(tvm_runtime PRIVATE ${NVSHMEM_HOST} ${NVSHMEM_DEVICE}) + set_target_properties(tvm PROPERTIES CUDA_SEPARABLE_COMPILATION ON) + set_target_properties(tvm_runtime PROPERTIES CUDA_SEPARABLE_COMPILATION ON) +endif() + if(USE_ROCM AND USE_RCCL) target_link_libraries(tvm PRIVATE rccl) target_link_libraries(tvm_runtime PRIVATE rccl) diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake index da9bc3e1c9d3..a2b51bb33195 100644 --- a/cmake/modules/LibInfo.cmake +++ b/cmake/modules/LibInfo.cmake @@ -143,6 +143,7 @@ function(add_lib_info src_file) TVM_INFO_USE_VERILATOR="${USE_VERILATOR}" TVM_INFO_USE_MSC="${USE_MSC}" TVM_INFO_USE_CCACHE="${USE_CCACHE}" + TVM_INFO_USE_NVSHMEM="${USE_NVSHMEM}" TVM_INFO_BACKTRACE_ON_SEGFAULT="${BACKTRACE_ON_SEGFAULT}" ) diff --git a/cmake/utils/FindNVSHMEM.cmake b/cmake/utils/FindNVSHMEM.cmake new file mode 100644 index 000000000000..1a833332a289 --- /dev/null +++ b/cmake/utils/FindNVSHMEM.cmake @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +####################################################### +# Enhanced version of find NVSHMEM. +# +# Usage: +# find_nvshmem(${USE_NVSHMEM}) +# +# - When USE_NVSHMEM=ON, use auto search +# - When USE_NVSHMEM=/path/to/installed/nvshmem, use the installed nvshmem path. +# Can be useful when nvshmem is installed at specified location. +# +# Provide variables: +# +# - NVSHMEM_FOUND +# - NVSHMEM_INCLUDE_DIR +# - NVSHMEM_LIB_DIR +# + +macro(find_nvshmem use_nvshmem) + set(__use_nvshmem ${use_nvshmem}) + if(IS_DIRECTORY ${__use_nvshmem}) + set(__nvshmem_path ${__use_nvshmem}) + message(STATUS "Custom NVSHMEM PATH=" ${__use_nvshmem}) + elseif(IS_DIRECTORY $ENV{NVSHMEM_HOME}) + set(__nvshmem_path $ENV{NVSHMEM_HOME}) + else() + set(__nvshmem_path "") + endif() + + find_package(NVSHMEM HINTS ${__nvshmem_path}/lib/cmake/nvshmem/) + + if(NVSHMEM_FOUND) + message(STATUS "NVSHMEM_INCLUDE_DIR=" ${NVSHMEM_INCLUDE_DIR}) + message(STATUS "NVSHMEM_LIB_DIR=" ${NVSHMEM_LIB_DIR}) + endif(NVSHMEM_FOUND) +endmacro(find_nvshmem) diff --git a/src/runtime/contrib/nvshmem/nvshmem.cc b/src/runtime/contrib/nvshmem/nvshmem.cc new file mode 100644 index 000000000000..985ba5510762 --- /dev/null +++ b/src/runtime/contrib/nvshmem/nvshmem.cc @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include + +#include "../../cuda/cuda_common.h" + +namespace tvm { +namespace runtime { + +ShapeTuple InitNVSHMEMUID() { + nvshmemx_uniqueid_t uid; + nvshmemx_get_uniqueid(&uid); + std::vector uid_64; + uid_64.push_back(static_cast(uid.version)); + for (int i = 0; i < UNIQUEID_PADDING; ++i) { + uid_64.push_back(static_cast(uid.internal[i])); + } + return ShapeTuple(uid_64); +} + +void InitNVSHMEM(ShapeTuple uid_64, int num_workers) { + DiscoWorker* worker = DiscoWorker::ThreadLocal(); + ICHECK(worker != nullptr); + CHECK_EQ(uid_64.size(), UNIQUEID_PADDING + 1) + << "ValueError: The length of unique_id must be " << UNIQUEID_PADDING << ", but got " + << uid_64.size() << "."; + + nvshmemx_init_attr_t attr = NVSHMEMX_INIT_ATTR_INITIALIZER; + + nvshmemx_uniqueid_t uid; + uid.version = static_cast(uid_64[0]); + for (int i = 0; i < UNIQUEID_PADDING; ++i) { + uid.internal[i] = static_cast(uid_64[i + 1]); + } + nvshmemx_set_attr_uniqueid_args(worker->worker_id, num_workers, &uid, &attr); + nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr); + LOG_INFO << "NVSHMEM init finished: mype=" << nvshmem_my_pe() << " " + << ", npes=" << nvshmem_n_pes(); +} + +TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_uid").set_body_typed(InitNVSHMEMUID); + +TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem").set_body_typed(InitNVSHMEM); + +} // namespace runtime +} // namespace tvm diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index 984a2f3323ad..73800338b143 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -275,6 +275,10 @@ #define TVM_INFO_USE_CCACHE "NOT-FOUND" #endif +#ifndef TVM_INFO_USE_NVSHMEM +#define TVM_INFO_USE_NVSHMEM "NOT-FOUND" +#endif + namespace tvm { /*! @@ -387,6 +391,7 @@ TVM_DLL Map GetLibInfo() { {"USE_VERILATOR", TVM_INFO_USE_VERILATOR}, {"USE_MSC", TVM_INFO_USE_MSC}, {"USE_CCACHE", TVM_INFO_USE_CCACHE}, + {"USE_NVSHMEM", TVM_INFO_USE_NVSHMEM}, {"BACKTRACE_ON_SEGFAULT", TVM_INFO_BACKTRACE_ON_SEGFAULT}, }; return result; diff --git a/tests/python/disco/test_nvshmem.py b/tests/python/disco/test_nvshmem.py new file mode 100644 index 000000000000..0b16fe93612f --- /dev/null +++ b/tests/python/disco/test_nvshmem.py @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Basic tests for a Disco nvshmem support""" +# pylint: disable=missing-docstring +import tempfile + +import numpy as np +import pytest +import subprocess +import threading +import sys + +import tvm +import tvm.testing +from tvm.runtime import ShapeTuple +from tvm.runtime import disco as di +from tvm.exec import disco_worker as _ # pylint: disable=unused-import + +_SOCKET_SESSION_TESTER = None + + +def get_free_port(): + import socket + + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("", 0)) + port = s.getsockname()[1] + s.close() + return port + + +class SocketSessionTester: + def __init__(self, num_workers): + num_nodes = 2 + num_groups = 1 + assert num_workers % num_nodes == 0 + num_workers_per_node = num_workers // num_nodes + server_host = "localhost" + server_port = get_free_port() + self.sess = None + + def start_server(): + self.sess = di.SocketSession( + num_nodes, num_workers_per_node, num_groups, server_host, server_port + ) + + thread = threading.Thread(target=start_server) + thread.start() + + cmd = "tvm.exec.disco_remote_socket_session" + self.remote_nodes = [] + for _ in range(num_nodes - 1): + self.remote_nodes.append( + subprocess.Popen( + [ + "python3", + "-m", + cmd, + server_host, + str(server_port), + str(num_workers_per_node), + ], + stdout=sys.stdout, + stderr=sys.stderr, + ) + ) + + thread.join() + + def __del__(self): + for node in self.remote_nodes: + node.kill() + if self.sess is not None: + self.sess.shutdown() + del self.sess + + +def create_socket_session(num_workers): + global _SOCKET_SESSION_TESTER + if _SOCKET_SESSION_TESTER is not None: + del _SOCKET_SESSION_TESTER + _SOCKET_SESSION_TESTER = SocketSessionTester(num_workers) + assert _SOCKET_SESSION_TESTER.sess is not None + return _SOCKET_SESSION_TESTER.sess + + +@pytest.mark.parametrize("num_workers", [2, 4]) +def test_nvshmem_init(num_workers): + if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is None: + return + sess = create_socket_session(num_workers=num_workers) + f_init_nvshmem_uid = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid") + uid = f_init_nvshmem_uid() + init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem") + init_dfunc(uid, num_workers) + sess.sync_worker_0() + + +if __name__ == "__main__": + tvm.testing.main()