From 5233178a41a3c2cd4365950dcc4fa5e3a1fec19a Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 10 May 2025 20:21:26 -0400 Subject: [PATCH] [FFI][JVM] Upgrade tvm4j to latest FFI This PR updates TVM4J to use the latest FFI --- jvm/README.md | 2 +- .../src/main/java/org/apache/tvm/Base.java | 35 +- .../src/main/java/org/apache/tvm/Device.java | 31 +- .../main/java/org/apache/tvm/Function.java | 165 +++---- .../src/main/java/org/apache/tvm/LibInfo.java | 57 ++- .../src/main/java/org/apache/tvm/Module.java | 43 +- .../src/main/java/org/apache/tvm/NDArray.java | 42 +- .../main/java/org/apache/tvm/NDArrayBase.java | 51 +- .../tvm/{ArgTypeCode.java => TVMObject.java} | 27 +- .../main/java/org/apache/tvm/TVMValue.java | 4 +- .../java/org/apache/tvm/TVMValueBytes.java | 1 - .../java/org/apache/tvm/TVMValueDouble.java | 1 - .../java/org/apache/tvm/TVMValueHandle.java | 1 - .../java/org/apache/tvm/TVMValueLong.java | 1 - .../java/org/apache/tvm/TVMValueNull.java | 1 - .../java/org/apache/tvm/TVMValueString.java | 1 - .../main/java/org/apache/tvm/TypeIndex.java | 44 ++ .../main/java/org/apache/tvm/rpc/Client.java | 3 + .../java/org/apache/tvm/rpc/RPCSession.java | 7 +- .../java/org/apache/tvm/FunctionTest.java | 2 + .../test/java/org/apache/tvm/ModuleTest.java | 3 - .../test/java/org/apache/tvm/rpc/RPCTest.java | 2 + .../src/test/scripts/prepare_test_libs.py | 83 ++++ jvm/native/linux-x86_64/pom.xml | 2 + jvm/native/osx-x86_64/pom.xml | 2 + jvm/native/src/main/native/jni_helper_func.h | 111 +++-- .../native/org_apache_tvm_native_c_api.cc | 442 ++++++++---------- tests/scripts/task_java_unittest.sh | 23 +- 28 files changed, 566 insertions(+), 621 deletions(-) rename jvm/core/src/main/java/org/apache/tvm/{ArgTypeCode.java => TVMObject.java} (65%) create mode 100644 jvm/core/src/main/java/org/apache/tvm/TypeIndex.java create mode 100644 jvm/core/src/test/scripts/prepare_test_libs.py diff --git a/jvm/README.md b/jvm/README.md index 0f53f4e561a2..71c737a4d00a 100644 --- a/jvm/README.md +++ b/jvm/README.md @@ -39,7 +39,7 @@ TVM4J contains three modules: - core * It contains all the Java interfaces. - native - * The JNI native library is compiled in this module. It does not link TVM runtime library (libtvm\_runtime.so for Linux and libtvm\_runtime.dylib for OSX). Instead, you have to specify `libtvm.so.path` which contains the TVM runtime library as Java system property. + * The JNI native library is compiled in this module. Need to expose libtvm_runtime to LD_LIBRARY_PATH - assembly * It assembles Java interfaces (core), JNI library (native) and TVM runtime library together. The simplest way to integrate tvm4j in your project is to rely on this module. It automatically extracts the native library to a tempfile and load it. diff --git a/jvm/core/src/main/java/org/apache/tvm/Base.java b/jvm/core/src/main/java/org/apache/tvm/Base.java index f5e677a2e0b3..97ae274a565c 100644 --- a/jvm/core/src/main/java/org/apache/tvm/Base.java +++ b/jvm/core/src/main/java/org/apache/tvm/Base.java @@ -87,37 +87,8 @@ public RefTVMValue() { } System.err.println("libtvm4j loads successfully."); - - if (loadNativeRuntimeLib) { - String tvmLibFilename = System.getProperty("libtvm.so.path"); - if (tvmLibFilename == null || !new File(tvmLibFilename).isFile() - || _LIB.nativeLibInit(tvmLibFilename) != 0) { - try { - String runtimeLibname; - String os = System.getProperty("os.name"); - // ref: http://lopica.sourceforge.net/os.html - if (os.startsWith("Linux")) { - runtimeLibname = "libtvm_runtime.so"; - } else if (os.startsWith("Mac")) { - runtimeLibname = "libtvm_runtime.dylib"; - } else { - // TODO(yizhi) support windows later - throw new UnsatisfiedLinkError(os + " not supported currently"); - } - NativeLibraryLoader.extractResourceFileToTempDir(runtimeLibname, new Action() { - @Override public void invoke(File target) { - System.err.println("Loading tvm runtime from " + target.getPath()); - checkCall(_LIB.nativeLibInit(target.getPath())); - } - }); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - } else { - _LIB.nativeLibInit(null); - } - + // always use linked lib + _LIB.nativeLibInit(null); Runtime.getRuntime().addShutdownHook(new Thread() { @Override public void run() { _LIB.shutdown(); @@ -170,7 +141,7 @@ private static void tryLoadLibraryXPU(String libname, String arch) throws Unsati */ public static void checkCall(int ret) throws TVMError { if (ret != 0) { - throw new TVMError(_LIB.tvmGetLastError()); + throw new TVMError(_LIB.tvmFFIGetLastError()); } } diff --git a/jvm/core/src/main/java/org/apache/tvm/Device.java b/jvm/core/src/main/java/org/apache/tvm/Device.java index 70fe13cec906..2396df94fbf0 100644 --- a/jvm/core/src/main/java/org/apache/tvm/Device.java +++ b/jvm/core/src/main/java/org/apache/tvm/Device.java @@ -17,18 +17,30 @@ package org.apache.tvm; +import org.apache.tvm.rpc.RPC; + import java.util.HashMap; import java.util.Map; -import org.apache.tvm.rpc.RPC; public class Device { /** * Provides the same information as the C++ enums DLDeviceType and * TVMDeviceExtType. */ - static final int kDLCPU = 1, kDLCUDA = 2, kDLCUDAHost = 3, kDLOpenCL = 4, kDLVulkan = 7, - kDLMetal = 8, kDLVPI = 9, kDLROCM = 10, kDLROCMHost = 11, kDLExtDev = 12, - kDLCUDAManaged = 13, kDLOneAPI = 14, kDLWebGPU = 15, kDLHexagon = 16; + static final int kDLCPU = 1; + static final int kDLCUDA = 2; + static final int kDLCUDAHost = 3; + static final int kDLOpenCL = 4; + static final int kDLVulkan = 7; + static final int kDLMetal = 8; + static final int kDLVPI = 9; + static final int kDLROCM = 10; + static final int kDLROCMHost = 11; + static final int kDLExtDev = 12; + static final int kDLCUDAManaged = 13; + static final int kDLOneAPI = 14; + static final int kDLWebGPU = 15; + static final int kDLHexagon = 16; private static final Map DEVICE_TYPE_TO_NAME = new HashMap(); private static final Map DEVICE_NAME_TO_TYPE = new HashMap(); @@ -161,7 +173,8 @@ public Device(String deviceType, int deviceId) { */ public boolean exist() { TVMValue ret = - APIInternal.get("_GetDeviceAttr").pushArg(deviceType).pushArg(deviceId).pushArg(0).invoke(); + APIInternal.get("runtime.GetDeviceAttr").pushArg(deviceType) + .pushArg(deviceId).pushArg(0).invoke(); return ((TVMValueLong) ret).value != 0; } @@ -171,7 +184,8 @@ public boolean exist() { */ public long maxThreadsPerBlock() { TVMValue ret = - APIInternal.get("_GetDeviceAttr").pushArg(deviceType).pushArg(deviceId).pushArg(1).invoke(); + APIInternal.get("runtime.GetDeviceAttr").pushArg(deviceType) + .pushArg(deviceId).pushArg(1).invoke(); return ((TVMValueLong) ret).value; } @@ -181,8 +195,9 @@ public long maxThreadsPerBlock() { */ public long warpSize() { TVMValue ret = - APIInternal.get("_GetDeviceAttr").pushArg(deviceType).pushArg(deviceId).pushArg(2).invoke(); - return ((TVMValueLong) ret).value; + APIInternal.get("runtime.GetDeviceAttr").pushArg(deviceType) + .pushArg(deviceId).pushArg(2).invoke(); + return ret.asLong(); } /** diff --git a/jvm/core/src/main/java/org/apache/tvm/Function.java b/jvm/core/src/main/java/org/apache/tvm/Function.java index 594b35b0af68..ee6b8e8cf5c5 100644 --- a/jvm/core/src/main/java/org/apache/tvm/Function.java +++ b/jvm/core/src/main/java/org/apache/tvm/Function.java @@ -24,23 +24,14 @@ /** * TVM Packed Function. */ -public class Function extends TVMValue { - final long handle; - public final boolean isResident; - private boolean isReleased = false; - +public class Function extends TVMObject { /** * Get registered function. * @param name full function name. * @return TVM function. */ public static Function getFunction(final String name) { - for (String fullName : listGlobalFuncNames()) { - if (fullName.equals(name)) { - return getGlobalFunc(fullName, true, false); - } - } - return null; + return getGlobalFunc(name, true); } /** @@ -49,22 +40,21 @@ public static Function getFunction(final String name) { */ private static List listGlobalFuncNames() { List names = new ArrayList(); - Base.checkCall(Base._LIB.tvmFuncListGlobalNames(names)); + Base.checkCall(Base._LIB.tvmFFIFunctionListGlobalNames(names)); return Collections.unmodifiableList(names); } /** * Get a global function by name. * @param name The name of the function. - * @param isResident Whether it is a global 'resident' function. * @param allowMissing Whether allow missing function or raise an error. * @return The function to be returned, None if function is missing. */ - private static Function getGlobalFunc(String name, boolean isResident, boolean allowMissing) { + private static Function getGlobalFunc(String name, boolean allowMissing) { Base.RefLong handle = new Base.RefLong(); - Base.checkCall(Base._LIB.tvmFuncGetGlobal(name, handle)); + Base.checkCall(Base._LIB.tvmFFIFunctionGetGlobal(name, handle)); if (handle.value != 0) { - return new Function(handle.value, isResident); + return new Function(handle.value); } else { if (allowMissing) { return null; @@ -74,24 +64,8 @@ private static Function getGlobalFunc(String name, boolean isResident, boolean a } } - /** - * Initialize the function with handle. - * @param handle the handle to the underlying function. - * @param isResident Whether this is a resident function in jvm - */ - Function(long handle, boolean isResident) { - super(ArgTypeCode.FUNC_HANDLE); - this.handle = handle; - this.isResident = isResident; - } - Function(long handle) { - this(handle, false); - } - - @Override protected void finalize() throws Throwable { - release(); - super.finalize(); + super(handle, TypeIndex.kTVMFFIFunction); } /** @@ -102,32 +76,13 @@ private static Function getGlobalFunc(String name, boolean isResident, boolean a return this; } - @Override long asHandle() { - return handle; - } - - /** - * Release the Function. - *

- * We highly recommend you to do this manually since the GC strategy is lazy. - *

- */ - @Override public void release() { - if (!isReleased) { - if (!isResident) { - Base.checkCall(Base._LIB.tvmFuncFree(handle)); - isReleased = true; - } - } - } - /** * Invoke the function. * @return the result. */ public TVMValue invoke() { Base.RefTVMValue ret = new Base.RefTVMValue(); - Base.checkCall(Base._LIB.tvmFuncCall(handle, ret)); + Base.checkCall(Base._LIB.tvmFFIFunctionCall(handle, ret)); return ret.value; } @@ -137,7 +92,7 @@ public TVMValue invoke() { * @return this */ public Function pushArg(int arg) { - Base._LIB.tvmFuncPushArgLong(arg); + Base._LIB.tvmFFIFunctionPushArgLong(arg); return this; } @@ -147,7 +102,7 @@ public Function pushArg(int arg) { * @return this */ public Function pushArg(long arg) { - Base._LIB.tvmFuncPushArgLong(arg); + Base._LIB.tvmFFIFunctionPushArgLong(arg); return this; } @@ -157,7 +112,7 @@ public Function pushArg(long arg) { * @return this */ public Function pushArg(float arg) { - Base._LIB.tvmFuncPushArgDouble(arg); + Base._LIB.tvmFFIFunctionPushArgDouble(arg); return this; } @@ -167,7 +122,7 @@ public Function pushArg(float arg) { * @return this */ public Function pushArg(double arg) { - Base._LIB.tvmFuncPushArgDouble(arg); + Base._LIB.tvmFFIFunctionPushArgDouble(arg); return this; } @@ -177,7 +132,7 @@ public Function pushArg(double arg) { * @return this */ public Function pushArg(String arg) { - Base._LIB.tvmFuncPushArgString(arg); + Base._LIB.tvmFFIFunctionPushArgString(arg); return this; } @@ -187,8 +142,11 @@ public Function pushArg(String arg) { * @return this */ public Function pushArg(NDArrayBase arg) { - int id = arg.isView ? ArgTypeCode.ARRAY_HANDLE.id : ArgTypeCode.NDARRAY_CONTAINER.id; - Base._LIB.tvmFuncPushArgHandle(arg.handle, id); + if (arg instanceof NDArray) { + Base._LIB.tvmFFIFunctionPushArgHandle(((NDArray) arg).handle, TypeIndex.kTVMFFINDArray); + } else { + Base._LIB.tvmFFIFunctionPushArgHandle(arg.dltensorHandle, TypeIndex.kTVMFFIDLTensorPtr); + } return this; } @@ -198,7 +156,7 @@ public Function pushArg(NDArrayBase arg) { * @return this */ public Function pushArg(Module arg) { - Base._LIB.tvmFuncPushArgHandle(arg.handle, ArgTypeCode.MODULE_HANDLE.id); + Base._LIB.tvmFFIFunctionPushArgHandle(arg.handle, TypeIndex.kTVMFFIModule); return this; } @@ -208,7 +166,7 @@ public Function pushArg(Module arg) { * @return this */ public Function pushArg(Function arg) { - Base._LIB.tvmFuncPushArgHandle(arg.handle, ArgTypeCode.FUNC_HANDLE.id); + Base._LIB.tvmFFIFunctionPushArgHandle(arg.handle, TypeIndex.kTVMFFIFunction); return this; } @@ -218,7 +176,7 @@ public Function pushArg(Function arg) { * @return this */ public Function pushArg(byte[] arg) { - Base._LIB.tvmFuncPushArgBytes(arg); + Base._LIB.tvmFFIFunctionPushArgBytes(arg); return this; } @@ -228,7 +186,7 @@ public Function pushArg(byte[] arg) { * @return this */ public Function pushArg(Device arg) { - Base._LIB.tvmFuncPushArgDevice(arg); + Base._LIB.tvmFFIFunctionPushArgDevice(arg); return this; } @@ -245,53 +203,44 @@ public TVMValue call(Object... args) { } private static void pushArgToStack(Object arg) { - if (arg instanceof Integer) { - Base._LIB.tvmFuncPushArgLong((Integer) arg); + if (arg instanceof NDArrayBase) { + NDArrayBase nd = (NDArrayBase) arg; + if (nd instanceof NDArray) { + Base._LIB.tvmFFIFunctionPushArgHandle(((NDArray) nd).handle, TypeIndex.kTVMFFINDArray); + } else { + Base._LIB.tvmFFIFunctionPushArgHandle(nd.dltensorHandle, TypeIndex.kTVMFFIDLTensorPtr); + } + } else if (arg instanceof TVMObject) { + TVMObject obj = (TVMObject) arg; + Base._LIB.tvmFFIFunctionPushArgHandle(obj.handle, obj.typeIndex); + } else if (arg instanceof Integer) { + Base._LIB.tvmFFIFunctionPushArgLong((Integer) arg); } else if (arg instanceof Long) { - Base._LIB.tvmFuncPushArgLong((Long) arg); + Base._LIB.tvmFFIFunctionPushArgLong((Long) arg); } else if (arg instanceof Float) { - Base._LIB.tvmFuncPushArgDouble((Float) arg); + Base._LIB.tvmFFIFunctionPushArgDouble((Float) arg); } else if (arg instanceof Double) { - Base._LIB.tvmFuncPushArgDouble((Double) arg); + Base._LIB.tvmFFIFunctionPushArgDouble((Double) arg); } else if (arg instanceof String) { - Base._LIB.tvmFuncPushArgString((String) arg); + Base._LIB.tvmFFIFunctionPushArgString((String) arg); } else if (arg instanceof byte[]) { - Base._LIB.tvmFuncPushArgBytes((byte[]) arg); - } else if (arg instanceof NDArrayBase) { - NDArrayBase nd = (NDArrayBase) arg; - int id = nd.isView ? ArgTypeCode.ARRAY_HANDLE.id : ArgTypeCode.NDARRAY_CONTAINER.id; - Base._LIB.tvmFuncPushArgHandle(nd.handle, id); - } else if (arg instanceof Module) { - Base._LIB.tvmFuncPushArgHandle(((Module) arg).handle, ArgTypeCode.MODULE_HANDLE.id); - } else if (arg instanceof Function) { - Base._LIB.tvmFuncPushArgHandle(((Function) arg).handle, ArgTypeCode.FUNC_HANDLE.id); + Base._LIB.tvmFFIFunctionPushArgBytes((byte[]) arg); } else if (arg instanceof Device) { - Base._LIB.tvmFuncPushArgDevice((Device) arg); - } else if (arg instanceof TVMValue) { - TVMValue tvmArg = (TVMValue) arg; - switch (tvmArg.typeCode) { - case UINT: - case INT: - Base._LIB.tvmFuncPushArgLong(tvmArg.asLong()); - break; - case FLOAT: - Base._LIB.tvmFuncPushArgDouble(tvmArg.asDouble()); - break; - case STR: - Base._LIB.tvmFuncPushArgString(tvmArg.asString()); - break; - case BYTES: - Base._LIB.tvmFuncPushArgBytes(tvmArg.asBytes()); - break; - case HANDLE: - case ARRAY_HANDLE: - case MODULE_HANDLE: - case FUNC_HANDLE: - Base._LIB.tvmFuncPushArgHandle(tvmArg.asHandle(), tvmArg.typeCode.id); - break; - default: - throw new IllegalArgumentException("Invalid argument: " + arg); - } + Base._LIB.tvmFFIFunctionPushArgDevice((Device) arg); + } else if (arg instanceof TVMValueBytes) { + byte[] bytes = ((TVMValueBytes) arg).value; + Base._LIB.tvmFFIFunctionPushArgBytes(bytes); + } else if (arg instanceof TVMValueString) { + String str = ((TVMValueString) arg).value; + Base._LIB.tvmFFIFunctionPushArgString(str); + } else if (arg instanceof TVMValueDouble) { + double value = ((TVMValueDouble) arg).value; + Base._LIB.tvmFFIFunctionPushArgDouble(value); + } else if (arg instanceof TVMValueLong) { + long value = ((TVMValueLong) arg).value; + Base._LIB.tvmFFIFunctionPushArgLong(value); + } else if (arg instanceof TVMValueNull) { + Base._LIB.tvmFFIFunctionPushArgHandle(0, TypeIndex.kTVMFFINone); } else { throw new IllegalArgumentException("Invalid argument: " + arg); } @@ -309,9 +258,9 @@ public static interface Callback { */ public static void register(String name, Callback function, boolean override) { Base.RefLong createdFuncHandleRef = new Base.RefLong(); - Base.checkCall(Base._LIB.tvmFuncCreateFromCFunc(function, createdFuncHandleRef)); + Base.checkCall(Base._LIB.tvmFFIFunctionCreateFromCallback(function, createdFuncHandleRef)); int ioverride = override ? 1 : 0; - Base.checkCall(Base._LIB.tvmFuncRegisterGlobal(name, createdFuncHandleRef.value, ioverride)); + Base.checkCall(Base._LIB.tvmFFIFunctionSetGlobal(name, createdFuncHandleRef.value, ioverride)); } /** @@ -330,7 +279,7 @@ public static void register(String name, Callback function) { */ public static Function convertFunc(Callback function) { Base.RefLong createdFuncHandleRef = new Base.RefLong(); - Base.checkCall(Base._LIB.tvmFuncCreateFromCFunc(function, createdFuncHandleRef)); + Base.checkCall(Base._LIB.tvmFFIFunctionCreateFromCallback(function, createdFuncHandleRef)); return new Function(createdFuncHandleRef.value); } diff --git a/jvm/core/src/main/java/org/apache/tvm/LibInfo.java b/jvm/core/src/main/java/org/apache/tvm/LibInfo.java index aede9be334c8..f471883ca5bc 100644 --- a/jvm/core/src/main/java/org/apache/tvm/LibInfo.java +++ b/jvm/core/src/main/java/org/apache/tvm/LibInfo.java @@ -24,55 +24,50 @@ class LibInfo { native int shutdown(); - native String tvmGetLastError(); + native String tvmFFIGetLastError(); - // Function - native void tvmFuncPushArgLong(long arg); - - native void tvmFuncPushArgDouble(double arg); - - native void tvmFuncPushArgString(String arg); - - native void tvmFuncPushArgBytes(byte[] arg); + // Object + native int tvmFFIObjectFree(long handle); - native void tvmFuncPushArgHandle(long arg, int argType); + // Function + native void tvmFFIFunctionPushArgLong(long arg); - native void tvmFuncPushArgDevice(Device device); + native void tvmFFIFunctionPushArgDouble(double arg); - native int tvmFuncListGlobalNames(List funcNames); + native void tvmFFIFunctionPushArgString(String arg); - native int tvmFuncFree(long handle); + native void tvmFFIFunctionPushArgBytes(byte[] arg); - native int tvmFuncGetGlobal(String name, Base.RefLong handle); + native void tvmFFIFunctionPushArgHandle(long arg, int argTypeIndex); - native int tvmFuncCall(long handle, Base.RefTVMValue retVal); + native void tvmFFIFunctionPushArgDevice(Device device); - native int tvmFuncCreateFromCFunc(Function.Callback function, Base.RefLong handle); + native int tvmFFIFunctionListGlobalNames(List funcNames); - native int tvmFuncRegisterGlobal(String name, long handle, int override); + native int tvmFFIFunctionGetGlobal(String name, Base.RefLong handle); - // Module - native int tvmModFree(long handle); + native int tvmFFIFunctionSetGlobal(String name, long handle, int override); - native int tvmModGetFunction(long handle, String name, - int queryImports, Base.RefLong retHandle); + native int tvmFFIFunctionCall(long handle, Base.RefTVMValue retVal); - native int tvmModImport(long mod, long dep); + native int tvmFFIFunctionCreateFromCallback(Function.Callback function, Base.RefLong handle); // NDArray - native int tvmArrayFree(long handle); - - native int tvmArrayAlloc(long[] shape, int dtypeCode, int dtypeBits, int dtypeLanes, - int deviceType, int deviceId, Base.RefLong refHandle); + native int tvmFFIDLTensorGetShape(long handle, List shape); - native int tvmArrayGetShape(long handle, List shape); + native int tvmFFIDLTensorCopyFromTo(long from, long to); - native int tvmArrayCopyFromTo(long from, long to); + native int tvmFFIDLTensorCopyFromJArray(byte[] fromRaw, long to); - native int tvmArrayCopyFromJArray(byte[] fromRaw, long from, long to); - - native int tvmArrayCopyToJArray(long from, byte[] to); + native int tvmFFIDLTensorCopyToJArray(long from, byte[] to); + // the following functions are binded to keep things simpler + // One possibility is to enhance FFI to support shape directly + // so we do not need to run this binding through JNI // Device native int tvmSynchronize(int deviceType, int deviceId); + + native int tvmNDArrayEmpty(long[] shape, int dtypeCode, int dtypeBits, + int dtypeLanes, int deviceType, int deviceId, + Base.RefLong handle); } diff --git a/jvm/core/src/main/java/org/apache/tvm/Module.java b/jvm/core/src/main/java/org/apache/tvm/Module.java index 0682a6595a3e..5e78e26ae739 100644 --- a/jvm/core/src/main/java/org/apache/tvm/Module.java +++ b/jvm/core/src/main/java/org/apache/tvm/Module.java @@ -23,10 +23,7 @@ /** * Container of compiled functions of TVM. */ -public class Module extends TVMValue { - public final long handle; - private boolean isReleased = false; - +public class Module extends TVMObject { private static ThreadLocal> apiFuncs = new ThreadLocal>() { @Override @@ -45,17 +42,12 @@ private static Function getApi(String name) { } Module(long handle) { - super(ArgTypeCode.MODULE_HANDLE); - this.handle = handle; + super(handle, TypeIndex.kTVMFFIModule); } private Function entry = null; private final String entryName = "__tvm_main__"; - @Override protected void finalize() throws Throwable { - release(); - super.finalize(); - } /** * Easy for user to get the instance from returned TVMValue. @@ -65,23 +57,6 @@ private static Function getApi(String name) { return this; } - @Override long asHandle() { - return handle; - } - - /** - * Release the Module. - *

- * We highly recommend you to do this manually since the GC strategy is lazy. - *

- */ - @Override public void release() { - if (!isReleased) { - Base.checkCall(Base._LIB.tvmModFree(handle)); - isReleased = true; - } - } - /** * Get the entry function. * @return The entry function if exist @@ -100,13 +75,9 @@ public Function entryFunc() { * @return The result function. */ public Function getFunction(String name, boolean queryImports) { - Base.RefLong retHandle = new Base.RefLong(); - Base.checkCall(Base._LIB.tvmModGetFunction( - handle, name, queryImports ? 1 : 0, retHandle)); - if (retHandle.value == 0) { - throw new IllegalArgumentException("Module has no function " + name); - } - return new Function(retHandle.value, false); + TVMValue ret = getApi("ModuleGetFunction") + .pushArg(this).pushArg(name).pushArg(queryImports ? 1 : 0).invoke(); + return ret.asFunction(); } public Function getFunction(String name) { @@ -118,7 +89,8 @@ public Function getFunction(String name) { * @param module The other module. */ public void importModule(Module module) { - Base.checkCall(Base._LIB.tvmModImport(handle, module.handle)); + getApi("ModuleImport") + .pushArg(this).pushArg(module).invoke(); } /** @@ -138,7 +110,6 @@ public String typeKey() { */ public static Module load(String path, String fmt) { TVMValue ret = getApi("ModuleLoadFromFile").pushArg(path).pushArg(fmt).invoke(); - assert ret.typeCode == ArgTypeCode.MODULE_HANDLE; return ret.asModule(); } diff --git a/jvm/core/src/main/java/org/apache/tvm/NDArray.java b/jvm/core/src/main/java/org/apache/tvm/NDArray.java index 68020db03999..6b151d7bf9d2 100644 --- a/jvm/core/src/main/java/org/apache/tvm/NDArray.java +++ b/jvm/core/src/main/java/org/apache/tvm/NDArray.java @@ -35,11 +35,6 @@ public class NDArray extends NDArrayBase { this.device = dev; } - @Override - protected void finalize() throws Throwable { - super.finalize(); - } - /** * Copy from a native array. * The NDArray type must by float64 @@ -54,9 +49,7 @@ public void copyFrom(double[] sourceArray) { for (int i = 0; i < sourceArray.length; ++i) { wrapBytes(nativeArr, i * dtype.numOfBytes, dtype.numOfBytes).putDouble(sourceArray[i]); } - NDArray tmpArr = empty(shape(), this.dtype); - Base.checkCall(Base._LIB.tvmArrayCopyFromJArray(nativeArr, tmpArr.handle, handle)); - tmpArr.release(); + Base.checkCall(Base._LIB.tvmFFIDLTensorCopyFromJArray(nativeArr, this.dltensorHandle)); } /** @@ -73,9 +66,7 @@ public void copyFrom(float[] sourceArray) { for (int i = 0; i < sourceArray.length; ++i) { wrapBytes(nativeArr, i * dtype.numOfBytes, dtype.numOfBytes).putFloat(sourceArray[i]); } - NDArray tmpArr = empty(shape(), this.dtype); - Base.checkCall(Base._LIB.tvmArrayCopyFromJArray(nativeArr, tmpArr.handle, handle)); - tmpArr.release(); + Base.checkCall(Base._LIB.tvmFFIDLTensorCopyFromJArray(nativeArr, this.dltensorHandle)); } /** @@ -92,9 +83,7 @@ public void copyFrom(long[] sourceArray) { for (int i = 0; i < sourceArray.length; ++i) { wrapBytes(nativeArr, i * dtype.numOfBytes, dtype.numOfBytes).putLong(sourceArray[i]); } - NDArray tmpArr = empty(shape(), this.dtype); - Base.checkCall(Base._LIB.tvmArrayCopyFromJArray(nativeArr, tmpArr.handle, handle)); - tmpArr.release(); + Base.checkCall(Base._LIB.tvmFFIDLTensorCopyFromJArray(nativeArr, this.dltensorHandle)); } /** @@ -111,9 +100,7 @@ public void copyFrom(int[] sourceArray) { for (int i = 0; i < sourceArray.length; ++i) { wrapBytes(nativeArr, i * dtype.numOfBytes, dtype.numOfBytes).putInt(sourceArray[i]); } - NDArray tmpArr = empty(shape(), this.dtype); - Base.checkCall(Base._LIB.tvmArrayCopyFromJArray(nativeArr, tmpArr.handle, handle)); - tmpArr.release(); + Base.checkCall(Base._LIB.tvmFFIDLTensorCopyFromJArray(nativeArr, this.dltensorHandle)); } /** @@ -130,9 +117,7 @@ public void copyFrom(short[] sourceArray) { for (int i = 0; i < sourceArray.length; ++i) { wrapBytes(nativeArr, i * dtype.numOfBytes, dtype.numOfBytes).putShort(sourceArray[i]); } - NDArray tmpArr = empty(shape(), this.dtype); - Base.checkCall(Base._LIB.tvmArrayCopyFromJArray(nativeArr, tmpArr.handle, handle)); - tmpArr.release(); + Base.checkCall(Base._LIB.tvmFFIDLTensorCopyFromJArray(nativeArr, this.dltensorHandle)); } /** @@ -162,9 +147,7 @@ public void copyFrom(char[] sourceArray) { for (int i = 0; i < sourceArray.length; ++i) { wrapBytes(nativeArr, i * dtype.numOfBytes, dtype.numOfBytes).putChar(sourceArray[i]); } - NDArray tmpArr = empty(shape(), this.dtype); - Base.checkCall(Base._LIB.tvmArrayCopyFromJArray(nativeArr, tmpArr.handle, handle)); - tmpArr.release(); + Base.checkCall(Base._LIB.tvmFFIDLTensorCopyFromJArray(nativeArr, this.dltensorHandle)); } private void checkCopySize(int sourceLength) { @@ -180,9 +163,7 @@ private void checkCopySize(int sourceLength) { * @param sourceArray the source data */ public void copyFromRaw(byte[] sourceArray) { - NDArray tmpArr = empty(shape(), this.dtype); - Base.checkCall(Base._LIB.tvmArrayCopyFromJArray(sourceArray, tmpArr.handle, handle)); - tmpArr.release(); + Base.checkCall(Base._LIB.tvmFFIDLTensorCopyFromJArray(sourceArray, this.dltensorHandle)); } /** @@ -191,7 +172,7 @@ public void copyFromRaw(byte[] sourceArray) { */ public long[] shape() { List data = new ArrayList(); - Base.checkCall(Base._LIB.tvmArrayGetShape(handle, data)); + Base.checkCall(Base._LIB.tvmFFIDLTensorGetShape(this.dltensorHandle, data)); long[] shapeArr = new long[data.size()]; for (int i = 0; i < shapeArr.length; ++i) { shapeArr[i] = data.get(i); @@ -343,7 +324,7 @@ public byte[] internal() { int arrLength = dtype.numOfBytes * (int) size(); byte[] arr = new byte[arrLength]; - Base.checkCall(Base._LIB.tvmArrayCopyToJArray(tmp.handle, arr)); + Base.checkCall(Base._LIB.tvmFFIDLTensorCopyToJArray(this.dltensorHandle, arr)); return arr; } @@ -380,8 +361,9 @@ public Device device() { */ public static NDArray empty(long[] shape, TVMType dtype, Device dev) { Base.RefLong refHandle = new Base.RefLong(); - Base.checkCall(Base._LIB.tvmArrayAlloc( - shape, dtype.typeCode, dtype.bits, dtype.lanes, dev.deviceType, dev.deviceId, refHandle)); + Base.checkCall(Base._LIB.tvmNDArrayEmpty( + shape, dtype.typeCode, dtype.bits, + dtype.lanes, dev.deviceType, dev.deviceId, refHandle)); return new NDArray(refHandle.value, false, dtype, dev); } diff --git a/jvm/core/src/main/java/org/apache/tvm/NDArrayBase.java b/jvm/core/src/main/java/org/apache/tvm/NDArrayBase.java index 26bb735e1a5b..534dcb38d4a9 100644 --- a/jvm/core/src/main/java/org/apache/tvm/NDArrayBase.java +++ b/jvm/core/src/main/java/org/apache/tvm/NDArrayBase.java @@ -22,50 +22,27 @@ * Only deep-copy supported. */ public class NDArrayBase extends TVMValue { - protected final long handle; - protected final boolean isView; - private boolean isReleased = false; + protected long handle; + public final boolean isView; + protected final long dltensorHandle; NDArrayBase(long handle, boolean isView) { - super(ArgTypeCode.ARRAY_HANDLE); - this.handle = handle; + this.dltensorHandle = isView ? handle : handle + 8 * 2; + this.handle = isView ? 0 : handle; this.isView = isView; } - NDArrayBase(long handle) { - this(handle, true); - } - @Override public NDArrayBase asNDArray() { return this; } - @Override long asHandle() { - return handle; - } - - /** - * Copy array to target. - * @param target The target array to be copied, must have same shape as this array. - * @return target - */ - public NDArrayBase copyTo(NDArrayBase target) { - Base.checkCall(Base._LIB.tvmArrayCopyFromTo(handle, target.handle)); - return target; - } - /** - * Release the NDArray memory. - *

- * We highly recommend you to do this manually since the GC strategy is lazy. - *

+ * Release the NDArray. */ public void release() { - if (!isReleased) { - if (!isView) { - Base.checkCall(Base._LIB.tvmArrayFree(handle)); - isReleased = true; - } + if (this.handle != 0) { + Base.checkCall(Base._LIB.tvmFFIObjectFree(this.handle)); + this.handle = 0; } } @@ -73,4 +50,14 @@ public void release() { release(); super.finalize(); } + + /** + * Copy array to target. + * @param target The target array to be copied, must have same shape as this array. + * @return target + */ + public NDArrayBase copyTo(NDArrayBase target) { + Base.checkCall(Base._LIB.tvmFFIDLTensorCopyFromTo(this.dltensorHandle, target.dltensorHandle)); + return target; + } } diff --git a/jvm/core/src/main/java/org/apache/tvm/ArgTypeCode.java b/jvm/core/src/main/java/org/apache/tvm/TVMObject.java similarity index 65% rename from jvm/core/src/main/java/org/apache/tvm/ArgTypeCode.java rename to jvm/core/src/main/java/org/apache/tvm/TVMObject.java index ed6d0f1a0e12..c2b3f0eb497f 100644 --- a/jvm/core/src/main/java/org/apache/tvm/ArgTypeCode.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMObject.java @@ -17,20 +17,25 @@ package org.apache.tvm; -// Type code used in API calls -public enum ArgTypeCode { - INT(0), UINT(1), FLOAT(2), HANDLE(3), NULL(4), TVM_TYPE(5), - DLDEVICE(6), ARRAY_HANDLE(7), NODE_HANDLE(8), MODULE_HANDLE(9), - FUNC_HANDLE(10), STR(11), BYTES(12), NDARRAY_CONTAINER(13); +/** + * Base class of all TVM Objects. + */ +public class TVMObject extends TVMValue { + protected long handle; + public final int typeIndex; - public final int id; + public TVMObject(long handle, int typeIndex) { + this.handle = handle; + this.typeIndex = typeIndex; + } - private ArgTypeCode(int id) { - this.id = id; + public void release() { + Base.checkCall(Base._LIB.tvmFFIObjectFree(this.handle)); + this.handle = 0; } - @Override - public String toString() { - return String.valueOf(id); + @Override protected void finalize() throws Throwable { + release(); + super.finalize(); } } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValue.java b/jvm/core/src/main/java/org/apache/tvm/TVMValue.java index d30cfcc4f30a..45aef808f44c 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValue.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValue.java @@ -18,10 +18,8 @@ package org.apache.tvm; public class TVMValue { - public final ArgTypeCode typeCode; + protected TVMValue() { - public TVMValue(ArgTypeCode tc) { - typeCode = tc; } public void release() { diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValueBytes.java b/jvm/core/src/main/java/org/apache/tvm/TVMValueBytes.java index 132d88f7622b..253dcbe66c87 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValueBytes.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValueBytes.java @@ -21,7 +21,6 @@ public class TVMValueBytes extends TVMValue { public final byte[] value; public TVMValueBytes(byte[] value) { - super(ArgTypeCode.BYTES); this.value = value; } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValueDouble.java b/jvm/core/src/main/java/org/apache/tvm/TVMValueDouble.java index 9db4c3bb0e8c..16351b3244ea 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValueDouble.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValueDouble.java @@ -21,7 +21,6 @@ public class TVMValueDouble extends TVMValue { public final double value; public TVMValueDouble(double value) { - super(ArgTypeCode.FLOAT); this.value = value; } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValueHandle.java b/jvm/core/src/main/java/org/apache/tvm/TVMValueHandle.java index b91f55e2f59b..849510ec3078 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValueHandle.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValueHandle.java @@ -24,7 +24,6 @@ public class TVMValueHandle extends TVMValue { public final long value; public TVMValueHandle(long value) { - super(ArgTypeCode.HANDLE); this.value = value; } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValueLong.java b/jvm/core/src/main/java/org/apache/tvm/TVMValueLong.java index 8a9b157d3961..0c232adf42b8 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValueLong.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValueLong.java @@ -21,7 +21,6 @@ public class TVMValueLong extends TVMValue { public final long value; public TVMValueLong(long value) { - super(ArgTypeCode.INT); this.value = value; } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValueNull.java b/jvm/core/src/main/java/org/apache/tvm/TVMValueNull.java index 8c49ee5b3df5..45e85a160728 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValueNull.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValueNull.java @@ -19,6 +19,5 @@ public class TVMValueNull extends TVMValue { public TVMValueNull() { - super(ArgTypeCode.NULL); } } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValueString.java b/jvm/core/src/main/java/org/apache/tvm/TVMValueString.java index 46926e7d3fc6..c93a5600931e 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValueString.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValueString.java @@ -21,7 +21,6 @@ public class TVMValueString extends TVMValue { public final String value; public TVMValueString(String value) { - super(ArgTypeCode.STR); this.value = value; } diff --git a/jvm/core/src/main/java/org/apache/tvm/TypeIndex.java b/jvm/core/src/main/java/org/apache/tvm/TypeIndex.java new file mode 100644 index 000000000000..97169bb6c58c --- /dev/null +++ b/jvm/core/src/main/java/org/apache/tvm/TypeIndex.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.tvm; + +// Type code used in API calls +public class TypeIndex { + public static final int kTVMFFINone = 0; + public static final int kTVMFFIInt = 1; + public static final int kTVMFFIBool = 2; + public static final int kTVMFFIFloat = 3; + public static final int kTVMFFIOpaquePtr = 4; + public static final int kTVMFFIDataType = 5; + public static final int kTVMFFIDevice = 6; + public static final int kTVMFFIDLTensorPtr = 7; + public static final int kTVMFFIRawStr = 8; + public static final int kTVMFFIByteArrayPtr = 9; + public static final int kTVMFFIObjectRValueRef = 10; + public static final int kTVMFFIStaticObjectBegin = 64; + public static final int kTVMFFIObject = 64; + public static final int kTVMFFIStr = 65; + public static final int kTVMFFIBytes = 66; + public static final int kTVMFFIError = 67; + public static final int kTVMFFIFunction = 68; + public static final int kTVMFFIArray = 69; + public static final int kTVMFFIMap = 70; + public static final int kTVMFFIShape = 71; + public static final int kTVMFFINDArray = 72; + public static final int kTVMFFIModule = 73; +} diff --git a/jvm/core/src/main/java/org/apache/tvm/rpc/Client.java b/jvm/core/src/main/java/org/apache/tvm/rpc/Client.java index 69321c3b51c8..4b20362c2a47 100644 --- a/jvm/core/src/main/java/org/apache/tvm/rpc/Client.java +++ b/jvm/core/src/main/java/org/apache/tvm/rpc/Client.java @@ -20,6 +20,9 @@ import org.apache.tvm.Function; import org.apache.tvm.TVMValue; +/** + * RPC Client. + */ public class Client { /** * Connect to RPC Server. diff --git a/jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java b/jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java index 07278f07b8c2..f3cf95f931cb 100644 --- a/jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java +++ b/jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java @@ -17,15 +17,16 @@ package org.apache.tvm.rpc; +import org.apache.tvm.Device; +import org.apache.tvm.Function; +import org.apache.tvm.Module; + import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.util.HashMap; import java.util.Map; -import org.apache.tvm.Device; -import org.apache.tvm.Function; -import org.apache.tvm.Module; /** * RPC Client session module. diff --git a/jvm/core/src/test/java/org/apache/tvm/FunctionTest.java b/jvm/core/src/test/java/org/apache/tvm/FunctionTest.java index 9ffcc5ab65ea..c2a1f78fa432 100644 --- a/jvm/core/src/test/java/org/apache/tvm/FunctionTest.java +++ b/jvm/core/src/test/java/org/apache/tvm/FunctionTest.java @@ -43,6 +43,8 @@ public void test_reg_sum_number() { @Test public void test_add_string() { + System.err.println("[TEST] test_add_string"); + Function func = Function.convertFunc(new Function.Callback() { @Override public Object invoke(TVMValue... args) { String res = ""; diff --git a/jvm/core/src/test/java/org/apache/tvm/ModuleTest.java b/jvm/core/src/test/java/org/apache/tvm/ModuleTest.java index b9538ca96b5d..888cd18923be 100644 --- a/jvm/core/src/test/java/org/apache/tvm/ModuleTest.java +++ b/jvm/core/src/test/java/org/apache/tvm/ModuleTest.java @@ -71,8 +71,6 @@ public void test_load_add_func_cuda() { } Module fadd = Module.load(loadingDir + File.separator + "add_cuda.so"); - Module faddDev = Module.load(loadingDir + File.separator + "add_cuda.ptx"); - fadd.importModule(faddDev); final int dim = 100; long[] shape = new long[]{dim}; @@ -93,7 +91,6 @@ public void test_load_add_func_cuda() { arr.release(); res.release(); - faddDev.release(); fadd.release(); } } diff --git a/jvm/core/src/test/java/org/apache/tvm/rpc/RPCTest.java b/jvm/core/src/test/java/org/apache/tvm/rpc/RPCTest.java index 641633def8a0..ca24c123da8e 100644 --- a/jvm/core/src/test/java/org/apache/tvm/rpc/RPCTest.java +++ b/jvm/core/src/test/java/org/apache/tvm/rpc/RPCTest.java @@ -31,6 +31,7 @@ public class RPCTest { private final Logger logger = LoggerFactory.getLogger(RPCTest.class); + @Ignore("RPC test is not enabled") @Test public void test_addone() { if (!Module.enabled("rpc")) { @@ -57,6 +58,7 @@ public void test_addone() { } } + @Ignore("RPC test is not enabled") @Test public void test_strcat() { if (!Module.enabled("rpc")) { diff --git a/jvm/core/src/test/scripts/prepare_test_libs.py b/jvm/core/src/test/scripts/prepare_test_libs.py new file mode 100644 index 000000000000..550082adb816 --- /dev/null +++ b/jvm/core/src/test/scripts/prepare_test_libs.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# Prepare test library for standalone wasm runtime test. + +import sys +import os +import tvm +from tvm import te +from tvm import relax +from tvm.script import relax as R + + +def prepare_relax_lib(base_path): + pipeline = relax.get_pipeline() + + @tvm.script.ir_module + class Mod: + @R.function + def main(x: R.Tensor(["n"], "float32"), y: R.Tensor(["n"], "float32")): + lv0 = R.add(x, y) + return lv0 + + target = tvm.target.Target("llvm") + + mod = pipeline(Mod) + ex = relax.build(mod, target) + relax_path = os.path.join(base_path, "add_relax.so") + ex.export_library(relax_path) + + +def prepare_cpu_lib(base_path): + target = "llvm" + if not tvm.runtime.enabled(target): + raise RuntimeError("Target %s is not enbaled" % target) + n = te.var("n") + A = te.placeholder((n,), name="A") + B = te.placeholder((n,), name="B") + C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C") + mod = tvm.IRModule.from_expr(te.create_prim_func([A, B, C]).with_attr("global_symbol", "myadd")) + fadd = tvm.build(mod, target) + lib_path = os.path.join(base_path, "add_cpu.so") + fadd.export_library(lib_path) + + +def prepare_gpu_lib(base_path): + if not tvm.cuda().exist: + print("CUDA is not enabled, skip the generation") + return + n = te.var("n") + A = te.placeholder((n,), name="A") + B = te.placeholder((n,), name="B") + C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C") + mod = tvm.IRModule.from_expr(te.create_prim_func([A, B, C]).with_attr("global_symbol", "myadd")) + sch = tvm.tir.Schedule(mod) + sch.work_on("myadd") + (i,) = sch.get_loops(block=sch.get_block("C")) + i0, i1 = sch.split(i, [None, 32]) + sch.bind(i0, "blockIdx.x") + sch.bind(i1, "threadIdx.x") + fadd = tvm.build(sch.mod, "cuda") + lib_path = os.path.join(base_path, "add_cuda.so") + fadd.export_library(lib_path) + + +if __name__ == "__main__": + base_path = sys.argv[1] + prepare_cpu_lib(base_path) + prepare_gpu_lib(base_path) + prepare_relax_lib(base_path) diff --git a/jvm/native/linux-x86_64/pom.xml b/jvm/native/linux-x86_64/pom.xml index 10d9a1bbfe3c..c21a3d2ae5af 100644 --- a/jvm/native/linux-x86_64/pom.xml +++ b/jvm/native/linux-x86_64/pom.xml @@ -127,6 +127,8 @@ under the License. -shared + -L${project.parent.basedir}/../../build/ + -ltvm_runtime ${ldflags} diff --git a/jvm/native/osx-x86_64/pom.xml b/jvm/native/osx-x86_64/pom.xml index ef28537b98f2..e2bd0fd7ae9d 100644 --- a/jvm/native/osx-x86_64/pom.xml +++ b/jvm/native/osx-x86_64/pom.xml @@ -134,6 +134,8 @@ under the License. -Wl,-x + -L${project.parent.basedir}/../../build/ + -ltvm_runtime ${ldflags} diff --git a/jvm/native/src/main/native/jni_helper_func.h b/jvm/native/src/main/native/jni_helper_func.h index 3e44f757392d..76520d43f7a9 100644 --- a/jvm/native/src/main/native/jni_helper_func.h +++ b/jvm/native/src/main/native/jni_helper_func.h @@ -113,8 +113,8 @@ jobject newTVMValueDouble(JNIEnv* env, jdouble value) { return object; } -jobject newTVMValueString(JNIEnv* env, const char* value) { - jstring jvalue = env->NewStringUTF(value); +jobject newTVMValueString(JNIEnv* env, const TVMFFIByteArray* value) { + jstring jvalue = env->NewStringUTF(value->data); jclass cls = env->FindClass("org/apache/tvm/TVMValueString"); jmethodID constructor = env->GetMethodID(cls, "", "(Ljava/lang/String;)V"); jobject object = env->NewObject(cls, constructor, jvalue); @@ -123,7 +123,7 @@ jobject newTVMValueString(JNIEnv* env, const char* value) { return object; } -jobject newTVMValueBytes(JNIEnv* env, const TVMByteArray* arr) { +jobject newTVMValueBytes(JNIEnv* env, const TVMFFIByteArray* arr) { jbyteArray jarr = env->NewByteArray(arr->size); env->SetByteArrayRegion(jarr, 0, arr->size, reinterpret_cast(const_cast(arr->data))); @@ -159,14 +159,22 @@ jobject newNDArray(JNIEnv* env, jlong handle, jboolean isview) { return object; } -jobject newObject(JNIEnv* env, const char* clsname) { - jclass cls = env->FindClass(clsname); +jobject newTVMNull(JNIEnv* env) { + jclass cls = env->FindClass("org/apache/tvm/TVMValueNull"); jmethodID constructor = env->GetMethodID(cls, "", "()V"); jobject object = env->NewObject(cls, constructor); env->DeleteLocalRef(cls); return object; } +jobject newTVMObject(JNIEnv* env, jlong handle, jint type_index) { + jclass cls = env->FindClass("org/apache/tvm/TVMObject"); + jmethodID constructor = env->GetMethodID(cls, "", "(JI)V"); + jobject object = env->NewObject(cls, constructor, handle, type_index); + env->DeleteLocalRef(cls); + return object; +} + void fromJavaDType(JNIEnv* env, jobject jdtype, DLDataType* dtype) { jclass tvmTypeClass = env->FindClass("org/apache/tvm/DLDataType"); dtype->code = (uint8_t)(env->GetIntField(jdtype, env->GetFieldID(tvmTypeClass, "typeCode", "I"))); @@ -184,55 +192,56 @@ void fromJavaDevice(JNIEnv* env, jobject jdev, DLDevice* dev) { env->DeleteLocalRef(deviceClass); } -jobject tvmRetValueToJava(JNIEnv* env, TVMValue value, int tcode) { - switch (tcode) { - case kDLUInt: - case kDLInt: - case kTVMArgBool: +jobject tvmRetValueToJava(JNIEnv* env, TVMFFIAny value) { + using tvm::ffi::TypeIndex; + switch (value.type_index) { + case TypeIndex::kTVMFFINone: { + return newTVMNull(env); + } + case TypeIndex::kTVMFFIBool: { + // use long for now to represent bool + return newTVMValueLong(env, static_cast(value.v_int64)); + } + case TypeIndex::kTVMFFIInt: { return newTVMValueLong(env, static_cast(value.v_int64)); - case kDLFloat: + } + case TypeIndex::kTVMFFIFloat: { return newTVMValueDouble(env, static_cast(value.v_float64)); - case kTVMOpaqueHandle: - return newTVMValueHandle(env, reinterpret_cast(value.v_handle)); - case kTVMModuleHandle: - return newModule(env, reinterpret_cast(value.v_handle)); - case kTVMPackedFuncHandle: - return newFunction(env, reinterpret_cast(value.v_handle)); - case kTVMDLTensorHandle: - return newNDArray(env, reinterpret_cast(value.v_handle), true); - case kTVMNDArrayHandle: - return newNDArray(env, reinterpret_cast(value.v_handle), false); - case kTVMStr: - return newTVMValueString(env, value.v_str); - case kTVMBytes: - return newTVMValueBytes(env, reinterpret_cast(value.v_handle)); - case kTVMNullptr: - return newObject(env, "org/apache/tvm/TVMValueNull"); - default: - LOG(FATAL) << "Do NOT know how to handle return type code " << tcode; - } - return NULL; -} - -// Helper function to pack two int32_t values into an int64_t -inline int64_t deviceToInt64(const int32_t device_type, const int32_t device_id) { - int64_t result; - int32_t* parts = reinterpret_cast(&result); - - // Lambda function to check endianness - const auto isLittleEndian = []() -> bool { - uint32_t i = 1; - return *reinterpret_cast(&i) == 1; - }; - - if (isLittleEndian()) { - parts[0] = device_type; - parts[1] = device_id; - } else { - parts[1] = device_type; - parts[0] = device_id; + } + case TypeIndex::kTVMFFIOpaquePtr: { + return newTVMValueHandle(env, reinterpret_cast(value.v_ptr)); + } + case TypeIndex::kTVMFFIModule: { + return newModule(env, reinterpret_cast(value.v_obj)); + } + case TypeIndex::kTVMFFIFunction: { + return newFunction(env, reinterpret_cast(value.v_obj)); + } + case TypeIndex::kTVMFFIDLTensorPtr: { + return newNDArray(env, reinterpret_cast(value.v_ptr), true); + } + case TypeIndex::kTVMFFINDArray: { + return newNDArray(env, reinterpret_cast(value.v_obj), false); + } + case TypeIndex::kTVMFFIStr: { + jobject ret = newTVMValueString(env, TVMFFIBytesGetByteArrayPtr(value.v_obj)); + TVMFFIObjectFree(value.v_obj); + return ret; + } + case TypeIndex::kTVMFFIBytes: { + jobject ret = newTVMValueBytes(env, TVMFFIBytesGetByteArrayPtr(value.v_obj)); + TVMFFIObjectFree(value.v_obj); + return ret; + } + default: { + if (value.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { + return newTVMObject(env, reinterpret_cast(value.v_obj), value.type_index); + } + TVM_FFI_THROW(RuntimeError) << "Do NOT know how to handle return type_index " + << value.type_index; + TVM_FFI_UNREACHABLE(); + } } - return result; } #endif // TVM4J_JNI_MAIN_NATIVE_JNI_HELPER_FUNC_H_ diff --git a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc index 77bc8d636098..a5481dd9ac54 100644 --- a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc +++ b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc @@ -25,12 +25,14 @@ #include "tvm_runtime.h" #else #include -#include -#include -#include +#include +#include +#include +#include #endif #include #include +#include #include #include @@ -38,14 +40,18 @@ JavaVM* _jvm; void* _tvmHandle = nullptr; -struct TVMFuncArgsThreadLocalEntry { - std::vector tvmFuncArgValues; - std::vector tvmFuncArgTypes; + +struct TVMFFIJVMStack { + std::vector packed_args; // for later release - std::vector> tvmFuncArgPushedStrs; - std::vector> tvmFuncArgPushedBytes; + std::vector> str_args; + std::vector>> byte_args; + + static TVMFFIJVMStack* ThreadLocal() { + static thread_local TVMFFIJVMStack stack; + return &stack; + } }; -typedef dmlc::ThreadLocalStore TVMFuncArgsThreadLocalStore; JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_nativeLibInit(JNIEnv* env, jobject obj, jstring jtvmLibFile) { @@ -68,172 +74,132 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_shutdown(JNIEnv* env, jobject return 0; } -JNIEXPORT jstring JNICALL Java_org_apache_tvm_LibInfo_tvmGetLastError(JNIEnv* env, jobject obj) { - return env->NewStringUTF(TVMGetLastError()); +JNIEXPORT jstring JNICALL Java_org_apache_tvm_LibInfo_tvmFFIGetLastError(JNIEnv* env, jobject obj) { + std::string err_msg = ::tvm::ffi::details::MoveFromSafeCallRaised().what(); + return env->NewStringUTF(err_msg.c_str()); } // Function -JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgLong(JNIEnv* env, jobject obj, - jlong arg) { - TVMValue value; - value.v_int64 = static_cast(arg); - TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); - e->tvmFuncArgValues.push_back(value); - e->tvmFuncArgTypes.push_back(kDLInt); +JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionPushArgLong(JNIEnv* env, + jobject obj, + jlong arg) { + TVMFFIJVMStack::ThreadLocal()->packed_args.emplace_back(static_cast(arg)); } -JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgDouble(JNIEnv* env, jobject obj, - jdouble arg) { - TVMValue value; - value.v_float64 = static_cast(arg); - TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); - e->tvmFuncArgValues.push_back(value); - e->tvmFuncArgTypes.push_back(kDLFloat); +JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionPushArgDouble(JNIEnv* env, + jobject obj, + jdouble arg) { + TVMFFIJVMStack::ThreadLocal()->packed_args.emplace_back(static_cast(arg)); } -JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgString(JNIEnv* env, jobject obj, - jstring arg) { - TVMValue value; +JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionPushArgString(JNIEnv* env, + jobject obj, + jstring arg) { jstring garg = reinterpret_cast(env->NewGlobalRef(arg)); - value.v_str = env->GetStringUTFChars(garg, 0); - TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); - e->tvmFuncArgValues.push_back(value); - e->tvmFuncArgTypes.push_back(kTVMStr); - // release string args later - e->tvmFuncArgPushedStrs.push_back(std::make_pair(garg, value.v_str)); + const char* str = env->GetStringUTFChars(garg, 0); + TVMFFIJVMStack* stack = TVMFFIJVMStack::ThreadLocal(); + stack->str_args.emplace_back(garg, str); + stack->packed_args.emplace_back(str); } -JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgHandle(JNIEnv* env, jobject obj, - jlong arg, jint argType) { - TVMValue value; - value.v_handle = reinterpret_cast(arg); - TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); - e->tvmFuncArgValues.push_back(value); - e->tvmFuncArgTypes.push_back(static_cast(argType)); +JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionPushArgHandle(JNIEnv* env, + jobject obj, + jlong arg, + jint argTypeIndex) { + TVMFFIJVMStack* stack = TVMFFIJVMStack::ThreadLocal(); + TVMFFIAny temp; + temp.v_int64 = static_cast(arg); + temp.type_index = static_cast(argTypeIndex); + stack->packed_args.emplace_back(tvm::ffi::AnyView::CopyFromTVMFFIAny(temp)); } -JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgDevice(JNIEnv* env, jobject obj, - jobject arg) { +JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionPushArgDevice(JNIEnv* env, + jobject obj, + jobject arg) { jclass deviceClass = env->FindClass("org/apache/tvm/Device"); jfieldID deviceTypeField = env->GetFieldID(deviceClass, "deviceType", "I"); jfieldID deviceIdField = env->GetFieldID(deviceClass, "deviceId", "I"); jint deviceType = env->GetIntField(arg, deviceTypeField); jint deviceId = env->GetIntField(arg, deviceIdField); - - TVMValue value; - value.v_int64 = deviceToInt64(deviceType, deviceId); - TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); - e->tvmFuncArgValues.push_back(value); - e->tvmFuncArgTypes.push_back(kDLDevice); + TVMFFIJVMStack* stack = TVMFFIJVMStack::ThreadLocal(); + stack->packed_args.emplace_back(DLDevice{static_cast(deviceType), deviceId}); } -JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgBytes(JNIEnv* env, jobject obj, - jbyteArray arg) { +JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionPushArgBytes(JNIEnv* env, + jobject obj, + jbyteArray arg) { jbyteArray garg = reinterpret_cast(env->NewGlobalRef(arg)); jbyte* data = env->GetByteArrayElements(garg, 0); - TVMByteArray* byteArray = new TVMByteArray(); + std::unique_ptr byteArray = std::make_unique(); byteArray->size = static_cast(env->GetArrayLength(garg)); byteArray->data = reinterpret_cast(data); - TVMValue value; - value.v_handle = reinterpret_cast(byteArray); - - TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); - e->tvmFuncArgValues.push_back(value); - e->tvmFuncArgTypes.push_back(kTVMBytes); - - e->tvmFuncArgPushedBytes.push_back(std::make_pair(garg, byteArray)); + TVMFFIJVMStack* stack = TVMFFIJVMStack::ThreadLocal(); + stack->packed_args.emplace_back(byteArray.get()); + stack->byte_args.emplace_back(garg, std::move(byteArray)); // release (garg, data), byteArray later } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncListGlobalNames(JNIEnv* env, jobject obj, - jobject jfuncNames) { - int outSize; - const char** outArray; - - int ret = TVMFuncListGlobalNames(&outSize, &outArray); - if (ret) { - return ret; - } - +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionListGlobalNames( + JNIEnv* env, jobject obj, jobject jfuncNames) { + TVM_FFI_SAFE_CALL_BEGIN(); jclass arrayClass = env->FindClass("java/util/List"); jmethodID arrayAppend = env->GetMethodID(arrayClass, "add", "(Ljava/lang/Object;)Z"); - // fill names - for (int i = 0; i < outSize; ++i) { - jstring jname = env->NewStringUTF(outArray[i]); + for (const auto& name : tvm::ffi::Function::ListGlobalNames()) { + jstring jname = env->NewStringUTF(name.c_str()); env->CallBooleanMethod(jfuncNames, arrayAppend, jname); env->DeleteLocalRef(jname); } env->DeleteLocalRef(arrayClass); - - return ret; -} - -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncFree(JNIEnv* env, jobject obj, - jlong jhandle) { - return TVMFuncFree(reinterpret_cast(jhandle)); + TVM_FFI_SAFE_CALL_END(); } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncGetGlobal(JNIEnv* env, jobject obj, - jstring jname, - jobject jhandle) { - TVMFunctionHandle handle; +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionGetGlobal(JNIEnv* env, jobject obj, + jstring jname, + jobject jhandle) { const char* name = env->GetStringUTFChars(jname, 0); - int ret = TVMFuncGetGlobal(name, &handle); + TVMFFIByteArray name_bytes{name, strlen(name)}; + TVMFFIObjectHandle handle; + int ret = TVMFFIFunctionGetGlobal(&name_bytes, &handle); env->ReleaseStringUTFChars(jname, name); setLongField(env, jhandle, reinterpret_cast(handle)); return ret; } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncCall(JNIEnv* env, jobject obj, - jlong jhandle, jobject jretVal) { - TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); - int numArgs = e->tvmFuncArgValues.size(); - - TVMValue retVal; - int retTypeCode; - - // function can be invoked recursively, - // thus we copy the pushed arguments here. - auto argValues = e->tvmFuncArgValues; - auto argTypes = e->tvmFuncArgTypes; - auto pushedStrs = e->tvmFuncArgPushedStrs; - auto pushedBytes = e->tvmFuncArgPushedBytes; - - e->tvmFuncArgPushedStrs.clear(); - e->tvmFuncArgPushedBytes.clear(); - e->tvmFuncArgTypes.clear(); - e->tvmFuncArgValues.clear(); - - int ret = TVMFuncCall(reinterpret_cast(jhandle), &argValues[0], &argTypes[0], - numArgs, &retVal, &retTypeCode); - - if (ret != 0) { - return ret; +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionCall(JNIEnv* env, jobject obj, + jlong jhandle, + jobject jretVal) { + TVMFFIJVMStack* stack = TVMFFIJVMStack::ThreadLocal(); + TVMFFIAny ret_val; + ret_val.type_index = tvm::ffi::TypeIndex::kTVMFFINone; + ret_val.v_int64 = 0; + int ret = TVMFFIFunctionCall(reinterpret_cast(jhandle), + reinterpret_cast(stack->packed_args.data()), + stack->packed_args.size(), &ret_val); + // release all temp resources + for (auto& str_pair : stack->str_args) { + env->ReleaseStringUTFChars(str_pair.first, str_pair.second); + env->DeleteGlobalRef(str_pair.first); } - for (auto iter = pushedStrs.cbegin(); iter != pushedStrs.cend(); iter++) { - env->ReleaseStringUTFChars(iter->first, iter->second); - env->DeleteGlobalRef(iter->first); - } - for (auto iter = pushedBytes.cbegin(); iter != pushedBytes.cend(); iter++) { + for (auto& byte_pair : stack->byte_args) { env->ReleaseByteArrayElements( - iter->first, reinterpret_cast(const_cast(iter->second->data)), 0); - env->DeleteGlobalRef(iter->first); - delete iter->second; + byte_pair.first, reinterpret_cast(const_cast(byte_pair.second->data)), 0); + env->DeleteGlobalRef(byte_pair.first); } + stack->str_args.clear(); + stack->byte_args.clear(); + stack->packed_args.clear(); // return TVMValue object to Java jclass refTVMValueCls = env->FindClass("org/apache/tvm/Base$RefTVMValue"); jfieldID refTVMValueFid = env->GetFieldID(refTVMValueCls, "value", "Lorg/apache/tvm/TVMValue;"); - env->SetObjectField(jretVal, refTVMValueFid, tvmRetValueToJava(env, retVal, retTypeCode)); - + env->SetObjectField(jretVal, refTVMValueFid, tvmRetValueToJava(env, ret_val)); env->DeleteLocalRef(refTVMValueCls); - return ret; } @@ -255,27 +221,24 @@ class JNIEnvPtrHelper { }; // Callback function -extern "C" int funcInvokeCallback(TVMValue* args, int* typeCodes, int numArgs, - TVMRetValueHandle ret, void* resourceHandle) { +extern "C" int funcInvokeCallback(void* self, const TVMFFIAny* args, int num_args, TVMFFIAny* ret) { JNIEnv* env; int jniStatus = _jvm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6); if (jniStatus == JNI_EDETACHED) { _jvm->AttachCurrentThread(JNIEnvPtrHelper(&env), nullptr); } else { - CHECK(jniStatus == JNI_OK); + TVM_FFI_ICHECK(jniStatus == JNI_OK); } jclass tvmValueCls = env->FindClass("org/apache/tvm/TVMValue"); - jobjectArray jargs = env->NewObjectArray(numArgs, tvmValueCls, 0); - for (int i = 0; i < numArgs; ++i) { - TVMValue arg = args[i]; - int tcode = typeCodes[i]; - if (tcode == kTVMObjectHandle || tcode == kTVMPackedFuncHandle || - tcode == kTVMObjectRValueRefArg || tcode == kTVMModuleHandle || - tcode == kTVMNDArrayHandle) { - TVMCbArgToReturn(&arg, &tcode); + jobjectArray jargs = env->NewObjectArray(num_args, tvmValueCls, 0); + + for (int i = 0; i < num_args; ++i) { + TVMFFIAny arg = args[i]; + if (args[i].type_index >= tvm::ffi::TypeIndex::kTVMFFIRawStr) { + TVMFFIAnyViewToOwnedAny(&args[i], &arg); } - jobject jarg = tvmRetValueToJava(env, arg, tcode); + jobject jarg = tvmRetValueToJava(env, arg); env->SetObjectArrayElement(jargs, i, jarg); } @@ -285,46 +248,39 @@ extern "C" int funcInvokeCallback(TVMValue* args, int* typeCodes, int numArgs, "(Lorg/apache/tvm/Function$Callback;[Lorg/apache/tvm/TVMValue;)Ljava/lang/Object;"); jmethodID pushArgToStack = env->GetStaticMethodID(clsFunc, "pushArgToStack", "(Ljava/lang/Object;)V"); - jobject jretValue = env->CallStaticObjectMethod(clsFunc, invokeRegisteredCbFunc, - reinterpret_cast(resourceHandle), jargs); + reinterpret_cast(self), jargs); - TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); - const size_t prevNumStrArg = e->tvmFuncArgPushedStrs.size(); - const size_t prevNumBytesArg = e->tvmFuncArgPushedBytes.size(); + // the stack + TVMFFIJVMStack* stack = TVMFFIJVMStack::ThreadLocal(); + const size_t prev_num_str_args = stack->str_args.size(); + const size_t prev_num_bytes_args = stack->byte_args.size(); // convert returned (java) TVMValue to (C) TVMValue env->CallStaticVoidMethod(clsFunc, pushArgToStack, jretValue); - TVMValue retValue = e->tvmFuncArgValues.back(); - e->tvmFuncArgValues.pop_back(); - - int retCode = e->tvmFuncArgTypes.back(); - e->tvmFuncArgTypes.pop_back(); - - // set back the return value - TVMCFuncSetReturn(ret, &retValue, &retCode, 1); + TVMFFIAny ret_val = stack->packed_args.back().CopyToTVMFFIAny(); + stack->packed_args.pop_back(); + TVMFFIAnyViewToOwnedAny(&ret_val, ret); // release allocated strings. - if (e->tvmFuncArgPushedStrs.size() > prevNumStrArg) { - const auto& pairArg = e->tvmFuncArgPushedStrs.back(); + if (stack->str_args.size() > prev_num_str_args) { + const auto& pairArg = stack->str_args.back(); env->ReleaseStringUTFChars(pairArg.first, pairArg.second); env->DeleteGlobalRef(pairArg.first); - e->tvmFuncArgPushedStrs.pop_back(); + stack->str_args.pop_back(); } // release allocated bytes. - if (e->tvmFuncArgPushedBytes.size() > prevNumBytesArg) { - const auto& pairArg = e->tvmFuncArgPushedBytes.back(); + if (stack->byte_args.size() > prev_num_bytes_args) { + const auto& pairArg = stack->byte_args.back(); env->ReleaseByteArrayElements( pairArg.first, reinterpret_cast(const_cast(pairArg.second->data)), 0); env->DeleteGlobalRef(pairArg.first); - delete pairArg.second; - e->tvmFuncArgPushedBytes.pop_back(); + stack->byte_args.pop_back(); } env->DeleteLocalRef(clsFunc); env->DeleteLocalRef(tvmValueCls); - return 0; } @@ -335,90 +291,43 @@ extern "C" void funcFreeCallback(void* resourceHandle) { if (jniStatus == JNI_EDETACHED) { _jvm->AttachCurrentThread(JNIEnvPtrHelper(&env), nullptr); } else { - CHECK(jniStatus == JNI_OK); + TVM_FFI_ICHECK(jniStatus == JNI_OK); } env->DeleteGlobalRef(reinterpret_cast(resourceHandle)); } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncCreateFromCFunc(JNIEnv* env, jobject obj, - jobject jfunction, - jobject jretHandle) { - TVMFunctionHandle out; - int ret = - TVMFuncCreateFromCFunc(reinterpret_cast(&funcInvokeCallback), - reinterpret_cast(env->NewGlobalRef(jfunction)), - reinterpret_cast(&funcFreeCallback), &out); +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionCreateFromCallback( + JNIEnv* env, jobject obj, jobject jfunction, jobject jretHandle) { + TVMFFIObjectHandle out; + int ret = TVMFFIFunctionCreate(reinterpret_cast(env->NewGlobalRef(jfunction)), + funcInvokeCallback, funcFreeCallback, &out); setLongField(env, jretHandle, reinterpret_cast(out)); return ret; } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncRegisterGlobal(JNIEnv* env, jobject obj, - jstring jname, - jlong jhandle, - jint joverride) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionSetGlobal(JNIEnv* env, jobject obj, + jstring jname, + jlong jhandle, + jint joverride) { const char* name = env->GetStringUTFChars(jname, 0); - int ret = TVMFuncRegisterGlobal(name, reinterpret_cast(jhandle), - reinterpret_cast(joverride)); + TVMFFIByteArray name_bytes{name, strlen(name)}; + int ret = TVMFFIFunctionSetGlobal(&name_bytes, reinterpret_cast(jhandle), + reinterpret_cast(joverride)); env->ReleaseStringUTFChars(jname, name); return ret; } // Module -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModFree(JNIEnv* env, jobject obj, - jlong jhandle) { - return TVMModFree(reinterpret_cast(jhandle)); -} - -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModImport(JNIEnv* env, jobject obj, - jlong jmod, jlong jdep) { - return TVMModImport(reinterpret_cast(jmod), - reinterpret_cast(jdep)); -} - -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModGetFunction(JNIEnv* env, jobject obj, - jlong jhandle, jstring jname, - jint jimport, jobject jret) { - TVMFunctionHandle retFunc; - - const char* name = env->GetStringUTFChars(jname, 0); - int ret = TVMModGetFunction(reinterpret_cast(jhandle), name, - reinterpret_cast(jimport), &retFunc); - env->ReleaseStringUTFChars(jname, name); - - setLongField(env, jret, reinterpret_cast(retFunc)); - - return ret; +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIObjectFree(JNIEnv* env, jobject obj, + jlong jhandle) { + return TVMFFIObjectFree(reinterpret_cast(jhandle)); } // NDArray -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayFree(JNIEnv* env, jobject obj, - jlong jhandle) { - return TVMArrayFree(reinterpret_cast(jhandle)); -} - -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayAlloc(JNIEnv* env, jobject obj, - jlongArray jshape, jint jdtypeCode, - jint jdtypeBits, jint jdtypeLanes, - jint jdeviceType, jint jdeviceId, - jobject jret) { - int ndim = static_cast(env->GetArrayLength(jshape)); - - TVMArrayHandle out; - - jlong* shapeArray = env->GetLongArrayElements(jshape, NULL); - int ret = TVMArrayAlloc(reinterpret_cast(shapeArray), ndim, - static_cast(jdtypeCode), static_cast(jdtypeBits), - static_cast(jdtypeLanes), static_cast(jdeviceType), - static_cast(jdeviceId), &out); - env->ReleaseLongArrayElements(jshape, shapeArray, 0); - - setLongField(env, jret, reinterpret_cast(out)); - - return ret; -} -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayGetShape(JNIEnv* env, jobject obj, - jlong jhandle, jobject jshape) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIDLTensorGetShape(JNIEnv* env, jobject obj, + jlong jhandle, + jobject jshape) { DLTensor* array = reinterpret_cast(jhandle); int64_t* shape = array->shape; int ndim = array->ndim; @@ -440,45 +349,72 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayGetShape(JNIEnv* env, return 0; } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyFromTo(JNIEnv* env, jobject obj, - jlong jfrom, jlong jto) { - return TVMArrayCopyFromTo(reinterpret_cast(jfrom), - reinterpret_cast(jto), NULL); +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIDLTensorCopyFromTo(JNIEnv* env, + jobject obj, + jlong jfrom, + jlong jto) { + TVM_FFI_SAFE_CALL_BEGIN(); + static auto fcopy_from_to = tvm::ffi::Function::GetGlobalRequired("runtime.TVMArrayCopyFromTo"); + fcopy_from_to(reinterpret_cast(jfrom), reinterpret_cast(jto)); + TVM_FFI_SAFE_CALL_END(); } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyFromJArray(JNIEnv* env, jobject obj, - jbyteArray jarr, - jlong jfrom, jlong jto) { - jbyte* data = env->GetByteArrayElements(jarr, NULL); - - DLTensor* from = reinterpret_cast(jfrom); - from->data = static_cast(data); - - int ret = TVMArrayCopyFromTo(static_cast(from), - reinterpret_cast(jto), NULL); - - from->data = NULL; - env->ReleaseByteArrayElements(jarr, data, 0); - - return ret; +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIDLTensorCopyFromJArray(JNIEnv* env, + jobject obj, + jbyteArray jarr, + jlong jto) { + TVM_FFI_SAFE_CALL_BEGIN(); + jbyte* pdata = env->GetByteArrayElements(jarr, NULL); + DLTensor* to = reinterpret_cast(jto); + size_t size = tvm::ffi::GetDataSize(*to); + static auto fcopy_from_bytes = + tvm::ffi::Function::GetGlobalRequired("runtime.TVMArrayCopyFromBytes"); + fcopy_from_bytes(to, static_cast(pdata), size); + env->ReleaseByteArrayElements(jarr, pdata, 0); + TVM_FFI_SAFE_CALL_END(); } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyToJArray(JNIEnv* env, jobject obj, - jlong jfrom, - jbyteArray jarr) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIDLTensorCopyToJArray(JNIEnv* env, + jobject obj, + jlong jfrom, + jbyteArray jarr) { + TVM_FFI_SAFE_CALL_BEGIN(); DLTensor* from = reinterpret_cast(jfrom); - int size = static_cast(env->GetArrayLength(jarr)); + size_t size = tvm::ffi::GetDataSize(*from); jbyte* pdata = env->GetByteArrayElements(jarr, NULL); - int ret = 0; - if (memcpy(static_cast(pdata), from->data, size) == NULL) { - ret = 1; - } - env->ReleaseByteArrayElements(jarr, pdata, 0); // copy back to java array automatically - return ret; + static auto fcopy_to_bytes = tvm::ffi::Function::GetGlobalRequired("runtime.TVMArrayCopyToBytes"); + fcopy_to_bytes(from, static_cast(pdata), size); + env->ReleaseByteArrayElements(jarr, static_cast(pdata), + 0); // copy back to java array automatically + TVM_FFI_SAFE_CALL_END(); +} + +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmSynchronize(JNIEnv* env, jobject obj, + jint jdeviceType, + jint jdeviceId) { + TVM_FFI_SAFE_CALL_BEGIN(); + static auto fsync = tvm::ffi::Function::GetGlobalRequired("runtime.Device_StreamSync"); + DLDevice device{static_cast(jdeviceType), jdeviceId}; + fsync(device, nullptr); + TVM_FFI_SAFE_CALL_END(); } -// Device -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmSynchronize(JNIEnv* env, jint deviceType, - jint deviceId) { - return TVMSynchronize(static_cast(deviceType), static_cast(deviceId), NULL); +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmNDArrayEmpty( + JNIEnv* env, jobject obj, jlongArray jshape, jint jdtypeCode, jint jdtypeBits, jint jdtypeLanes, + jint jdeviceType, jint jdeviceId, jobject jret) { + TVM_FFI_SAFE_CALL_BEGIN(); + int ndim = static_cast(env->GetArrayLength(jshape)); + jlong* shapeArray = env->GetLongArrayElements(jshape, NULL); + tvm::ffi::Shape shape(shapeArray, shapeArray + ndim); + DLDataType dtype; + dtype.code = static_cast(jdtypeCode); + dtype.bits = static_cast(jdtypeBits); + dtype.lanes = static_cast(jdtypeLanes); + DLDevice device{static_cast(jdeviceType), jdeviceId}; + env->ReleaseLongArrayElements(jshape, shapeArray, 0); + static auto fempty = tvm::ffi::Function::GetGlobalRequired("runtime.TVMArrayAllocWithScope"); + tvm::ffi::NDArray out = fempty(shape, dtype, device, nullptr).cast(); + void* handle = tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(out)); + setLongField(env, jret, reinterpret_cast(handle)); + TVM_FFI_SAFE_CALL_END(); } diff --git a/tests/scripts/task_java_unittest.sh b/tests/scripts/task_java_unittest.sh index 2eabac31cc28..f7c9f3c097af 100755 --- a/tests/scripts/task_java_unittest.sh +++ b/tests/scripts/task_java_unittest.sh @@ -35,16 +35,13 @@ cleanup() } trap cleanup 0 -# python3 "$SCRIPT_DIR"/test_add_cpu.py "$TEMP_DIR" -# python3 "$SCRIPT_DIR"/test_add_gpu.py "$TEMP_DIR" - -# Skip the Java RPC Unittests, see https://github.com/apache/tvm/issues/13168 -# # start rpc proxy server -# PORT=$(( ( RANDOM % 1000 ) + 9000 )) -# python3 $SCRIPT_DIR/test_rpc_proxy_server.py $PORT 30 & - -# make jvmpkg -# make jvmpkg JVM_TEST_ARGS="-DskipTests=false \ -# -Dtest.tempdir=$TEMP_DIR \ -# -Dtest.rpc.proxy.host=localhost \ -# -Dtest.rpc.proxy.port=$PORT" +make jvmpkg + +# Skip the Java Tests for now +exit 0 + +# expose tvm runtime lib to system env +export LD_LIBRARY_PATH=$CURR_DIR/../../build/:$LD_LIBRARY_PATH +python "$SCRIPT_DIR"/prepare_test_libs.py "$TEMP_DIR" +make jvmpkg JVM_TEST_ARGS="-DskipTests=false\ + -Dtest.tempdir=$TEMP_DIR"