Skip to content

Commit

Permalink
Move serde helpers into specific class definitions
Browse files Browse the repository at this point in the history
Signed-off-by: Finn Carroll <carrofin@amazon.com>
  • Loading branch information
finnegancarroll committed Sep 18, 2024
1 parent b13520e commit e430f8a
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 131 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ public TransportSerializationException(StreamInput in) throws IOException {
super(in);
}

public TransportSerializationException(String msg) {
super(msg);
}

public TransportSerializationException(String msg, Throwable cause) {
super(msg, cause);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import com.google.protobuf.ByteString;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.SortField;
import org.apache.lucene.util.BytesRef;
import org.opensearch.common.document.DocumentField;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.bytes.BytesArray;
Expand All @@ -21,46 +20,28 @@
import org.opensearch.core.index.Index;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.SearchSortValues;
import org.opensearch.search.fetch.subphase.highlight.HighlightField;
import org.opensearch.proto.search.SearchHitsProtoDef.DocumentFieldProto;
import org.opensearch.proto.search.SearchHitsProtoDef.ExplanationProto;
import org.opensearch.proto.search.SearchHitsProtoDef.HighlightFieldProto;
import org.opensearch.proto.search.SearchHitsProtoDef.IndexProto;
import org.opensearch.proto.search.SearchHitsProtoDef.SearchShardTargetProto;
import org.opensearch.proto.search.SearchHitsProtoDef.SearchSortValuesProto;
import org.opensearch.proto.search.SearchHitsProtoDef.ShardIdProto;
import org.opensearch.proto.search.SearchHitsProtoDef.SortValueProto;
import org.opensearch.proto.search.SearchHitsProtoDef.SortFieldProto;
import org.opensearch.proto.search.SearchHitsProtoDef.SortTypeProto;
import org.opensearch.proto.search.SearchHitsProtoDef.GenericObjectProto;
import org.opensearch.proto.search.SearchHitsProtoDef.MissingValueProto;
import org.opensearch.transport.TransportSerializationException;

import java.io.IOException;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Collection;

/**
* SerDe interfaces and protobuf SerDe implementations for some "primitive" types.
* Serialization helpers for common objects shared across multiple protobuf types.
* @opensearch.internal
*/
public class ProtoSerDeHelpers {

/**
* Serialization/Deserialization exception.
* @opensearch.internal
*/
public static class SerializationException extends RuntimeException {
public SerializationException(String message) {
super(message);
}

public SerializationException(String message, Throwable cause) {
super(message, cause);
}
}

public static GenericObjectProto genericObjectToProto(Object obj) {
GenericObjectProto.Builder builder = GenericObjectProto.newBuilder();

Expand All @@ -81,13 +62,12 @@ public static Object genericObjectFromProto(GenericObjectProto proto) {
try (StreamInput in = valuesBytes.streamInput()) {
obj = in.readGenericValue();
} catch (IOException e) {
throw new ProtoSerDeHelpers.SerializationException("Failed to deserialize DocumentField values from proto object", e);
throw new TransportSerializationException("Failed to deserialize DocumentField values from proto object", e);
}

return obj;
}


public static ExplanationProto explanationToProto(Explanation explanation) {
ExplanationProto.Builder builder = ExplanationProto.newBuilder()
.setMatch(explanation.isMatch())
Expand All @@ -103,7 +83,7 @@ public static ExplanationProto explanationToProto(Explanation explanation) {
} else if (num instanceof Float) {
builder.setFloatValue(num.floatValue());
} else {
throw new SerializationException("Unknown numeric type [" + num + "]");
throw new TransportSerializationException("Unknown numeric type [" + num + "]");
}

for (Explanation detail : explanation.getDetails()) {
Expand Down Expand Up @@ -140,6 +120,7 @@ public static Explanation explanationFromProto(ExplanationProto proto) {
}

if (proto.getMatch()) {
assert val != null;
return Explanation.match(val, description, details);
}

Expand All @@ -156,7 +137,7 @@ public static DocumentFieldProto documentFieldToProto(DocumentField field) {
return builder.build();
}

public static DocumentField documentFieldFromProto(DocumentFieldProto proto) throws ProtoSerDeHelpers.SerializationException {
public static DocumentField documentFieldFromProto(DocumentFieldProto proto) throws TransportSerializationException {
String name = proto.getName();
ArrayList<Object> values = new ArrayList<>();

Expand Down Expand Up @@ -195,102 +176,6 @@ public static HighlightField highlightFieldFromProto(HighlightFieldProto proto)
return new HighlightField(name, fragments);
}

public static SearchSortValuesProto searchSortValuesToProto(SearchSortValues searchSortValues) {
SearchSortValuesProto.Builder builder = SearchSortValuesProto.newBuilder();

for (Object value : searchSortValues.getFormattedSortValues()) {
builder.addFormattedSortValues(sortValueToProto(value));
}

for (Object value : searchSortValues.getRawSortValues()) {
builder.addRawSortValues(sortValueToProto(value));
}

return builder.build();
}

public static SearchSortValues searchSortValuesFromProto(SearchSortValuesProto proto) throws ProtoSerDeHelpers.SerializationException {
Object[] formattedSortValues = new Object[proto.getFormattedSortValuesCount()];
Object[] rawSortValues = new Object[proto.getRawSortValuesCount()];

for (int i = 0; i < formattedSortValues.length; i++) {
SortValueProto sortProto = proto.getFormattedSortValues(i);
formattedSortValues[i] = sortValueFromProto(sortProto);
}

for (int i = 0; i < rawSortValues.length; i++) {
SortValueProto sortProto = proto.getRawSortValues(i);
rawSortValues[i] = sortValueFromProto(sortProto);
}

return new SearchSortValues(formattedSortValues, rawSortValues);
}

public static SortValueProto sortValueToProto(Object sortValue) throws ProtoSerDeHelpers.SerializationException {
SortValueProto.Builder builder = SortValueProto.newBuilder();

if (sortValue == null) {
builder.setIsNull(true);
} else if (sortValue.getClass().equals(String.class)) {
builder.setStringValue((String) sortValue);
} else if (sortValue.getClass().equals(Integer.class)) {
builder.setIntValue((Integer) sortValue);
} else if (sortValue.getClass().equals(Long.class)) {
builder.setLongValue((Long) sortValue);
} else if (sortValue.getClass().equals(Float.class)) {
builder.setFloatValue((Float) sortValue);
} else if (sortValue.getClass().equals(Double.class)) {
builder.setDoubleValue((Double) sortValue);
} else if (sortValue.getClass().equals(Byte.class)) {
builder.setByteValue((Byte) sortValue);
} else if (sortValue.getClass().equals(Short.class)) {
builder.setShortValue((Short) sortValue);
} else if (sortValue.getClass().equals(Boolean.class)) {
builder.setBoolValue((Boolean) sortValue);
} else if (sortValue.getClass().equals(BytesRef.class)) {
builder.setBytesValue(ByteString.copyFrom(
((BytesRef) sortValue).bytes,
((BytesRef) sortValue).offset,
((BytesRef) sortValue).length));
} else if (sortValue.getClass().equals(BigInteger.class)) {
builder.setBigIntegerValue(sortValue.toString());
} else {
throw new ProtoSerDeHelpers.SerializationException("Unexpected sortValue: " + sortValue.toString());
}

return builder.build();
}

public static Object sortValueFromProto(SortValueProto proto) throws ProtoSerDeHelpers.SerializationException {
switch (proto.getValueCase()) {
case STRING_VALUE:
return proto.getStringValue();
case INT_VALUE:
return proto.getIntValue();
case LONG_VALUE:
return proto.getLongValue();
case FLOAT_VALUE:
return proto.getFloatValue();
case DOUBLE_VALUE:
return proto.getDoubleValue();
case BYTE_VALUE:
return (byte) proto.getByteValue();
case SHORT_VALUE:
return (short) proto.getShortValue();
case BOOL_VALUE:
return proto.getBoolValue();
case BYTES_VALUE:
ByteString byteString = proto.getBytesValue();
return new BytesRef(byteString.toByteArray());
case BIG_INTEGER_VALUE:
return new BigInteger(proto.getBigIntegerValue());
case IS_NULL:
return null;
}

throw new ProtoSerDeHelpers.SerializationException("Unexpected value case: " + proto.getValueCase());
}

public static SortFieldProto sortFieldToProto(SortField sortField) {
SortFieldProto.Builder builder = SortFieldProto.newBuilder()
.setType(sortTypeToProto(sortField.getType()))
Expand Down Expand Up @@ -355,7 +240,7 @@ public static Object missingValueFromProto(MissingValueProto proto) {
case OBJ_VAL:
return genericObjectFromProto(proto.getObjVal());
default:
throw new ProtoSerDeHelpers.SerializationException("Unexpected value case: " + proto.getValueCase());
throw new TransportSerializationException("Unexpected value case: " + proto.getValueCase());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,20 @@
package org.opensearch.transport.protobuf;

import com.google.protobuf.ByteString;
import org.apache.lucene.util.BytesRef;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.text.Text;
import org.opensearch.proto.search.SearchHitsProtoDef;
import org.opensearch.search.SearchHit;
import org.opensearch.proto.search.SearchHitsProtoDef.SearchHitProto;
import org.opensearch.proto.search.SearchHitsProtoDef.NestedIdentityProto;
import org.opensearch.search.SearchSortValues;
import org.opensearch.transport.TransportSerializationException;

import java.io.IOException;
import java.math.BigInteger;
import java.util.HashMap;

import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.documentFieldFromProto;
Expand All @@ -28,8 +33,6 @@
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.highlightFieldToProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.searchShardTargetFromProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.searchShardTargetToProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.searchSortValuesFromProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.searchSortValuesToProto;

/**
* Serialization/Deserialization implementations for SearchHit.
Expand Down Expand Up @@ -152,4 +155,100 @@ public static SearchHit.NestedIdentity nestedIdentityFromProto(NestedIdentityPro

return new SearchHit.NestedIdentity(field, offset, child);
}

public static SearchHitsProtoDef.SearchSortValuesProto searchSortValuesToProto(SearchSortValues searchSortValues) {
SearchHitsProtoDef.SearchSortValuesProto.Builder builder = SearchHitsProtoDef.SearchSortValuesProto.newBuilder();

for (Object value : searchSortValues.getFormattedSortValues()) {
builder.addFormattedSortValues(sortValueToProto(value));
}

for (Object value : searchSortValues.getRawSortValues()) {
builder.addRawSortValues(sortValueToProto(value));
}

return builder.build();
}

public static SearchSortValues searchSortValuesFromProto(SearchHitsProtoDef.SearchSortValuesProto proto) throws TransportSerializationException {
Object[] formattedSortValues = new Object[proto.getFormattedSortValuesCount()];
Object[] rawSortValues = new Object[proto.getRawSortValuesCount()];

for (int i = 0; i < formattedSortValues.length; i++) {
SearchHitsProtoDef.SortValueProto sortProto = proto.getFormattedSortValues(i);
formattedSortValues[i] = sortValueFromProto(sortProto);
}

for (int i = 0; i < rawSortValues.length; i++) {
SearchHitsProtoDef.SortValueProto sortProto = proto.getRawSortValues(i);
rawSortValues[i] = sortValueFromProto(sortProto);
}

return new SearchSortValues(formattedSortValues, rawSortValues);
}

public static SearchHitsProtoDef.SortValueProto sortValueToProto(Object sortValue) throws TransportSerializationException {
SearchHitsProtoDef.SortValueProto.Builder builder = SearchHitsProtoDef.SortValueProto.newBuilder();

if (sortValue == null) {
builder.setIsNull(true);
} else if (sortValue.getClass().equals(String.class)) {
builder.setStringValue((String) sortValue);
} else if (sortValue.getClass().equals(Integer.class)) {
builder.setIntValue((Integer) sortValue);
} else if (sortValue.getClass().equals(Long.class)) {
builder.setLongValue((Long) sortValue);
} else if (sortValue.getClass().equals(Float.class)) {
builder.setFloatValue((Float) sortValue);
} else if (sortValue.getClass().equals(Double.class)) {
builder.setDoubleValue((Double) sortValue);
} else if (sortValue.getClass().equals(Byte.class)) {
builder.setByteValue((Byte) sortValue);
} else if (sortValue.getClass().equals(Short.class)) {
builder.setShortValue((Short) sortValue);
} else if (sortValue.getClass().equals(Boolean.class)) {
builder.setBoolValue((Boolean) sortValue);
} else if (sortValue.getClass().equals(BytesRef.class)) {
builder.setBytesValue(ByteString.copyFrom(
((BytesRef) sortValue).bytes,
((BytesRef) sortValue).offset,
((BytesRef) sortValue).length));
} else if (sortValue.getClass().equals(BigInteger.class)) {
builder.setBigIntegerValue(sortValue.toString());
} else {
throw new TransportSerializationException("Unexpected sortValue: " + sortValue);
}

return builder.build();
}

public static Object sortValueFromProto(SearchHitsProtoDef.SortValueProto proto) throws TransportSerializationException {
switch (proto.getValueCase()) {
case STRING_VALUE:
return proto.getStringValue();
case INT_VALUE:
return proto.getIntValue();
case LONG_VALUE:
return proto.getLongValue();
case FLOAT_VALUE:
return proto.getFloatValue();
case DOUBLE_VALUE:
return proto.getDoubleValue();
case BYTE_VALUE:
return (byte) proto.getByteValue();
case SHORT_VALUE:
return (short) proto.getShortValue();
case BOOL_VALUE:
return proto.getBoolValue();
case BYTES_VALUE:
ByteString byteString = proto.getBytesValue();
return new BytesRef(byteString.toByteArray());
case BIG_INTEGER_VALUE:
return new BigInteger(proto.getBigIntegerValue());
case IS_NULL:
return null;
}

throw new TransportSerializationException("Unexpected value case: " + proto.getValueCase());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
import java.io.IOException;
import java.util.List;

import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.sortValueToProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.sortValueFromProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.sortFieldToProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.sortFieldFromProto;
import static org.opensearch.transport.protobuf.SearchHitProtobuf.sortValueFromProto;
import static org.opensearch.transport.protobuf.SearchHitProtobuf.sortValueToProto;

/**
* SearchHits child which implements serde operations as protobuf.
Expand Down Expand Up @@ -110,7 +110,7 @@ private TotalHits totalHitsFromProto(TotalHitsProto proto) {
long rel = proto.getRelation();
long val = proto.getValue();
if (rel < 0 || rel >= TotalHits.Relation.values().length) {
throw new ProtoSerDeHelpers.SerializationException("Failed to deserialize TotalHits from proto");
throw new TransportSerializationException("Failed to deserialize TotalHits from proto");
}
return new TotalHits(val, TotalHits.Relation.values()[(int) rel]);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.highlightFieldToProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.searchShardTargetFromProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.searchShardTargetToProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.searchSortValuesFromProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.searchSortValuesToProto;
import static org.opensearch.transport.protobuf.SearchHitProtobuf.searchSortValuesFromProto;
import static org.opensearch.transport.protobuf.SearchHitProtobuf.searchSortValuesToProto;

public class SearchHitProtobufTests extends AbstractWireSerializingTestCase<SearchHitProtobuf> {
public void testDocumentFieldProtoSerialization () {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@

import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.sortFieldFromProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.sortFieldToProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.sortValueFromProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.sortValueToProto;
import static org.opensearch.transport.protobuf.SearchHitProtobuf.sortValueFromProto;
import static org.opensearch.transport.protobuf.SearchHitProtobuf.sortValueToProto;

public class SearchHitsProtobufTests extends AbstractWireSerializingTestCase<SearchHitsProtobuf> {
public void testSortFieldProtoSerialization () {
Expand Down

0 comments on commit e430f8a

Please sign in to comment.