diff --git a/tony-core/src/main/java/com/linkedin/tony/ApplicationMaster.java b/tony-core/src/main/java/com/linkedin/tony/ApplicationMaster.java index d579792d..67a1fbbf 100644 --- a/tony-core/src/main/java/com/linkedin/tony/ApplicationMaster.java +++ b/tony-core/src/main/java/com/linkedin/tony/ApplicationMaster.java @@ -4,7 +4,6 @@ */ package com.linkedin.tony; -import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.linkedin.tony.events.TaskFinished; @@ -883,9 +882,11 @@ public void taskExecutorHeartbeat(String taskId) { } @Override - public String getClusterSpec() throws IOException { - ObjectMapper objectMapper = new ObjectMapper(); - return objectMapper.writeValueAsString(session.getClusterSpec()); + public String getClusterSpec(String taskId) throws IOException { + if (amRuntimeAdapter.canStartTask(distributedMode, taskId)) { + return amRuntimeAdapter.constructClusterSpec(taskId); + } + return null; } @Override @@ -902,10 +903,7 @@ public String registerWorkerSpec(String taskId, String spec) throws IOException LOG.info("[" + taskId + "] Received Registration for HB !!"); hbMonitor.register(task); killChiefWorkerIfTesting(taskId); - } - - if (amRuntimeAdapter.canStartTask(distributedMode, taskId)) { - return amRuntimeAdapter.constructClusterSpec(taskId); + return ""; } return null; } diff --git a/tony-core/src/main/java/com/linkedin/tony/TaskExecutor.java b/tony-core/src/main/java/com/linkedin/tony/TaskExecutor.java index 1545f75c..dcb4059d 100644 --- a/tony-core/src/main/java/com/linkedin/tony/TaskExecutor.java +++ b/tony-core/src/main/java/com/linkedin/tony/TaskExecutor.java @@ -110,17 +110,6 @@ private void releasePort(ServerPort port) throws Exception { } } - /** - * Releases the reserved ports if any. This method has to be invoked after ports are created. - */ - private void releasePorts() throws Exception { - try { - this.releasePort(this.rpcPort); - } finally { - this.releasePort(this.tbPort); - } - } - /** * @return true if reusing port is enabled by user, false otherwise. */ @@ -288,7 +277,7 @@ protected void initConfigs() { Utils.initHdfsConf(hdfsConf); } - private String registerAndGetClusterSpec() { + private String registerAndGetClusterSpec() throws IOException, YarnException { ContainerId containerId = ContainerId.fromString(System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name())); String hostName = Utils.getCurrentHostName(); LOG.info("ContainerId is: " + containerId + " HostName is: " + hostName); @@ -299,9 +288,13 @@ private String registerAndGetClusterSpec() { LOG.info("Connecting to " + amHost + ":" + amPort + " to register worker spec: " + jobName + " " + taskIndex + " " + hostName + ":" + this.rpcPort.getPort()); - return Utils.pollTillNonNull(() -> - proxy.registerWorkerSpec(jobName + ":" + taskIndex, - hostName + ":" + this.rpcPort.getPort()), 3, 0); + + String taskId = String.format("%s:%s", jobName, taskIndex); + String hostAndPort = String.format("%s:%s", hostName, rpcPort.getPort()); + + Utils.pollTillNonNull(() -> proxy.registerWorkerSpec(taskId, hostAndPort), 3, 0); + + return Utils.pollTillNonNull(() -> proxy.getClusterSpec(taskId), 3, 0); } public void callbackInfoToAM(String taskId, String callbackInfo) throws IOException, YarnException { diff --git a/tony-core/src/main/java/com/linkedin/tony/rpc/ApplicationRpc.java b/tony-core/src/main/java/com/linkedin/tony/rpc/ApplicationRpc.java index 758e24a7..70b8ab90 100644 --- a/tony-core/src/main/java/com/linkedin/tony/rpc/ApplicationRpc.java +++ b/tony-core/src/main/java/com/linkedin/tony/rpc/ApplicationRpc.java @@ -16,12 +16,12 @@ public interface ApplicationRpc { */ Set getTaskInfos() throws IOException, YarnException; - String getClusterSpec() throws IOException, YarnException; String registerWorkerSpec(String worker, String spec) throws IOException, YarnException; + String getClusterSpec(String taskId) throws IOException, YarnException; + void registerCallbackInfo(String taskId, String callbackInfo) throws YarnException, IOException; String registerTensorBoardUrl(String spec) throws Exception; String registerExecutionResult(int exitCode, String jobName, String jobIndex, String sessionId) throws Exception; void finishApplication() throws YarnException, IOException; void taskExecutorHeartbeat(String taskId) throws YarnException, IOException; void reset(); - void registerCallbackInfo(String taskId, String callbackInfo) throws YarnException, IOException; } diff --git a/tony-core/src/main/java/com/linkedin/tony/rpc/ApplicationRpcServer.java b/tony-core/src/main/java/com/linkedin/tony/rpc/ApplicationRpcServer.java index 8d9199cd..ee762687 100644 --- a/tony-core/src/main/java/com/linkedin/tony/rpc/ApplicationRpcServer.java +++ b/tony-core/src/main/java/com/linkedin/tony/rpc/ApplicationRpcServer.java @@ -54,7 +54,7 @@ public GetTaskInfosResponse getTaskInfos(GetTaskInfosRequest request) throws IOE public GetClusterSpecResponse getClusterSpec(GetClusterSpecRequest request) throws YarnException, IOException { GetClusterSpecResponse response = RECORD_FACTORY.newRecordInstance(GetClusterSpecResponse.class); - response.setClusterSpec(this.appRpc.getClusterSpec()); + response.setClusterSpec(this.appRpc.getClusterSpec(request.getTaskId())); return response; } diff --git a/tony-core/src/main/java/com/linkedin/tony/rpc/GetClusterSpecRequest.java b/tony-core/src/main/java/com/linkedin/tony/rpc/GetClusterSpecRequest.java index 5a2b18f6..39f670ac 100644 --- a/tony-core/src/main/java/com/linkedin/tony/rpc/GetClusterSpecRequest.java +++ b/tony-core/src/main/java/com/linkedin/tony/rpc/GetClusterSpecRequest.java @@ -5,4 +5,6 @@ package com.linkedin.tony.rpc; public interface GetClusterSpecRequest { + String getTaskId(); + void setTaskId(String taskId); } diff --git a/tony-core/src/main/java/com/linkedin/tony/rpc/impl/ApplicationRpcClient.java b/tony-core/src/main/java/com/linkedin/tony/rpc/impl/ApplicationRpcClient.java index 7cba29f0..fae14a1c 100644 --- a/tony-core/src/main/java/com/linkedin/tony/rpc/impl/ApplicationRpcClient.java +++ b/tony-core/src/main/java/com/linkedin/tony/rpc/impl/ApplicationRpcClient.java @@ -84,9 +84,10 @@ public Set getTaskInfos() throws IOException, YarnException { } @Override - public String getClusterSpec() throws IOException, YarnException { - GetClusterSpecResponse response = - tensorflow.getClusterSpec(recordFactory.newRecordInstance(GetClusterSpecRequest.class)); + public String getClusterSpec(String taskId) throws IOException, YarnException { + GetClusterSpecRequest request = recordFactory.newRecordInstance(GetClusterSpecRequest.class); + request.setTaskId(taskId); + GetClusterSpecResponse response = tensorflow.getClusterSpec(request); return response.getClusterSpec(); } diff --git a/tony-core/src/main/java/com/linkedin/tony/rpc/impl/pb/GetClusterSpecRequestPBImpl.java b/tony-core/src/main/java/com/linkedin/tony/rpc/impl/pb/GetClusterSpecRequestPBImpl.java index a92376c5..629d16aa 100644 --- a/tony-core/src/main/java/com/linkedin/tony/rpc/impl/pb/GetClusterSpecRequestPBImpl.java +++ b/tony-core/src/main/java/com/linkedin/tony/rpc/impl/pb/GetClusterSpecRequestPBImpl.java @@ -4,8 +4,8 @@ */ package com.linkedin.tony.rpc.impl.pb; - import com.linkedin.tony.rpc.GetClusterSpecRequest; +import com.linkedin.tony.rpc.proto.YarnTonyClusterProtos; import com.linkedin.tony.rpc.proto.YarnTonyClusterProtos.GetClusterSpecRequestProto; public class GetClusterSpecRequestPBImpl implements GetClusterSpecRequest { @@ -13,7 +13,7 @@ public class GetClusterSpecRequestPBImpl implements GetClusterSpecRequest { private GetClusterSpecRequestProto.Builder builder = null; private boolean viaProto = false; - private boolean rebuild = false; + private String taskId; public GetClusterSpecRequestPBImpl() { builder = GetClusterSpecRequestProto.newBuilder(); @@ -28,15 +28,19 @@ private void mergeLocalToProto() { if (viaProto) { maybeInitBuilder(); } + mergeLocalToBuilder(); proto = builder.build(); - rebuild = false; viaProto = true; } + private void mergeLocalToBuilder() { + if (this.taskId != null) { + builder.setTaskId(this.taskId); + } + } + public GetClusterSpecRequestProto getProto() { - if (rebuild) { - mergeLocalToProto(); - } + mergeLocalToProto(); proto = viaProto ? proto : builder.build(); viaProto = true; return proto; @@ -48,4 +52,26 @@ private void maybeInitBuilder() { } viaProto = false; } + + @Override + public String getTaskId() { + YarnTonyClusterProtos.GetClusterSpecRequestProtoOrBuilder p = viaProto ? proto : builder; + if (this.taskId != null) { + return this.taskId; + } + if (!p.hasTaskId()) { + return null; + } + this.taskId = p.getTaskId(); + return this.taskId; + } + + @Override + public void setTaskId(String taskId) { + maybeInitBuilder(); + if (taskId == null) { + builder.clearTaskId(); + } + this.taskId = taskId; + } } diff --git a/tony-core/src/main/proto/yarn_tony_cluster_protos.proto b/tony-core/src/main/proto/yarn_tony_cluster_protos.proto index 16862419..177ec36d 100644 --- a/tony-core/src/main/proto/yarn_tony_cluster_protos.proto +++ b/tony-core/src/main/proto/yarn_tony_cluster_protos.proto @@ -28,6 +28,7 @@ message GetTaskInfosResponseProto { } message GetClusterSpecRequestProto { + required string taskId = 1; } message GetClusterSpecResponseProto {