Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce collective ops to ort inference build #14399

Merged
merged 24 commits into from
Feb 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 80 additions & 77 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1338,102 +1338,105 @@ if (onnxruntime_ENABLE_TRAINING)
add_compile_definitions(ENABLE_STRIDED_TENSORS)
add_compile_definitions(ENABLE_TRAINING)

if (UNIX)
if (EXISTS "${onnxruntime_MPI_HOME}")
set(MPI_HOME "${onnxruntime_MPI_HOME}")
elseif (EXISTS "/bert_ort/openmpi")
set(MPI_HOME "/bert_ort/openmpi")
endif()
add_subdirectory(tensorboard EXCLUDE_FROM_ALL)
souptc marked this conversation as resolved.
Show resolved Hide resolved
list(APPEND onnxruntime_EXTERNAL_LIBRARIES tensorboard)
endif()

if (UNIX AND onnxruntime_USE_MPI)
if (EXISTS "${onnxruntime_MPI_HOME}")
set(MPI_HOME "${onnxruntime_MPI_HOME}")
elseif (EXISTS "/bert_ort/openmpi")
souptc marked this conversation as resolved.
Show resolved Hide resolved
set(MPI_HOME "/bert_ort/openmpi")
endif()

find_package(MPI)
find_package(MPI)

if (onnxruntime_USE_MPI OR onnxruntime_USE_NCCL)
if (MPI_CXX_FOUND)
message( STATUS "MPI Version: ${MPI_CXX_VERSION}")
message( STATUS "MPI (include: ${MPI_CXX_INCLUDE_DIRS}, library: ${MPI_CXX_LIBRARIES})" )
mark_as_advanced(MPI_CXX_INCLUDE_DIRS MPI_CXX_LIBRARIES)
list(APPEND onnxruntime_EXTERNAL_LIBRARIES ${MPI_CXX_LIBRARIES} ${MPI_CXX_LINK_FLAGS})
else ()
message(
if (MPI_CXX_FOUND)
message( STATUS "MPI Version: ${MPI_CXX_VERSION}")
message( STATUS "MPI (include: ${MPI_CXX_INCLUDE_DIRS}, library: ${MPI_CXX_LIBRARIES})" )
mark_as_advanced(MPI_CXX_INCLUDE_DIRS MPI_CXX_LIBRARIES)
list(APPEND onnxruntime_EXTERNAL_LIBRARIES ${MPI_CXX_LIBRARIES} ${MPI_CXX_LINK_FLAGS})
else ()
message(
FATAL_ERROR
"MPI is not found. Please define onnxruntime_MPI_HOME to specify the path of MPI. Otherwise, NCCL will be disabled."
)
endif()
endif()
endif()

# Find NCCL and MPI
if (onnxruntime_USE_NCCL)
if (onnxruntime_USE_CUDA)
set(NCCL_LIBNAME "nccl")
elseif (onnxruntime_USE_ROCM)
set(NCCL_LIBNAME "rccl")
# Find NCCL and MPI
if (onnxruntime_USE_NCCL AND MPI_CXX_FOUND)
if (onnxruntime_USE_CUDA)
set(NCCL_LIBNAME "nccl")
elseif (onnxruntime_USE_ROCM)
set(NCCL_LIBNAME "rccl")
endif()
find_path(NCCL_INCLUDE_DIR
NAMES ${NCCL_LIBNAME}.h
HINTS
${onnxruntime_NCCL_HOME}/include
$ENV{CUDA_ROOT}/include)

find_library(NCCL_LIBRARY
NAMES ${NCCL_LIBNAME}
HINTS
${onnxruntime_NCCL_HOME}/lib/x86_64-linux-gnu
${onnxruntime_NCCL_HOME}/lib
$ENV{CUDA_ROOT}/lib64)

include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIR NCCL_LIBRARY)

if (NCCL_FOUND)
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIR}/${NCCL_LIBNAME}.h")
message( STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}" )
file (STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$" LIMIT_COUNT 1)
if (NCCL_MAJOR_VERSION_DEFINED)
string (REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
message( STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}" )
endif()
file (STRINGS ${NCCL_HEADER_FILE} NCCL_MINOR_VERSION_DEFINED
REGEX "^[ \t]*#define[ \t]+NCCL_MINOR[ \t]+[0-9]+.*$" LIMIT_COUNT 1)
if (NCCL_MINOR_VERSION_DEFINED)
string (REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MINOR[ \t]+" ""
NCCL_MINOR_VERSION ${NCCL_MINOR_VERSION_DEFINED})
message(STATUS "NCCL_MINOR_VERSION: ${NCCL_MINOR_VERSION}")
endif()

find_path(NCCL_INCLUDE_DIR
NAMES ${NCCL_LIBNAME}.h
HINTS
${onnxruntime_NCCL_HOME}/include
$ENV{CUDA_ROOT}/include)

find_library(NCCL_LIBRARY
NAMES ${NCCL_LIBNAME}
HINTS
${onnxruntime_NCCL_HOME}/lib/x86_64-linux-gnu
${onnxruntime_NCCL_HOME}/lib
$ENV{CUDA_ROOT}/lib64)

include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIR NCCL_LIBRARY)

if (NCCL_FOUND)
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIR}/${NCCL_LIBNAME}.h")
message( STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}" )
file (STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$" LIMIT_COUNT 1)
if (NCCL_MAJOR_VERSION_DEFINED)
string (REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
message( STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}" )
endif()
file (STRINGS ${NCCL_HEADER_FILE} NCCL_MINOR_VERSION_DEFINED
REGEX "^[ \t]*#define[ \t]+NCCL_MINOR[ \t]+[0-9]+.*$" LIMIT_COUNT 1)
if (NCCL_MINOR_VERSION_DEFINED)
string (REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MINOR[ \t]+" ""
NCCL_MINOR_VERSION ${NCCL_MINOR_VERSION_DEFINED})
message(STATUS "NCCL_MINOR_VERSION: ${NCCL_MINOR_VERSION}")
if (NCCL_MAJOR_VERSION_DEFINED AND NCCL_MINOR_VERSION_DEFINED)
if ("${NCCL_MAJOR_VERSION}.${NCCL_MINOR_VERSION}" VERSION_GREATER_EQUAL "2.7")
add_definitions(-DUSE_NCCL_P2P=1)
message( STATUS "NCCL P2P is enabled for supporting ncclSend and ncclRecv." )
endif()
endif()

if (NCCL_MAJOR_VERSION_DEFINED AND NCCL_MINOR_VERSION_DEFINED)
if ("${NCCL_MAJOR_VERSION}.${NCCL_MINOR_VERSION}" VERSION_GREATER_EQUAL "2.7")
add_definitions(-DUSE_NCCL_P2P=1)
message( STATUS "NCCL P2P is enabled for supporting ncclSend and ncclRecv." )
endif()
endif()
set(NCCL_INCLUDE_DIRS ${NCCL_INCLUDE_DIR})
set(NCCL_LIBRARIES ${NCCL_LIBRARY})
message( STATUS "NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})" )
mark_as_advanced(NCCL_INCLUDE_DIRS NCCL_LIBRARIES)

set(NCCL_INCLUDE_DIRS ${NCCL_INCLUDE_DIR})
set(NCCL_LIBRARIES ${NCCL_LIBRARY})
message( STATUS "NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})" )
mark_as_advanced(NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
list(APPEND onnxruntime_EXTERNAL_LIBRARIES ${NCCL_LIBRARIES})

list(APPEND onnxruntime_EXTERNAL_LIBRARIES ${NCCL_LIBRARIES})

add_definitions(-DORT_USE_NCCL=1)
message( STATUS "NCCL is enabled in Linux GPU Build." )
else ()
message(
add_definitions(-DORT_USE_NCCL=1)
message( STATUS "NCCL is enabled in Linux GPU Build." )
else ()
set(onnxruntime_USE_NCCL OFF)
message(
FATAL_ERROR
"NCCL is not found. Please use --nccl_home to specify the path of NCCL. Otherwise, NCCL is disabled."
)
endif()
endif()
endif()
else()
set(onnxruntime_USE_NCCL OFF)
set(onnxruntime_USE_MPI OFF)
message( WARNING "MPI and NCCL disabled on Win build." )
endif()

if (onnxruntime_USE_MPI)
add_definitions(-DUSE_MPI=1)
endif()

add_subdirectory(tensorboard EXCLUDE_FROM_ALL)
list(APPEND onnxruntime_EXTERNAL_LIBRARIES tensorboard)
if (onnxruntime_USE_MPI)
add_definitions(-DUSE_MPI=1)
endif()

# Default version parts for Microsoft.AI.MachineLearning.dll, onnxruntime.dll, onnxruntime_providers_openvino.dll and onnxruntime_providers_shared.dll in non-ADO pipeline local builds
Expand Down
6 changes: 3 additions & 3 deletions cmake/onnxruntime_framework.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ if (onnxruntime_ENABLE_TRAINING_OPS)
onnxruntime_add_include_to_target(onnxruntime_framework Python::Module)
target_include_directories(onnxruntime_framework PRIVATE ${dlpack_SOURCE_DIR}/include)
endif()
if (onnxruntime_USE_NCCL OR onnxruntime_USE_MPI)
target_include_directories(onnxruntime_framework PUBLIC ${MPI_CXX_INCLUDE_DIRS})
endif()
endif()
if (onnxruntime_USE_MPI)
target_include_directories(onnxruntime_framework PUBLIC ${MPI_CXX_INCLUDE_DIRS})
endif()

if (onnxruntime_ENABLE_ATEN)
Expand Down
20 changes: 13 additions & 7 deletions cmake/onnxruntime_providers.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,11 @@ if (onnxruntime_USE_CUDA)
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/aten_ops/aten_op.cc"
)
endif()
if (NOT onnxruntime_USE_NCCL)
list(REMOVE_ITEM onnxruntime_cuda_contrib_ops_cc_srcs
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/nccl_kernels.cc"
)
endif()
# add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio
source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_cuda_contrib_ops_cc_srcs} ${onnxruntime_cuda_contrib_ops_cu_srcs})
list(APPEND onnxruntime_providers_cuda_src ${onnxruntime_cuda_contrib_ops_cc_srcs} ${onnxruntime_cuda_contrib_ops_cu_srcs})
Expand Down Expand Up @@ -507,14 +512,15 @@ if (onnxruntime_USE_CUDA)

if (onnxruntime_ENABLE_TRAINING_OPS)
target_include_directories(onnxruntime_providers_cuda PRIVATE ${ORTTRAINING_ROOT} ${MPI_CXX_INCLUDE_DIRS})
if(onnxruntime_USE_MPI)
target_link_libraries(onnxruntime_providers_cuda PRIVATE ${MPI_LIBRARIES} ${MPI_CXX_LINK_FLAGS})
endif()
endif()

if (onnxruntime_USE_NCCL)
target_include_directories(onnxruntime_providers_cuda PRIVATE ${NCCL_INCLUDE_DIRS})
target_link_libraries(onnxruntime_providers_cuda PRIVATE ${NCCL_LIBRARIES})
endif()
if(onnxruntime_USE_MPI)
target_link_libraries(onnxruntime_providers_cuda PRIVATE ${MPI_LIBRARIES} ${MPI_CXX_LINK_FLAGS})
endif()

if (onnxruntime_USE_NCCL)
target_include_directories(onnxruntime_providers_cuda PRIVATE ${NCCL_INCLUDE_DIRS})
target_link_libraries(onnxruntime_providers_cuda PRIVATE ${NCCL_LIBRARIES})
endif()

if (WIN32)
Expand Down
3 changes: 3 additions & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ set(contrib_ops_excluded_files
if (NOT onnxruntime_ENABLE_ATEN)
list(APPEND contrib_ops_excluded_files "aten_ops/aten_op.cc")
endif()
if (NOT onnxruntime_USE_NCCL)
list(APPEND contrib_ops_excluded_files "collective/nccl_kernels.cc")
endif()

set(provider_excluded_files
"atomic/common.cuh"
Expand Down
28 changes: 28 additions & 0 deletions onnxruntime/contrib_ops/cuda/collective/mpi_include.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#if defined(USE_MPI)
#define OMPI_SKIP_MPICXX 1 // See https://github.com/open-mpi/ompi/issues/5157
#include <mpi.h>
#undef OMPI_SKIP_MPICXX

namespace onnxruntime {

#if defined(USE_MPI)
#define MPI_CHECK(condition) \
do { \
int error = (condition); \
ORT_ENFORCE( \
error == MPI_SUCCESS, \
"MPI Error at: ", \
__FILE__, \
":", \
__LINE__, \
": ", \
error); \
} while (0)
#endif
} // namespace onnxruntime
#endif
Loading