Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Seperate the interface of registerTask and getClusterSpec in TaskExec… #646

Merged
merged 1 commit into from
Feb 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions tony-core/src/main/java/com/linkedin/tony/ApplicationMaster.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
*/
package com.linkedin.tony;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.linkedin.tony.events.TaskFinished;
Expand Down Expand Up @@ -883,9 +882,11 @@ public void taskExecutorHeartbeat(String taskId) {
}

@Override
public String getClusterSpec() throws IOException {
ObjectMapper objectMapper = new ObjectMapper();
return objectMapper.writeValueAsString(session.getClusterSpec());
public String getClusterSpec(String taskId) throws IOException {
if (amRuntimeAdapter.canStartTask(distributedMode, taskId)) {
return amRuntimeAdapter.constructClusterSpec(taskId);
}
return null;
}

@Override
Expand All @@ -902,10 +903,7 @@ public String registerWorkerSpec(String taskId, String spec) throws IOException
LOG.info("[" + taskId + "] Received Registration for HB !!");
hbMonitor.register(task);
killChiefWorkerIfTesting(taskId);
}

if (amRuntimeAdapter.canStartTask(distributedMode, taskId)) {
return amRuntimeAdapter.constructClusterSpec(taskId);
return "";
}
return null;
}
Expand Down
23 changes: 8 additions & 15 deletions tony-core/src/main/java/com/linkedin/tony/TaskExecutor.java
Original file line number Diff line number Diff line change
Expand Up @@ -110,17 +110,6 @@ private void releasePort(ServerPort port) throws Exception {
}
}

/**
* Releases the reserved ports if any. This method has to be invoked after ports are created.
*/
private void releasePorts() throws Exception {
try {
this.releasePort(this.rpcPort);
} finally {
this.releasePort(this.tbPort);
}
}

/**
* @return true if reusing port is enabled by user, false otherwise.
*/
Expand Down Expand Up @@ -288,7 +277,7 @@ protected void initConfigs() {
Utils.initHdfsConf(hdfsConf);
}

private String registerAndGetClusterSpec() {
private String registerAndGetClusterSpec() throws IOException, YarnException {
ContainerId containerId = ContainerId.fromString(System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name()));
String hostName = Utils.getCurrentHostName();
LOG.info("ContainerId is: " + containerId + " HostName is: " + hostName);
Expand All @@ -299,9 +288,13 @@ private String registerAndGetClusterSpec() {

LOG.info("Connecting to " + amHost + ":" + amPort + " to register worker spec: " + jobName + " " + taskIndex + " "
+ hostName + ":" + this.rpcPort.getPort());
return Utils.pollTillNonNull(() ->
proxy.registerWorkerSpec(jobName + ":" + taskIndex,
hostName + ":" + this.rpcPort.getPort()), 3, 0);

String taskId = String.format("%s:%s", jobName, taskIndex);
String hostAndPort = String.format("%s:%s", hostName, rpcPort.getPort());

Utils.pollTillNonNull(() -> proxy.registerWorkerSpec(taskId, hostAndPort), 3, 0);

return Utils.pollTillNonNull(() -> proxy.getClusterSpec(taskId), 3, 0);
}

public void callbackInfoToAM(String taskId, String callbackInfo) throws IOException, YarnException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ public interface ApplicationRpc {
*/
Set<TaskInfo> getTaskInfos() throws IOException, YarnException;

String getClusterSpec() throws IOException, YarnException;
String registerWorkerSpec(String worker, String spec) throws IOException, YarnException;
String getClusterSpec(String taskId) throws IOException, YarnException;
void registerCallbackInfo(String taskId, String callbackInfo) throws YarnException, IOException;
String registerTensorBoardUrl(String spec) throws Exception;
String registerExecutionResult(int exitCode, String jobName, String jobIndex, String sessionId) throws Exception;
void finishApplication() throws YarnException, IOException;
void taskExecutorHeartbeat(String taskId) throws YarnException, IOException;
void reset();
void registerCallbackInfo(String taskId, String callbackInfo) throws YarnException, IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public GetTaskInfosResponse getTaskInfos(GetTaskInfosRequest request) throws IOE
public GetClusterSpecResponse getClusterSpec(GetClusterSpecRequest request)
throws YarnException, IOException {
GetClusterSpecResponse response = RECORD_FACTORY.newRecordInstance(GetClusterSpecResponse.class);
response.setClusterSpec(this.appRpc.getClusterSpec());
response.setClusterSpec(this.appRpc.getClusterSpec(request.getTaskId()));
return response;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@
package com.linkedin.tony.rpc;

public interface GetClusterSpecRequest {
String getTaskId();
void setTaskId(String taskId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,10 @@ public Set<TaskInfo> getTaskInfos() throws IOException, YarnException {
}

@Override
public String getClusterSpec() throws IOException, YarnException {
GetClusterSpecResponse response =
tensorflow.getClusterSpec(recordFactory.newRecordInstance(GetClusterSpecRequest.class));
public String getClusterSpec(String taskId) throws IOException, YarnException {
GetClusterSpecRequest request = recordFactory.newRecordInstance(GetClusterSpecRequest.class);
request.setTaskId(taskId);
GetClusterSpecResponse response = tensorflow.getClusterSpec(request);
return response.getClusterSpec();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
*/
package com.linkedin.tony.rpc.impl.pb;


import com.linkedin.tony.rpc.GetClusterSpecRequest;
import com.linkedin.tony.rpc.proto.YarnTonyClusterProtos;
import com.linkedin.tony.rpc.proto.YarnTonyClusterProtos.GetClusterSpecRequestProto;

public class GetClusterSpecRequestPBImpl implements GetClusterSpecRequest {
private GetClusterSpecRequestProto proto = GetClusterSpecRequestProto.getDefaultInstance();
private GetClusterSpecRequestProto.Builder builder = null;
private boolean viaProto = false;

private boolean rebuild = false;
private String taskId;

public GetClusterSpecRequestPBImpl() {
builder = GetClusterSpecRequestProto.newBuilder();
Expand All @@ -28,15 +28,19 @@ private void mergeLocalToProto() {
if (viaProto) {
maybeInitBuilder();
}
mergeLocalToBuilder();
proto = builder.build();
rebuild = false;
viaProto = true;
}

private void mergeLocalToBuilder() {
if (this.taskId != null) {
builder.setTaskId(this.taskId);
}
}

public GetClusterSpecRequestProto getProto() {
if (rebuild) {
mergeLocalToProto();
}
mergeLocalToProto();
proto = viaProto ? proto : builder.build();
viaProto = true;
return proto;
Expand All @@ -48,4 +52,26 @@ private void maybeInitBuilder() {
}
viaProto = false;
}

@Override
public String getTaskId() {
YarnTonyClusterProtos.GetClusterSpecRequestProtoOrBuilder p = viaProto ? proto : builder;
if (this.taskId != null) {
return this.taskId;
}
if (!p.hasTaskId()) {
return null;
}
this.taskId = p.getTaskId();
return this.taskId;
}

@Override
public void setTaskId(String taskId) {
maybeInitBuilder();
if (taskId == null) {
builder.clearTaskId();
}
this.taskId = taskId;
}
}
1 change: 1 addition & 0 deletions tony-core/src/main/proto/yarn_tony_cluster_protos.proto
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ message GetTaskInfosResponseProto {
}

message GetClusterSpecRequestProto {
required string taskId = 1;
}

message GetClusterSpecResponseProto {
Expand Down