diff --git a/libs/core/src/main/java/org/opensearch/core/common/io/stream/BaseStreamInput.java b/libs/core/src/main/java/org/opensearch/core/common/io/stream/BaseStreamInput.java new file mode 100644 index 0000000000000..a0e50722bf01d --- /dev/null +++ b/libs/core/src/main/java/org/opensearch/core/common/io/stream/BaseStreamInput.java @@ -0,0 +1,19 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.core.common.io.stream; + +import java.io.InputStream; + +/** + * Foundation class for reading core types off the transport stream + * + * todo: refactor {@code StreamInput} primitive readers to this class + * + * @opensearch.internal + */ +public abstract class BaseStreamInput extends InputStream {} diff --git a/libs/core/src/main/java/org/opensearch/core/common/io/stream/BaseStreamOutput.java b/libs/core/src/main/java/org/opensearch/core/common/io/stream/BaseStreamOutput.java new file mode 100644 index 0000000000000..f7a8862fa5f2c --- /dev/null +++ b/libs/core/src/main/java/org/opensearch/core/common/io/stream/BaseStreamOutput.java @@ -0,0 +1,19 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.core.common.io.stream; + +import java.io.OutputStream; + +/** + * Foundation class for writing core types over the transport stream + * + * todo: refactor {@code StreamOutput} primitive writers to this class + * + * @opensearch.internal + */ +public abstract class BaseStreamOutput extends OutputStream {} diff --git a/libs/core/src/main/java/org/opensearch/core/common/io/stream/BaseWriteable.java b/libs/core/src/main/java/org/opensearch/core/common/io/stream/BaseWriteable.java new file mode 100644 index 0000000000000..56172e7c6a50e --- /dev/null +++ b/libs/core/src/main/java/org/opensearch/core/common/io/stream/BaseWriteable.java @@ -0,0 +1,130 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.core.common.io.stream; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Implementers can be written to a {@code StreamOutput} and read from a {@code StreamInput}. This allows them to be "thrown + * across the wire" using OpenSearch's internal protocol. If the implementer also implements equals and hashCode then a copy made by + * serializing and deserializing must be equal and have the same hashCode. It isn't required that such a copy be entirely unchanged. + * + * @opensearch.internal + */ +public interface BaseWriteable { + /** + * A WriteableRegistry registers {@link Writer} methods for writing data types over a + * {@link BaseStreamOutput} channel and {@link Reader} methods for reading data from a + * {@link BaseStreamInput} channel. + * + * @opensearch.internal + */ + class WriteableRegistry { + private static final Map, Writer> WRITER_REGISTRY = new ConcurrentHashMap<>(); + private static final Map> READER_REGISTRY = new ConcurrentHashMap<>(); + + /** + * registers a streamable writer + * + * @opensearch.internal + */ + public static > void registerWriter(final Class clazz, final W writer) { + if (WRITER_REGISTRY.containsKey(clazz)) { + throw new IllegalArgumentException("Streamable writer already registered for type [" + clazz.getName() + "]"); + } + WRITER_REGISTRY.put(clazz, writer); + } + + /** + * registers a streamable reader + * + * @opensearch.internal + */ + public static > void registerReader(final byte ordinal, final R reader) { + if (READER_REGISTRY.containsKey(ordinal)) { + throw new IllegalArgumentException("Streamable reader already registered for ordinal [" + (int) ordinal + "]"); + } + READER_REGISTRY.put(ordinal, reader); + } + + /** + * Returns the registered writer keyed by the class type + */ + @SuppressWarnings("unchecked") + public static > W getWriter(final Class clazz) { + return (W) WRITER_REGISTRY.get(clazz); + } + + /** + * Returns the ristered reader keyed by the unique ordinal + */ + @SuppressWarnings("unchecked") + public static > R getReader(final byte b) { + return (R) READER_REGISTRY.get(b); + } + } + + /** + * Write this into the {@linkplain BaseStreamOutput}. + */ + void writeTo(final S out) throws IOException; + + /** + * Reference to a method that can write some object to a {@link BaseStreamOutput}. + *

+ * By convention this is a method from {@link BaseStreamOutput} itself (e.g., {@code StreamOutput#writeString}). If the value can be + * {@code null}, then the "optional" variant of methods should be used! + *

+ * Most classes should implement {@code Writeable} and the {@code Writeable#writeTo(BaseStreamOutput)} method should use + * {@link BaseStreamOutput} methods directly or this indirectly: + *


+     * public void writeTo(StreamOutput out) throws IOException {
+     *     out.writeVInt(someValue);
+     *     out.writeMapOfLists(someMap, StreamOutput::writeString, StreamOutput::writeString);
+     * }
+     * 
+ */ + @FunctionalInterface + interface Writer { + + /** + * Write {@code V}-type {@code value} to the {@code out}put stream. + * + * @param out Output to write the {@code value} too + * @param value The value to add + */ + void write(final S out, V value) throws IOException; + } + + /** + * Reference to a method that can read some object from a stream. By convention this is a constructor that takes + * {@linkplain BaseStreamInput} as an argument for most classes and a static method for things like enums. Returning null from one of these + * is always wrong - for that we use methods like {@code StreamInput#readOptionalWriteable(Reader)}. + *

+ * As most classes will implement this via a constructor (or a static method in the case of enumerations), it's something that should + * look like: + *


+     * public MyClass(final StreamInput in) throws IOException {
+     *     this.someValue = in.readVInt();
+     *     this.someMap = in.readMapOfLists(StreamInput::readString, StreamInput::readString);
+     * }
+     * 
+ */ + @FunctionalInterface + interface Reader { + + /** + * Read {@code V}-type value from a stream. + * + * @param in Input to read the value from + */ + V read(final S in) throws IOException; + } +} diff --git a/libs/core/src/main/java/org/opensearch/core/common/io/stream/package-info.java b/libs/core/src/main/java/org/opensearch/core/common/io/stream/package-info.java new file mode 100644 index 0000000000000..76d0842466b96 --- /dev/null +++ b/libs/core/src/main/java/org/opensearch/core/common/io/stream/package-info.java @@ -0,0 +1,9 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +/** Core transport stream classes */ +package org.opensearch.core.common.io.stream; diff --git a/modules/geo/src/main/java/org/opensearch/geo/search/aggregations/bucket/geogrid/BaseGeoGrid.java b/modules/geo/src/main/java/org/opensearch/geo/search/aggregations/bucket/geogrid/BaseGeoGrid.java index f8cfcabc10593..9b5ed7777204e 100644 --- a/modules/geo/src/main/java/org/opensearch/geo/search/aggregations/bucket/geogrid/BaseGeoGrid.java +++ b/modules/geo/src/main/java/org/opensearch/geo/search/aggregations/bucket/geogrid/BaseGeoGrid.java @@ -33,7 +33,6 @@ import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.common.io.stream.Writeable; import org.opensearch.common.util.LongObjectPagedHashMap; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.search.aggregations.InternalAggregation; @@ -69,7 +68,7 @@ protected BaseGeoGrid(String name, int requiredSize, List buc this.buckets = buckets; } - protected abstract Writeable.Reader getBucketReader(); + protected abstract Reader getBucketReader(); /** * Read from a stream. diff --git a/modules/geo/src/main/java/org/opensearch/geo/search/aggregations/bucket/geogrid/GeoHashGrid.java b/modules/geo/src/main/java/org/opensearch/geo/search/aggregations/bucket/geogrid/GeoHashGrid.java index aa1d5504ad24f..9b6713ac033ae 100644 --- a/modules/geo/src/main/java/org/opensearch/geo/search/aggregations/bucket/geogrid/GeoHashGrid.java +++ b/modules/geo/src/main/java/org/opensearch/geo/search/aggregations/bucket/geogrid/GeoHashGrid.java @@ -76,7 +76,7 @@ protected InternalGeoHashGridBucket createBucket(long hashAsLong, long docCount, } @Override - protected Reader getBucketReader() { + protected Reader getBucketReader() { return InternalGeoHashGridBucket::new; } diff --git a/modules/geo/src/main/java/org/opensearch/geo/search/aggregations/bucket/geogrid/GeoTileGrid.java b/modules/geo/src/main/java/org/opensearch/geo/search/aggregations/bucket/geogrid/GeoTileGrid.java index 91c523c80855e..bf45080759a07 100644 --- a/modules/geo/src/main/java/org/opensearch/geo/search/aggregations/bucket/geogrid/GeoTileGrid.java +++ b/modules/geo/src/main/java/org/opensearch/geo/search/aggregations/bucket/geogrid/GeoTileGrid.java @@ -76,7 +76,7 @@ protected InternalGeoTileGridBucket createBucket(long hashAsLong, long docCount, } @Override - protected Reader getBucketReader() { + protected Reader getBucketReader() { return InternalGeoTileGridBucket::new; } diff --git a/server/src/main/java/org/opensearch/common/geo/GeoBoundingBox.java b/server/src/main/java/org/opensearch/common/geo/GeoBoundingBox.java index 8e3db45854f28..9609c10b6614f 100644 --- a/server/src/main/java/org/opensearch/common/geo/GeoBoundingBox.java +++ b/server/src/main/java/org/opensearch/common/geo/GeoBoundingBox.java @@ -79,8 +79,8 @@ public GeoBoundingBox(GeoPoint topLeft, GeoPoint bottomRight) { } public GeoBoundingBox(StreamInput input) throws IOException { - this.topLeft = input.readGeoPoint(); - this.bottomRight = input.readGeoPoint(); + this.topLeft = new GeoPoint(input); + this.bottomRight = new GeoPoint(input); } public boolean isUnbounded() { @@ -164,8 +164,8 @@ public boolean pointInBounds(double lon, double lat) { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeGeoPoint(topLeft); - out.writeGeoPoint(bottomRight); + topLeft.writeTo(out); + bottomRight.writeTo(out); } @Override diff --git a/server/src/main/java/org/opensearch/common/geo/GeoPoint.java b/server/src/main/java/org/opensearch/common/geo/GeoPoint.java index b5c2d6a846f92..874f0ffb80be1 100644 --- a/server/src/main/java/org/opensearch/common/geo/GeoPoint.java +++ b/server/src/main/java/org/opensearch/common/geo/GeoPoint.java @@ -40,6 +40,11 @@ import org.apache.lucene.util.BytesRef; import org.opensearch.OpenSearchParseException; import org.opensearch.common.geo.GeoUtils.EffectivePoint; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.BaseWriteable.Reader; +import org.opensearch.core.common.io.stream.BaseWriteable.Writer; +import org.opensearch.core.common.io.stream.BaseWriteable.WriteableRegistry; import org.opensearch.core.xcontent.ToXContentFragment; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.geometry.Geometry; @@ -87,6 +92,22 @@ public GeoPoint(GeoPoint template) { this(template.getLat(), template.getLon()); } + public GeoPoint(final StreamInput in) throws IOException { + this.lat = in.readDouble(); + this.lon = in.readDouble(); + } + + /** + * Register this type as a streamable so it can be serialized over the wire + */ + public static void registerStreamables() { + WriteableRegistry.>registerWriter(GeoPoint.class, (o, v) -> { + o.writeByte((byte) 22); + ((GeoPoint) v).writeTo(o); + }); + WriteableRegistry.>registerReader(Byte.valueOf((byte) 22), GeoPoint::new); + } + public GeoPoint reset(double lat, double lon) { this.lat = lat; this.lon = lon; @@ -210,6 +231,11 @@ public GeoPoint resetFromGeoHash(long geohashLong) { return this.resetFromIndexHash(BitUtil.flipFlop((geohashLong >>> 4) << ((level * 5) + 2))); } + public void writeTo(final StreamOutput out) throws IOException { + out.writeDouble(this.lat); + out.writeDouble(this.lon); + } + public double lat() { return this.lat; } diff --git a/server/src/main/java/org/opensearch/common/io/stream/StreamInput.java b/server/src/main/java/org/opensearch/common/io/stream/StreamInput.java index f1b4ffe2219aa..2b51d6c469fcf 100644 --- a/server/src/main/java/org/opensearch/common/io/stream/StreamInput.java +++ b/server/src/main/java/org/opensearch/common/io/stream/StreamInput.java @@ -50,20 +50,18 @@ import org.opensearch.common.Strings; import org.opensearch.common.bytes.BytesArray; import org.opensearch.common.bytes.BytesReference; -import org.opensearch.common.geo.GeoPoint; import org.opensearch.common.settings.SecureString; import org.opensearch.common.text.Text; -import org.opensearch.common.time.DateUtils; import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.io.stream.BaseStreamInput; +import org.opensearch.core.common.io.stream.BaseWriteable; import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; -import org.opensearch.script.JodaCompatibleZonedDateTime; import java.io.ByteArrayInputStream; import java.io.EOFException; import java.io.FileNotFoundException; import java.io.FilterInputStream; import java.io.IOException; -import java.io.InputStream; import java.math.BigInteger; import java.nio.file.AccessDeniedException; import java.nio.file.AtomicMoveNotSupportedException; @@ -108,7 +106,7 @@ * * @opensearch.internal */ -public abstract class StreamInput extends InputStream { +public abstract class StreamInput extends BaseStreamInput { private Version version = Version.CURRENT; @@ -686,6 +684,11 @@ public Map readMap() throws IOException { @Nullable public Object readGenericValue() throws IOException { byte type = readByte(); + BaseWriteable.Reader r = BaseWriteable.WriteableRegistry.getReader(type); + if (r != null) { + return r.read(this); + } + switch (type) { case -1: return null; @@ -715,8 +718,6 @@ public Object readGenericValue() throws IOException { return readByte(); case 12: return readDate(); - case 13: - return readDateTime(); case 14: return readBytesReference(); case 15: @@ -733,8 +734,6 @@ public Object readGenericValue() throws IOException { return readDoubleArray(); case 21: return readBytesRef(); - case 22: - return readGeoPoint(); case 23: return readZonedDateTime(); case 24: @@ -778,14 +777,6 @@ private List readArrayList() throws IOException { return list; } - private JodaCompatibleZonedDateTime readDateTime() throws IOException { - // we reuse DateTime to communicate with older nodes that don't know about the joda compat layer, but - // here we are on a new node so we always want a compat datetime - final ZoneId zoneId = DateUtils.dateTimeZoneToZoneId(DateTimeZone.forID(readString())); - long millis = readLong(); - return new JodaCompatibleZonedDateTime(Instant.ofEpochMilli(millis), zoneId); - } - private ZonedDateTime readZonedDateTime() throws IOException { final String timeZoneId = readString(); return ZonedDateTime.ofInstant(Instant.ofEpochMilli(readLong()), ZoneId.of(timeZoneId)); @@ -833,13 +824,6 @@ private Date readDate() throws IOException { return new Date(readLong()); } - /** - * Reads a {@link GeoPoint} from this stream input - */ - public GeoPoint readGeoPoint() throws IOException { - return new GeoPoint(readDouble(), readDouble()); - } - /** * Read a {@linkplain DateTimeZone}. */ @@ -1181,7 +1165,7 @@ public C readOptionalNamedWriteable(Class category * @return the list of objects * @throws IOException if an I/O exception occurs reading the list */ - public List readList(final Writeable.Reader reader) throws IOException { + public List readList(final BaseWriteable.Reader reader) throws IOException { return readCollection(reader, ArrayList::new, Collections.emptyList()); } @@ -1223,8 +1207,11 @@ public Set readSet(Writeable.Reader reader) throws IOException { /** * Reads a collection of objects */ - private > C readCollection(Writeable.Reader reader, IntFunction constructor, C empty) - throws IOException { + private > C readCollection( + BaseWriteable.Reader reader, + IntFunction constructor, + C empty + ) throws IOException { int count = readArraySize(); if (count == 0) { return empty; diff --git a/server/src/main/java/org/opensearch/common/io/stream/StreamOutput.java b/server/src/main/java/org/opensearch/common/io/stream/StreamOutput.java index da17cddfcf97b..b0f4f6c8a6139 100644 --- a/server/src/main/java/org/opensearch/common/io/stream/StreamOutput.java +++ b/server/src/main/java/org/opensearch/common/io/stream/StreamOutput.java @@ -51,18 +51,17 @@ import org.opensearch.common.Nullable; import org.opensearch.common.bytes.BytesArray; import org.opensearch.common.bytes.BytesReference; -import org.opensearch.common.geo.GeoPoint; import org.opensearch.common.io.stream.Writeable.Writer; import org.opensearch.common.settings.SecureString; import org.opensearch.common.text.Text; import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.io.stream.BaseStreamOutput; +import org.opensearch.core.common.io.stream.BaseWriteable; import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; -import org.opensearch.script.JodaCompatibleZonedDateTime; import java.io.EOFException; import java.io.FileNotFoundException; import java.io.IOException; -import java.io.OutputStream; import java.math.BigInteger; import java.nio.file.AccessDeniedException; import java.nio.file.AtomicMoveNotSupportedException; @@ -103,7 +102,7 @@ * * @opensearch.internal */ -public abstract class StreamOutput extends OutputStream { +public abstract class StreamOutput extends BaseStreamOutput { private static final int MAX_NESTED_EXCEPTION_LEVEL = 100; @@ -663,10 +662,10 @@ public final void writeOptionalInstant(@Nullable Instant instant) throws IOExcep } } - private static final Map, Writer> WRITERS; + private static final Map, BaseWriteable.Writer> WRITERS; static { - Map, Writer> writers = new HashMap<>(); + Map, BaseWriteable.Writer> writers = new HashMap<>(); writers.put(String.class, (o, v) -> { o.writeByte((byte) 0); o.writeString((String) v); @@ -773,25 +772,12 @@ public final void writeOptionalInstant(@Nullable Instant instant) throws IOExcep o.writeByte((byte) 21); o.writeBytesRef((BytesRef) v); }); - writers.put(GeoPoint.class, (o, v) -> { - o.writeByte((byte) 22); - o.writeGeoPoint((GeoPoint) v); - }); writers.put(ZonedDateTime.class, (o, v) -> { o.writeByte((byte) 23); final ZonedDateTime zonedDateTime = (ZonedDateTime) v; o.writeString(zonedDateTime.getZone().getId()); o.writeLong(zonedDateTime.toInstant().toEpochMilli()); }); - writers.put(JodaCompatibleZonedDateTime.class, (o, v) -> { - // write the joda compatibility datetime as joda datetime - o.writeByte((byte) 13); - final JodaCompatibleZonedDateTime zonedDateTime = (JodaCompatibleZonedDateTime) v; - String zoneId = zonedDateTime.getZonedDateTime().getZone().getId(); - // joda does not understand "Z" for utc, so we must special case - o.writeString(zoneId.equals("Z") ? DateTimeZone.UTC.getID() : zoneId); - o.writeLong(zonedDateTime.toInstant().toEpochMilli()); - }); writers.put(Set.class, (o, v) -> { if (v instanceof LinkedHashSet) { o.writeByte((byte) 24); @@ -838,7 +824,12 @@ public void writeGenericValue(@Nullable Object value) throws IOException { return; } final Class type = getGenericType(value); - final Writer writer = WRITERS.get(type); + BaseWriteable.Writer writer = BaseWriteable.WriteableRegistry.getWriter(type); + if (writer == null) { + // fallback to this local hashmap + // todo: move all writers to the registry + writer = WRITERS.get(type); + } if (writer != null) { writer.write(this, value); } else { @@ -1145,14 +1136,6 @@ public void writeOptionalNamedWriteable(@Nullable NamedWriteable namedWriteable) } } - /** - * Writes the given {@link GeoPoint} to the stream - */ - public void writeGeoPoint(GeoPoint geoPoint) throws IOException { - writeDouble(geoPoint.lat()); - writeDouble(geoPoint.lon()); - } - /** * Write a {@linkplain DateTimeZone} to the stream. */ @@ -1193,7 +1176,7 @@ public void writeOptionalZoneId(@Nullable ZoneId timeZone) throws IOException { /** * Writes a collection to this stream. The corresponding collection can be read from a stream input using - * {@link StreamInput#readList(Writeable.Reader)}. + * {@link StreamInput#readList(BaseWriteable.Reader)}. * * @param collection the collection to write to this stream * @throws IOException if an I/O exception occurs writing the collection @@ -1224,7 +1207,7 @@ public void writeCollection(final Collection collection, final Writer /** * Writes a collection of a strings. The corresponding collection can be read from a stream input using - * {@link StreamInput#readList(Writeable.Reader)}. + * {@link StreamInput#readList(BaseWriteable.Reader)}. * * @param collection the collection of strings * @throws IOException if an I/O exception occurs writing the collection @@ -1235,7 +1218,7 @@ public void writeStringCollection(final Collection collection) throws IO /** * Writes an optional collection of a strings. The corresponding collection can be read from a stream input using - * {@link StreamInput#readList(Writeable.Reader)}. + * {@link StreamInput#readList(BaseWriteable.Reader)}. * * @param collection the collection of strings * @throws IOException if an I/O exception occurs writing the collection diff --git a/server/src/main/java/org/opensearch/common/io/stream/Writeable.java b/server/src/main/java/org/opensearch/common/io/stream/Writeable.java index 5fd227db6ca83..c04cd7977fdc0 100644 --- a/server/src/main/java/org/opensearch/common/io/stream/Writeable.java +++ b/server/src/main/java/org/opensearch/common/io/stream/Writeable.java @@ -32,6 +32,8 @@ package org.opensearch.common.io.stream; +import org.opensearch.core.common.io.stream.BaseWriteable; + import java.io.IOException; /** @@ -41,7 +43,7 @@ * * @opensearch.internal */ -public interface Writeable { +public interface Writeable extends BaseWriteable { /** * Write this into the {@linkplain StreamOutput}. @@ -64,17 +66,7 @@ public interface Writeable { * */ @FunctionalInterface - interface Writer { - - /** - * Write {@code V}-type {@code value} to the {@code out}put stream. - * - * @param out Output to write the {@code value} too - * @param value The value to add - */ - void write(StreamOutput out, V value) throws IOException; - - } + interface Writer extends BaseWriteable.Writer {} /** * Reference to a method that can read some object from a stream. By convention this is a constructor that takes @@ -91,15 +83,6 @@ interface Writer { * */ @FunctionalInterface - interface Reader { - - /** - * Read {@code V}-type value from a stream. - * - * @param in Input to read the value from - */ - V read(StreamInput in) throws IOException; - - } + interface Reader extends BaseWriteable.Reader {} } diff --git a/server/src/main/java/org/opensearch/index/query/GeoDistanceQueryBuilder.java b/server/src/main/java/org/opensearch/index/query/GeoDistanceQueryBuilder.java index e3ad2781f5546..bc52d8fe6a6df 100644 --- a/server/src/main/java/org/opensearch/index/query/GeoDistanceQueryBuilder.java +++ b/server/src/main/java/org/opensearch/index/query/GeoDistanceQueryBuilder.java @@ -113,7 +113,7 @@ public GeoDistanceQueryBuilder(StreamInput in) throws IOException { fieldName = in.readString(); distance = in.readDouble(); validationMethod = GeoValidationMethod.readFromStream(in); - center = in.readGeoPoint(); + center = new GeoPoint(in); geoDistance = GeoDistance.readFromStream(in); ignoreUnmapped = in.readBoolean(); } @@ -123,7 +123,7 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(fieldName); out.writeDouble(distance); validationMethod.writeTo(out); - out.writeGeoPoint(center); + center.writeTo(out); geoDistance.writeTo(out); out.writeBoolean(ignoreUnmapped); } diff --git a/server/src/main/java/org/opensearch/index/query/GeoPolygonQueryBuilder.java b/server/src/main/java/org/opensearch/index/query/GeoPolygonQueryBuilder.java index e0c7d44a08fac..f5e40fec78c25 100644 --- a/server/src/main/java/org/opensearch/index/query/GeoPolygonQueryBuilder.java +++ b/server/src/main/java/org/opensearch/index/query/GeoPolygonQueryBuilder.java @@ -114,7 +114,7 @@ public GeoPolygonQueryBuilder(StreamInput in) throws IOException { int size = in.readVInt(); shell = new ArrayList<>(size); for (int i = 0; i < size; i++) { - shell.add(in.readGeoPoint()); + shell.add(new GeoPoint(in)); } validationMethod = GeoValidationMethod.readFromStream(in); ignoreUnmapped = in.readBoolean(); @@ -125,7 +125,7 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(fieldName); out.writeVInt(shell.size()); for (GeoPoint point : shell) { - out.writeGeoPoint(point); + point.writeTo(out); } validationMethod.writeTo(out); out.writeBoolean(ignoreUnmapped); diff --git a/server/src/main/java/org/opensearch/script/JodaCompatibleZonedDateTime.java b/server/src/main/java/org/opensearch/script/JodaCompatibleZonedDateTime.java index 08306b3f275a8..8d14a4fae992d 100644 --- a/server/src/main/java/org/opensearch/script/JodaCompatibleZonedDateTime.java +++ b/server/src/main/java/org/opensearch/script/JodaCompatibleZonedDateTime.java @@ -32,8 +32,15 @@ package org.opensearch.script; +import org.joda.time.DateTimeZone; import org.opensearch.common.SuppressForbidden; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.BaseWriteable.Reader; +import org.opensearch.core.common.io.stream.BaseWriteable.Writer; import org.opensearch.common.time.DateFormatter; +import org.opensearch.common.time.DateUtils; +import org.opensearch.core.common.io.stream.BaseWriteable.WriteableRegistry; import java.time.DayOfWeek; import java.time.Instant; @@ -77,6 +84,26 @@ public JodaCompatibleZonedDateTime(Instant instant, ZoneId zone) { this.dt = ZonedDateTime.ofInstant(instant, zone); } + /** + * Register this type as a streamable so it can be serialized over the wire + */ + public static void registerStreamables() { + WriteableRegistry.>registerWriter(JodaCompatibleZonedDateTime.class, (o, v) -> { + // write the joda compatibility datetime as joda datetime + o.writeByte((byte) 13); + final JodaCompatibleZonedDateTime zonedDateTime = (JodaCompatibleZonedDateTime) v; + String zoneId = zonedDateTime.getZonedDateTime().getZone().getId(); + // joda does not understand "Z" for utc, so we must special case + o.writeString(zoneId.equals("Z") ? DateTimeZone.UTC.getID() : zoneId); + o.writeLong(zonedDateTime.toInstant().toEpochMilli()); + }); + WriteableRegistry.>registerReader(Byte.valueOf((byte) 13), (i) -> { + final ZoneId zoneId = DateUtils.dateTimeZoneToZoneId(DateTimeZone.forID(i.readString())); + long millis = i.readLong(); + return new JodaCompatibleZonedDateTime(Instant.ofEpochMilli(millis), zoneId); + }); + } + // access the underlying ZonedDateTime public ZonedDateTime getZonedDateTime() { return dt; diff --git a/server/src/main/java/org/opensearch/transport/TransportService.java b/server/src/main/java/org/opensearch/transport/TransportService.java index 2b5dee229b8cb..e96f1d11c89d0 100644 --- a/server/src/main/java/org/opensearch/transport/TransportService.java +++ b/server/src/main/java/org/opensearch/transport/TransportService.java @@ -44,6 +44,7 @@ import org.opensearch.common.Nullable; import org.opensearch.common.Strings; import org.opensearch.common.component.AbstractLifecycleComponent; +import org.opensearch.common.geo.GeoPoint; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.io.stream.Writeable; @@ -61,6 +62,7 @@ import org.opensearch.common.util.io.IOUtils; import org.opensearch.node.NodeClosedException; import org.opensearch.node.ReportingService; +import org.opensearch.script.JodaCompatibleZonedDateTime; import org.opensearch.tasks.Task; import org.opensearch.tasks.TaskManager; import org.opensearch.threadpool.Scheduler; @@ -161,6 +163,14 @@ public boolean isClosed() { public void close() {} }; + static { + // registers server specific streamables + registerStreamables(); + } + + /** does nothing. easy way to ensure class is loaded */ + public static void ensureClassloaded() {} + /** * Build the service. * @@ -231,6 +241,15 @@ public TransportService( ); } + /** + * Registers server specific types as a streamables for serialization + * over the {@link StreamOutput} and {@link StreamInput} wire + */ + private static void registerStreamables() { + JodaCompatibleZonedDateTime.registerStreamables(); + GeoPoint.registerStreamables(); + } + public RemoteClusterService getRemoteClusterService() { return remoteClusterService; } diff --git a/server/src/test/java/org/opensearch/common/io/stream/BaseStreamTests.java b/server/src/test/java/org/opensearch/common/io/stream/BaseStreamTests.java index b92e59e43e0db..bd970be5e977d 100644 --- a/server/src/test/java/org/opensearch/common/io/stream/BaseStreamTests.java +++ b/server/src/test/java/org/opensearch/common/io/stream/BaseStreamTests.java @@ -40,12 +40,14 @@ import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.collect.Tuple; import org.opensearch.common.settings.SecureString; +import org.opensearch.script.JodaCompatibleZonedDateTime; import org.opensearch.test.OpenSearchTestCase; import java.io.ByteArrayInputStream; import java.io.EOFException; import java.io.IOException; import java.time.Instant; +import java.time.ZoneOffset; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -396,6 +398,18 @@ public void testOptionalInstantSerialization() throws IOException { } } + public void testJodaDateTimeSerialization() throws IOException { + final BytesStreamOutput output = new BytesStreamOutput(); + long millis = randomIntBetween(0, Integer.MAX_VALUE); + JodaCompatibleZonedDateTime time = new JodaCompatibleZonedDateTime(Instant.ofEpochMilli(millis), ZoneOffset.ofHours(-7)); + output.writeGenericValue(time); + + final BytesReference bytesReference = output.bytes(); + final StreamInput input = getStreamInput(bytesReference); + Object inTime = input.readGenericValue(); + assertEquals(time, inTime); + } + static final class WriteableString implements Writeable { final String string; diff --git a/server/src/test/java/org/opensearch/common/io/stream/BytesStreamsTests.java b/server/src/test/java/org/opensearch/common/io/stream/BytesStreamsTests.java index bff35ad9fc975..36b84560bce41 100644 --- a/server/src/test/java/org/opensearch/common/io/stream/BytesStreamsTests.java +++ b/server/src/test/java/org/opensearch/common/io/stream/BytesStreamsTests.java @@ -639,9 +639,9 @@ public void testReadWriteGeoPoint() throws IOException { try (BytesStreamOutput out = new BytesStreamOutput()) { GeoPoint geoPoint = new GeoPoint(randomDouble(), randomDouble()); - out.writeGeoPoint(geoPoint); + geoPoint.writeTo(out); StreamInput wrap = out.bytes().streamInput(); - GeoPoint point = wrap.readGeoPoint(); + GeoPoint point = new GeoPoint(wrap); assertEquals(point, geoPoint); } } diff --git a/test/framework/src/main/java/org/opensearch/test/OpenSearchTestCase.java b/test/framework/src/main/java/org/opensearch/test/OpenSearchTestCase.java index aed9064c9b09f..ca416cafe8ea4 100644 --- a/test/framework/src/main/java/org/opensearch/test/OpenSearchTestCase.java +++ b/test/framework/src/main/java/org/opensearch/test/OpenSearchTestCase.java @@ -129,6 +129,7 @@ import org.opensearch.test.junit.listeners.LoggingListener; import org.opensearch.test.junit.listeners.ReproduceInfoPrinter; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; import org.opensearch.transport.nio.MockNioTransportPlugin; import org.joda.time.DateTimeZone; import org.junit.After; @@ -258,6 +259,7 @@ public void append(LogEvent event) { })); BootstrapForTesting.ensureInitialized(); + TransportService.ensureClassloaded(); // ensure server streamables are registered // filter out joda timezones that are deprecated for the java time migration List jodaTZIds = DateTimeZone.getAvailableIDs()