From 10ebc3f0bcc5abc073cfb6398703a588b102583b Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 2 Apr 2019 15:08:17 -0700 Subject: [PATCH 1/3] update reshape operator --- .../src/main/scala/org/apache/mxnet/LibInfo.scala | 5 +++-- .../src/main/scala/org/apache/mxnet/NDArray.scala | 13 ++++++++++++- .../test/scala/org/apache/mxnet/NDArraySuite.scala | 8 ++++++-- .../main/native/org_apache_mxnet_native_c_api.cc | 12 ++++++------ .../src/main/native/org_apache_mxnet_native_c_api.h | 6 +++--- 5 files changed, 30 insertions(+), 14 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala index 20b6ed9fc806..40fc0951e885 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala @@ -86,9 +86,10 @@ private[mxnet] class LibInfo { @native def mxNDArrayAt(handle: NDArrayHandle, idx: MXUint, out: NDArrayHandleRef): Int - @native def mxNDArrayReshape(handle: NDArrayHandle, + @native def mxNDArrayReshape64(handle: NDArrayHandle, nDim: Int, - dims: Array[Int], + dims: Array[Long], + reverse: Boolean, reshapeHandle: NDArrayHandleRef): Int @native def mxNDArraySyncCopyFromCPU(handle: NDArrayHandle, source: Array[MXFloat], diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 915e4c69de31..ab42265ae102 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -950,8 +950,19 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, * @return a reshaped NDArray that shares memory with current one. */ def reshape(dims: Array[Int]): NDArray = { + reshape(dims.map(_.toLong)) + } + + /** + * Return a reshaped NDArray that shares memory with current one. + * @param dims New shape. + * @param reverse whether to inplace reshape + * @return a reshaped NDArray that shares memory with current one. + */ + def reshape(dims: Array[Long], reverse: Option[Boolean] = None): NDArray = { val reshapeHandle = new NDArrayHandleRef - checkCall(_LIB.mxNDArrayReshape(handle, dims.length, dims, reshapeHandle)) + checkCall(_LIB.mxNDArrayReshape64(handle, + dims.length, dims, reverse.getOrElse(false), reshapeHandle)) new NDArray(handle = reshapeHandle.value, writable = this.writable) } diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala index 206094c15958..c2ef641f9c9a 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala @@ -878,14 +878,18 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { } test("reshape") { - val arr = NDArray.array(Array(1f, 2f, 3f, 4f, 5f, 6f), shape = Shape(3, 2)) + var arr = NDArray.array(Array(1f, 2f, 3f, 4f, 5f, 6f), shape = Shape(3, 2)) - val arr1 = arr.reshape(Array(2, 3)) + var arr1 = arr.reshape(Array(2, 3)) assert(arr1.shape === Shape(2, 3)) assert(arr1.toArray === Array(1f, 2f, 3f, 4f, 5f, 6f)) arr.set(1f) assert(arr1.toArray === Array(1f, 1f, 1f, 1f, 1f, 1f)) + + arr = NDArray.ones(1, 384, 1) + arr1 = arr.reshape(Array(0, -3)) + assert(arr1.shape === Shape(1, 384)) } test("dispose deps") { diff --git a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc index ea6e9c8f5ba4..55863d8864de 100644 --- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc @@ -404,14 +404,14 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayAt return ret; } -JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayReshape - (JNIEnv *env, jobject obj, jlong ndArrayPtr, jint ndim, jintArray dims, jobject reshapedHandle) { +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayReshape64 + (JNIEnv *env, jobject obj, jlong ndArrayPtr, jint ndim, jlongArray dims, jboolean reverse, jobject reshapedHandle) { NDArrayHandle out; - jint *pdims = env->GetIntArrayElements(dims, NULL); - int ret = MXNDArrayReshape(reinterpret_cast(ndArrayPtr), ndim, - reinterpret_cast(pdims), &out); + jlong *pdims = env->GetLongArrayElements(dims, NULL); + int ret = MXNDArrayReshape64(reinterpret_cast(ndArrayPtr), ndim, + reinterpret_cast(pdims), reverse, &out); SetLongField(env, reshapedHandle, reinterpret_cast(out)); - env->ReleaseIntArrayElements(dims, pdims, 0); + env->ReleaseLongArrayElements(dims, pdims, 0); return ret; } diff --git a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h index 7e8e03de9124..0cee416fd28f 100644 --- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h +++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h @@ -161,11 +161,11 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayAt /* * Class: org_apache_mxnet_LibInfo - * Method: mxNDArrayReshape + * Method: mxNDArrayReshape64 * Signature: (JI[ILorg/apache/mxnet/Base/RefLong;)I */ -JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayReshape - (JNIEnv *, jobject, jlong, jint, jintArray, jobject); +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayReshape64 + (JNIEnv *, jobject, jlong, jint, jlongArray, jboolean, jobject); /* * Class: org_apache_mxnet_LibInfo From 610e4171c914ce4fc33bfedbc8ad8942603954b3 Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 2 Apr 2019 15:15:28 -0700 Subject: [PATCH 2/3] Satisfy the Lint God =v= --- .../native/src/main/native/org_apache_mxnet_native_c_api.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc index 55863d8864de..33e4cca99b3a 100644 --- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc @@ -405,7 +405,8 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayAt } JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayReshape64 - (JNIEnv *env, jobject obj, jlong ndArrayPtr, jint ndim, jlongArray dims, jboolean reverse, jobject reshapedHandle) { + (JNIEnv *env, jobject obj, jlong ndArrayPtr, jint ndim, + jlongArray dims, jboolean reverse, jobject reshapedHandle) { NDArrayHandle out; jlong *pdims = env->GetLongArrayElements(dims, NULL); int ret = MXNDArrayReshape64(reinterpret_cast(ndArrayPtr), ndim, From 48b5b374b9983c3b58c8c5cb8ad459034c3f14eb Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 2 Apr 2019 15:49:16 -0700 Subject: [PATCH 3/3] update the jni header signature --- .../native/src/main/native/org_apache_mxnet_native_c_api.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h index 0cee416fd28f..b8a9b3b9e64f 100644 --- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h +++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h @@ -162,7 +162,7 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayAt /* * Class: org_apache_mxnet_LibInfo * Method: mxNDArrayReshape64 - * Signature: (JI[ILorg/apache/mxnet/Base/RefLong;)I + * Signature: (JI[JZLorg/apache/mxnet/Base/RefLong;)I */ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayReshape64 (JNIEnv *, jobject, jlong, jint, jlongArray, jboolean, jobject);