diff --git a/src/main/java/org/apposed/appose/NDArray.java b/src/main/java/org/apposed/appose/NDArray.java index e86084b..7e0bdde 100644 --- a/src/main/java/org/apposed/appose/NDArray.java +++ b/src/main/java/org/apposed/appose/NDArray.java @@ -43,7 +43,7 @@ public class NDArray implements AutoCloseable { /** * shared memory containing the flattened array data. */ - private final SharedMemory sharedMemory; + private final SharedMemory shm; /** * data type of the array elements. @@ -56,16 +56,19 @@ public class NDArray implements AutoCloseable { private final Shape shape; /** - * Constructs an {@code NDArray} with the specified {@code SharedMemory}. + * Constructs an {@code NDArray} with the specified data type, shape, + * and {@code SharedMemory}. * - * @param sharedMemory the flattened array data. * @param dType element data type * @param shape array shape + * @param shm the flattened array data. */ - public NDArray(final SharedMemory sharedMemory, final DType dType, final Shape shape) { - this.sharedMemory = sharedMemory; + public NDArray(final DType dType, final Shape shape, final SharedMemory shm) { this.dType = dType; this.shape = shape; + this.shm = shm == null + ? SharedMemory.create(null, safeInt(shape.numElements() * dType.bytesPerElement())) + : shm; } /** @@ -76,8 +79,7 @@ public NDArray(final SharedMemory sharedMemory, final DType dType, final Shape s * @param shape array shape */ public NDArray(final DType dType, final Shape shape) { - this(SharedMemory.create(null, - safeInt(shape.numElements() * dType.bytesPerElement())), dType, shape); + this(dType, shape, null); } /** @@ -98,7 +100,7 @@ public Shape shape() { * @return The shared memory block containing the array data. */ public SharedMemory shm() { - return sharedMemory; + return shm; } /** @@ -108,7 +110,7 @@ public SharedMemory shm() { */ public ByteBuffer buffer() { final long length = shape.numElements() * dType.bytesPerElement(); - return sharedMemory.pointer().getByteBuffer(0, length); + return shm.pointer().getByteBuffer(0, length); } /** @@ -116,16 +118,16 @@ public ByteBuffer buffer() { */ @Override public void close() throws Exception { - sharedMemory.close(); + shm.close(); } @Override public String toString() { - return "NDArray{" + - "sharedMemory=" + sharedMemory + - ", dType=" + dType + - ", shape=" + shape + - '}'; + return "NDArray(" + + "dType=" + dType + + ", shape=" + shape + + ", shm=" + shm + + ")"; } /** diff --git a/src/main/java/org/apposed/appose/Types.java b/src/main/java/org/apposed/appose/Types.java index 811ff0c..f89bae4 100644 --- a/src/main/java/org/apposed/appose/Types.java +++ b/src/main/java/org/apposed/appose/Types.java @@ -132,9 +132,9 @@ public Object convert(final Object value, final String key) { map.put("name", shm.name()); map.put("size", shm.size()); })).addConverter(convert(NDArray.class, "ndarray", (map, ndArray) -> { - map.put("shm", ndArray.shm()); map.put("dtype", ndArray.dType().label()); map.put("shape", ndArray.shape().toIntArray(C_ORDER)); + map.put("shm", ndArray.shm()); })).build(); @@ -163,10 +163,10 @@ private static Object processValue(Object value) { final int size = (int) map.get("size"); return SharedMemory.attach(name, size); case "ndarray": - final SharedMemory shm = (SharedMemory) map.get("shm"); final NDArray.DType dType = toDType((String) map.get("dtype")); final NDArray.Shape shape = toShape((List) map.get("shape")); - return new NDArray(shm, dType, shape); + final SharedMemory shm = (SharedMemory) map.get("shm"); + return new NDArray(dType, shape, shm); default: System.err.println("unknown appose_type \"" + appose_type + "\""); } diff --git a/src/test/java/org/apposed/appose/NDArrayExampleGroovy.java b/src/test/java/org/apposed/appose/NDArrayExampleGroovy.java index 1467656..0948605 100644 --- a/src/test/java/org/apposed/appose/NDArrayExampleGroovy.java +++ b/src/test/java/org/apposed/appose/NDArrayExampleGroovy.java @@ -13,7 +13,8 @@ public static void main(String[] args) throws Exception { // create a FLOAT32 NDArray with shape (4,3,2) in F_ORDER // respectively (2,3,4) in C_ORDER final NDArray.DType dType = NDArray.DType.FLOAT32; - final NDArray ndArray = new NDArray(dType, new NDArray.Shape(F_ORDER, 4, 3, 2)); + final NDArray.Shape shape = new NDArray.Shape(F_ORDER, 4, 3, 2); + final NDArray ndArray = new NDArray(dType, shape); // fill with values 0..23 in flat iteration order final FloatBuffer buf = ndArray.buffer().asFloatBuffer(); diff --git a/src/test/java/org/apposed/appose/NDArrayExamplePython.java b/src/test/java/org/apposed/appose/NDArrayExamplePython.java index 4a80b39..4512377 100644 --- a/src/test/java/org/apposed/appose/NDArrayExamplePython.java +++ b/src/test/java/org/apposed/appose/NDArrayExamplePython.java @@ -13,7 +13,8 @@ public static void main(String[] args) throws Exception { // create a FLOAT32 NDArray with shape (4,3,2) in F_ORDER // respectively (2,3,4) in C_ORDER final NDArray.DType dType = NDArray.DType.FLOAT32; - final NDArray ndArray = new NDArray(dType, new NDArray.Shape(F_ORDER, 4, 3, 2)); + final NDArray.Shape shape = new NDArray.Shape(F_ORDER, 4, 3, 2); + final NDArray ndArray = new NDArray(dType, shape); // fill with values 0..23 in flat iteration order final FloatBuffer buf = ndArray.buffer().asFloatBuffer();