Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change NDArray.toString() output #1142

Merged
merged 3 commits into from
Aug 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -4397,6 +4397,15 @@ default NDArray countNonzero(int axis) {
*/
NDArrayEx getNDArrayInternal();

/**
* Runs the debug string representation of this {@code NDArray}.
*
* @return the debug string representation of this {@code NDArray}
*/
default String toDebugString() {
return toDebugString(100, 10, 10, 20);
}

/**
* Runs the debug string representation of this {@code NDArray}.
*
Expand Down
87 changes: 67 additions & 20 deletions api/src/main/java/ai/djl/ndarray/internal/NDFormat.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.util.Utils;
import java.lang.management.ManagementFactory;
import java.util.Arrays;
import java.util.Locale;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
Expand All @@ -26,6 +28,11 @@ public abstract class NDFormat {
private static final int PRECISION = 8;
private static final String LF = System.getProperty("line.separator");
private static final Pattern PATTERN = Pattern.compile("\\s*\\d\\.(\\d*?)0*e[+-](\\d+)");
private static final boolean DEBUG =
ManagementFactory.getRuntimeMXBean()
.getInputArguments()
.stream()
.anyMatch(arg -> arg.startsWith("-agentlib:jdwp"));

/**
* Formats the contents of an array as a pretty printable string.
Expand All @@ -39,24 +46,6 @@ public abstract class NDFormat {
*/
public static String format(
NDArray array, int maxSize, int maxDepth, int maxRows, int maxColumns) {
NDFormat format;
DataType dataType = array.getDataType();

if (dataType == DataType.UINT8) {
format = new HexFormat();
} else if (dataType == DataType.BOOLEAN) {
format = new BooleanFormat();
} else if (dataType.isInteger()) {
format = new IntFormat(array);
} else {
format = new FloatFormat(array);
}
return format.dump(array, maxSize, maxDepth, maxRows, maxColumns);
}

protected abstract CharSequence format(Number value);

private String dump(NDArray array, int maxSize, int maxDepth, int maxRows, int maxColumns) {
StringBuilder sb = new StringBuilder(1000);
String name = array.getName();
if (name != null) {
Expand All @@ -72,6 +61,39 @@ private String dump(NDArray array, int maxSize, int maxDepth, int maxRows, int m
if (array.hasGradient()) {
sb.append(" hasGradient");
}
if (DEBUG) {
return sb.toString();
}

NDFormat format;
DataType dataType = array.getDataType();

if (dataType == DataType.UINT8) {
format = new HexFormat();
} else if (dataType == DataType.BOOLEAN) {
format = new BooleanFormat();
} else if (dataType == DataType.STRING) {
format = new StringFormat();
} else if (dataType.isInteger()) {
format = new IntFormat();
} else {
format = new FloatFormat();
}
return format.dump(sb, array, maxSize, maxDepth, maxRows, maxColumns);
}

protected abstract CharSequence format(Number value);

protected void init(NDArray array) {}

protected String dump(
StringBuilder sb,
NDArray array,
int maxSize,
int maxDepth,
int maxRows,
int maxColumns) {
init(array);
sb.append(LF);

long size = array.size();
Expand Down Expand Up @@ -152,7 +174,9 @@ private static final class FloatFormat extends NDFormat {
private int precision;
private int totalLength;

public FloatFormat(NDArray array) {
/** {@inheritDoc} */
@Override
public void init(NDArray array) {
Number[] values = array.toArray();
int maxIntPartLen = 0;
int maxFractionLen = 0;
Expand Down Expand Up @@ -290,7 +314,9 @@ private static final class IntFormat extends NDFormat {
private int precision;
private int totalLength;

public IntFormat(NDArray array) {
/** {@inheritDoc} */
@Override
public void init(NDArray array) {
Number[] values = array.toArray();
// scalar case
if (values.length == 1) {
Expand Down Expand Up @@ -338,4 +364,25 @@ public CharSequence format(Number value) {
return value.byteValue() != 0 ? " true" : "false";
}
}

private static final class StringFormat extends NDFormat {

/** {@inheritDoc} */
@Override
public CharSequence format(Number value) {
return null;
}

/** {@inheritDoc} */
@Override
protected String dump(
StringBuilder sb,
NDArray array,
int maxSize,
int maxDepth,
int maxRows,
int maxColumns) {
return Arrays.toString(array.toStringArray());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.UUID;

/** {@code DlrNDArray} is the DLR implementation of {@link NDArray}. */
Expand Down Expand Up @@ -129,14 +128,7 @@ public String toString() {
if (isClosed) {
return "This array is already closed";
}
return "ND: "
+ getShape()
+ ' '
+ getDevice()
+ ' '
+ getDataType()
+ '\n'
+ Arrays.toString(toArray());
return toDebugString();
}

/** {@inheritDoc} */
Expand Down
2 changes: 2 additions & 0 deletions extensions/benchmark/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ dependencies {
implementation "ai.djl:model-zoo"
runtimeOnly "ai.djl.pytorch:pytorch-model-zoo"
runtimeOnly "ai.djl.pytorch:pytorch-native-auto"

runtimeOnly "ai.djl.tensorflow:tensorflow-model-zoo"
runtimeOnly "ai.djl.tensorflow:tensorflow-native-auto"

runtimeOnly "ai.djl.mxnet:mxnet-model-zoo"
Expand Down
10 changes: 1 addition & 9 deletions ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicLong;
import ml.dmlc.xgboost4j.java.JniUtils;
Expand Down Expand Up @@ -147,14 +146,7 @@ public String toString() {
if (isClosed) {
return "This array is already closed";
}
return "ND: "
+ getShape()
+ ' '
+ getDevice()
+ ' '
+ getDataType()
+ '\n'
+ Arrays.toString(toArray());
return toDebugString();
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,6 @@
/** {@code MxNDArray} is the MXNet implementation of {@link NDArray}. */
public class MxNDArray extends NativeResource<Pointer> implements LazyNDArray {

private static final int MAX_SIZE = 100;
private static final int MAX_DEPTH = 10;
private static final int MAX_ROWS = 10;
private static final int MAX_COLUMNS = 20;

private String name;
private Device device;
private SparseFormat sparseFormat;
Expand Down Expand Up @@ -1617,7 +1612,7 @@ public String toString() {
if (isReleased()) {
return "This array is already closed";
}
return toDebugString(MAX_SIZE, MAX_DEPTH, MAX_ROWS, MAX_COLUMNS);
return toDebugString();
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import ai.onnxruntime.OrtException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;
import java.util.UUID;

/** {@code OrtNDArray} is the ONNX Runtime implementation of {@link NDArray}. */
Expand Down Expand Up @@ -150,13 +149,7 @@ public String toString() {
if (isClosed) {
return "This array is already closed";
}
String arrStr;
if (getDataType() == DataType.STRING) {
arrStr = Arrays.toString(toStringArray());
} else {
arrStr = Arrays.toString(toArray());
}
return "ND: " + getShape() + ' ' + getDevice() + ' ' + getDataType() + '\n' + arrStr;
return toDebugString();
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import ai.djl.paddlepaddle.jni.JniUtils;
import ai.djl.util.NativeResource;
import java.nio.ByteBuffer;
import java.util.Arrays;

/** {@code PpNDArray} is the PaddlePaddle implementation of {@link NDArray}. */
public class PpNDArray extends NativeResource<Long> implements NDArrayAdapter {
Expand Down Expand Up @@ -140,14 +139,7 @@ public String toString() {
if (isReleased()) {
return "This array is already closed";
}
return "ND: "
+ getShape()
+ ' '
+ getDevice()
+ ' '
+ getDataType()
+ '\n'
+ Arrays.toString(toArray());
return toDebugString();
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.internal.NDFormat;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
Expand All @@ -39,11 +38,6 @@
/** {@code PtNDArray} is the PyTorch implementation of {@link NDArray}. */
public class PtNDArray extends NativeResource<Long> implements NDArray {

private static final int MAX_SIZE = 100;
private static final int MAX_DEPTH = 10;
private static final int MAX_ROWS = 10;
private static final int MAX_COLUMNS = 20;

private String name;
private Device device;
private DataType dataType;
Expand Down Expand Up @@ -328,9 +322,9 @@ public PtNDArray booleanMask(NDArray index, int axis) {
} else {
throw new UnsupportedOperationException(
"Not supported for shape not broadcastable "
+ indexShape.toString()
+ indexShape
+ " vs "
+ getShape().toString());
+ getShape());
}
}

Expand Down Expand Up @@ -1447,10 +1441,10 @@ public String toString() {
// index operator in toDebugString is not supported for MKLDNN & Sparse layout
if (JniUtils.getLayout(this) != 0) {
try (NDArray tmp = toDense()) {
return NDFormat.format(tmp, MAX_SIZE, MAX_DEPTH, MAX_ROWS, MAX_COLUMNS);
return tmp.toDebugString();
}
}
return toDebugString(MAX_SIZE, MAX_DEPTH, MAX_ROWS, MAX_COLUMNS);
return toDebugString();
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,6 @@
@SuppressWarnings("PMD.UseTryWithResources")
public class TfNDArray extends NativeResource<TFE_TensorHandle> implements NDArray {

private static final int MAX_SIZE = 100;
private static final int MAX_DEPTH = 10;
private static final int MAX_ROWS = 10;
private static final int MAX_COLUMNS = 20;

private Shape shape;
private Device device;
private TfNDManager manager;
Expand Down Expand Up @@ -1611,7 +1606,7 @@ public String toString() {
if (isReleased()) {
return "This array is already closed";
}
return toDebugString(MAX_SIZE, MAX_DEPTH, MAX_ROWS, MAX_COLUMNS);
return toDebugString();
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,7 @@ public NDArray toType(DataType dataType, boolean copy) {
return manager.create(booleanResult).reshape(shape);
default:
throw new UnsupportedOperationException(
"Type conversion is not supported for TFLite for data type "
+ dataType.toString());
"Type conversion is not supported for TFLite for data type " + dataType);
}
}

Expand Down Expand Up @@ -196,14 +195,7 @@ public String toString() {
if (isClosed) {
return "This array is already closed";
}
return "ND: "
+ getShape()
+ ' '
+ getDevice()
+ ' '
+ getDataType()
+ '\n'
+ Arrays.toString(toArray());
return toDebugString();
}

/** {@inheritDoc} */
Expand Down