Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
minmingzhu committed Nov 13, 2024
1 parent 724ced2 commit 90b1ad3
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 38 deletions.
76 changes: 43 additions & 33 deletions mllib-dal/src/main/native/OneCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,48 +57,58 @@ getDalComm() {
}
#endif
JNIEXPORT jint JNICALL Java_com_intel_oap_mllib_OneCCL_00024_c_1init(
JNIEnv *env, jobject obj, jint size, jint rank, jstring ip_port,
JNIEnv *env, jobject obj, jint size, jint rank, jstring ip_port, jint computeDeviceOrdinal,
jobject param) {

logger::println(logger::INFO, "OneCCL (native): init");
const char *str = env->GetStringUTFChars(ip_port, 0);
ccl::string ccl_ip_port(str);
auto &singletonCCLInit = CCLInitSingleton::get(size, rank, ccl_ip_port);

#ifdef CPU_ONLY_PROFILE
auto t1 = std::chrono::high_resolution_clock::now();
g_comms.push_back(
ccl::create_communicator(size, rank, singletonCCLInit.kvs));
auto t2 = std::chrono::high_resolution_clock::now();
auto duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1)
.count();
logger::println(logger::INFO,
"OneCCL (native): create communicator took %f secs",
duration / 1000);
rank_id = getComm().rank();
comm_size = getComm().size();

#endif

ComputeDevice device = getComputeDeviceByOrdinal(computeDeviceOrdinal);
switch (device) {
case ComputeDevice::host:
case ComputeDevice::cpu: {
auto t1 = std::chrono::high_resolution_clock::now();
g_comms.push_back(
ccl::create_communicator(size, rank, singletonCCLInit.kvs));
auto t2 = std::chrono::high_resolution_clock::now();
auto duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1)
.count();
logger::println(logger::INFO,
"OneCCL (native): create communicator took %f secs",
duration / 1000);
rank_id = getComm().rank();
comm_size = getComm().size();
break;
}
#ifdef CPU_GPU_PROFILE
auto gpus = get_gpus();
sycl::queue queue{gpus[0]};
auto t1 = std::chrono::high_resolution_clock::now();
auto comm = oneapi::dal::preview::spmd::make_communicator<
oneapi::dal::preview::spmd::backend::ccl>(queue, size, rank,
singletonCCLInit.kvs);
auto t2 = std::chrono::high_resolution_clock::now();
auto duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1)
.count();
logger::println(logger::INFO,
"OneCCL (native): create communicator took %f secs",
duration / 1000);
g_dal_comms.push_back(comm);
rank_id = getDalComm().get_rank();
comm_size = getDalComm().get_rank_count();
case ComputeDevice::gpu: {
auto gpus = get_gpus();
sycl::queue queue{gpus[0]};
auto t1 = std::chrono::high_resolution_clock::now();
auto comm = oneapi::dal::preview::spmd::make_communicator<
oneapi::dal::preview::spmd::backend::ccl>(queue, size, rank,
singletonCCLInit.kvs);
auto t2 = std::chrono::high_resolution_clock::now();
auto duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1)
.count();
logger::println(logger::INFO,
"OneCCL (native): create communicator took %f secs",
duration / 1000);
g_dal_comms.push_back(comm);
rank_id = getDalComm().get_rank();
comm_size = getDalComm().get_rank_count();
break;
}
#endif
default: {
deviceError("communicator",
ComputeDeviceString[computeDeviceOrdinal].c_str());
}
}
jclass cls = env->GetObjectClass(param);
jfieldID fid_comm_size = env->GetFieldID(cls, "commSize", "J");
jfieldID fid_rank_id = env->GetFieldID(cls, "rankId", "J");
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion mllib-dal/src/main/scala/com/intel/oap/mllib/CommonJob.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.intel.oap.mllib

import com.intel.oneapi.dal.table.Common
import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD

Expand All @@ -25,7 +26,8 @@ object CommonJob {
kvsIPPort: String,
useDevice: String): Unit = {
data.mapPartitionsWithIndex { (rank, table) =>
OneCCL.init(executorNum, rank, kvsIPPort)
OneCCL.init(executorNum, rank, kvsIPPort,
Common.ComputeDevice.getDeviceByName(useDevice).ordinal())
val gpuIndices = if (useDevice == "GPU") {
val resources = TaskContext.get().resources()
resources("gpu").addresses.map(_.toInt)
Expand Down
7 changes: 4 additions & 3 deletions mllib-dal/src/main/scala/com/intel/oap/mllib/OneCCL.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ object OneCCL extends Logging {

var cclParam = new CCLParam()

def init(executor_num: Int, rank: Int, ip_port: String): Unit = {
def init(executor_num: Int, rank: Int, ip_port: String, computeDevice: Int): Unit = {

logInfo(s"Initializing with IP_PORT: ${ip_port}")

// cclParam is output from native code
c_init(executor_num, rank, ip_port, cclParam)
c_init(executor_num, rank, ip_port, computeDevice, cclParam)

// executor number should equal to oneCCL world size
assert(executor_num == cclParam.getCommSize,
Expand Down Expand Up @@ -61,7 +61,8 @@ object OneCCL extends Logging {

@native def c_getAvailPort(localIP: String): Int

@native private def c_init(size: Int, rank: Int, ip_port: String, param: CCLParam): Int
@native private def c_init(size: Int, rank: Int, ip_port: String,
computeDevice: Int, param: CCLParam): Int

@native private def c_cleanup(): Unit
}

0 comments on commit 90b1ad3

Please sign in to comment.