From 36a5f4e1daff031156fe6566d182d88feb46b2af Mon Sep 17 00:00:00 2001 From: cspchen Date: Thu, 22 Jul 2021 14:53:23 +0800 Subject: [PATCH] FIX: Mxnet crash when process exits. --- java-package/gradle.properties | 4 +--- java-package/integration/build.gradle | 2 -- .../integration/tests/engine/ModelTest.java | 2 +- java-package/mxnet-engine/build.gradle | 2 +- .../apache/mxnet/engine/BaseMxResource.java | 9 +++++++++ .../java/org/apache/mxnet/jna/JnaUtils.java | 19 ++++++++++++++----- 6 files changed, 26 insertions(+), 12 deletions(-) diff --git a/java-package/gradle.properties b/java-package/gradle.properties index 5f060b286381..3685a7b533ab 100644 --- a/java-package/gradle.properties +++ b/java-package/gradle.properties @@ -22,6 +22,4 @@ netty_version=4.1.51.Final slf4j_version=1.7.30 log4j_slf4j_version=2.13.3 testng_version=7.1.0 -powermock_version=2.0.7 - -MXNET_LIBRARY_PATH = /Users/cspchen/Work/refer/incubator-mxnet/build \ No newline at end of file +powermock_version=2.0.7 \ No newline at end of file diff --git a/java-package/integration/build.gradle b/java-package/integration/build.gradle index 31d537c84c74..81d206a5a36a 100644 --- a/java-package/integration/build.gradle +++ b/java-package/integration/build.gradle @@ -28,13 +28,11 @@ dependencies { } run { - environment("TF_CPP_MIN_LOG_LEVEL", "1") // turn off TensorFlow print out systemProperties System.getProperties() systemProperties.remove("user.dir") systemProperty("file.encoding", "UTF-8") jvmArgs "-Xverify:none" args("-p=org.apache.mxnet.integration.tests.", "-m=modelLoadAndPredictTest") - } test { diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/engine/ModelTest.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/engine/ModelTest.java index c5f9d4c7a6bf..13b6c8377ea1 100644 --- a/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/engine/ModelTest.java +++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/engine/ModelTest.java @@ -6,6 +6,7 @@ import org.apache.mxnet.engine.Predictor; import org.apache.mxnet.integration.tests.jna.JnaUtilTest; import org.apache.mxnet.integration.util.Assertions; +import org.apache.mxnet.jna.JnaUtils; import org.apache.mxnet.ndarray.MxNDArray; import org.apache.mxnet.ndarray.MxNDList; import org.apache.mxnet.ndarray.types.Shape; @@ -15,7 +16,6 @@ import org.testng.annotations.Test; import java.io.IOException; -import java.nio.file.Paths; public class ModelTest { private static final Logger logger = LoggerFactory.getLogger(JnaUtilTest.class); diff --git a/java-package/mxnet-engine/build.gradle b/java-package/mxnet-engine/build.gradle index 04477870ae57..113f795bf281 100644 --- a/java-package/mxnet-engine/build.gradle +++ b/java-package/mxnet-engine/build.gradle @@ -64,7 +64,7 @@ test { useDefaultListeners = true } environment "PATH", "src/test/bin:${environment.PATH}" - environment "MXNET_LIBRARY_PATH", "${MXNET_LIBRARY_PATH}" +// environment "MXNET_LIBRARY_PATH", "${MXNET_LIBRARY_PATH}" maxHeapSize = '6G' testLogging.showStandardStreams = true beforeTest { descriptor -> diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/BaseMxResource.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/BaseMxResource.java index 3dfb358453b3..b823bb070e95 100644 --- a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/BaseMxResource.java +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/BaseMxResource.java @@ -1,11 +1,20 @@ package org.apache.mxnet.engine; +import org.apache.mxnet.jna.JnaUtils; + public final class BaseMxResource extends MxResource{ static BaseMxResource SYSTEM_MX_RESOURCE; protected BaseMxResource() { super(); + // Workaround MXNet engine lazy initialization issue + JnaUtils.getAllOpNames(); + + JnaUtils.setNumpyMode(JnaUtils.NumpyMode.GLOBAL_ON); + + // Workaround MXNet shutdown crash issue + Runtime.getRuntime().addShutdownHook(new Thread(JnaUtils::waitAll)); // NOPMD } public static BaseMxResource getSystemMxResource() { diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/JnaUtils.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/JnaUtils.java index 1e19aa67895e..4cf1c44faf11 100644 --- a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/JnaUtils.java +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/JnaUtils.java @@ -41,6 +41,7 @@ public final class JnaUtils { private static final Logger logger = LoggerFactory.getLogger(JnaUtils.class); public static final MxnetLibrary LIB = LibUtils.loadLibrary(); + public static final ObjectPool REFS = new ObjectPool<>(PointerByReference::new, r -> r.setValue(null)); @@ -63,6 +64,14 @@ public enum NumpyMode { GLOBAL_ON } + public static void waitAll() { + checkCall(LIB.MXNDArrayWaitAll()); + } + + public static void init() { + Runtime.getRuntime().addShutdownHook(new Thread(JnaUtils::waitAll)); // NOPMD + } + public static void setNumpyMode(NumpyMode mode) { IntBuffer ret = IntBuffer.allocate(1); checkCall(LIB.MXSetIsNumpyShape(mode.ordinal(), ret)); @@ -813,6 +822,7 @@ private static void checkNDArray(Pointer pointer, String msg) { public static void checkCall(int ret) { if (ret != 0) { + logger.error("MXNet engine call failed: " + getLastError()); throw new JnaCallException("MXNet engine call failed: " + getLastError()); } } @@ -875,26 +885,25 @@ public static boolean autogradIsTraining() { return isTraining.get(0) == 1; } - public static void waitAll() { - checkCall(LIB.MXNDArrayWaitAll()); - } - /***************************************************************************** * Tests *****************************************************************************/ public static void main(String... args) { try { + Runtime.getRuntime().addShutdownHook(new Thread(JnaUtils::waitAll)); Set opNames = JnaUtils.getAllOpNames(); List list = new ArrayList<>(opNames); PointerByReference ref = REFS.acquire(); - for (String opName : list.subList(0, 300)) { + + for (String opName : list.subList(0, 400)) { checkCall(LIB.NNGetOpHandle(opName, ref)); String functionName = getOpNamePrefix(opName); // System.out.println("Name: " + opName + "/" + functionName); getFunctionByName(opName, functionName, ref.getValue()); } + ref.setValue(null); REFS.recycle(ref);