Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wire hex encoding for tensors #33233

Merged
merged 1 commit into from
Feb 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions container-search/abi-spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -5443,6 +5443,8 @@
"public java.util.Set getSummaryFields()",
"public void setSummaryFields(java.lang.String)",
"public boolean getTensorShortForm()",
"public boolean getTensorHexDense()",
"public java.lang.String getTensorFormat()",
"public void setTensorShortForm(java.lang.String)",
"public void setTensorFormat(java.lang.String)",
"public void setTensorShortForm(boolean)",
Expand Down Expand Up @@ -8067,6 +8069,7 @@
"public java.lang.String toJson()",
"public java.lang.String toJson(boolean)",
"public java.lang.String toJson(boolean, boolean)",
"public java.lang.String toJson(com.yahoo.tensor.serialization.JsonFormat$EncodeOptions)",
"public java.lang.StringBuilder writeJson(java.lang.StringBuilder)",
"public java.lang.Double getDouble(java.lang.String)",
"public com.yahoo.tensor.Tensor getTensor(java.lang.String)",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ public class Presentation implements Cloneable {
/** Whether to renders tensors in short form */
private boolean tensorDirectValues = false; // TODO: Flip default on Vespa 9

/** Whether to dense (part of) tensors in hex string form */
private boolean tensorHexDense = false;

/** Set of explicitly requested summary fields, instead of summary classes */
private Set<String> summaryFields = LazySet.newHashSet();

Expand Down Expand Up @@ -186,6 +189,20 @@ public void setSummaryFields(String asString) {
*/
public boolean getTensorShortForm() { return tensorShortForm; }

/** whether dense part of tensors should be represented as a string of hex digits */
public boolean getTensorHexDense() { return tensorHexDense; }

/** the current tensor format, see setTensorFormat() */
public String getTensorFormat() {
String format = "long";
if (tensorShortForm) format = "short";
if (tensorHexDense) format = "hex";
if (tensorDirectValues) {
return (format + "-value");
}
return format;
}

/** @deprecated use setTensorFormat(). */
@Deprecated // TODO: Remove on Vespa 9
public void setTensorShortForm(String value) {
Expand All @@ -199,6 +216,16 @@ public void setTensorShortForm(String value) {
*/
public void setTensorFormat(String value) {
switch (value) {
case "hex" :
tensorHexDense = true;
tensorShortForm = true;
tensorDirectValues = false;
break;
case "hex-value" :
tensorHexDense = true;
tensorShortForm = true;
tensorDirectValues = true;
break;
case "short" :
tensorShortForm = true;
tensorDirectValues = false;
Expand Down Expand Up @@ -254,4 +281,3 @@ public int hashCode() {
}

}

Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ private static Map<CompoundName, GetterSetter> createPropertySetterMap() {
map.put(CompoundName.fromComponents(Presentation.PRESENTATION, Presentation.FORMAT), GetterSetter.of(query -> query.getPresentation().getFormat(), (query, value) -> query.getPresentation().setFormat(asString(value, ""))));
map.put(CompoundName.fromComponents(Presentation.PRESENTATION, Presentation.TIMING), GetterSetter.of(query -> query.getPresentation().getTiming(), (query, value) -> query.getPresentation().setTiming(asBoolean(value, true))));
map.put(CompoundName.fromComponents(Presentation.PRESENTATION, Presentation.SUMMARY_FIELDS), GetterSetter.of(query -> query.getPresentation().getSummaryFields(), (query, value) -> query.getPresentation().setSummaryFields(asString(value, ""))));
map.put(CompoundName.fromComponents(Presentation.PRESENTATION, Presentation.FORMAT, Presentation.TENSORS), GetterSetter.of(query -> query.getPresentation().getTensorShortForm(), (query, value) -> query.getPresentation().setTensorFormat(asString(value, "short")))); // TODO: Switch default to short-value on Vespa 9);
map.put(CompoundName.fromComponents(Presentation.PRESENTATION, Presentation.FORMAT, Presentation.TENSORS), GetterSetter.of(query -> query.getPresentation().getTensorFormat(), (query, value) -> query.getPresentation().setTensorFormat(asString(value, "short")))); // TODO: Switch default to short-value on Vespa 9);
map.put(Query.HITS, GetterSetter.of(Query::getHits, (query, value) -> query.setHits(asInteger(value,10))));
map.put(Query.OFFSET, GetterSetter.of(Query::getOffset, (query, value) -> query.setOffset(asInteger(value,0))));
map.put(Query.TIMEOUT, GetterSetter.of(Query::getTimeout, (query, value) -> query.setTimeout(value.toString())));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,17 +132,15 @@ static class FieldConsumerSettings {
volatile boolean jsonWsets = true;
volatile boolean jsonMapsAll = true;
volatile boolean jsonWsetsAll = false;
volatile boolean tensorShortForm = true;
volatile boolean tensorDirectValues = false;
volatile JsonFormat.EncodeOptions tensorOptions;
boolean convertDeep() { return (jsonDeepMaps || jsonWsets); }
void init() {
this.debugRendering = false;
this.jsonDeepMaps = true;
this.jsonWsets = true;
this.jsonMapsAll = true;
this.jsonWsetsAll = true;
this.tensorShortForm = true;
this.tensorDirectValues = false;
this.tensorOptions = new JsonFormat.EncodeOptions(true, false, false);
}
void getSettings(Query q) {
if (q == null) {
Expand All @@ -156,9 +154,11 @@ void getSettings(Query q) {
// we may need more fine tuning, but for now use the same query parameters here:
this.jsonMapsAll = props.getBoolean(WRAP_DEEP_MAPS, true);
this.jsonWsetsAll = props.getBoolean(WRAP_WSETS, true);
this.tensorShortForm = q.getPresentation().getTensorShortForm();
this.tensorDirectValues = q.getPresentation().getTensorDirectValues();
}
this.tensorOptions = new JsonFormat.EncodeOptions(
q.getPresentation().getTensorShortForm(),
q.getPresentation().getTensorDirectValues(),
q.getPresentation().getTensorHexDense());
}
}

private volatile FieldConsumerSettings fieldConsumerSettings;
Expand Down Expand Up @@ -560,14 +560,16 @@ public static class FieldConsumer implements Hit.RawUtf8Consumer, TraceRenderer.

/** Invoke this from your constructor when sub-classing {@link FieldConsumer} */
protected FieldConsumer(boolean debugRendering, boolean tensorShortForm, boolean jsonMaps) {
this(null, debugRendering, tensorShortForm, jsonMaps);
this(null, debugRendering, new JsonFormat.EncodeOptions(tensorShortForm, false, false), jsonMaps);
}

private FieldConsumer(JsonGenerator generator, boolean debugRendering, boolean tensorShortForm, boolean jsonMaps) {
private FieldConsumer(JsonGenerator generator, boolean debugRendering,
JsonFormat.EncodeOptions tensorOptions,
boolean jsonMaps) {
this.generator = generator;
this.settings = new FieldConsumerSettings();
this.settings.debugRendering = debugRendering;
this.settings.tensorShortForm = tensorShortForm;
this.settings.tensorOptions = tensorOptions;
this.settings.jsonDeepMaps = jsonMaps;
}

Expand Down Expand Up @@ -768,27 +770,27 @@ protected void renderFieldContents(Object field) throws IOException {
public void accept(Object field) throws IOException {
if (field == null) {
generator().writeNull();
} else if (field instanceof Boolean) {
generator().writeBoolean((Boolean)field);
} else if (field instanceof Number) {
renderNumberField((Number) field);
} else if (field instanceof TreeNode) {
generator().writeTree((TreeNode) field);
} else if (field instanceof Tensor) {
renderTensor(Optional.of((Tensor)field));
} else if (field instanceof FeatureData) {
generator().writeRawValue(((FeatureData)field).toJson(settings.tensorShortForm, settings.tensorDirectValues));
} else if (field instanceof Inspectable) {
renderInspectorDirect(((Inspectable)field).inspect());
} else if (field instanceof JsonProducer) {
generator().writeRawValue(((JsonProducer) field).toJson());
} else if (field instanceof StringFieldValue) {
generator().writeString(((StringFieldValue)field).getString());
} else if (field instanceof TensorFieldValue) {
renderTensor(((TensorFieldValue)field).getTensor());
} else if (field instanceof FieldValue) {
// the null below is the field which has already been written
((FieldValue) field).serialize(null, new JsonWriter(generator));
} else if (field instanceof Boolean bool) {
generator().writeBoolean(bool);
} else if (field instanceof Number num) {
renderNumberField(num);
} else if (field instanceof TreeNode treenode) {
generator().writeTree(treenode);
} else if (field instanceof Tensor t) {
renderTensor(Optional.of(t));
} else if (field instanceof FeatureData featureData) {
generator().writeRawValue(featureData.toJson(settings.tensorOptions));
} else if (field instanceof Inspectable i) {
renderInspectorDirect(i.inspect());
} else if (field instanceof JsonProducer jp) {
generator().writeRawValue(jp.toJson());
} else if (field instanceof StringFieldValue sfv) {
generator().writeString(sfv.getString());
} else if (field instanceof TensorFieldValue tfv) {
renderTensor(tfv.getTensor());
} else if (field instanceof FieldValue fv) {
// the null below is the field name which has already been written
fv.serialize(null, new JsonWriter(generator));
} else {
generator().writeString(field.toString());
}
Expand All @@ -797,27 +799,27 @@ public void accept(Object field) throws IOException {
private void renderNumberField(Number field) throws IOException {
if (field instanceof Integer) {
generator().writeNumber(field.intValue());
} else if (field instanceof Float) {
} else if (field instanceof Float) {
generator().writeNumber(field.floatValue());
} else if (field instanceof Double) {
} else if (field instanceof Double) {
generator().writeNumber(field.doubleValue());
} else if (field instanceof Long) {
generator().writeNumber(field.longValue());
} else if (field instanceof Byte || field instanceof Short) {
generator().writeNumber(field.intValue());
} else if (field instanceof BigInteger) {
generator().writeNumber((BigInteger) field);
} else if (field instanceof BigDecimal) {
generator().writeNumber((BigDecimal) field);
} else if (field instanceof BigInteger bigint) {
generator().writeNumber(bigint);
} else if (field instanceof BigDecimal bigdec) {
generator().writeNumber(bigdec);
} else {
generator().writeNumber(field.doubleValue());
}
}

private void renderTensor(Optional<Tensor> tensor) throws IOException {
generator().writeRawValue(new String(JsonFormat.encode(tensor.orElse(Tensor.Builder.of(TensorType.empty).build()),
settings.tensorShortForm, settings.tensorDirectValues),
StandardCharsets.UTF_8));
var t = tensor.orElse(Tensor.Builder.of(TensorType.empty).build());
byte[] json = JsonFormat.encode(t, settings.tensorOptions);
generator().writeRawValue(new String(json, StandardCharsets.UTF_8));
}

private JsonGenerator generator() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ public class FeatureData implements Inspectable, JsonProducer {
/** The lazily computed feature names of this */
private Set<String> featureNames = null;

/** The lazily computed json form of this */
private String jsonForm = null;

public FeatureData(Inspector encodedValues) {
this.encodedValues = Objects.requireNonNull(encodedValues);
}
Expand All @@ -71,40 +68,43 @@ public Inspector inspect() {

@Override
public String toJson() {
return toJson(false, false);
return toJson(new JsonFormat.EncodeOptions(false, false, false));
}

public String toJson(boolean tensorShortForm) {
return toJson(tensorShortForm, false);
return toJson(new JsonFormat.EncodeOptions(tensorShortForm, false, false));
}

public String toJson(boolean tensorShortForm, boolean tensorDirectValues) {
return writeJson(tensorShortForm, tensorDirectValues, new StringBuilder()).toString();
return toJson(new JsonFormat.EncodeOptions(tensorShortForm, tensorDirectValues, false));
}

public String toJson(JsonFormat.EncodeOptions tensorOptions) {
return writeJson(tensorOptions, new StringBuilder()).toString();
}

@Override
public StringBuilder writeJson(StringBuilder target) {
return JsonRender.render(encodedValues, new Encoder(target, true, false, false));
return writeJson(new JsonFormat.EncodeOptions(false, false, false), target);
}

private StringBuilder writeJson(boolean tensorShortForm, boolean tensorDirectValues, StringBuilder target) {
private StringBuilder writeJson(JsonFormat.EncodeOptions tensorOptions, StringBuilder target) {
if (this == empty) return target.append("{}");
if (jsonForm != null) return target.append(jsonForm);

if (encodedValues != null)
return JsonRender.render(encodedValues, new Encoder(target, true, tensorShortForm, tensorDirectValues));
return JsonRender.render(encodedValues, new Encoder(target, true, tensorOptions));
else
return writeJson(values, tensorShortForm, tensorDirectValues, target);
return writeJson(values, tensorOptions, target);
}

private StringBuilder writeJson(Map<String, Tensor> values, boolean tensorShortForm, boolean tensorDirectValues, StringBuilder target) {
private StringBuilder writeJson(Map<String, Tensor> values, JsonFormat.EncodeOptions tensorOptions, StringBuilder target) {
target.append("{");
for (Map.Entry<String, Tensor> entry : values.entrySet()) {
target.append("\"").append(entry.getKey()).append("\":");
if (entry.getValue().type().rank() == 0) {
target.append(entry.getValue().asDouble());
} else {
byte[] encodedTensor = JsonFormat.encode(entry.getValue(), tensorShortForm, tensorDirectValues);
byte[] encodedTensor = JsonFormat.encode(entry.getValue(), tensorOptions);
target.append(new String(encodedTensor, StandardCharsets.UTF_8));
}
target.append(",");
Expand Down Expand Up @@ -149,7 +149,7 @@ private Tensor decodeTensor(String featureName) {

return switch (featureValue.type()) {
case DOUBLE -> Tensor.from(featureValue.asDouble());
case DATA -> TypedBinaryFormat.decode(Optional.empty(), GrowableByteBuffer.wrap(featureValue.asData()));
case DATA -> tensorFromData(featureValue.asData());
default -> throw new IllegalStateException("Unexpected feature value type " + featureValue.type());
};
}
Expand Down Expand Up @@ -192,23 +192,24 @@ public boolean equals(Object other) {
/** A JSON encoder which encodes DATA as a tensor */
private static class Encoder extends JsonRender.StringEncoder {

private final boolean tensorShortForm;
private final boolean tensorDirectValues;
private final JsonFormat.EncodeOptions tensorOptions;

Encoder(StringBuilder out, boolean compact, boolean tensorShortForm, boolean tensorDirectValues) {
Encoder(StringBuilder out, boolean compact, JsonFormat.EncodeOptions tensorOptions) {
super(out, compact);
this.tensorShortForm = tensorShortForm;
this.tensorDirectValues = tensorDirectValues;
this.tensorOptions = tensorOptions;
}

@Override
public void encodeDATA(byte[] value) {
// This could be done more efficiently ...
Tensor tensor = TypedBinaryFormat.decode(Optional.empty(), GrowableByteBuffer.wrap(value));
byte[] encodedTensor = JsonFormat.encode(tensor, tensorShortForm, tensorDirectValues);
Tensor tensor = tensorFromData(value);
byte[] encodedTensor = JsonFormat.encode(tensor, tensorOptions);
target().append(new String(encodedTensor, StandardCharsets.UTF_8));
}

}

private static Tensor tensorFromData(byte[] value) {
return TypedBinaryFormat.decode(Optional.empty(), GrowableByteBuffer.wrap(value));
}
}
Loading