Skip to content

Commit

Permalink
Improve performance of encoding composite keys in multi-term aggregat…
Browse files Browse the repository at this point in the history
…ions (#9412)

* Improve performance of encoding composite keys in multi-term aggregations

Signed-off-by: Ketan Verma <ketan9495@gmail.com>

* Rename StreamOutput::getGenericValueWriter to getGenericValueWriterByClass

Signed-off-by: Ketan Verma <ketan9495@gmail.com>

* Rename StreamOutput::getGenericValueWriterByClass to getWriter and remove unused code

Signed-off-by: Ketan Verma <ketan9495@gmail.com>

---------

Signed-off-by: Ketan Verma <ketan9495@gmail.com>
  • Loading branch information
ketanv3 authored Aug 18, 2023
1 parent 8e95a82 commit 46f2bd0
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 61 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Add base class for parameterizing the search based tests #9083 ([#9083](https://github.com/opensearch-project/OpenSearch/pull/9083))
- Add support for wrapping CollectorManager with profiling during concurrent execution ([#9129](https://github.com/opensearch-project/OpenSearch/pull/9129))
- Rethrow OpenSearch exception for non-concurrent path while using concurrent search ([#9177](https://github.com/opensearch-project/OpenSearch/pull/9177))
- Improve performance of encoding composite keys in multi-term aggregations ([#9412](https://github.com/opensearch-project/OpenSearch/pull/9412))

### Deprecated

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,23 @@ private static Class<?> getGenericType(Object value) {
}
}

/**
* Returns the registered writer for the given class type.
*/
@SuppressWarnings("unchecked")
public static <W extends Writer<?>> W getWriter(Class<?> type) {
Writer<Object> writer = WriteableRegistry.getWriter(type);
if (writer == null) {
// fallback to this local hashmap
// todo: move all writers to the registry
writer = WRITERS.get(type);
}
if (writer == null) {
throw new IllegalArgumentException("can not write type [" + type + "]");
}
return (W) writer;
}

/**
* Notice: when serialization a map, the stream out map with the stream in map maybe have the
* different key-value orders, they will maybe have different stream order.
Expand All @@ -816,17 +833,8 @@ public void writeGenericValue(@Nullable Object value) throws IOException {
return;
}
final Class<?> type = getGenericType(value);
Writer<Object> writer = WriteableRegistry.getWriter(type);
if (writer == null) {
// fallback to this local hashmap
// todo: move all writers to the registry
writer = WRITERS.get(type);
}
if (writer != null) {
writer.write(this, value);
} else {
throw new IllegalArgumentException("can not write type [" + type + "]");
}
final Writer<Object> writer = getWriter(type);
writer.write(this, value);
}

public static void checkWriteable(@Nullable Object value) throws IllegalArgumentException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,18 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.NumericUtils;
import org.apache.lucene.util.PriorityQueue;
import org.opensearch.ExceptionsHelper;
import org.opensearch.common.CheckedSupplier;
import org.opensearch.common.Numbers;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.lease.Releasable;
import org.opensearch.common.lease.Releasables;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.index.fielddata.SortedBinaryDocValues;
import org.opensearch.index.fielddata.SortedNumericDoubleValues;
import org.opensearch.search.DocValueFormat;
Expand Down Expand Up @@ -218,8 +219,8 @@ protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucket
return new LeafBucketCollector() {
@Override
public void collect(int doc, long owningBucketOrd) throws IOException {
for (List<Object> value : collector.apply(doc)) {
long bucketOrd = bucketOrds.add(owningBucketOrd, encode(value));
for (BytesRef compositeKey : collector.apply(doc)) {
long bucketOrd = bucketOrds.add(owningBucketOrd, compositeKey);
if (bucketOrd < 0) {
bucketOrd = -1 - bucketOrd;
collectExistingBucket(sub, doc, bucketOrd);
Expand All @@ -233,16 +234,7 @@ public void collect(int doc, long owningBucketOrd) throws IOException {

@Override
protected void doClose() {
Releasables.close(bucketOrds);
}

private static BytesRef encode(List<Object> values) {
try (BytesStreamOutput output = new BytesStreamOutput()) {
output.writeCollection(values, StreamOutput::writeGenericValue);
return output.bytes().toBytesRef();
} catch (IOException e) {
throw ExceptionsHelper.convertToRuntime(e);
}
Releasables.close(bucketOrds, multiTermsValue);
}

private static List<Object> decode(BytesRef bytesRef) {
Expand Down Expand Up @@ -279,8 +271,8 @@ private void collectZeroDocEntriesIfNeeded(long owningBucketOrd) throws IOExcept
MultiTermsValuesSourceCollector collector = multiTermsValue.getValues(ctx);
// brute force
for (int docId = 0; docId < ctx.reader().maxDoc(); ++docId) {
for (List<Object> value : collector.apply(docId)) {
bucketOrds.add(owningBucketOrd, encode(value));
for (BytesRef compositeKey : collector.apply(docId)) {
bucketOrds.add(owningBucketOrd, compositeKey);
}
}
}
Expand All @@ -295,7 +287,7 @@ interface MultiTermsValuesSourceCollector {
* Collect a list values of multi_terms on each doc.
* Each terms could have multi_values, so the result is the cartesian product of each term's values.
*/
List<List<Object>> apply(int doc) throws IOException;
List<BytesRef> apply(int doc) throws IOException;
}

@FunctionalInterface
Expand All @@ -314,16 +306,56 @@ interface InternalValuesSourceCollector {
/**
* Collect a list values of a term on specific doc.
*/
List<Object> apply(int doc) throws IOException;
List<TermValue<?>> apply(int doc) throws IOException;
}

/**
* Represents an individual term value.
*/
static class TermValue<T> implements Writeable {
private static final Writer<BytesRef> BYTES_REF_WRITER = StreamOutput.getWriter(BytesRef.class);
private static final Writer<Long> LONG_WRITER = StreamOutput.getWriter(Long.class);
private static final Writer<BigInteger> BIG_INTEGER_WRITER = StreamOutput.getWriter(BigInteger.class);
private static final Writer<Double> DOUBLE_WRITER = StreamOutput.getWriter(Double.class);

private final T value;
private final Writer<T> writer;

private TermValue(T value, Writer<T> writer) {
this.value = value;
this.writer = writer;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
writer.write(out, value);
}

public static TermValue<BytesRef> of(BytesRef value) {
return new TermValue<>(value, BYTES_REF_WRITER);
}

public static TermValue<Long> of(Long value) {
return new TermValue<>(value, LONG_WRITER);
}

public static TermValue<BigInteger> of(BigInteger value) {
return new TermValue<>(value, BIG_INTEGER_WRITER);
}

public static TermValue<Double> of(Double value) {
return new TermValue<>(value, DOUBLE_WRITER);
}
}

/**
* Multi_Term ValuesSource, it is a collection of {@link InternalValuesSource}
*
* @opensearch.internal
*/
static class MultiTermsValuesSource {
static class MultiTermsValuesSource implements Releasable {
private final List<InternalValuesSource> valuesSources;
private final BytesStreamOutput scratch = new BytesStreamOutput();

public MultiTermsValuesSource(List<InternalValuesSource> valuesSources) {
this.valuesSources = valuesSources;
Expand All @@ -336,37 +368,50 @@ public MultiTermsValuesSourceCollector getValues(LeafReaderContext ctx) throws I
}
return new MultiTermsValuesSourceCollector() {
@Override
public List<List<Object>> apply(int doc) throws IOException {
List<CheckedSupplier<List<Object>, IOException>> collectedValues = new ArrayList<>();
public List<BytesRef> apply(int doc) throws IOException {
List<List<TermValue<?>>> collectedValues = new ArrayList<>();
for (InternalValuesSourceCollector collector : collectors) {
collectedValues.add(() -> collector.apply(doc));
collectedValues.add(collector.apply(doc));
}
List<List<Object>> result = new ArrayList<>();
apply(0, collectedValues, new ArrayList<>(), result);
List<BytesRef> result = new ArrayList<>();
scratch.seek(0);
scratch.writeVInt(collectors.size()); // number of fields per composite key
cartesianProduct(result, scratch, collectedValues, 0);
return result;
}

/**
* DFS traverse each term's values and add cartesian product to results lists.
* Cartesian product using depth first search.
*
* <p>
* Composite keys are encoded to a {@link BytesRef} in a format compatible with {@link StreamOutput::writeGenericValue},
* but reuses the encoding of the shared prefixes from the previous levels to avoid wasteful work.
*/
private void apply(
int index,
List<CheckedSupplier<List<Object>, IOException>> collectedValues,
List<Object> current,
List<List<Object>> results
private void cartesianProduct(
List<BytesRef> compositeKeys,
BytesStreamOutput scratch,
List<List<TermValue<?>>> collectedValues,
int index
) throws IOException {
if (index == collectedValues.size()) {
results.add(List.copyOf(current));
} else if (null != collectedValues.get(index)) {
for (Object value : collectedValues.get(index).get()) {
current.add(value);
apply(index + 1, collectedValues, current, results);
current.remove(current.size() - 1);
}
if (collectedValues.size() == index) {
compositeKeys.add(BytesRef.deepCopyOf(scratch.bytes().toBytesRef()));
return;
}

long position = scratch.position();
for (TermValue<?> value : collectedValues.get(index)) {
value.writeTo(scratch); // encode the value
cartesianProduct(compositeKeys, scratch, collectedValues, index + 1); // dfs
scratch.seek(position); // backtrack
}
}
};
}

@Override
public void close() {
scratch.close();
}
}

/**
Expand All @@ -379,27 +424,26 @@ static InternalValuesSource bytesValuesSource(ValuesSource valuesSource, Include
return ctx -> {
SortedBinaryDocValues values = valuesSource.bytesValues(ctx);
return doc -> {
BytesRefBuilder previous = new BytesRefBuilder();

if (false == values.advanceExact(doc)) {
return Collections.emptyList();
}
int valuesCount = values.docValueCount();
List<Object> termValues = new ArrayList<>(valuesCount);
List<TermValue<?>> termValues = new ArrayList<>(valuesCount);

// SortedBinaryDocValues don't guarantee uniqueness so we
// need to take care of dups
previous.clear();
BytesRef previous = null;
for (int i = 0; i < valuesCount; ++i) {
BytesRef bytes = values.nextValue();
if (includeExclude != null && false == includeExclude.accept(bytes)) {
continue;
}
if (i > 0 && previous.get().equals(bytes)) {
if (i > 0 && bytes.equals(previous)) {
continue;
}
previous.copyBytes(bytes);
termValues.add(BytesRef.deepCopyOf(bytes));
BytesRef copy = BytesRef.deepCopyOf(bytes);
termValues.add(TermValue.of(copy));
previous = copy;
}
return termValues;
};
Expand All @@ -414,12 +458,12 @@ static InternalValuesSource unsignedLongValuesSource(ValuesSource.Numeric values
int valuesCount = values.docValueCount();

BigInteger previous = Numbers.MAX_UNSIGNED_LONG_VALUE;
List<Object> termValues = new ArrayList<>(valuesCount);
List<TermValue<?>> termValues = new ArrayList<>(valuesCount);
for (int i = 0; i < valuesCount; ++i) {
BigInteger val = Numbers.toUnsignedBigInteger(values.nextValue());
if (previous.compareTo(val) != 0 || i == 0) {
if (longFilter == null || longFilter.accept(NumericUtils.doubleToSortableLong(val.doubleValue()))) {
termValues.add(val);
termValues.add(TermValue.of(val));
}
previous = val;
}
Expand All @@ -439,12 +483,12 @@ static InternalValuesSource longValuesSource(ValuesSource.Numeric valuesSource,
int valuesCount = values.docValueCount();

long previous = Long.MAX_VALUE;
List<Object> termValues = new ArrayList<>(valuesCount);
List<TermValue<?>> termValues = new ArrayList<>(valuesCount);
for (int i = 0; i < valuesCount; ++i) {
long val = values.nextValue();
if (previous != val || i == 0) {
if (longFilter == null || longFilter.accept(val)) {
termValues.add(val);
termValues.add(TermValue.of(val));
}
previous = val;
}
Expand All @@ -464,12 +508,12 @@ static InternalValuesSource doubleValueSource(ValuesSource.Numeric valuesSource,
int valuesCount = values.docValueCount();

double previous = Double.MAX_VALUE;
List<Object> termValues = new ArrayList<>(valuesCount);
List<TermValue<?>> termValues = new ArrayList<>(valuesCount);
for (int i = 0; i < valuesCount; ++i) {
double val = values.nextValue();
if (previous != val || i == 0) {
if (longFilter == null || longFilter.accept(NumericUtils.doubleToSortableLong(val))) {
termValues.add(val);
termValues.add(TermValue.of(val));
}
previous = val;
}
Expand Down

0 comments on commit 46f2bd0

Please sign in to comment.