Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
FIX: Mxnet crash when process exits.
Browse files Browse the repository at this point in the history
  • Loading branch information
cspchen committed Jul 22, 2021
1 parent dbb6d83 commit 36a5f4e
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 12 deletions.
4 changes: 1 addition & 3 deletions java-package/gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,4 @@ netty_version=4.1.51.Final
slf4j_version=1.7.30
log4j_slf4j_version=2.13.3
testng_version=7.1.0
powermock_version=2.0.7

MXNET_LIBRARY_PATH = /Users/cspchen/Work/refer/incubator-mxnet/build
powermock_version=2.0.7
2 changes: 0 additions & 2 deletions java-package/integration/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,11 @@ dependencies {
}

run {
environment("TF_CPP_MIN_LOG_LEVEL", "1") // turn off TensorFlow print out
systemProperties System.getProperties()
systemProperties.remove("user.dir")
systemProperty("file.encoding", "UTF-8")
jvmArgs "-Xverify:none"
args("-p=org.apache.mxnet.integration.tests.", "-m=modelLoadAndPredictTest")

}

test {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import org.apache.mxnet.engine.Predictor;
import org.apache.mxnet.integration.tests.jna.JnaUtilTest;
import org.apache.mxnet.integration.util.Assertions;
import org.apache.mxnet.jna.JnaUtils;
import org.apache.mxnet.ndarray.MxNDArray;
import org.apache.mxnet.ndarray.MxNDList;
import org.apache.mxnet.ndarray.types.Shape;
Expand All @@ -15,7 +16,6 @@
import org.testng.annotations.Test;

import java.io.IOException;
import java.nio.file.Paths;

public class ModelTest {
private static final Logger logger = LoggerFactory.getLogger(JnaUtilTest.class);
Expand Down
2 changes: 1 addition & 1 deletion java-package/mxnet-engine/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ test {
useDefaultListeners = true
}
environment "PATH", "src/test/bin:${environment.PATH}"
environment "MXNET_LIBRARY_PATH", "${MXNET_LIBRARY_PATH}"
// environment "MXNET_LIBRARY_PATH", "${MXNET_LIBRARY_PATH}"
maxHeapSize = '6G'
testLogging.showStandardStreams = true
beforeTest { descriptor ->
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
package org.apache.mxnet.engine;

import org.apache.mxnet.jna.JnaUtils;

public final class BaseMxResource extends MxResource{

static BaseMxResource SYSTEM_MX_RESOURCE;

protected BaseMxResource() {
super();
// Workaround MXNet engine lazy initialization issue
JnaUtils.getAllOpNames();

JnaUtils.setNumpyMode(JnaUtils.NumpyMode.GLOBAL_ON);

// Workaround MXNet shutdown crash issue
Runtime.getRuntime().addShutdownHook(new Thread(JnaUtils::waitAll)); // NOPMD
}

public static BaseMxResource getSystemMxResource() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public final class JnaUtils {
private static final Logger logger = LoggerFactory.getLogger(JnaUtils.class);

public static final MxnetLibrary LIB = LibUtils.loadLibrary();

public static final ObjectPool<PointerByReference> REFS =
new ObjectPool<>(PointerByReference::new, r -> r.setValue(null));

Expand All @@ -63,6 +64,14 @@ public enum NumpyMode {
GLOBAL_ON
}

public static void waitAll() {
checkCall(LIB.MXNDArrayWaitAll());
}

public static void init() {
Runtime.getRuntime().addShutdownHook(new Thread(JnaUtils::waitAll)); // NOPMD
}

public static void setNumpyMode(NumpyMode mode) {
IntBuffer ret = IntBuffer.allocate(1);
checkCall(LIB.MXSetIsNumpyShape(mode.ordinal(), ret));
Expand Down Expand Up @@ -813,6 +822,7 @@ private static void checkNDArray(Pointer pointer, String msg) {

public static void checkCall(int ret) {
if (ret != 0) {
logger.error("MXNet engine call failed: " + getLastError());
throw new JnaCallException("MXNet engine call failed: " + getLastError());
}
}
Expand Down Expand Up @@ -875,26 +885,25 @@ public static boolean autogradIsTraining() {
return isTraining.get(0) == 1;
}

public static void waitAll() {
checkCall(LIB.MXNDArrayWaitAll());
}

/*****************************************************************************
* Tests
*****************************************************************************/
public static void main(String... args) {
try {
Runtime.getRuntime().addShutdownHook(new Thread(JnaUtils::waitAll));
Set<String> opNames = JnaUtils.getAllOpNames();
List<String> list = new ArrayList<>(opNames);

PointerByReference ref = REFS.acquire();
for (String opName : list.subList(0, 300)) {

for (String opName : list.subList(0, 400)) {
checkCall(LIB.NNGetOpHandle(opName, ref));
String functionName = getOpNamePrefix(opName);
// System.out.println("Name: " + opName + "/" + functionName);
getFunctionByName(opName, functionName, ref.getValue());

}

ref.setValue(null);
REFS.recycle(ref);

Expand Down

0 comments on commit 36a5f4e

Please sign in to comment.