From 804b6b11d1c0c9aa5f00ba7b51a0035b9a116710 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Tue, 24 May 2022 13:17:47 +0100 Subject: [PATCH] Revert "[Java] Remove RayRuntimeInternal class (#25016)" (#25139) This reverts commit 4026b38b098804aec8782ea498d66b4194f8c292. Broke test_raydp_dataset --- .../io/ray/runtime/AbstractRayRuntime.java | 15 ++++------ .../io/ray/runtime/ConcurrencyGroupImpl.java | 2 +- .../ray/runtime/DefaultRayRuntimeFactory.java | 3 +- .../java/io/ray/runtime/RayNativeRuntime.java | 2 ++ .../io/ray/runtime/RayRuntimeInternal.java | 30 +++++++++++++++++++ .../ray/runtime/actor/NativeActorHandle.java | 4 +-- .../runtime/context/RuntimeContextImpl.java | 6 ++-- .../ray/runtime/object/NativeObjectStore.java | 4 +-- .../io/ray/runtime/object/ObjectRefImpl.java | 10 +++---- .../runtime/runner/worker/DefaultWorker.java | 4 +-- .../io/ray/runtime/task/ArgumentsBuilder.java | 8 ++--- .../runtime/task/LocalModeTaskExecutor.java | 4 +-- .../runtime/task/LocalModeTaskSubmitter.java | 6 ++-- .../ray/runtime/task/NativeTaskExecutor.java | 4 +-- .../io/ray/runtime/task/TaskExecutor.java | 6 ++-- .../java/io/ray/runtime/util/MethodUtils.java | 4 +-- .../ParallelActorContextImpl.java | 6 ++-- .../ParallelActorExecutorImpl.java | 4 +-- .../src/main/java/io/ray/test/TestUtils.java | 10 +++++-- .../java/io_ray_runtime_RayNativeRuntime.cc | 6 ++++ .../java/io_ray_runtime_RayNativeRuntime.h | 8 +++++ 21 files changed, 97 insertions(+), 49 deletions(-) create mode 100644 java/runtime/src/main/java/io/ray/runtime/RayRuntimeInternal.java diff --git a/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java index 36126f630da05..d19b25ed48d5a 100644 --- a/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java @@ -21,7 +21,6 @@ import io.ray.api.options.PlacementGroupCreationOptions; import io.ray.api.parallelactor.ParallelActorContext; import io.ray.api.placementgroup.PlacementGroup; -import io.ray.api.runtime.RayRuntime; import io.ray.api.runtimecontext.RuntimeContext; import io.ray.api.runtimeenv.RuntimeEnv; import io.ray.runtime.config.RayConfig; @@ -32,7 +31,6 @@ import io.ray.runtime.functionmanager.FunctionManager; import io.ray.runtime.functionmanager.PyFunctionDescriptor; import io.ray.runtime.functionmanager.RayFunction; -import io.ray.runtime.gcs.GcsClient; import io.ray.runtime.generated.Common.Language; import io.ray.runtime.object.ObjectRefImpl; import io.ray.runtime.object.ObjectStore; @@ -52,7 +50,7 @@ import org.slf4j.LoggerFactory; /** Core functionality to implement Ray APIs. */ -public abstract class AbstractRayRuntime implements RayRuntime { +public abstract class AbstractRayRuntime implements RayRuntimeInternal { private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRayRuntime.class); public static final String PYTHON_INIT_METHOD_NAME = "__init__"; @@ -84,12 +82,6 @@ public ObjectRef put(T obj) { /*skipAddingLocalRef=*/ true); } - public abstract GcsClient getGcsClient(); - - public abstract void start(); - - public abstract void run(); - @Override public ObjectRef put(T obj, BaseActorHandle ownerActor) { if (LOGGER.isDebugEnabled()) { @@ -363,22 +355,27 @@ private BaseActorHandle createActorImpl( abstract List getCurrentReturnIds(int numReturns, ActorId actorId); + @Override public WorkerContext getWorkerContext() { return workerContext; } + @Override public ObjectStore getObjectStore() { return objectStore; } + @Override public TaskExecutor getTaskExecutor() { return taskExecutor; } + @Override public FunctionManager getFunctionManager() { return functionManager; } + @Override public RayConfig getRayConfig() { return rayConfig; } diff --git a/java/runtime/src/main/java/io/ray/runtime/ConcurrencyGroupImpl.java b/java/runtime/src/main/java/io/ray/runtime/ConcurrencyGroupImpl.java index 5403666e83210..53ac57da52e93 100644 --- a/java/runtime/src/main/java/io/ray/runtime/ConcurrencyGroupImpl.java +++ b/java/runtime/src/main/java/io/ray/runtime/ConcurrencyGroupImpl.java @@ -24,7 +24,7 @@ public ConcurrencyGroupImpl(String name, int maxConcurrency, List funcs funcs.forEach( func -> { RayFunction rayFunc = - ((AbstractRayRuntime) Ray.internal()).getFunctionManager().getFunction(func); + ((RayRuntimeInternal) Ray.internal()).getFunctionManager().getFunction(func); functionDescriptors.add(rayFunc.getFunctionDescriptor()); }); } diff --git a/java/runtime/src/main/java/io/ray/runtime/DefaultRayRuntimeFactory.java b/java/runtime/src/main/java/io/ray/runtime/DefaultRayRuntimeFactory.java index e9ecc0889d9a3..806ec020951fe 100644 --- a/java/runtime/src/main/java/io/ray/runtime/DefaultRayRuntimeFactory.java +++ b/java/runtime/src/main/java/io/ray/runtime/DefaultRayRuntimeFactory.java @@ -28,10 +28,11 @@ public RayRuntime createRayRuntime() { try { logger.debug("Initializing runtime with config: {}", rayConfig); - AbstractRayRuntime runtime = + AbstractRayRuntime innerRuntime = rayConfig.runMode == RunMode.LOCAL ? new RayDevRuntime(rayConfig) : new RayNativeRuntime(rayConfig); + RayRuntimeInternal runtime = innerRuntime; runtime.start(); return runtime; } catch (Exception e) { diff --git a/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java index 6b0032594fcd9..2f245273028d4 100644 --- a/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java @@ -288,6 +288,8 @@ private static native void nativeInitialize( private static native byte[] nativeGetActorIdOfNamedActor(String actorName, String namespace); + private static native void nativeSetCoreWorker(byte[] workerId); + private static native Map> nativeGetResourceIds(); private static native String nativeGetNamespace(); diff --git a/java/runtime/src/main/java/io/ray/runtime/RayRuntimeInternal.java b/java/runtime/src/main/java/io/ray/runtime/RayRuntimeInternal.java new file mode 100644 index 0000000000000..fd1a23b90b3be --- /dev/null +++ b/java/runtime/src/main/java/io/ray/runtime/RayRuntimeInternal.java @@ -0,0 +1,30 @@ +package io.ray.runtime; + +import io.ray.api.runtime.RayRuntime; +import io.ray.runtime.config.RayConfig; +import io.ray.runtime.context.WorkerContext; +import io.ray.runtime.functionmanager.FunctionManager; +import io.ray.runtime.gcs.GcsClient; +import io.ray.runtime.object.ObjectStore; +import io.ray.runtime.task.TaskExecutor; + +/** This interface is required to make {@link RayRuntimeProxy} work. */ +public interface RayRuntimeInternal extends RayRuntime { + + /** Start runtime. */ + void start(); + + WorkerContext getWorkerContext(); + + ObjectStore getObjectStore(); + + TaskExecutor getTaskExecutor(); + + FunctionManager getFunctionManager(); + + RayConfig getRayConfig(); + + GcsClient getGcsClient(); + + void run(); +} diff --git a/java/runtime/src/main/java/io/ray/runtime/actor/NativeActorHandle.java b/java/runtime/src/main/java/io/ray/runtime/actor/NativeActorHandle.java index f5ee9deffea84..10b1137505b5a 100644 --- a/java/runtime/src/main/java/io/ray/runtime/actor/NativeActorHandle.java +++ b/java/runtime/src/main/java/io/ray/runtime/actor/NativeActorHandle.java @@ -8,7 +8,7 @@ import io.ray.api.Ray; import io.ray.api.id.ActorId; import io.ray.api.id.ObjectId; -import io.ray.runtime.AbstractRayRuntime; +import io.ray.runtime.RayRuntimeInternal; import io.ray.runtime.generated.Common.Language; import java.io.Externalizable; import java.io.IOException; @@ -122,7 +122,7 @@ private static final class NativeActorHandleReference public NativeActorHandleReference(NativeActorHandle handle) { super(handle, REFERENCE_QUEUE); this.actorId = handle.actorId; - AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal(); + RayRuntimeInternal runtime = (RayRuntimeInternal) Ray.internal(); this.workerId = runtime.getWorkerContext().getCurrentWorkerId().getBytes(); this.removed = new AtomicBoolean(false); REFERENCES.add(this); diff --git a/java/runtime/src/main/java/io/ray/runtime/context/RuntimeContextImpl.java b/java/runtime/src/main/java/io/ray/runtime/context/RuntimeContextImpl.java index 41648ad0753b2..ba10acc0a1053 100644 --- a/java/runtime/src/main/java/io/ray/runtime/context/RuntimeContextImpl.java +++ b/java/runtime/src/main/java/io/ray/runtime/context/RuntimeContextImpl.java @@ -8,7 +8,7 @@ import io.ray.api.runtimecontext.NodeInfo; import io.ray.api.runtimecontext.ResourceValue; import io.ray.api.runtimecontext.RuntimeContext; -import io.ray.runtime.AbstractRayRuntime; +import io.ray.runtime.RayRuntimeInternal; import io.ray.runtime.config.RunMode; import io.ray.runtime.util.ResourceUtil; import java.util.ArrayList; @@ -21,9 +21,9 @@ public class RuntimeContextImpl implements RuntimeContext { - private AbstractRayRuntime runtime; + private RayRuntimeInternal runtime; - public RuntimeContextImpl(AbstractRayRuntime runtime) { + public RuntimeContextImpl(RayRuntimeInternal runtime) { this.runtime = runtime; } diff --git a/java/runtime/src/main/java/io/ray/runtime/object/NativeObjectStore.java b/java/runtime/src/main/java/io/ray/runtime/object/NativeObjectStore.java index ef48447d5ac84..bf99e6f2ac7b0 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/NativeObjectStore.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/NativeObjectStore.java @@ -6,7 +6,7 @@ import io.ray.api.id.BaseId; import io.ray.api.id.ObjectId; import io.ray.api.id.UniqueId; -import io.ray.runtime.AbstractRayRuntime; +import io.ray.runtime.RayRuntimeInternal; import io.ray.runtime.context.WorkerContext; import io.ray.runtime.generated.Common.Address; import java.util.HashMap; @@ -40,7 +40,7 @@ public ObjectId putRaw(NativeRayObject obj) { @Override public ObjectId putRaw(NativeRayObject obj, ActorId ownerActorId) { byte[] serializedOwnerAddressBytes = - ((AbstractRayRuntime) Ray.internal()).getGcsClient().getActorAddress(ownerActorId); + ((RayRuntimeInternal) Ray.internal()).getGcsClient().getActorAddress(ownerActorId); return new ObjectId(nativePut(obj, serializedOwnerAddressBytes)); } diff --git a/java/runtime/src/main/java/io/ray/runtime/object/ObjectRefImpl.java b/java/runtime/src/main/java/io/ray/runtime/object/ObjectRefImpl.java index 51bf3c20dc7dc..6fb64e8055ca6 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/ObjectRefImpl.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/ObjectRefImpl.java @@ -8,7 +8,7 @@ import io.ray.api.Ray; import io.ray.api.id.ObjectId; import io.ray.api.id.UniqueId; -import io.ray.runtime.AbstractRayRuntime; +import io.ray.runtime.RayRuntimeInternal; import java.io.Externalizable; import java.io.IOException; import java.io.ObjectInput; @@ -60,7 +60,7 @@ public ObjectRefImpl(ObjectId id, Class type) { public void init(ObjectId id, Class type, boolean skipAddingLocalRef) { this.id = id; this.type = (Class) type; - AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal(); + RayRuntimeInternal runtime = (RayRuntimeInternal) Ray.internal(); Preconditions.checkState(workerId == null); workerId = runtime.getWorkerContext().getCurrentWorkerId(); @@ -106,7 +106,7 @@ public String toString() { public void writeExternal(ObjectOutput out) throws IOException { out.writeObject(this.getId()); out.writeObject(this.getType()); - AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal(); + RayRuntimeInternal runtime = (RayRuntimeInternal) Ray.internal(); byte[] ownerAddress = runtime.getObjectStore().getOwnershipInfo(this.getId()); out.writeInt(ownerAddress.length); out.write(ownerAddress); @@ -121,7 +121,7 @@ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundExcept byte[] ownerAddress = new byte[len]; in.readFully(ownerAddress); - AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal(); + RayRuntimeInternal runtime = (RayRuntimeInternal) Ray.internal(); Preconditions.checkState(workerId == null); workerId = runtime.getWorkerContext().getCurrentWorkerId(); runtime.getObjectStore().addLocalReference(workerId, id); @@ -156,7 +156,7 @@ public void finalizeReferent() { REFERENCES.remove(this); // It's possible that GC is executed after the runtime is shutdown. if (Ray.isInitialized()) { - ((AbstractRayRuntime) (Ray.internal())) + ((RayRuntimeInternal) (Ray.internal())) .getObjectStore() .removeLocalReference(workerId, objectId); allObjects.remove(objectId); diff --git a/java/runtime/src/main/java/io/ray/runtime/runner/worker/DefaultWorker.java b/java/runtime/src/main/java/io/ray/runtime/runner/worker/DefaultWorker.java index 3e88bfe484480..6352011474431 100644 --- a/java/runtime/src/main/java/io/ray/runtime/runner/worker/DefaultWorker.java +++ b/java/runtime/src/main/java/io/ray/runtime/runner/worker/DefaultWorker.java @@ -1,7 +1,7 @@ package io.ray.runtime.runner.worker; import io.ray.api.Ray; -import io.ray.runtime.AbstractRayRuntime; +import io.ray.runtime.RayRuntimeInternal; /** Default implementation of the worker process. */ public class DefaultWorker { @@ -12,6 +12,6 @@ public static void main(String[] args) { System.setProperty("ray.run-mode", "CLUSTER"); System.setProperty("ray.worker.mode", "WORKER"); Ray.init(); - ((AbstractRayRuntime) Ray.internal()).run(); + ((RayRuntimeInternal) Ray.internal()).run(); } } diff --git a/java/runtime/src/main/java/io/ray/runtime/task/ArgumentsBuilder.java b/java/runtime/src/main/java/io/ray/runtime/task/ArgumentsBuilder.java index af46661a708a6..1c45d934038c4 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/ArgumentsBuilder.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/ArgumentsBuilder.java @@ -5,7 +5,7 @@ import io.ray.api.ObjectRef; import io.ray.api.Ray; import io.ray.api.id.ObjectId; -import io.ray.runtime.AbstractRayRuntime; +import io.ray.runtime.RayRuntimeInternal; import io.ray.runtime.generated.Common.Address; import io.ray.runtime.generated.Common.Language; import io.ray.runtime.object.NativeRayObject; @@ -41,7 +41,7 @@ public static List wrap(Object[] args, Language language) { if (arg instanceof ObjectRef) { Preconditions.checkState(arg instanceof ObjectRefImpl); id = ((ObjectRefImpl) arg).getId(); - address = ((AbstractRayRuntime) Ray.internal()).getObjectStore().getOwnerAddress(id); + address = ((RayRuntimeInternal) Ray.internal()).getObjectStore().getOwnerAddress(id); } else { value = ObjectSerializer.serialize(arg); if (language != Language.JAVA) { @@ -60,8 +60,8 @@ public static List wrap(Object[] args, Language language) { } } if (value.data.length > LARGEST_SIZE_PASS_BY_VALUE) { - id = ((AbstractRayRuntime) Ray.internal()).getObjectStore().putRaw(value); - address = ((AbstractRayRuntime) Ray.internal()).getWorkerContext().getRpcAddress(); + id = ((RayRuntimeInternal) Ray.internal()).getObjectStore().putRaw(value); + address = ((RayRuntimeInternal) Ray.internal()).getWorkerContext().getRpcAddress(); value = null; } } diff --git a/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskExecutor.java b/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskExecutor.java index 94830b316f490..90d93b0a405a1 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskExecutor.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskExecutor.java @@ -1,7 +1,7 @@ package io.ray.runtime.task; import io.ray.api.id.UniqueId; -import io.ray.runtime.AbstractRayRuntime; +import io.ray.runtime.RayRuntimeInternal; /** Task executor for local mode. */ public class LocalModeTaskExecutor extends TaskExecutor { @@ -20,7 +20,7 @@ public UniqueId getWorkerId() { } } - public LocalModeTaskExecutor(AbstractRayRuntime runtime) { + public LocalModeTaskExecutor(RayRuntimeInternal runtime) { super(runtime); } diff --git a/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java b/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java index a44df04c12110..c6376c6390826 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java @@ -14,8 +14,8 @@ import io.ray.api.options.CallOptions; import io.ray.api.options.PlacementGroupCreationOptions; import io.ray.api.placementgroup.PlacementGroup; -import io.ray.runtime.AbstractRayRuntime; import io.ray.runtime.ConcurrencyGroupImpl; +import io.ray.runtime.RayRuntimeInternal; import io.ray.runtime.actor.LocalModeActorHandle; import io.ray.runtime.context.LocalModeWorkerContext; import io.ray.runtime.functionmanager.FunctionDescriptor; @@ -59,7 +59,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { private final Map> waitingTasks = new HashMap<>(); private final Object taskAndObjectLock = new Object(); - private final AbstractRayRuntime runtime; + private final RayRuntimeInternal runtime; private final TaskExecutor taskExecutor; private final LocalModeObjectStore objectStore; @@ -169,7 +169,7 @@ public synchronized void shutdown() { } public LocalModeTaskSubmitter( - AbstractRayRuntime runtime, TaskExecutor taskExecutor, LocalModeObjectStore objectStore) { + RayRuntimeInternal runtime, TaskExecutor taskExecutor, LocalModeObjectStore objectStore) { this.runtime = runtime; this.taskExecutor = taskExecutor; this.objectStore = objectStore; diff --git a/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskExecutor.java b/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskExecutor.java index 755dde2e3dad5..e13e98fd87161 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskExecutor.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskExecutor.java @@ -1,14 +1,14 @@ package io.ray.runtime.task; import io.ray.api.id.UniqueId; -import io.ray.runtime.AbstractRayRuntime; +import io.ray.runtime.RayRuntimeInternal; /** Task executor for cluster mode. */ public class NativeTaskExecutor extends TaskExecutor { static class NativeActorContext extends TaskExecutor.ActorContext {} - public NativeTaskExecutor(AbstractRayRuntime runtime) { + public NativeTaskExecutor(RayRuntimeInternal runtime) { super(runtime); } diff --git a/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java b/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java index 8a99006aa8711..1f13734e9423c 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java @@ -8,7 +8,7 @@ import io.ray.api.id.JobId; import io.ray.api.id.TaskId; import io.ray.api.id.UniqueId; -import io.ray.runtime.AbstractRayRuntime; +import io.ray.runtime.RayRuntimeInternal; import io.ray.runtime.functionmanager.JavaFunctionDescriptor; import io.ray.runtime.functionmanager.RayFunction; import io.ray.runtime.generated.Common.TaskType; @@ -32,7 +32,7 @@ public abstract class TaskExecutor { private static final Logger LOGGER = LoggerFactory.getLogger(TaskExecutor.class); - protected final AbstractRayRuntime runtime; + protected final RayRuntimeInternal runtime; // TODO(qwang): Use actorContext instead later. private final ConcurrentHashMap actorContextMap = new ConcurrentHashMap<>(); @@ -44,7 +44,7 @@ static class ActorContext { Object currentActor = null; } - TaskExecutor(AbstractRayRuntime runtime) { + TaskExecutor(RayRuntimeInternal runtime) { this.runtime = runtime; } diff --git a/java/runtime/src/main/java/io/ray/runtime/util/MethodUtils.java b/java/runtime/src/main/java/io/ray/runtime/util/MethodUtils.java index aad31daf1890f..b6523562f5e68 100644 --- a/java/runtime/src/main/java/io/ray/runtime/util/MethodUtils.java +++ b/java/runtime/src/main/java/io/ray/runtime/util/MethodUtils.java @@ -1,7 +1,7 @@ package io.ray.runtime.util; import io.ray.api.Ray; -import io.ray.runtime.AbstractRayRuntime; +import io.ray.runtime.RayRuntimeInternal; import java.lang.reflect.Array; import java.lang.reflect.Field; import java.lang.reflect.Method; @@ -47,7 +47,7 @@ public static Class getReturnTypeFromSignature(String signature) { /// This code path indicates that here might be in another thread of a worker. /// So try to load the class from URLClassLoader of this worker. ClassLoader cl = - ((AbstractRayRuntime) Ray.internal()).getFunctionManager().getClassLoader(); + ((RayRuntimeInternal) Ray.internal()).getFunctionManager().getClassLoader(); actorClz = Class.forName(className, true, cl); } } catch (Exception e) { diff --git a/java/runtime/src/main/java/io/ray/runtime/utils/parallelactor/ParallelActorContextImpl.java b/java/runtime/src/main/java/io/ray/runtime/utils/parallelactor/ParallelActorContextImpl.java index 6bcd5c18b5b9e..9da1f4cd0e3e9 100644 --- a/java/runtime/src/main/java/io/ray/runtime/utils/parallelactor/ParallelActorContextImpl.java +++ b/java/runtime/src/main/java/io/ray/runtime/utils/parallelactor/ParallelActorContextImpl.java @@ -8,7 +8,7 @@ import io.ray.api.function.RayFunc; import io.ray.api.function.RayFuncR; import io.ray.api.parallelactor.*; -import io.ray.runtime.AbstractRayRuntime; +import io.ray.runtime.RayRuntimeInternal; import io.ray.runtime.functionmanager.FunctionManager; import io.ray.runtime.functionmanager.JavaFunctionDescriptor; @@ -26,7 +26,7 @@ public ParallelActorHandle createParallelActorExecutor( .build(); } - FunctionManager functionManager = ((AbstractRayRuntime) Ray.internal()).getFunctionManager(); + FunctionManager functionManager = ((RayRuntimeInternal) Ray.internal()).getFunctionManager(); JavaFunctionDescriptor functionDescriptor = functionManager.getFunction(ctorFunc).getFunctionDescriptor(); ActorHandle parallelExecutorHandle = @@ -42,7 +42,7 @@ public ObjectRef submitTask( ParallelActorHandle parallelActorHandle, int instanceId, RayFunc func, Object[] args) { ActorHandle parallelExecutor = ((ParallelActorHandleImpl) parallelActorHandle).getExecutor(); - FunctionManager functionManager = ((AbstractRayRuntime) Ray.internal()).getFunctionManager(); + FunctionManager functionManager = ((RayRuntimeInternal) Ray.internal()).getFunctionManager(); JavaFunctionDescriptor functionDescriptor = functionManager.getFunction(func).getFunctionDescriptor(); ObjectRef ret = diff --git a/java/runtime/src/main/java/io/ray/runtime/utils/parallelactor/ParallelActorExecutorImpl.java b/java/runtime/src/main/java/io/ray/runtime/utils/parallelactor/ParallelActorExecutorImpl.java index 91020366fd24a..3836303e13e94 100644 --- a/java/runtime/src/main/java/io/ray/runtime/utils/parallelactor/ParallelActorExecutorImpl.java +++ b/java/runtime/src/main/java/io/ray/runtime/utils/parallelactor/ParallelActorExecutorImpl.java @@ -2,7 +2,7 @@ import com.google.common.base.Preconditions; import io.ray.api.Ray; -import io.ray.runtime.AbstractRayRuntime; +import io.ray.runtime.RayRuntimeInternal; import io.ray.runtime.functionmanager.FunctionManager; import io.ray.runtime.functionmanager.JavaFunctionDescriptor; import io.ray.runtime.functionmanager.RayFunction; @@ -22,7 +22,7 @@ public class ParallelActorExecutorImpl { public ParallelActorExecutorImpl(int parallelism, JavaFunctionDescriptor javaFunctionDescriptor) throws InvocationTargetException, IllegalAccessException { - functionManager = ((AbstractRayRuntime) Ray.internal()).getFunctionManager(); + functionManager = ((RayRuntimeInternal) Ray.internal()).getFunctionManager(); RayFunction init = functionManager.getFunction(javaFunctionDescriptor); Thread.currentThread().setContextClassLoader(init.classLoader); for (int i = 0; i < parallelism; ++i) { diff --git a/java/test/src/main/java/io/ray/test/TestUtils.java b/java/test/src/main/java/io/ray/test/TestUtils.java index 408f189447924..d302a0baee57f 100644 --- a/java/test/src/main/java/io/ray/test/TestUtils.java +++ b/java/test/src/main/java/io/ray/test/TestUtils.java @@ -3,7 +3,7 @@ import com.google.common.base.Preconditions; import io.ray.api.ObjectRef; import io.ray.api.Ray; -import io.ray.runtime.AbstractRayRuntime; +import io.ray.runtime.RayRuntimeInternal; import io.ray.runtime.config.RayConfig; import io.ray.runtime.config.RunMode; import io.ray.runtime.task.ArgumentsBuilder; @@ -122,8 +122,12 @@ public static void warmUpCluster() { Assert.assertEquals(obj.get(), "hi"); } - public static AbstractRayRuntime getRuntime() { - return (AbstractRayRuntime) Ray.internal(); + public static RayRuntimeInternal getRuntime() { + return (RayRuntimeInternal) Ray.internal(); + } + + public static RayRuntimeInternal getUnderlyingRuntime() { + return (RayRuntimeInternal) Ray.internal(); } public static ProcessBuilder buildDriver(Class mainClass, String[] args) { diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc index 1521b5d3f82db..c467b4ecaba0c 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc @@ -390,6 +390,12 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeKillActor( THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); } +JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeSetCoreWorker( + JNIEnv *env, jclass, jbyteArray workerId) { + const auto worker_id = JavaByteArrayToId(env, workerId); + CoreWorkerProcess::SetCurrentThreadWorkerId(worker_id); +} + JNIEXPORT jobject JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeGetResourceIds(JNIEnv *env, jclass) { auto key_converter = [](JNIEnv *env, const std::string &str) -> jstring { diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h index b620a21e355a9..6650799ce2488 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h @@ -79,6 +79,14 @@ Java_io_ray_runtime_RayNativeRuntime_nativeGetActorIdOfNamedActor(JNIEnv *, jstring, jstring); +/* + * Class: io_ray_runtime_RayNativeRuntime + * Method: nativeSetCoreWorker + * Signature: ([B)V + */ +JNIEXPORT void JNICALL +Java_io_ray_runtime_RayNativeRuntime_nativeSetCoreWorker(JNIEnv *, jclass, jbyteArray); + /* * Class: io_ray_runtime_RayNativeRuntime * Method: nativeGetResourceIds