Skip to content

Commit

Permalink
DJL parameter saving (#669)
Browse files Browse the repository at this point in the history
* add load/save functionalities to MXNet and PyTorch

* make the change

Change-Id: Id1fe61f32a049185a94bf2d2ef01f5ec4cfdfeaf

* fix twice

* Refactoring JNI

Co-authored-by: gstu1130 <gstu1130@gmail.com>
  • Loading branch information
lanking520 and stu1130 authored Feb 23, 2021
1 parent 5b28382 commit 92cc448
Show file tree
Hide file tree
Showing 9 changed files with 200 additions and 7 deletions.
13 changes: 7 additions & 6 deletions api/src/main/java/ai/djl/util/Platform.java
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,17 @@ public boolean isPlaceholder() {
/**
* Returns true the platforms match (os and flavor).
*
* @param other the platform to compare it to
* @param system the platform to compare it to
* @return true if the platforms match
*/
public boolean matches(Platform other) {
if (!osPrefix.equals(other.osPrefix)) {
public boolean matches(Platform system) {
if (!osPrefix.equals(system.osPrefix)) {
return false;
}
if (flavor.startsWith("cu") != other.flavor.startsWith("cu")) {
return false;
// if system Machine is GPU
if (system.flavor.startsWith("cu")) {
return "".equals(flavor) || "cpu".equals(flavor) || flavor.equals(system.flavor);
}
return flavor.startsWith(other.flavor) || other.flavor.startsWith(flavor);
return "".equals(flavor) || "cpu".equals(flavor);
}
}
14 changes: 14 additions & 0 deletions mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/jna/JnaUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -1165,6 +1165,20 @@ public static Pointer createSymbolFromFile(String path) {
return pointer;
}

public static Pointer createSymbolFromString(String json) {
PointerByReference ref = REFS.acquire();
checkCall(LIB.MXSymbolCreateFromJSON(json, ref));
Pointer pointer = ref.getValue();
REFS.recycle(ref);
return pointer;
}

public static String getSymbolString(Pointer symbol) {
String[] holder = new String[1];
checkCall(LIB.MXSymbolSaveToJSON(symbol, holder));
return holder[0];
}

private static List<Shape> recoverShape(
NativeSizeByReference size, PointerByReference nDim, PointerByReference data) {
int shapeLength = (int) size.getValue().longValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import ai.djl.pytorch.engine.PtNDArray;
import ai.djl.pytorch.engine.PtNDManager;
import ai.djl.pytorch.engine.PtSymbolBlock;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.Path;
Expand All @@ -44,6 +46,8 @@ public final class JniUtils {

private static final int NULL_PTR = 0;

private static final int BYTE_LENGTH = 4194304;

private JniUtils() {}

private static int layoutMapper(SparseFormat fmt, Device device) {
Expand Down Expand Up @@ -1374,6 +1378,36 @@ public static PtSymbolBlock loadModule(
return new PtSymbolBlock(manager, handle);
}

public static PtSymbolBlock loadModule(PtNDManager manager, InputStream is, Device device) {
byte[] buf = new byte[BYTE_LENGTH];
long handle =
PyTorchLibrary.LIB.moduleLoad(
is,
new int[] {
PtDeviceType.toDeviceType(device),
device.equals(Device.cpu()) ? -1 : device.getDeviceId()
},
buf);
return new PtSymbolBlock(manager, handle);
}

public static void writeModule(PtSymbolBlock block, OutputStream os) {
byte[] buf = new byte[BYTE_LENGTH];
PyTorchLibrary.LIB.moduleWrite(block.getHandle(), os, buf);
}

public static NDList moduleGetParams(PtSymbolBlock block, PtNDManager manager) {
long[] handles = PyTorchLibrary.LIB.moduleGetParams(block.getHandle());
String[] names = PyTorchLibrary.LIB.moduleGetParamNames(block.getHandle());
NDList list = new NDList(handles.length);
for (int i = 0; i < handles.length; i++) {
PtNDArray array = new PtNDArray(manager, handles[i]);
array.setName(names[i]);
list.add(array);
}
return list;
}

public static void enableInferenceMode(PtSymbolBlock block) {
PyTorchLibrary.LIB.moduleEval(block.getHandle());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
*/
package ai.djl.pytorch.jni;

import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.Set;

Expand Down Expand Up @@ -460,12 +462,20 @@ native void torchBackward(
native long moduleLoad(
String path, int[] device, String[] extraFileNames, String[] extraFileValues);

native long moduleLoad(InputStream is, int[] device, byte[] buffer);

native void moduleEval(long handle);

native void moduleTrain(long handle);

native long moduleForward(long moduleHandle, long[] iValueHandles, boolean isTrain);

native void moduleWrite(long moduleHandle, OutputStream os, byte[] buffer);

native long[] moduleGetParams(long moduleHandle);

native String[] moduleGetParamNames(long moduleHandle);

native long iValueFromTensor(long tensorHandle);

native long iValueFromList(long[] tensorHandles);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,19 @@
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.pytorch.engine.PtNDManager;
import ai.djl.pytorch.engine.PtSymbolBlock;
import ai.djl.pytorch.jni.JniUtils;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.nio.file.Path;
import org.testng.Assert;
import org.testng.annotations.Test;
Expand Down Expand Up @@ -65,4 +72,21 @@ public void testDictInput() throws ModelException, IOException, TranslateExcepti
}
}
}

@Test
public void testInputOutput() throws IOException {
URL url =
new URL("https://djl-ai.s3.amazonaws.com/resources/test-models/traced_resnet18.pt");
try (PtNDManager manager = (PtNDManager) NDManager.newBaseManager()) {
try (InputStream is = url.openStream()) {
PtSymbolBlock block = JniUtils.loadModule(manager, is, manager.getDevice());
ByteArrayOutputStream os = new ByteArrayOutputStream();
JniUtils.writeModule(block, os);
ByteArrayInputStream bis = new ByteArrayInputStream(os.toByteArray());
JniUtils.loadModule(manager, bis, manager.getDevice());
bis.close();
os.close();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ struct JITCallGuard {
torch::NoGradGuard no_grad;
};

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleLoad(
JNIEXPORT jlong JNICALL
Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleLoad__Ljava_lang_String_2_3I_3Ljava_lang_String_2_3Ljava_lang_String_2(
JNIEnv* env, jobject jthis, jstring jpath, jintArray jarray, jobjectArray jefnames, jobjectArray jefvalues) {
API_BEGIN()
const std::string path = djl::utils::jni::GetStringFromJString(env, jpath);
Expand All @@ -46,6 +47,72 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleLoad(
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleLoad__Ljava_io_InputStream_2_3I_3B(
JNIEnv* env, jobject jthis, jobject jis, jintArray jarray, jbyteArray arr) {
API_BEGIN()
jclass is_class = env->GetObjectClass(jis);
if (is_class == nullptr) {
env->ThrowNew(NULL_PTR_EXCEPTION_CLASS, "Java inputStream class is not found");
return -1;
}
jmethodID method_id = env->GetMethodID(is_class, "read", "([BII)I");
if (method_id == nullptr) {
env->ThrowNew(ENGINE_EXCEPTION_CLASS, "The read method in InputStream is not found");
return -1;
}
std::ostringstream os;
int len = env->GetArrayLength(arr);
int available = 0;
jbyte* data;
while (available != -1) {
available = env->CallIntMethod(jis, method_id, arr, 0, len);
if (available != -1) {
data = env->GetByteArrayElements(arr, JNI_FALSE);
os.write(reinterpret_cast<char*>(data), available);
env->ReleaseByteArrayElements(arr, data, JNI_ABORT);
}
}
std::istringstream in(os.str());
const torch::Device device = utils::GetDeviceFromJDevice(env, jarray);
const torch::jit::script::Module module = torch::jit::load(in, device);
const auto* module_ptr = new torch::jit::script::Module(module);
return reinterpret_cast<uintptr_t>(module_ptr);
API_END_RETURN()
}

JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleWrite(
JNIEnv* env, jobject jthis, jlong module_handle, jobject jos, jbyteArray arr) {
API_BEGIN()
auto* module_ptr = reinterpret_cast<torch::jit::script::Module*>(module_handle);
std::ostringstream stream;
module_ptr->save(stream);
auto str = stream.str();
jclass os_class = env->GetObjectClass(jos);
if (os_class == nullptr) {
env->ThrowNew(NULL_PTR_EXCEPTION_CLASS, "Java OutputStream class is not found");
return;
}
jmethodID method_id = env->GetMethodID(os_class, "write", "([BII)V");
if (method_id == nullptr) {
env->ThrowNew(ENGINE_EXCEPTION_CLASS, "The write method in OutputStream is not found");
return;
}
int len = env->GetArrayLength(arr);
int i = 0;
for (; i + len < str.length(); i += len) {
auto substr = str.substr(i, i + len);
env->SetByteArrayRegion(arr, 0, len, (jbyte*)substr.c_str());
env->CallVoidMethod(jos, method_id, arr, 0, len);
}
auto last_len = str.length() - i;
if (last_len > 0) {
auto substr = str.substr(i, last_len);
env->SetByteArrayRegion(arr, 0, last_len, (jbyte*)substr.c_str());
env->CallVoidMethod(jos, method_id, arr, 0, last_len);
}
API_END()
}

JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleEval(
JNIEnv* env, jobject jthis, jlong module_handle) {
API_BEGIN()
Expand Down Expand Up @@ -99,3 +166,38 @@ JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchDeleteModule(
delete module_ptr;
API_END()
}

JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleSave(
JNIEnv* env, jobject jthis, jlong jhandle, jstring jpath) {
API_BEGIN()
auto* module_ptr = reinterpret_cast<torch::jit::script::Module*>(jhandle);
module_ptr->save(djl::utils::jni::GetStringFromJString(env, jpath));
API_END()
}

JNIEXPORT jlongArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleGetParams(
JNIEnv* env, jobject jthis, jlong jhandle) {
API_BEGIN()
auto* module_ptr = reinterpret_cast<torch::jit::script::Module*>(jhandle);
std::vector<jlong> jptrs;
for (const auto& tensor : module_ptr->parameters()) {
jptrs.push_back(reinterpret_cast<uintptr_t>(new torch::Tensor(tensor)));
}
size_t len = jptrs.size();
jlongArray jarray = env->NewLongArray(len);
env->SetLongArrayRegion(jarray, 0, len, jptrs.data());
return jarray;
API_END_RETURN()
}

JNIEXPORT jobjectArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleGetParamNames(
JNIEnv* env, jobject jthis, jlong jhandle) {
API_BEGIN()
auto* module_ptr = reinterpret_cast<torch::jit::script::Module*>(jhandle);
std::vector<std::string> jptrs;
for (const auto& named_tensor : module_ptr->named_parameters()) {
jptrs.push_back(named_tensor.name);
}
return djl::utils::jni::GetStringArrayFromVec(env, jptrs);
API_END_RETURN()
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNInterpolat
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
const auto size_vec = djl::utils::jni::GetVecFromJLongArray(env, jsize);

#if defined(__ANDROID__)
torch::Tensor result;
if (jmode == 0) {
Expand All @@ -52,6 +53,7 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNInterpolat
result = torch::upsample_bicubic2d(*tensor_ptr, size_vec, jalign_corners);
} else {
env->ThrowNew(ENGINE_EXCEPTION_CLASS, "This kind of mode is not supported on Android");
return nullptr;
}
const auto* result_ptr = new torch::Tensor(result);
#else
Expand Down Expand Up @@ -160,6 +162,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNRnn(J
jtraining, jbidirectional, jbatch_first);
} else {
env->ThrowNew(ENGINE_EXCEPTION_CLASS, "can't find activation");
return nullptr;
}

// process output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,12 @@ JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchShowConfig(
jclass set_class = env->GetObjectClass(jset);
if (set_class == nullptr) {
env->ThrowNew(NULL_PTR_EXCEPTION_CLASS, "Java Set class is not found");
return;
}
jmethodID add_method_id = env->GetMethodID(set_class, "add", "(Ljava/lang/Object;)Z");
if (add_method_id == nullptr) {
env->ThrowNew(NULL_PTR_EXCEPTION_CLASS, "The add method in Set is not found");
return;
}
std::string feature;
jstring jfeature;
Expand Down Expand Up @@ -253,6 +255,7 @@ JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchStartProfile(
API_BEGIN()
if (profilerEnabled()) {
env->ThrowNew(ENGINE_EXCEPTION_CLASS, "please call stopProfile before you start a new section");
return;
}
enableProfiler(ProfilerConfig(juse_cuda ? ProfilerState::CUDA : ProfilerState::CPU,
/* report_input_shapes */ jrecord_shape,
Expand All @@ -265,6 +268,7 @@ JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchStopProfile(
API_BEGIN()
if (!profilerEnabled()) {
env->ThrowNew(ENGINE_EXCEPTION_CLASS, "please call startProfiler() before you use stopProfile!");
return;
}
std::string output_file = djl::utils::jni::GetStringFromJString(env, joutput_file);
std::ofstream file(output_file);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ JNIEXPORT jintArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchDevice(
jintArray result = env->NewIntArray(2);
if (result == nullptr) {
env->ThrowNew(NULL_PTR_EXCEPTION_CLASS, "Unable to create int array");
return nullptr;
}
jint temp_device[] = {static_cast<int>(tensor_ptr->device().type()), tensor_ptr->device().index()};
env->SetIntArrayRegion(result, 0, 2, temp_device);
Expand Down

0 comments on commit 92cc448

Please sign in to comment.