Skip to content

Commit

Permalink
Seperate the interface of registerTask and getClusterSpec in TaskExec…
Browse files Browse the repository at this point in the history
…utor
  • Loading branch information
zuston committed Feb 20, 2022
1 parent 5f0267b commit ebcd1f0
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 29 deletions.
13 changes: 5 additions & 8 deletions tony-core/src/main/java/com/linkedin/tony/ApplicationMaster.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
*/
package com.linkedin.tony;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.linkedin.tony.events.TaskFinished;
Expand Down Expand Up @@ -883,9 +882,11 @@ public void taskExecutorHeartbeat(String taskId) {
}

@Override
public String getClusterSpec() throws IOException {
ObjectMapper objectMapper = new ObjectMapper();
return objectMapper.writeValueAsString(session.getClusterSpec());
public String getClusterSpec(String taskId) throws IOException {
if (amRuntimeAdapter.canStartTask(distributedMode, taskId)) {
return amRuntimeAdapter.constructClusterSpec(taskId);
}
return null;
}

@Override
Expand All @@ -903,10 +904,6 @@ public String registerWorkerSpec(String taskId, String spec) throws IOException
hbMonitor.register(task);
killChiefWorkerIfTesting(taskId);
}

if (amRuntimeAdapter.canStartTask(distributedMode, taskId)) {
return amRuntimeAdapter.constructClusterSpec(taskId);
}
return null;
}

Expand Down
23 changes: 8 additions & 15 deletions tony-core/src/main/java/com/linkedin/tony/TaskExecutor.java
Original file line number Diff line number Diff line change
Expand Up @@ -110,17 +110,6 @@ private void releasePort(ServerPort port) throws Exception {
}
}

/**
* Releases the reserved ports if any. This method has to be invoked after ports are created.
*/
private void releasePorts() throws Exception {
try {
this.releasePort(this.rpcPort);
} finally {
this.releasePort(this.tbPort);
}
}

/**
* @return true if reusing port is enabled by user, false otherwise.
*/
Expand Down Expand Up @@ -288,7 +277,7 @@ protected void initConfigs() {
Utils.initHdfsConf(hdfsConf);
}

private String registerAndGetClusterSpec() {
private String registerAndGetClusterSpec() throws IOException, YarnException {
ContainerId containerId = ContainerId.fromString(System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name()));
String hostName = Utils.getCurrentHostName();
LOG.info("ContainerId is: " + containerId + " HostName is: " + hostName);
Expand All @@ -299,9 +288,13 @@ private String registerAndGetClusterSpec() {

LOG.info("Connecting to " + amHost + ":" + amPort + " to register worker spec: " + jobName + " " + taskIndex + " "
+ hostName + ":" + this.rpcPort.getPort());
return Utils.pollTillNonNull(() ->
proxy.registerWorkerSpec(jobName + ":" + taskIndex,
hostName + ":" + this.rpcPort.getPort()), 3, 0);

String taskId = String.format("%s:%s", jobName, taskIndex);
String hostAndPort = String.format("%s:%s", hostName, rpcPort.getPort());

proxy.registerWorkerSpec(taskId, hostAndPort);

return Utils.pollTillNonNull(() -> proxy.getClusterSpec(taskId), 3, 0);
}

public void callbackInfoToAM(String taskId, String callbackInfo) throws IOException, YarnException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ public interface ApplicationRpc {
*/
Set<TaskInfo> getTaskInfos() throws IOException, YarnException;

String getClusterSpec() throws IOException, YarnException;
String registerWorkerSpec(String worker, String spec) throws IOException, YarnException;
String getClusterSpec(String taskId) throws IOException, YarnException;
void registerCallbackInfo(String taskId, String callbackInfo) throws YarnException, IOException;
String registerTensorBoardUrl(String spec) throws Exception;
String registerExecutionResult(int exitCode, String jobName, String jobIndex, String sessionId) throws Exception;
void finishApplication() throws YarnException, IOException;
void taskExecutorHeartbeat(String taskId) throws YarnException, IOException;
void reset();
void registerCallbackInfo(String taskId, String callbackInfo) throws YarnException, IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public GetTaskInfosResponse getTaskInfos(GetTaskInfosRequest request) throws IOE
public GetClusterSpecResponse getClusterSpec(GetClusterSpecRequest request)
throws YarnException, IOException {
GetClusterSpecResponse response = RECORD_FACTORY.newRecordInstance(GetClusterSpecResponse.class);
response.setClusterSpec(this.appRpc.getClusterSpec());
response.setClusterSpec(this.appRpc.getClusterSpec(request.getTaskId()));
return response;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@
package com.linkedin.tony.rpc;

public interface GetClusterSpecRequest {
String getTaskId();
void setTaskId(String taskId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,10 @@ public Set<TaskInfo> getTaskInfos() throws IOException, YarnException {
}

@Override
public String getClusterSpec() throws IOException, YarnException {
GetClusterSpecResponse response =
tensorflow.getClusterSpec(recordFactory.newRecordInstance(GetClusterSpecRequest.class));
public String getClusterSpec(String taskId) throws IOException, YarnException {
GetClusterSpecRequest request = recordFactory.newRecordInstance(GetClusterSpecRequest.class);
request.setTaskId(taskId);
GetClusterSpecResponse response = tensorflow.getClusterSpec(request);
return response.getClusterSpec();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@


import com.linkedin.tony.rpc.GetClusterSpecRequest;
import com.linkedin.tony.rpc.proto.YarnTonyClusterProtos;
import com.linkedin.tony.rpc.proto.YarnTonyClusterProtos.GetClusterSpecRequestProto;

public class GetClusterSpecRequestPBImpl implements GetClusterSpecRequest {
private GetClusterSpecRequestProto proto = GetClusterSpecRequestProto.getDefaultInstance();
private GetClusterSpecRequestProto.Builder builder = null;
private boolean viaProto = false;

private String taskId;

private boolean rebuild = false;

public GetClusterSpecRequestPBImpl() {
Expand All @@ -28,11 +31,18 @@ private void mergeLocalToProto() {
if (viaProto) {
maybeInitBuilder();
}
mergeLocalToBuilder();
proto = builder.build();
rebuild = false;
viaProto = true;
}

private void mergeLocalToBuilder() {
if (this.taskId != null) {
builder.setTaskId(this.taskId);
}
}

public GetClusterSpecRequestProto getProto() {
if (rebuild) {
mergeLocalToProto();
Expand All @@ -48,4 +58,26 @@ private void maybeInitBuilder() {
}
viaProto = false;
}

@Override
public String getTaskId() {
YarnTonyClusterProtos.GetClusterSpecRequestProtoOrBuilder p = viaProto ? proto : builder;
if (this.taskId != null) {
return this.taskId;
}
if (!p.hasTaskId()) {
return null;
}
this.taskId = p.getTaskId();
return this.taskId;
}

@Override
public void setTaskId(String taskId) {
maybeInitBuilder();
if (taskId == null) {
builder.clearTaskId();
}
this.taskId = taskId;
}
}
1 change: 1 addition & 0 deletions tony-core/src/main/proto/yarn_tony_cluster_protos.proto
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ message GetTaskInfosResponseProto {
}

message GetClusterSpecRequestProto {
required string taskId = 1;
}

message GetClusterSpecResponseProto {
Expand Down

0 comments on commit ebcd1f0

Please sign in to comment.