Skip to content

Commit

Permalink
[api] Support encode/decode String tensor (#3034)
Browse files Browse the repository at this point in the history
Fixes: #3033
  • Loading branch information
frankfliu authored Mar 22, 2024
1 parent e3a8e4c commit b3b04f5
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 0 deletions.
21 changes: 21 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDSerializer.java
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ static void encode(NDArray array, OutputStream os) throws IOException {
Shape shape = array.getShape();
dos.write(shape.getEncoded());

if (array.getDataType() == DataType.STRING) {
String[] data = array.toStringArray();
dos.writeInt(data.length);
for (String str : data) {
dos.writeUTF(str);
}
dos.flush();
return;
}

ByteBuffer bb = array.toByteBuffer();
dos.write(bb.order() == ByteOrder.BIG_ENDIAN ? '>' : '<');
int length = bb.remaining();
Expand Down Expand Up @@ -167,6 +177,17 @@ static NDArray decode(NDManager manager, ByteBuffer bb) {
// Shape
Shape shape = Shape.decode(bb);

if (dataType == DataType.STRING) {
int size = bb.getInt();
String[] data = new String[size];
for (int i = 0; i < size; ++i) {
data[i] = readUTF(bb);
}
NDArray array = manager.create(data, StandardCharsets.UTF_8, shape);
array.setName(name);
return array;
}

// Data
ByteOrder order;
if (version > 2) {
Expand Down
17 changes: 17 additions & 0 deletions api/src/test/java/ai/djl/ndarray/NDSerializerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,23 @@ public void testNDSerializer() throws IOException {
}
}

@Test
public void testStringTensor() {
try (NDManager manager = NDManager.newBaseManager("PyTorch")) {
NDArray array = manager.create("hello");
byte[] buf = array.encode();
NDArray decoded = NDArray.decode(manager, buf);
Assert.assertTrue(decoded.getShape().isScalar());

array = manager.create(new String[] {"hello", "world"});
buf = array.encode();
decoded = NDArray.decode(manager, buf);
Assert.assertEquals(decoded.getShape(), array.getShape());
Assert.assertEquals(decoded.toStringArray()[1], "world");
Assert.assertEquals(decoded, array);
}
}

private static byte[] encode(NDArray array) throws IOException {
try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) {
NDSerializer.encodeAsNumpy(array, bos);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ public PtNDArray(PtNDManager manager, String[] strs, Shape shape) {
super(-1L);
this.manager = manager;
this.strs = strs;
this.sparseFormat = SparseFormat.DENSE;
this.shape = shape;
this.dataType = DataType.STRING;
NDScope.register(this);
Expand Down Expand Up @@ -225,6 +226,10 @@ public NDArray stopGradient() {
/** {@inheritDoc} */
@Override
public ByteBuffer toByteBuffer() {
if (getDataType() == DataType.STRING) {
throw new UnsupportedOperationException(
"toByteBuffer is not supported for String tensor.");
}
return JniUtils.getByteBuffer(this);
}

Expand Down Expand Up @@ -429,6 +434,9 @@ public boolean contentEquals(NDArray other) {
if (getDataType() != other.getDataType()) {
return false;
}
if (getDataType() == DataType.STRING) {
return Arrays.equals(toStringArray(), other.toStringArray());
}
return JniUtils.contentEqual(this, manager.from(other));
}

Expand Down

0 comments on commit b3b04f5

Please sign in to comment.