Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Java] Remove RayRuntimeInternal class #25016

Merged
merged 4 commits into from
May 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.ray.api.options.PlacementGroupCreationOptions;
import io.ray.api.parallelactor.ParallelActorContext;
import io.ray.api.placementgroup.PlacementGroup;
import io.ray.api.runtime.RayRuntime;
import io.ray.api.runtimecontext.RuntimeContext;
import io.ray.api.runtimeenv.RuntimeEnv;
import io.ray.runtime.config.RayConfig;
Expand All @@ -31,6 +32,7 @@
import io.ray.runtime.functionmanager.FunctionManager;
import io.ray.runtime.functionmanager.PyFunctionDescriptor;
import io.ray.runtime.functionmanager.RayFunction;
import io.ray.runtime.gcs.GcsClient;
import io.ray.runtime.generated.Common.Language;
import io.ray.runtime.object.ObjectRefImpl;
import io.ray.runtime.object.ObjectStore;
Expand All @@ -50,7 +52,7 @@
import org.slf4j.LoggerFactory;

/** Core functionality to implement Ray APIs. */
public abstract class AbstractRayRuntime implements RayRuntimeInternal {
public abstract class AbstractRayRuntime implements RayRuntime {

private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRayRuntime.class);
public static final String PYTHON_INIT_METHOD_NAME = "__init__";
Expand Down Expand Up @@ -82,6 +84,12 @@ public <T> ObjectRef<T> put(T obj) {
/*skipAddingLocalRef=*/ true);
}

public abstract GcsClient getGcsClient();

public abstract void start();

public abstract void run();

@Override
public <T> ObjectRef<T> put(T obj, BaseActorHandle ownerActor) {
if (LOGGER.isDebugEnabled()) {
Expand Down Expand Up @@ -355,27 +363,22 @@ private BaseActorHandle createActorImpl(

abstract List<ObjectId> getCurrentReturnIds(int numReturns, ActorId actorId);

@Override
public WorkerContext getWorkerContext() {
return workerContext;
}

@Override
public ObjectStore getObjectStore() {
return objectStore;
}

@Override
public TaskExecutor getTaskExecutor() {
return taskExecutor;
}

@Override
public FunctionManager getFunctionManager() {
return functionManager;
}

@Override
public RayConfig getRayConfig() {
return rayConfig;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public ConcurrencyGroupImpl(String name, int maxConcurrency, List<RayFunc> funcs
funcs.forEach(
func -> {
RayFunction rayFunc =
((RayRuntimeInternal) Ray.internal()).getFunctionManager().getFunction(func);
((AbstractRayRuntime) Ray.internal()).getFunctionManager().getFunction(func);
functionDescriptors.add(rayFunc.getFunctionDescriptor());
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,10 @@ public RayRuntime createRayRuntime() {

try {
logger.debug("Initializing runtime with config: {}", rayConfig);
AbstractRayRuntime innerRuntime =
AbstractRayRuntime runtime =
rayConfig.runMode == RunMode.LOCAL
? new RayDevRuntime(rayConfig)
: new RayNativeRuntime(rayConfig);
RayRuntimeInternal runtime = innerRuntime;
runtime.start();
return runtime;
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,6 @@ private static native void nativeInitialize(

private static native byte[] nativeGetActorIdOfNamedActor(String actorName, String namespace);

private static native void nativeSetCoreWorker(byte[] workerId);

private static native Map<String, List<ResourceValue>> nativeGetResourceIds();

private static native String nativeGetNamespace();
Expand Down
30 changes: 0 additions & 30 deletions java/runtime/src/main/java/io/ray/runtime/RayRuntimeInternal.java

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import io.ray.api.Ray;
import io.ray.api.id.ActorId;
import io.ray.api.id.ObjectId;
import io.ray.runtime.RayRuntimeInternal;
import io.ray.runtime.AbstractRayRuntime;
import io.ray.runtime.generated.Common.Language;
import java.io.Externalizable;
import java.io.IOException;
Expand Down Expand Up @@ -122,7 +122,7 @@ private static final class NativeActorHandleReference
public NativeActorHandleReference(NativeActorHandle handle) {
super(handle, REFERENCE_QUEUE);
this.actorId = handle.actorId;
RayRuntimeInternal runtime = (RayRuntimeInternal) Ray.internal();
AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal();
this.workerId = runtime.getWorkerContext().getCurrentWorkerId().getBytes();
this.removed = new AtomicBoolean(false);
REFERENCES.add(this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import io.ray.api.runtimecontext.NodeInfo;
import io.ray.api.runtimecontext.ResourceValue;
import io.ray.api.runtimecontext.RuntimeContext;
import io.ray.runtime.RayRuntimeInternal;
import io.ray.runtime.AbstractRayRuntime;
import io.ray.runtime.config.RunMode;
import io.ray.runtime.util.ResourceUtil;
import java.util.ArrayList;
Expand All @@ -21,9 +21,9 @@

public class RuntimeContextImpl implements RuntimeContext {

private RayRuntimeInternal runtime;
private AbstractRayRuntime runtime;

public RuntimeContextImpl(RayRuntimeInternal runtime) {
public RuntimeContextImpl(AbstractRayRuntime runtime) {
this.runtime = runtime;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import io.ray.api.id.BaseId;
import io.ray.api.id.ObjectId;
import io.ray.api.id.UniqueId;
import io.ray.runtime.RayRuntimeInternal;
import io.ray.runtime.AbstractRayRuntime;
import io.ray.runtime.context.WorkerContext;
import io.ray.runtime.generated.Common.Address;
import java.util.HashMap;
Expand Down Expand Up @@ -40,7 +40,7 @@ public ObjectId putRaw(NativeRayObject obj) {
@Override
public ObjectId putRaw(NativeRayObject obj, ActorId ownerActorId) {
byte[] serializedOwnerAddressBytes =
((RayRuntimeInternal) Ray.internal()).getGcsClient().getActorAddress(ownerActorId);
((AbstractRayRuntime) Ray.internal()).getGcsClient().getActorAddress(ownerActorId);
return new ObjectId(nativePut(obj, serializedOwnerAddressBytes));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import io.ray.api.Ray;
import io.ray.api.id.ObjectId;
import io.ray.api.id.UniqueId;
import io.ray.runtime.RayRuntimeInternal;
import io.ray.runtime.AbstractRayRuntime;
import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
Expand Down Expand Up @@ -60,7 +60,7 @@ public ObjectRefImpl(ObjectId id, Class<T> type) {
public void init(ObjectId id, Class<?> type, boolean skipAddingLocalRef) {
this.id = id;
this.type = (Class<T>) type;
RayRuntimeInternal runtime = (RayRuntimeInternal) Ray.internal();
AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal();
Preconditions.checkState(workerId == null);
workerId = runtime.getWorkerContext().getCurrentWorkerId();

Expand Down Expand Up @@ -106,7 +106,7 @@ public String toString() {
public void writeExternal(ObjectOutput out) throws IOException {
out.writeObject(this.getId());
out.writeObject(this.getType());
RayRuntimeInternal runtime = (RayRuntimeInternal) Ray.internal();
AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal();
byte[] ownerAddress = runtime.getObjectStore().getOwnershipInfo(this.getId());
out.writeInt(ownerAddress.length);
out.write(ownerAddress);
Expand All @@ -121,7 +121,7 @@ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundExcept
byte[] ownerAddress = new byte[len];
in.readFully(ownerAddress);

RayRuntimeInternal runtime = (RayRuntimeInternal) Ray.internal();
AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal();
Preconditions.checkState(workerId == null);
workerId = runtime.getWorkerContext().getCurrentWorkerId();
runtime.getObjectStore().addLocalReference(workerId, id);
Expand Down Expand Up @@ -156,7 +156,7 @@ public void finalizeReferent() {
REFERENCES.remove(this);
// It's possible that GC is executed after the runtime is shutdown.
if (Ray.isInitialized()) {
((RayRuntimeInternal) (Ray.internal()))
((AbstractRayRuntime) (Ray.internal()))
.getObjectStore()
.removeLocalReference(workerId, objectId);
allObjects.remove(objectId);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package io.ray.runtime.runner.worker;

import io.ray.api.Ray;
import io.ray.runtime.RayRuntimeInternal;
import io.ray.runtime.AbstractRayRuntime;

/** Default implementation of the worker process. */
public class DefaultWorker {
Expand All @@ -12,6 +12,6 @@ public static void main(String[] args) {
System.setProperty("ray.run-mode", "CLUSTER");
System.setProperty("ray.worker.mode", "WORKER");
Ray.init();
((RayRuntimeInternal) Ray.internal()).run();
((AbstractRayRuntime) Ray.internal()).run();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.api.id.ObjectId;
import io.ray.runtime.RayRuntimeInternal;
import io.ray.runtime.AbstractRayRuntime;
import io.ray.runtime.generated.Common.Address;
import io.ray.runtime.generated.Common.Language;
import io.ray.runtime.object.NativeRayObject;
Expand Down Expand Up @@ -41,7 +41,7 @@ public static List<FunctionArg> wrap(Object[] args, Language language) {
if (arg instanceof ObjectRef) {
Preconditions.checkState(arg instanceof ObjectRefImpl);
id = ((ObjectRefImpl<?>) arg).getId();
address = ((RayRuntimeInternal) Ray.internal()).getObjectStore().getOwnerAddress(id);
address = ((AbstractRayRuntime) Ray.internal()).getObjectStore().getOwnerAddress(id);
} else {
value = ObjectSerializer.serialize(arg);
if (language != Language.JAVA) {
Expand All @@ -60,8 +60,8 @@ public static List<FunctionArg> wrap(Object[] args, Language language) {
}
}
if (value.data.length > LARGEST_SIZE_PASS_BY_VALUE) {
id = ((RayRuntimeInternal) Ray.internal()).getObjectStore().putRaw(value);
address = ((RayRuntimeInternal) Ray.internal()).getWorkerContext().getRpcAddress();
id = ((AbstractRayRuntime) Ray.internal()).getObjectStore().putRaw(value);
address = ((AbstractRayRuntime) Ray.internal()).getWorkerContext().getRpcAddress();
value = null;
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package io.ray.runtime.task;

import io.ray.api.id.UniqueId;
import io.ray.runtime.RayRuntimeInternal;
import io.ray.runtime.AbstractRayRuntime;

/** Task executor for local mode. */
public class LocalModeTaskExecutor extends TaskExecutor<LocalModeTaskExecutor.LocalActorContext> {
Expand All @@ -20,7 +20,7 @@ public UniqueId getWorkerId() {
}
}

public LocalModeTaskExecutor(RayRuntimeInternal runtime) {
public LocalModeTaskExecutor(AbstractRayRuntime runtime) {
super(runtime);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import io.ray.api.options.CallOptions;
import io.ray.api.options.PlacementGroupCreationOptions;
import io.ray.api.placementgroup.PlacementGroup;
import io.ray.runtime.AbstractRayRuntime;
import io.ray.runtime.ConcurrencyGroupImpl;
import io.ray.runtime.RayRuntimeInternal;
import io.ray.runtime.actor.LocalModeActorHandle;
import io.ray.runtime.context.LocalModeWorkerContext;
import io.ray.runtime.functionmanager.FunctionDescriptor;
Expand Down Expand Up @@ -59,7 +59,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {

private final Map<ObjectId, Set<TaskSpec>> waitingTasks = new HashMap<>();
private final Object taskAndObjectLock = new Object();
private final RayRuntimeInternal runtime;
private final AbstractRayRuntime runtime;
private final TaskExecutor taskExecutor;
private final LocalModeObjectStore objectStore;

Expand Down Expand Up @@ -169,7 +169,7 @@ public synchronized void shutdown() {
}

public LocalModeTaskSubmitter(
RayRuntimeInternal runtime, TaskExecutor taskExecutor, LocalModeObjectStore objectStore) {
AbstractRayRuntime runtime, TaskExecutor taskExecutor, LocalModeObjectStore objectStore) {
this.runtime = runtime;
this.taskExecutor = taskExecutor;
this.objectStore = objectStore;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package io.ray.runtime.task;

import io.ray.api.id.UniqueId;
import io.ray.runtime.RayRuntimeInternal;
import io.ray.runtime.AbstractRayRuntime;

/** Task executor for cluster mode. */
public class NativeTaskExecutor extends TaskExecutor<NativeTaskExecutor.NativeActorContext> {

static class NativeActorContext extends TaskExecutor.ActorContext {}

public NativeTaskExecutor(RayRuntimeInternal runtime) {
public NativeTaskExecutor(AbstractRayRuntime runtime) {
super(runtime);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import io.ray.api.id.JobId;
import io.ray.api.id.TaskId;
import io.ray.api.id.UniqueId;
import io.ray.runtime.RayRuntimeInternal;
import io.ray.runtime.AbstractRayRuntime;
import io.ray.runtime.functionmanager.JavaFunctionDescriptor;
import io.ray.runtime.functionmanager.RayFunction;
import io.ray.runtime.generated.Common.TaskType;
Expand All @@ -32,7 +32,7 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {

private static final Logger LOGGER = LoggerFactory.getLogger(TaskExecutor.class);

protected final RayRuntimeInternal runtime;
protected final AbstractRayRuntime runtime;

// TODO(qwang): Use actorContext instead later.
private final ConcurrentHashMap<UniqueId, T> actorContextMap = new ConcurrentHashMap<>();
Expand All @@ -44,7 +44,7 @@ static class ActorContext {
Object currentActor = null;
}

TaskExecutor(RayRuntimeInternal runtime) {
TaskExecutor(AbstractRayRuntime runtime) {
this.runtime = runtime;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package io.ray.runtime.util;

import io.ray.api.Ray;
import io.ray.runtime.RayRuntimeInternal;
import io.ray.runtime.AbstractRayRuntime;
import java.lang.reflect.Array;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
Expand Down Expand Up @@ -47,7 +47,7 @@ public static Class<?> getReturnTypeFromSignature(String signature) {
/// This code path indicates that here might be in another thread of a worker.
/// So try to load the class from URLClassLoader of this worker.
ClassLoader cl =
((RayRuntimeInternal) Ray.internal()).getFunctionManager().getClassLoader();
((AbstractRayRuntime) Ray.internal()).getFunctionManager().getClassLoader();
actorClz = Class.forName(className, true, cl);
}
} catch (Exception e) {
Expand Down
Loading