Skip to content

Commit

Permalink
auto detect port for oneccl
Browse files Browse the repository at this point in the history
  • Loading branch information
xwu99 committed Feb 5, 2021
1 parent b60158b commit 503aeaf
Show file tree
Hide file tree
Showing 14 changed files with 163 additions and 15 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
*.o
*.log
.vscode
*.iml
target/
.idea/
.idea_modules/
52 changes: 51 additions & 1 deletion mllib-dal/src/main/native/OneCCL.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
#include <iostream>
#include <chrono>

#include <arpa/inet.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <unistd.h>

#include <oneapi/ccl.hpp>

#include "org_apache_spark_ml_util_OneCCL__.h"

// todo: fill initial comm_size and rank_id
Expand All @@ -17,10 +25,12 @@ JNIEXPORT jint JNICALL Java_org_apache_spark_ml_util_OneCCL_00024_c_1init

std::cout << "oneCCL (native): init" << std::endl;

auto t1 = std::chrono::high_resolution_clock::now();

ccl::init();

const char *str = env->GetStringUTFChars(ip_port, 0);
ccl::string ccl_ip_port(str);
ccl::string ccl_ip_port(str);

auto kvs_attr = ccl::create_kvs_attr();
kvs_attr.set<ccl::kvs_attr_id::ip_port>(ccl_ip_port);
Expand All @@ -30,6 +40,10 @@ JNIEXPORT jint JNICALL Java_org_apache_spark_ml_util_OneCCL_00024_c_1init

g_comms.push_back(ccl::create_communicator(size, rank, kvs));

auto t2 = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::seconds>( t2 - t1 ).count();
std::cout << "oneCCL (native): init took " << duration << " secs" << std::endl;

rank_id = getComm().rank();
comm_size = getComm().size();

Expand Down Expand Up @@ -97,3 +111,39 @@ JNIEXPORT jint JNICALL Java_org_apache_spark_ml_util_OneCCL_00024_setEnv

return err;
}

/*
* Class: org_apache_spark_ml_util_OneCCL__
* Method: getAvailPort
* Signature: (Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_org_apache_spark_ml_util_OneCCL_00024_getAvailPort
(JNIEnv *env, jobject obj, jstring localIP) {

const int port_start_base = 3000;

char* local_host_ip = (char *) env->GetStringUTFChars(localIP, NULL);

struct sockaddr_in main_server_address;
int server_listen_sock;

if ((server_listen_sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
perror("init_main_server_by_string: server_listen_sock init");
return -1;
}

main_server_address.sin_family = AF_INET;
main_server_address.sin_addr.s_addr = inet_addr(local_host_ip);
main_server_address.sin_port = port_start_base;

while (bind(server_listen_sock,
(const struct sockaddr *)&main_server_address,
sizeof(main_server_address)) < 0) {
main_server_address.sin_port++;
}
close(server_listen_sock);

env->ReleaseStringUTFChars(localIP, local_host_ip);

return main_server_address.sin_port;
}
50 changes: 50 additions & 0 deletions mllib-dal/src/main/native/OneDAL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include <cstring>
#include "org_apache_spark_ml_util_OneDAL__.h"

#include "service.h"

using namespace daal;
using namespace daal::data_management;

Expand Down Expand Up @@ -123,3 +125,51 @@ JNIEXPORT jboolean JNICALL Java_org_apache_spark_ml_util_OneDAL_00024_cCheckPlat
// Only guarantee compatibility and performance on Intel platforms, use oneDAL lib function
return daal_check_is_intel_cpu();
}

/*
* Class: org_apache_spark_ml_util_OneDAL__
* Method: cNewCSRNumericTable
* Signature: ([F[J[JJJ)J
*/
JNIEXPORT jlong JNICALL Java_org_apache_spark_ml_util_OneDAL_00024_cNewCSRNumericTable
(JNIEnv *env, jobject, jfloatArray data, jlongArray colIndices, jlongArray rowOffsets, jlong nFeatures, jlong nVectors) {

long numData = env->GetArrayLength(data);
// long numColIndices = numData;
// long numRowOffsets = env->GetArrayLength(rowOffsets);

size_t * resultRowOffsets = NULL;
size_t * resultColIndices = NULL;
float * resultData = NULL;
CSRNumericTable * numericTable = new CSRNumericTable(resultData, resultColIndices, resultRowOffsets, nFeatures, nVectors);
numericTable->allocateDataMemory(numData);
numericTable->getArrays<float>(&resultData, &resultColIndices, &resultRowOffsets);

size_t * pRowOffsets = (size_t *)env->GetLongArrayElements(rowOffsets, 0);
size_t * pColIndices = (size_t *)env->GetLongArrayElements(colIndices, 0);
float * pData = env->GetFloatArrayElements(data, 0);

// std::memcpy(resultRowOffsets, pRowOffsets, numRowOffsets*sizeof(jlong));
// std::memcpy(resultColIndices, pColIndices, numColIndices*sizeof(jlong));
// std::memcpy(resultData, pData, numData*sizeof(float));

for (size_t i = 0; i < (size_t)numData; ++i)
{
resultData[i] = pData[i];
resultColIndices[i] = pColIndices[i];
}
for (size_t i = 0; i < (size_t)nVectors + 1; ++i)
{
resultRowOffsets[i] = pRowOffsets[i];
}

env->ReleaseLongArrayElements(rowOffsets, (jlong *)pRowOffsets, 0);
env->ReleaseLongArrayElements(colIndices, (jlong *)pColIndices, 0);
env->ReleaseFloatArrayElements(data, pData, 0);

CSRNumericTablePtr *ret = new CSRNumericTablePtr(numericTable);

//printNumericTable(*ret, "cNewCSRNumericTable", 10);

return (jlong)ret;
}
1 change: 1 addition & 0 deletions mllib-dal/src/main/native/build-jni.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ javah -d $WORK_DIR/javah -classpath "$WORK_DIR/../../../target/classes:$DAAL_JAR
org.apache.spark.ml.util.OneDAL$ \
org.apache.spark.ml.clustering.KMeansDALImpl \
org.apache.spark.ml.feature.PCADALImpl

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions mllib-dal/src/main/native/service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ CSRNumericTable * createSparseTable(const std::string & datasetFileName)
return numericTable;
}

CSRNumericTable * createFloatSparseTable(const std::string & datasetFileName) {
return createSparseTable<float>(datasetFileName);
}

void printAprioriItemsets(NumericTablePtr largeItemsetsTable, NumericTablePtr largeItemsetsSupportTable, size_t nItemsetToPrint = 20)
{
size_t largeItemsetCount = largeItemsetsSupportTable->getNumberOfRows();
Expand Down
3 changes: 3 additions & 0 deletions mllib-dal/src/main/native/service.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,8 @@ typedef std::vector<daal::byte> ByteBuffer;

void printNumericTable(const NumericTablePtr & dataTable, const char * message = "", size_t nPrintedRows = 0, size_t nPrintedCols = 0,
size_t interval = 10);
size_t serializeDAALObject(SerializationIface * pData, ByteBuffer & buffer);
SerializationIfacePtr deserializeDAALObject(daal::byte * buff, size_t length);
CSRNumericTable * createFloatSparseTable(const std::string & datasetFileName);

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class KMeansDALImpl (

val executorIPAddress = Utils.sparkFirstExecutorIP(data.sparkContext)
val kvsIP = data.sparkContext.conf.get("spark.oap.mllib.oneccl.kvs.ip", executorIPAddress)
val kvsPort = Utils.checkExecutorAvailPort(data.sparkContext, kvsIP)
val kvsIPPort = kvsIP+"_"+kvsPort

// repartition to executorNum if not enough partitions
val dataForConversion = if (data.getNumPartitions < executorNum) {
Expand Down Expand Up @@ -114,7 +116,7 @@ class KMeansDALImpl (

val results = coalescedTables.mapPartitionsWithIndex { (rank, table) =>
val tableArr = table.next()
OneCCL.init(executorNum, rank, kvsIP)
OneCCL.init(executorNum, rank, kvsIPPort)

val initCentroids = OneDAL.makeNumericTable(centers)
val result = new KMeansResult()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,20 @@ class PCADALImpl (
res.map(_.asML)
}

def fitWithDAL(input: RDD[Vector]) : MLlibPCAModel = {
def fitWithDAL(data: RDD[Vector]) : MLlibPCAModel = {

val normalizedData = normalizeData(input)
val normalizedData = normalizeData(data)

val coalescedTables = OneDAL.rddVectorToNumericTables(normalizedData, executorNum)

val executorIPAddress = Utils.sparkFirstExecutorIP(input.sparkContext)
val kvsIP = input.sparkContext.conf.get("spark.oap.mllib.oneccl.kvs.ip", executorIPAddress)
val executorIPAddress = Utils.sparkFirstExecutorIP(data.sparkContext)
val kvsIP = data.sparkContext.conf.get("spark.oap.mllib.oneccl.kvs.ip", executorIPAddress)
val kvsPort = Utils.checkExecutorAvailPort(data.sparkContext, kvsIP)
val kvsIPPort = kvsIP+"_"+kvsPort

val results = coalescedTables.mapPartitionsWithIndex { (rank, table) =>
val tableArr = table.next()
OneCCL.init(executorNum, rank, kvsIP)
OneCCL.init(executorNum, rank, kvsIPPort)

val result = new PCAResult()
cPCATrainDAL(
Expand Down
11 changes: 6 additions & 5 deletions mllib-dal/src/main/scala/org/apache/spark/ml/util/OneCCL.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ object OneCCL {
// var kvsIPPort = sys.env.getOrElse("CCL_KVS_IP_PORT", "")
// var worldSize = sys.env.getOrElse("CCL_WORLD_SIZE", "1").toInt

var kvsPort = 5000
// var kvsPort = 5000

// private def checkEnv() {
// val altTransport = sys.env.getOrElse("CCL_ATL_TRANSPORT", "")
Expand Down Expand Up @@ -57,21 +57,21 @@ object OneCCL {
// // setEnv("CCL_LOG_LEVEL", "2")
// }

def init(executor_num: Int, rank: Int, ip: String) = {
def init(executor_num: Int, rank: Int, ip_port: String) = {

// setExecutorEnv(executor_num, ip, port)
println(s"oneCCL: Initializing with IP_PORT: ${ip}_${kvsPort}")
println(s"oneCCL: Initializing with IP_PORT: ${ip_port}")

// cclParam is output from native code
c_init(executor_num, rank, ip+"_"+kvsPort.toString, cclParam)
c_init(executor_num, rank, ip_port, cclParam)

// executor number should equal to oneCCL world size
assert(executor_num == cclParam.commSize, "executor number should equal to oneCCL world size")

println(s"oneCCL: Initialized with executorNum: $executor_num, commSize, ${cclParam.commSize}, rankId: ${cclParam.rankId}")

// Use a new port when calling init again
kvsPort = kvsPort + 1
// kvsPort = kvsPort + 1

}

Expand All @@ -87,4 +87,5 @@ object OneCCL {
@native def rankID() : Int

@native def setEnv(key: String, value: String, overwrite: Boolean = true): Int
@native def getAvailPort(localIP: String): Int
}
Original file line number Diff line number Diff line change
Expand Up @@ -149,4 +149,7 @@ object OneDAL {
@native def cFreeDataMemory(numTableAddr: Long)

@native def cCheckPlatformCompatibility() : Boolean

@native def cNewCSRNumericTable(data: Array[Float], colIndices: Array[Long], rowOffsets: Array[Long], nFeatures: Long,
nVectors: Long) : Long
}
14 changes: 14 additions & 0 deletions mllib-dal/src/main/scala/org/apache/spark/ml/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,20 @@ object Utils {
ip
}

def checkExecutorAvailPort(sc: SparkContext, localIP: String) : Int = {
val executor_num = Utils.sparkExecutorNum(sc)
val data = sc.parallelize(1 to executor_num, executor_num)
val result = data.mapPartitionsWithIndex { (index, p) =>
LibLoader.loadLibraries()
if (index == 0)
Iterator(OneCCL.getAvailPort(localIP))
else
Iterator()
}.collect()

return result(0)
}

def checkClusterPlatformCompatibility(sc: SparkContext) : Boolean = {
LibLoader.loadLibraries()

Expand Down
7 changes: 4 additions & 3 deletions mllib-dal/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ export LD_PRELOAD=$JAVA_HOME/jre/lib/amd64/libjsig.so
# -Dtest=none to turn off the Java tests

# Test all
mvn -Dtest=none -Dmaven.test.skip=false test
#mvn -Dtest=none -Dmaven.test.skip=false test

# Individual test
# mvn -Dtest=none -DwildcardSuites=org.apache.spark.ml.clustering.IntelKMeansSuite test
# mvn -Dtest=none -DwildcardSuites=org.apache.spark.ml.feature.IntelPCASuite test
mvn -Dtest=none -DwildcardSuites=org.apache.spark.ml.clustering.IntelKMeansSuite test
mvn -Dtest=none -DwildcardSuites=org.apache.spark.ml.feature.IntelPCASuite test

0 comments on commit 503aeaf

Please sign in to comment.