Skip to content

Commit

Permalink
Revert "[Java] Remove RayRuntimeInternal class (ray-project#25016)" (r…
Browse files Browse the repository at this point in the history
…ay-project#25139)

This reverts commit 4026b38.

Broke test_raydp_dataset
  • Loading branch information
krfricke authored May 24, 2022
1 parent a7e7593 commit 804b6b1
Show file tree
Hide file tree
Showing 21 changed files with 97 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import io.ray.api.options.PlacementGroupCreationOptions;
import io.ray.api.parallelactor.ParallelActorContext;
import io.ray.api.placementgroup.PlacementGroup;
import io.ray.api.runtime.RayRuntime;
import io.ray.api.runtimecontext.RuntimeContext;
import io.ray.api.runtimeenv.RuntimeEnv;
import io.ray.runtime.config.RayConfig;
Expand All @@ -32,7 +31,6 @@
import io.ray.runtime.functionmanager.FunctionManager;
import io.ray.runtime.functionmanager.PyFunctionDescriptor;
import io.ray.runtime.functionmanager.RayFunction;
import io.ray.runtime.gcs.GcsClient;
import io.ray.runtime.generated.Common.Language;
import io.ray.runtime.object.ObjectRefImpl;
import io.ray.runtime.object.ObjectStore;
Expand All @@ -52,7 +50,7 @@
import org.slf4j.LoggerFactory;

/** Core functionality to implement Ray APIs. */
public abstract class AbstractRayRuntime implements RayRuntime {
public abstract class AbstractRayRuntime implements RayRuntimeInternal {

private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRayRuntime.class);
public static final String PYTHON_INIT_METHOD_NAME = "__init__";
Expand Down Expand Up @@ -84,12 +82,6 @@ public <T> ObjectRef<T> put(T obj) {
/*skipAddingLocalRef=*/ true);
}

public abstract GcsClient getGcsClient();

public abstract void start();

public abstract void run();

@Override
public <T> ObjectRef<T> put(T obj, BaseActorHandle ownerActor) {
if (LOGGER.isDebugEnabled()) {
Expand Down Expand Up @@ -363,22 +355,27 @@ private BaseActorHandle createActorImpl(

abstract List<ObjectId> getCurrentReturnIds(int numReturns, ActorId actorId);

@Override
public WorkerContext getWorkerContext() {
return workerContext;
}

@Override
public ObjectStore getObjectStore() {
return objectStore;
}

@Override
public TaskExecutor getTaskExecutor() {
return taskExecutor;
}

@Override
public FunctionManager getFunctionManager() {
return functionManager;
}

@Override
public RayConfig getRayConfig() {
return rayConfig;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public ConcurrencyGroupImpl(String name, int maxConcurrency, List<RayFunc> funcs
funcs.forEach(
func -> {
RayFunction rayFunc =
((AbstractRayRuntime) Ray.internal()).getFunctionManager().getFunction(func);
((RayRuntimeInternal) Ray.internal()).getFunctionManager().getFunction(func);
functionDescriptors.add(rayFunc.getFunctionDescriptor());
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ public RayRuntime createRayRuntime() {

try {
logger.debug("Initializing runtime with config: {}", rayConfig);
AbstractRayRuntime runtime =
AbstractRayRuntime innerRuntime =
rayConfig.runMode == RunMode.LOCAL
? new RayDevRuntime(rayConfig)
: new RayNativeRuntime(rayConfig);
RayRuntimeInternal runtime = innerRuntime;
runtime.start();
return runtime;
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,8 @@ private static native void nativeInitialize(

private static native byte[] nativeGetActorIdOfNamedActor(String actorName, String namespace);

private static native void nativeSetCoreWorker(byte[] workerId);

private static native Map<String, List<ResourceValue>> nativeGetResourceIds();

private static native String nativeGetNamespace();
Expand Down
30 changes: 30 additions & 0 deletions java/runtime/src/main/java/io/ray/runtime/RayRuntimeInternal.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package io.ray.runtime;

import io.ray.api.runtime.RayRuntime;
import io.ray.runtime.config.RayConfig;
import io.ray.runtime.context.WorkerContext;
import io.ray.runtime.functionmanager.FunctionManager;
import io.ray.runtime.gcs.GcsClient;
import io.ray.runtime.object.ObjectStore;
import io.ray.runtime.task.TaskExecutor;

/** This interface is required to make {@link RayRuntimeProxy} work. */
public interface RayRuntimeInternal extends RayRuntime {

/** Start runtime. */
void start();

WorkerContext getWorkerContext();

ObjectStore getObjectStore();

TaskExecutor getTaskExecutor();

FunctionManager getFunctionManager();

RayConfig getRayConfig();

GcsClient getGcsClient();

void run();
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import io.ray.api.Ray;
import io.ray.api.id.ActorId;
import io.ray.api.id.ObjectId;
import io.ray.runtime.AbstractRayRuntime;
import io.ray.runtime.RayRuntimeInternal;
import io.ray.runtime.generated.Common.Language;
import java.io.Externalizable;
import java.io.IOException;
Expand Down Expand Up @@ -122,7 +122,7 @@ private static final class NativeActorHandleReference
public NativeActorHandleReference(NativeActorHandle handle) {
super(handle, REFERENCE_QUEUE);
this.actorId = handle.actorId;
AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal();
RayRuntimeInternal runtime = (RayRuntimeInternal) Ray.internal();
this.workerId = runtime.getWorkerContext().getCurrentWorkerId().getBytes();
this.removed = new AtomicBoolean(false);
REFERENCES.add(this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import io.ray.api.runtimecontext.NodeInfo;
import io.ray.api.runtimecontext.ResourceValue;
import io.ray.api.runtimecontext.RuntimeContext;
import io.ray.runtime.AbstractRayRuntime;
import io.ray.runtime.RayRuntimeInternal;
import io.ray.runtime.config.RunMode;
import io.ray.runtime.util.ResourceUtil;
import java.util.ArrayList;
Expand All @@ -21,9 +21,9 @@

public class RuntimeContextImpl implements RuntimeContext {

private AbstractRayRuntime runtime;
private RayRuntimeInternal runtime;

public RuntimeContextImpl(AbstractRayRuntime runtime) {
public RuntimeContextImpl(RayRuntimeInternal runtime) {
this.runtime = runtime;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import io.ray.api.id.BaseId;
import io.ray.api.id.ObjectId;
import io.ray.api.id.UniqueId;
import io.ray.runtime.AbstractRayRuntime;
import io.ray.runtime.RayRuntimeInternal;
import io.ray.runtime.context.WorkerContext;
import io.ray.runtime.generated.Common.Address;
import java.util.HashMap;
Expand Down Expand Up @@ -40,7 +40,7 @@ public ObjectId putRaw(NativeRayObject obj) {
@Override
public ObjectId putRaw(NativeRayObject obj, ActorId ownerActorId) {
byte[] serializedOwnerAddressBytes =
((AbstractRayRuntime) Ray.internal()).getGcsClient().getActorAddress(ownerActorId);
((RayRuntimeInternal) Ray.internal()).getGcsClient().getActorAddress(ownerActorId);
return new ObjectId(nativePut(obj, serializedOwnerAddressBytes));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import io.ray.api.Ray;
import io.ray.api.id.ObjectId;
import io.ray.api.id.UniqueId;
import io.ray.runtime.AbstractRayRuntime;
import io.ray.runtime.RayRuntimeInternal;
import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
Expand Down Expand Up @@ -60,7 +60,7 @@ public ObjectRefImpl(ObjectId id, Class<T> type) {
public void init(ObjectId id, Class<?> type, boolean skipAddingLocalRef) {
this.id = id;
this.type = (Class<T>) type;
AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal();
RayRuntimeInternal runtime = (RayRuntimeInternal) Ray.internal();
Preconditions.checkState(workerId == null);
workerId = runtime.getWorkerContext().getCurrentWorkerId();

Expand Down Expand Up @@ -106,7 +106,7 @@ public String toString() {
public void writeExternal(ObjectOutput out) throws IOException {
out.writeObject(this.getId());
out.writeObject(this.getType());
AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal();
RayRuntimeInternal runtime = (RayRuntimeInternal) Ray.internal();
byte[] ownerAddress = runtime.getObjectStore().getOwnershipInfo(this.getId());
out.writeInt(ownerAddress.length);
out.write(ownerAddress);
Expand All @@ -121,7 +121,7 @@ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundExcept
byte[] ownerAddress = new byte[len];
in.readFully(ownerAddress);

AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal();
RayRuntimeInternal runtime = (RayRuntimeInternal) Ray.internal();
Preconditions.checkState(workerId == null);
workerId = runtime.getWorkerContext().getCurrentWorkerId();
runtime.getObjectStore().addLocalReference(workerId, id);
Expand Down Expand Up @@ -156,7 +156,7 @@ public void finalizeReferent() {
REFERENCES.remove(this);
// It's possible that GC is executed after the runtime is shutdown.
if (Ray.isInitialized()) {
((AbstractRayRuntime) (Ray.internal()))
((RayRuntimeInternal) (Ray.internal()))
.getObjectStore()
.removeLocalReference(workerId, objectId);
allObjects.remove(objectId);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package io.ray.runtime.runner.worker;

import io.ray.api.Ray;
import io.ray.runtime.AbstractRayRuntime;
import io.ray.runtime.RayRuntimeInternal;

/** Default implementation of the worker process. */
public class DefaultWorker {
Expand All @@ -12,6 +12,6 @@ public static void main(String[] args) {
System.setProperty("ray.run-mode", "CLUSTER");
System.setProperty("ray.worker.mode", "WORKER");
Ray.init();
((AbstractRayRuntime) Ray.internal()).run();
((RayRuntimeInternal) Ray.internal()).run();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.api.id.ObjectId;
import io.ray.runtime.AbstractRayRuntime;
import io.ray.runtime.RayRuntimeInternal;
import io.ray.runtime.generated.Common.Address;
import io.ray.runtime.generated.Common.Language;
import io.ray.runtime.object.NativeRayObject;
Expand Down Expand Up @@ -41,7 +41,7 @@ public static List<FunctionArg> wrap(Object[] args, Language language) {
if (arg instanceof ObjectRef) {
Preconditions.checkState(arg instanceof ObjectRefImpl);
id = ((ObjectRefImpl<?>) arg).getId();
address = ((AbstractRayRuntime) Ray.internal()).getObjectStore().getOwnerAddress(id);
address = ((RayRuntimeInternal) Ray.internal()).getObjectStore().getOwnerAddress(id);
} else {
value = ObjectSerializer.serialize(arg);
if (language != Language.JAVA) {
Expand All @@ -60,8 +60,8 @@ public static List<FunctionArg> wrap(Object[] args, Language language) {
}
}
if (value.data.length > LARGEST_SIZE_PASS_BY_VALUE) {
id = ((AbstractRayRuntime) Ray.internal()).getObjectStore().putRaw(value);
address = ((AbstractRayRuntime) Ray.internal()).getWorkerContext().getRpcAddress();
id = ((RayRuntimeInternal) Ray.internal()).getObjectStore().putRaw(value);
address = ((RayRuntimeInternal) Ray.internal()).getWorkerContext().getRpcAddress();
value = null;
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package io.ray.runtime.task;

import io.ray.api.id.UniqueId;
import io.ray.runtime.AbstractRayRuntime;
import io.ray.runtime.RayRuntimeInternal;

/** Task executor for local mode. */
public class LocalModeTaskExecutor extends TaskExecutor<LocalModeTaskExecutor.LocalActorContext> {
Expand All @@ -20,7 +20,7 @@ public UniqueId getWorkerId() {
}
}

public LocalModeTaskExecutor(AbstractRayRuntime runtime) {
public LocalModeTaskExecutor(RayRuntimeInternal runtime) {
super(runtime);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import io.ray.api.options.CallOptions;
import io.ray.api.options.PlacementGroupCreationOptions;
import io.ray.api.placementgroup.PlacementGroup;
import io.ray.runtime.AbstractRayRuntime;
import io.ray.runtime.ConcurrencyGroupImpl;
import io.ray.runtime.RayRuntimeInternal;
import io.ray.runtime.actor.LocalModeActorHandle;
import io.ray.runtime.context.LocalModeWorkerContext;
import io.ray.runtime.functionmanager.FunctionDescriptor;
Expand Down Expand Up @@ -59,7 +59,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {

private final Map<ObjectId, Set<TaskSpec>> waitingTasks = new HashMap<>();
private final Object taskAndObjectLock = new Object();
private final AbstractRayRuntime runtime;
private final RayRuntimeInternal runtime;
private final TaskExecutor taskExecutor;
private final LocalModeObjectStore objectStore;

Expand Down Expand Up @@ -169,7 +169,7 @@ public synchronized void shutdown() {
}

public LocalModeTaskSubmitter(
AbstractRayRuntime runtime, TaskExecutor taskExecutor, LocalModeObjectStore objectStore) {
RayRuntimeInternal runtime, TaskExecutor taskExecutor, LocalModeObjectStore objectStore) {
this.runtime = runtime;
this.taskExecutor = taskExecutor;
this.objectStore = objectStore;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package io.ray.runtime.task;

import io.ray.api.id.UniqueId;
import io.ray.runtime.AbstractRayRuntime;
import io.ray.runtime.RayRuntimeInternal;

/** Task executor for cluster mode. */
public class NativeTaskExecutor extends TaskExecutor<NativeTaskExecutor.NativeActorContext> {

static class NativeActorContext extends TaskExecutor.ActorContext {}

public NativeTaskExecutor(AbstractRayRuntime runtime) {
public NativeTaskExecutor(RayRuntimeInternal runtime) {
super(runtime);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import io.ray.api.id.JobId;
import io.ray.api.id.TaskId;
import io.ray.api.id.UniqueId;
import io.ray.runtime.AbstractRayRuntime;
import io.ray.runtime.RayRuntimeInternal;
import io.ray.runtime.functionmanager.JavaFunctionDescriptor;
import io.ray.runtime.functionmanager.RayFunction;
import io.ray.runtime.generated.Common.TaskType;
Expand All @@ -32,7 +32,7 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {

private static final Logger LOGGER = LoggerFactory.getLogger(TaskExecutor.class);

protected final AbstractRayRuntime runtime;
protected final RayRuntimeInternal runtime;

// TODO(qwang): Use actorContext instead later.
private final ConcurrentHashMap<UniqueId, T> actorContextMap = new ConcurrentHashMap<>();
Expand All @@ -44,7 +44,7 @@ static class ActorContext {
Object currentActor = null;
}

TaskExecutor(AbstractRayRuntime runtime) {
TaskExecutor(RayRuntimeInternal runtime) {
this.runtime = runtime;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package io.ray.runtime.util;

import io.ray.api.Ray;
import io.ray.runtime.AbstractRayRuntime;
import io.ray.runtime.RayRuntimeInternal;
import java.lang.reflect.Array;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
Expand Down Expand Up @@ -47,7 +47,7 @@ public static Class<?> getReturnTypeFromSignature(String signature) {
/// This code path indicates that here might be in another thread of a worker.
/// So try to load the class from URLClassLoader of this worker.
ClassLoader cl =
((AbstractRayRuntime) Ray.internal()).getFunctionManager().getClassLoader();
((RayRuntimeInternal) Ray.internal()).getFunctionManager().getClassLoader();
actorClz = Class.forName(className, true, cl);
}
} catch (Exception e) {
Expand Down
Loading

0 comments on commit 804b6b1

Please sign in to comment.