Skip to content

Commit 172a3ae

Browse files
authoredJan 31, 2025
Merge pull request #33232 from vespa-engine/arnej/add-hex-encoding-1b
optional hex encoding for dense parts of tensors
2 parents d888663 + 1e6d102 commit 172a3ae

File tree

2 files changed

+150
-29
lines changed

2 files changed

+150
-29
lines changed
 

‎vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java

Lines changed: 136 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
import com.yahoo.lang.MutableInteger;
55
import com.yahoo.slime.ArrayTraverser;
66
import com.yahoo.slime.Cursor;
7+
import com.yahoo.slime.Inserter;
78
import com.yahoo.slime.Inspector;
89
import com.yahoo.slime.JsonDecoder;
10+
import com.yahoo.slime.ObjectInserter;
911
import com.yahoo.slime.ObjectTraverser;
1012
import com.yahoo.slime.Slime;
13+
import com.yahoo.slime.SlimeInserter;
1114
import com.yahoo.slime.Type;
1215
import com.yahoo.tensor.DimensionSizes;
1316
import com.yahoo.tensor.IndexedTensor;
@@ -17,6 +20,7 @@
1720
import com.yahoo.tensor.TensorAddress;
1821
import com.yahoo.tensor.TensorType;
1922
import java.util.Iterator;
23+
import java.util.function.Function;
2024

2125
/**
2226
* Writes tensors on the JSON format used in Vespa tensor document fields:
@@ -28,6 +32,10 @@
2832
*/
2933
public class JsonFormat {
3034

35+
/** Options for encode */
36+
public record EncodeOptions(boolean shortForm, boolean directValues, boolean hexForDensePart) {
37+
// TODO - consider "compact" flag
38+
}
3139
/**
3240
* Serializes the given tensor value into JSON format.
3341
*
@@ -36,38 +44,55 @@ public class JsonFormat {
3644
* @param directValues whether to encode values directly, or wrapped in am object containing "type" and "cells"
3745
*/
3846
public static byte[] encode(Tensor tensor, boolean shortForm, boolean directValues) {
47+
return encode(tensor, new EncodeOptions(shortForm, directValues, false));
48+
}
49+
50+
/**
51+
* Serializes the given tensor value into JSON format.
52+
*
53+
* @param tensor the tensor to serialize
54+
* @param options format options for short/long, wrapped/direct, etc
55+
*/
56+
public static byte[] encode(Tensor tensor, EncodeOptions options) {
3957
Slime slime = new Slime();
40-
Cursor root = null;
41-
if ( ! directValues) {
42-
root = slime.setObject();
58+
Function<String, Inserter> target = (key -> new SlimeInserter(slime));
59+
final Cursor root = options.directValues() ? null : slime.setObject();
60+
if ( ! options.directValues()) {
4361
root.setString("type", tensor.type().toString());
62+
target = (key -> new ObjectInserter(root, key));
4463
}
4564

46-
if (shortForm) {
65+
if (options.shortForm()) {
4766
if (tensor instanceof IndexedTensor denseTensor) {
48-
// Encode as nested lists if indexed tensor
49-
Cursor parent = root == null ? slime.setArray() : root.setArray("values");
50-
encodeValues(denseTensor, parent, new long[denseTensor.dimensionSizes().dimensions()], 0);
51-
} else if (tensor instanceof MappedTensor && tensor.type().dimensions().size() == 1) {
67+
if (options.hexForDensePart()) {
68+
target.apply("values").insertSTRING(asHexString(denseTensor));
69+
} else {
70+
// Encode as nested arrays if indexed tensor
71+
Cursor parent = target.apply("values").insertARRAY();
72+
encodeDenseValues(denseTensor, parent);
73+
}
74+
} else if (tensor instanceof MappedTensor mapped && tensor.type().dimensions().size() == 1) {
5275
// Short form for a single mapped dimension
53-
Cursor parent = root == null ? slime.setObject() : root.setObject("cells");
54-
encodeSingleDimensionCells((MappedTensor) tensor, parent);
55-
} else if (tensor instanceof MixedTensor && tensor.type().hasMappedDimensions()) {
76+
Cursor parent = target.apply("cells").insertOBJECT();
77+
encodeSingleDimensionCells(mapped, parent);
78+
} else if (tensor instanceof MixedTensor mixed && tensor.type().hasMappedDimensions()) {
5679
// Short form for a mixed tensor
5780
boolean singleMapped = tensor.type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() == 1;
58-
Cursor parent = root == null ? ( singleMapped ? slime.setObject() : slime.setArray() )
59-
: ( singleMapped ? root.setObject("blocks") : root.setArray("blocks"));
60-
encodeBlocks((MixedTensor) tensor, parent);
81+
if (singleMapped) {
82+
encodeLabeledBlocks(mixed, target.apply("blocks").insertOBJECT(), options.hexForDensePart());
83+
} else {
84+
encodeAddressedBlocks(mixed, target.apply("blocks").insertARRAY(), options.hexForDensePart());
85+
}
6186
} else {
6287
// default to standard cell address output
63-
Cursor parent = root == null ? slime.setArray() : root.setArray("cells");
88+
Cursor parent = target.apply("cells").insertARRAY();
6489
encodeCells(tensor, parent);
6590
}
6691

6792
return com.yahoo.slime.JsonFormat.toJsonBytes(slime);
6893
}
6994
else {
70-
Cursor parent = root == null ? slime.setArray() : root.setArray("cells");
95+
Cursor parent = target.apply("cells").insertARRAY();
7196
encodeCells(tensor, parent);
7297
}
7398
return com.yahoo.slime.JsonFormat.toJsonBytes(slime);
@@ -118,6 +143,73 @@ private static void encodeAddress(TensorType type, TensorAddress address, Cursor
118143
addressObject.setString(type.dimensions().get(i).name(), address.label(i));
119144
}
120145

146+
private static final char[] hexDigits = {
147+
'0', '1', '2', '3', '4', '5', '6', '7',
148+
'8', '9', 'A', 'B', 'C', 'D', 'E', 'F'
149+
};
150+
151+
152+
private static String asHexString(IndexedTensor tensor) {
153+
return asHexString(tensor.sizeAsInt(),
154+
tensor.type().valueType(),
155+
i -> tensor.get(i),
156+
i -> tensor.getFloat(i));
157+
}
158+
159+
private static String asHexString(int denseSize,
160+
TensorType.Value cellType,
161+
Function<Integer, Double> dblSrc,
162+
Function<Integer, Float> fltSrc)
163+
{
164+
StringBuilder buf = new StringBuilder();
165+
switch (cellType) {
166+
case DOUBLE:
167+
for (int i = 0; i < denseSize; i++) {
168+
double d = dblSrc.apply(i);
169+
long bits = Double.doubleToRawLongBits(d);
170+
for (int nibble = 16; nibble-- > 0; ) {
171+
int digit = (int) (bits >> (4 * nibble)) & 0xF;
172+
buf.append(hexDigits[digit]);
173+
}
174+
}
175+
break;
176+
case FLOAT:
177+
for (int i = 0; i < denseSize; i++) {
178+
float f = fltSrc.apply(i);
179+
int bits = Float.floatToRawIntBits(f);
180+
for (int nibble = 8; nibble-- > 0; ) {
181+
int digit = (bits >> (4 * nibble)) & 0xF;
182+
buf.append(hexDigits[digit]);
183+
}
184+
}
185+
break;
186+
case BFLOAT16:
187+
for (int i = 0; i < denseSize; i++) {
188+
float f = fltSrc.apply(i);
189+
int bits = Float.floatToRawIntBits(f);
190+
for (int nibble = 8; nibble-- > 4; ) {
191+
int digit = (bits >> (4 * nibble)) & 0xF;
192+
buf.append(hexDigits[digit]);
193+
}
194+
}
195+
break;
196+
case INT8:
197+
for (int i = 0; i < denseSize; i++) {
198+
byte bits = fltSrc.apply(i).byteValue();
199+
for (int nibble = 2; nibble-- > 0; ) {
200+
int digit = (bits >> (4 * nibble)) & 0xF;
201+
buf.append(hexDigits[digit]);
202+
}
203+
}
204+
break;
205+
}
206+
return buf.toString();
207+
}
208+
209+
private static void encodeDenseValues(IndexedTensor tensor, Cursor target) {
210+
encodeValues(tensor, target, new long[tensor.dimensionSizes().dimensions()], 0);
211+
}
212+
121213
private static void encodeValues(IndexedTensor tensor, Cursor cursor, long[] indexes, int dimension) {
122214
DimensionSizes sizes = tensor.dimensionSizes();
123215
if (indexes.length == 0) {
@@ -133,26 +225,41 @@ private static void encodeValues(IndexedTensor tensor, Cursor cursor, long[] ind
133225
}
134226
}
135227

136-
private static void encodeBlocks(MixedTensor tensor, Cursor cursor) {
137-
var mappedDimensions = tensor.type().dimensions().stream().filter(TensorType.Dimension::isMapped)
138-
.map(d -> TensorType.Dimension.mapped(d.name())).toList();
228+
private static void encodeLabeledSubspace(String label, MixedTensor.DenseSubspace subspace, TensorType denseSubType, Cursor cursor, boolean hexForDensePart) {
229+
if (hexForDensePart) {
230+
cursor.setString(label, asHexString(subspace.cells.length,
231+
denseSubType.valueType(),
232+
i -> subspace.cells[i],
233+
i -> (float)subspace.cells[i]));
234+
} else {
235+
IndexedTensor denseSubspace = IndexedTensor.Builder.of(denseSubType, subspace.cells).build();
236+
var target = cursor.setArray(label);
237+
encodeDenseValues(denseSubspace, target);
238+
}
239+
}
240+
241+
private static void encodeLabeledBlocks(MixedTensor tensor, Cursor cursor, boolean hexForDensePart) {
242+
TensorType denseSubType = tensor.type().indexedSubtype();
243+
for (var subspace : tensor.getInternalDenseSubspaces()) {
244+
String label = subspace.sparseAddress.label(0);
245+
encodeLabeledSubspace(label, subspace, denseSubType, cursor, hexForDensePart);
246+
}
247+
}
248+
249+
private static void encodeAddressedBlocks(MixedTensor tensor, Cursor cursor, boolean hexForDensePart) {
250+
var mappedDimensions = tensor.type().dimensions().stream()
251+
.filter(TensorType.Dimension::isMapped)
252+
.toList();
139253
if (mappedDimensions.isEmpty()) {
140254
throw new IllegalArgumentException("Should be ensured by caller");
141255
}
142-
143256
// Create tensor type for mapped dimensions subtype
144257
TensorType mappedSubType = new TensorType.Builder(mappedDimensions).build();
145258
TensorType denseSubType = tensor.type().indexedSubtype();
146259
for (var subspace : tensor.getInternalDenseSubspaces()) {
147-
IndexedTensor denseSubspace = IndexedTensor.Builder.of(denseSubType, subspace.cells).build();
148-
if (mappedDimensions.size() == 1) {
149-
encodeValues(denseSubspace, cursor.setArray(subspace.sparseAddress.label(0)), new long[denseSubspace.dimensionSizes().dimensions()], 0);
150-
} else {
151-
Cursor block = cursor.addObject();
152-
encodeAddress(mappedSubType, subspace.sparseAddress, block.setObject("address"));
153-
encodeValues(denseSubspace, block.setArray("values"), new long[denseSubspace.dimensionSizes().dimensions()], 0);
154-
}
155-
260+
Cursor block = cursor.addObject();
261+
encodeAddress(mappedSubType, subspace.sparseAddress, block.setObject("address"));
262+
encodeLabeledSubspace("values", subspace, denseSubType, block, hexForDensePart);
156263
}
157264
}
158265

‎vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,8 @@ public void testDenseInt8Tensor() {
187187
assertEqualJson(shortJson, new String(shortEncoded, StandardCharsets.UTF_8));
188188
assertEquals(tensor, JsonFormat.decode(tensor.type(), shortEncoded));
189189

190+
assertEquals("\"02030507\"", new String(JsonFormat.encode(tensor, new JsonFormat.EncodeOptions(true, true, true)), StandardCharsets.UTF_8));
191+
190192
String longJson = """
191193
{
192194
"type":"tensor<int8>(x[2],y[2])",
@@ -481,6 +483,8 @@ public void testMixedInt8TensorWithHexForm() {
481483
String shortJson = "{\"a\": \"020304\", \"b\": \"050607\"}";
482484
decoded = JsonFormat.decode(expected.type(), shortJson.getBytes(StandardCharsets.UTF_8));
483485
assertEquals(expected, decoded);
486+
var encoded = JsonFormat.encode(decoded, new JsonFormat.EncodeOptions(true, true, true));
487+
assertEquals("{\"a\":\"020304\",\"b\":\"050607\"}", new String(encoded, StandardCharsets.UTF_8));
484488
}
485489

486490
@Test
@@ -505,6 +509,9 @@ public void testBFloat16VectorInHexForm() {
505509
String denseJson = "{\"values\":\"422849803580c37f0000800000807f7f7f80ff807fc0ffc0\"}";
506510
Tensor decoded = JsonFormat.decode(expected.type(), denseJson.getBytes(StandardCharsets.UTF_8));
507511
assertEquals(expected, decoded);
512+
var encoded = JsonFormat.encode(decoded, new JsonFormat.EncodeOptions(true, true, true));
513+
assertEquals("\"422849803580C37F0000800000807F7F7F80FF807FC0FFC0\"",
514+
new String(encoded, StandardCharsets.UTF_8));
508515
}
509516

510517
@Test
@@ -533,6 +540,13 @@ public void testFloatVectorInHexForm() {
533540
+"\"}";
534541
Tensor decoded = JsonFormat.decode(expected.type(), denseJson.getBytes(StandardCharsets.UTF_8));
535542
assertEquals(expected, decoded);
543+
var encoded = JsonFormat.encode(decoded, new JsonFormat.EncodeOptions(true, true, true));
544+
assertEquals("\""
545+
+"42280000"+"49800008"+"35800000"+"C37F0000"
546+
+"00000000"+"80000000"+"00000001"+"7F7FFFFF"
547+
+"7F800000"+"FF800000"+"7FC00000"+"FFC00000"
548+
+"\"",
549+
new String(encoded, StandardCharsets.UTF_8));
536550
}
537551

538552
@Test

0 commit comments

Comments
 (0)