Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

package org.apache.celeborn.client

import java.lang.{Byte => JByte}
import java.nio.ByteBuffer
import java.security.SecureRandom
import java.util
import java.util.{function, List => JList}
import java.util.concurrent.{Callable, ConcurrentHashMap, LinkedBlockingQueue, ScheduledFuture, TimeUnit}
Expand All @@ -41,12 +43,14 @@ import org.apache.celeborn.common.client.MasterClient
import org.apache.celeborn.common.identity.{IdentityProvider, UserIdentifier}
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo, WorkerInfo}
import org.apache.celeborn.common.network.sasl.registration.RegistrationInfo
import org.apache.celeborn.common.protocol._
import org.apache.celeborn.common.protocol.RpcNameConstants.WORKER_EP
import org.apache.celeborn.common.protocol.message.ControlMessages._
import org.apache.celeborn.common.protocol.message.StatusCode
import org.apache.celeborn.common.rpc._
import org.apache.celeborn.common.rpc.netty.{LocalNettyRpcCallContext, RemoteNettyRpcCallContext}
import org.apache.celeborn.common.security.{ClientSaslContextBuilder, RpcSecurityContext, RpcSecurityContextBuilder}
import org.apache.celeborn.common.util.{JavaUtils, PbSerDeUtils, ThreadUtils, Utils}
// Can Remove this if celeborn don't support scala211 in future
import org.apache.celeborn.common.util.FunctionConverter._
Expand Down Expand Up @@ -108,6 +112,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
.build().asInstanceOf[Cache[Int, ByteBuffer]]

private val mockDestroyFailure = conf.testMockDestroySlotsFailure
private val authEnabled = conf.authEnabled

@VisibleForTesting
def workerSnapshots(shuffleId: Int): util.Map[WorkerInfo, ShufflePartitionLocationInfo] =
Expand Down Expand Up @@ -159,7 +164,32 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends

logInfo(s"Starting LifecycleManager on ${rpcEnv.address}")

private val masterClient = new MasterClient(rpcEnv, conf)
private var masterRpcEnvInUse = rpcEnv
private var workerRpcEnvInUse = rpcEnv
if (authEnabled) {
logInfo(s"Authentication is enabled; setting up master and worker RPC environments")
val appSecret = createSecret()
val registrationInfo = new RegistrationInfo()
masterRpcEnvInUse =
RpcEnv.create(
RpcNameConstants.LIFECYCLE_MANAGER_MASTER_SYS,
lifecycleHost,
0,
conf,
createRpcSecurityContext(
appSecret,
addClientRegistrationBootstrap = true,
Some(registrationInfo)))
workerRpcEnvInUse =
RpcEnv.create(
RpcNameConstants.LIFECYCLE_MANAGER_WORKER_SYS,
lifecycleHost,
0,
conf,
createRpcSecurityContext(appSecret))
}

private val masterClient = new MasterClient(masterRpcEnvInUse, conf, false)
val commitManager = new CommitManager(appUniqueId, conf, this)
val workerStatusTracker = new WorkerStatusTracker(conf, this)
private val heartbeater =
Expand Down Expand Up @@ -214,6 +244,36 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
rpcEnv.shutdown()
rpcEnv.awaitTermination()
}
if (authEnabled) {
if (masterRpcEnvInUse != null) {
masterRpcEnvInUse.shutdown()
masterRpcEnvInUse.awaitTermination()
}
if (workerRpcEnvInUse != null) {
workerRpcEnvInUse.shutdown()
workerRpcEnvInUse.awaitTermination()
}
}
}

/**
* Creates security context for external RPC endpoint.
*/
def createRpcSecurityContext(
appSecret: String,
addClientRegistrationBootstrap: Boolean = false,
registrationInfo: Option[RegistrationInfo] = None): Option[RpcSecurityContext] = {
val clientSaslContextBuilder = new ClientSaslContextBuilder()
.withAddRegistrationBootstrap(addClientRegistrationBootstrap)
.withAppId(appUniqueId)
.withSaslUser(appUniqueId)
.withSaslPassword(appSecret)
if (registrationInfo.isDefined) {
clientSaslContextBuilder.withRegistrationInfo(registrationInfo.get)
}
val rpcSecurityContext = new RpcSecurityContextBuilder()
.withClientSaslContext(clientSaslContextBuilder.build()).build()
Some(rpcSecurityContext)
}

def getUserIdentifier: UserIdentifier = {
Expand Down Expand Up @@ -356,7 +416,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
connectFailedWorkers: ShuffleFailedWorkers): Unit = {
val futures = new util.LinkedList[(Future[RpcEndpointRef], WorkerInfo)]()
slots.asScala foreach { case (workerInfo, _) =>
val future = rpcEnv.asyncSetupEndpointRefByAddr(RpcEndpointAddress(
val future = workerRpcEnvInUse.asyncSetupEndpointRefByAddr(RpcEndpointAddress(
RpcAddress.apply(workerInfo.host, workerInfo.rpcPort),
WORKER_EP))
futures.add((future, workerInfo))
Expand Down Expand Up @@ -1065,7 +1125,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
s" ${destroyWorkerInfo.readableAddress()}, init according to partition info")
try {
if (!workerStatusTracker.workerExcluded(destroyWorkerInfo)) {
destroyWorkerInfo.endpoint = rpcEnv.setupEndpointRef(
destroyWorkerInfo.endpoint = workerRpcEnvInUse.setupEndpointRef(
RpcAddress.apply(destroyWorkerInfo.host, destroyWorkerInfo.rpcPort),
WORKER_EP)
} else {
Expand Down Expand Up @@ -1573,4 +1633,12 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
heartbeater.stop()
super.stop()
}

private def createSecret(): String = {
val bits = 256
val rnd = new SecureRandom()
val secretBytes = new Array[Byte](bits / JByte.SIZE)
rnd.nextBytes(secretBytes)
JavaUtils.bytesToString(ByteBuffer.wrap(secretBytes))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,15 @@ public class MasterClient {

private final AtomicReference<RpcEndpointRef> rpcEndpointRef;
private final ExecutorService oneWayMessageSender;
private final CelebornConf conf;
private final boolean isWorker;
private String masterEndpointName;

public MasterClient(RpcEnv rpcEnv, CelebornConf conf) {
public MasterClient(RpcEnv rpcEnv, CelebornConf conf, boolean isWorker) {
this.rpcEnv = rpcEnv;
this.masterEndpoints = Arrays.asList(conf.masterEndpoints());
this.conf = conf;
this.isWorker = isWorker;
this.masterEndpoints = resolveMasterEndpoints();
Collections.shuffle(this.masterEndpoints);
this.maxRetries = Math.max(masterEndpoints.size(), conf.masterClientMaxRetries());
this.rpcTimeout = conf.masterClientRpcAskTimeout();
Expand Down Expand Up @@ -250,7 +255,7 @@ private RpcEndpointRef setupEndpointRef(String endpoint) {
RpcEndpointRef endpointRef = null;
try {
endpointRef =
rpcEnv.setupEndpointRef(RpcAddress.fromHostAndPort(endpoint), RpcNameConstants.MASTER_EP);
rpcEnv.setupEndpointRef(RpcAddress.fromHostAndPort(endpoint), masterEndpointName);
} catch (Exception e) {
// Catch all exceptions. Because we don't care whether this exception is IOException or
// TimeoutException or other exceptions, so we just try to connect to host:port, if fail,
Expand All @@ -259,4 +264,26 @@ private RpcEndpointRef setupEndpointRef(String endpoint) {
}
return endpointRef;
}

private List<String> resolveMasterEndpoints() {
if (isWorker) {
// For worker, we should use the internal endpoints if internal port is enabled.
if (conf.internalPortEnabled()) {
masterEndpointName = RpcNameConstants.MASTER_INTERNAL_EP;
return Arrays.asList(conf.masterInternalEndpoints());
} else {
masterEndpointName = RpcNameConstants.MASTER_EP;
return Arrays.asList(conf.masterEndpoints());
}
} else {
// This is for client, so we should use the secured endpoints if auth is enabled.
if (conf.authEnabled()) {
masterEndpointName = RpcNameConstants.MASTER_SECURED_EP;
return Arrays.asList(conf.masterSecuredEndpoints());
} else {
masterEndpointName = RpcNameConstants.MASTER_EP;
return Arrays.asList(conf.masterEndpoints());
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ public class RpcNameConstants {
// For Master
public static String MASTER_SYS = "Master";
public static String MASTER_INTERNAL_SYS = "MasterInternal";
public static String MASTER_SECURED_SYS = "MasterSecured";
// Master Endpoint Name
public static String MASTER_EP = "MasterEndpoint";
public static String MASTER_INTERNAL_EP = "MasterInternalEndpoint";
public static String MASTER_SECURED_EP = "MasterSecuredEndpoint";

// For Worker
public static String WORKER_SYS = "Worker";
Expand All @@ -32,6 +34,8 @@ public class RpcNameConstants {

// For LifecycleManager
public static String LIFECYCLE_MANAGER_SYS = "LifecycleManager";
public static String LIFECYCLE_MANAGER_MASTER_SYS = "LifecycleManagerMasterSys";
public static String LIFECYCLE_MANAGER_WORKER_SYS = "LifecycleManagerWorkerSys";
public static String LIFECYCLE_MANAGER_EP = "LifecycleManagerEndpoint";

// For Shuffle Client
Expand Down
103 changes: 92 additions & 11 deletions common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import scala.collection.mutable
import scala.concurrent.duration._
import scala.util.Try

import org.apache.celeborn.common.CelebornConf.MASTER_INTERNAL_ENDPOINTS
import org.apache.celeborn.common.identity.{DefaultIdentityProvider, IdentityProvider}
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.internal.config._
Expand Down Expand Up @@ -1116,22 +1117,52 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
// //////////////////////////////////////////////////////
// Authentication //
// //////////////////////////////////////////////////////
def authEnabled: Boolean = get(AUTH_ENABLED)
def authEnabled: Boolean = {
val authEnabled = get(AUTH_ENABLED)
val internalPortEnabled = get(INTERNAL_PORT_ENABLED)
if (authEnabled && !internalPortEnabled) {
throw new IllegalArgumentException(
s"${AUTH_ENABLED.key} is true, but ${INTERNAL_PORT_ENABLED.key} is false")
Copy link
Member

@pan3793 pan3793 Feb 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are all valid combinations of AUTH_ENABLED and INTERNAL_PORT_ENABLED?

  • true, true
  • false, false
  • and others?

what if we eliminate INTERNAL_PORT_ENABLED and just respect AUTH_ENABLED?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another valid combination is auth_enabled = false and internal_port_enabled =true.
Having Masters and workers communicate on a separate port is a distinct feature from authentication. In a prior discussion with @waitinfuture, they were considering adding a separate port for internal communication for different reasons. However, it's important to note that this separate internal port is a prerequisite for authentication.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for explanation

}
return authEnabled && internalPortEnabled
}

def haMasterNodeSecuredPort(nodeId: String): Int = {
val key = HA_MASTER_NODE_SECURED_PORT.key.replace("<id>", nodeId)
getInt(key, HA_MASTER_NODE_SECURED_PORT.defaultValue.get)
}

def masterSecuredPort: Int = get(MASTER_SECURED_PORT)

def masterSecuredEndpoints: Array[String] =
get(MASTER_SECURED_ENDPOINTS).toArray.map { endpoint =>
Utils.parseHostPort(endpoint.replace("<localhost>", Utils.localHostName(this))) match {
case (host, 0) => s"$host:${HA_MASTER_NODE_SECURED_PORT.defaultValue.get}"
case (host, port) => s"$host:$port"
}
}

// //////////////////////////////////////////////////////
// Internal Port //
// //////////////////////////////////////////////////////
def internalPortEnabled: Boolean = get(INTERNAL_PORT_ENABLED)

def masterInternalEndpoints: Array[String] =
get(MASTER_INTERNAL_ENDPOINTS).toArray.map { endpoint =>
Utils.parseHostPort(endpoint.replace("<localhost>", Utils.localHostName(this))) match {
case (host, 0) => s"$host:${HA_MASTER_NODE_INTERNAL_PORT.defaultValue.get}"
case (host, port) => s"$host:$port"
}
}

// //////////////////////////////////////////////////////
// Rack Resolver //
// //////////////////////////////////////////////////////
def rackResolverRefreshInterval = get(RACKRESOLVER_REFRESH_INTERVAL)

def haMasterNodeInternalPort(nodeId: String): Int = {
val key = HA_MASTER_NODE_INTERNAL_PORT.key.replace("<id>", nodeId)
val legacyKey = HA_MASTER_NODE_INTERNAL_PORT.alternatives.head._1.replace("<id>", nodeId)
getInt(key, getInt(legacyKey, HA_MASTER_NODE_INTERNAL_PORT.defaultValue.get))
getInt(key, HA_MASTER_NODE_INTERNAL_PORT.defaultValue.get)
}

def masterInternalPort: Int = get(MASTER_INTERNAL_PORT)
Expand Down Expand Up @@ -4392,14 +4423,6 @@ object CelebornConf extends Logging {
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefaultString("30s")

val AUTH_ENABLED: ConfigEntry[Boolean] =
buildConf("celeborn.auth.enabled")
.categories("auth")
.version("0.5.0")
.doc("Whether to enable authentication.")
.booleanConf
.createWithDefault(false)

val INTERNAL_PORT_ENABLED: ConfigEntry[Boolean] =
buildConf("celeborn.internal.port.enabled")
.categories("master", "worker")
Expand All @@ -4411,6 +4434,15 @@ object CelebornConf extends Logging {
.booleanConf
.createWithDefault(false)

val AUTH_ENABLED: ConfigEntry[Boolean] =
buildConf("celeborn.auth.enabled")
.categories("auth")
.version("0.5.0")
.doc("Whether to enable authentication. Authentication will be enabled only when " +
s"${INTERNAL_PORT_ENABLED.key} is enabled as well.")
.booleanConf
.createWithDefault(false)

val MASTER_INTERNAL_PORT: ConfigEntry[Int] =
buildConf("celeborn.master.internal.port")
.categories("master")
Expand All @@ -4431,11 +4463,60 @@ object CelebornConf extends Logging {
.checkValue(p => p >= 1024 && p < 65535, "Invalid port")
.createWithDefault(8097)

val MASTER_INTERNAL_ENDPOINTS: ConfigEntry[Seq[String]] =
buildConf("celeborn.master.internal.endpoints")
.categories("worker")
.doc("Endpoints of master nodes just for celeborn workers to connect, allowed pattern " +
"is: `<host1>:<port1>[,<host2>:<port2>]*`, e.g. `clb1:8097,clb2:8097,clb3:8097`. " +
"If the port is omitted, 8097 will be used.")
.version("0.5.0")
.stringConf
.toSequence
.checkValue(
endpoints => endpoints.map(_ => Try(Utils.parseHostPort(_))).forall(_.isSuccess),
"Allowed pattern is: `<host1>:<port1>[,<host2>:<port2>]*`")
.createWithDefaultString(s"<localhost>:8097")

val RACKRESOLVER_REFRESH_INTERVAL: ConfigEntry[Long] =
buildConf("celeborn.master.rackResolver.refresh.interval")
.categories("master")
.version("0.5.0")
.doc("Interval for refreshing the node rack information periodically.")
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefaultString("30s")

val MASTER_SECURED_PORT: ConfigEntry[Int] =
buildConf("celeborn.master.secured.port")
.categories("master", "auth")
.version("0.5.0")
.doc(
"Secured port on the master where clients connect.")
.intConf
.checkValue(p => p >= 1024 && p < 65535, "Invalid port")
.createWithDefault(19097)

val HA_MASTER_NODE_SECURED_PORT: ConfigEntry[Int] =
buildConf("celeborn.master.ha.node.<id>.secured.port")
.categories("ha", "auth")
.doc(
"Secured port for the clients to bind to a master node <id> in HA mode.")
.version("0.5.0")
.intConf
.checkValue(p => p >= 1024 && p < 65535, "Invalid port")
.createWithDefault(19097)

val MASTER_SECURED_ENDPOINTS: ConfigEntry[Seq[String]] =
buildConf("celeborn.master.secured.endpoints")
.categories("client", "auth")
.doc("Endpoints of master nodes for celeborn client to connect for secured communication, allowed pattern " +
"is: `<host1>:<port1>[,<host2>:<port2>]*`, e.g. `clb1:19097,clb2:19097,clb3:19097`. " +
"If the port is omitted, 19097 will be used.")
.version("0.5.0")
.stringConf
.toSequence
.checkValue(
endpoints => endpoints.map(_ => Try(Utils.parseHostPort(_))).forall(_.isSuccess),
"Allowed pattern is: `<host1>:<port1>[,<host2>:<port2>]*`")
.createWithDefaultString(s"<localhost>:19097")

}
Loading