Skip to content

Commit

Permalink
Latency improvements to Multi Term Aggregations (opensearch-project#1…
Browse files Browse the repository at this point in the history
…4993)

* Avoid deep copy and other allocation improvements
* Refactoring based on PR Comments and added JavaDocs
* Added more comments
* Added character for Triggering Jenkins build
* Changes to cover collectZeroDocEntries method
* Updated comment based on change in method's functionality
* Added test to cover branches in collectZeroDocEntriesIfRequired
* Rebased and resolved changelog conflict

---------

Signed-off-by: expani <anijainc@amazon.com>
  • Loading branch information
expani authored Oct 7, 2024
1 parent 146b0f7 commit e885aa9
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 51 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Add changes to block calls in cat shards, indices and segments based on dynamic limit settings ([#15986](https://github.com/opensearch-project/OpenSearch/pull/15986))
- New `phone` & `phone-search` analyzer + tokenizer ([#15915](https://github.com/opensearch-project/OpenSearch/pull/15915))
- Add _list/shards API as paginated alternate to _cat/shards ([#14641](https://github.com/opensearch-project/OpenSearch/pull/14641))
- Latency and Memory allocation improvements to Multi Term Aggregation queries ([#14993](https://github.com/opensearch-project/OpenSearch/pull/14993))

### Dependencies
- Bump `com.azure:azure-identity` from 1.13.0 to 1.13.2 ([#15578](https://github.com/opensearch-project/OpenSearch/pull/15578))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.opensearch.search.aggregations.InternalAggregation;
import org.opensearch.search.aggregations.InternalOrder;
import org.opensearch.search.aggregations.LeafBucketCollector;
import org.opensearch.search.aggregations.bucket.BucketsAggregator;
import org.opensearch.search.aggregations.bucket.DeferableBucketAggregator;
import org.opensearch.search.aggregations.bucket.LocalBucketCountThresholds;
import org.opensearch.search.aggregations.support.AggregationPath;
Expand Down Expand Up @@ -215,19 +216,11 @@ public InternalAggregation buildEmptyAggregation() {

@Override
protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
MultiTermsValuesSourceCollector collector = multiTermsValue.getValues(ctx);
MultiTermsValuesSourceCollector collector = multiTermsValue.getValues(ctx, bucketOrds, this, sub);
return new LeafBucketCollector() {
@Override
public void collect(int doc, long owningBucketOrd) throws IOException {
for (BytesRef compositeKey : collector.apply(doc)) {
long bucketOrd = bucketOrds.add(owningBucketOrd, compositeKey);
if (bucketOrd < 0) {
bucketOrd = -1 - bucketOrd;
collectExistingBucket(sub, doc, bucketOrd);
} else {
collectBucket(sub, doc, bucketOrd);
}
}
collector.apply(doc, owningBucketOrd);
}
};
}
Expand Down Expand Up @@ -268,12 +261,10 @@ private void collectZeroDocEntriesIfNeeded(long owningBucketOrd) throws IOExcept
}
// we need to fill-in the blanks
for (LeafReaderContext ctx : context.searcher().getTopReaderContext().leaves()) {
MultiTermsValuesSourceCollector collector = multiTermsValue.getValues(ctx);
// brute force
MultiTermsValuesSourceCollector collector = multiTermsValue.getValues(ctx, bucketOrds, null, null);
for (int docId = 0; docId < ctx.reader().maxDoc(); ++docId) {
for (BytesRef compositeKey : collector.apply(docId)) {
bucketOrds.add(owningBucketOrd, compositeKey);
}
collector.apply(docId, owningBucketOrd);
}
}
}
Expand All @@ -284,10 +275,11 @@ private void collectZeroDocEntriesIfNeeded(long owningBucketOrd) throws IOExcept
@FunctionalInterface
interface MultiTermsValuesSourceCollector {
/**
* Collect a list values of multi_terms on each doc.
* Each terms could have multi_values, so the result is the cartesian product of each term's values.
* Generates the cartesian product of all fields used in aggregation and
* collects them in buckets using the composite key of their field values.
*/
List<BytesRef> apply(int doc) throws IOException;
void apply(int doc, long owningBucketOrd) throws IOException;

}

@FunctionalInterface
Expand Down Expand Up @@ -361,47 +353,72 @@ public MultiTermsValuesSource(List<InternalValuesSource> valuesSources) {
this.valuesSources = valuesSources;
}

public MultiTermsValuesSourceCollector getValues(LeafReaderContext ctx) throws IOException {
public MultiTermsValuesSourceCollector getValues(
LeafReaderContext ctx,
BytesKeyedBucketOrds bucketOrds,
BucketsAggregator aggregator,
LeafBucketCollector sub
) throws IOException {
List<InternalValuesSourceCollector> collectors = new ArrayList<>();
for (InternalValuesSource valuesSource : valuesSources) {
collectors.add(valuesSource.apply(ctx));
}
boolean collectBucketOrds = aggregator != null && sub != null;
return new MultiTermsValuesSourceCollector() {

/**
* This method does the following : <br>
* <li>Fetches the values of every field present in the doc List<List<TermValue<?>>> via @{@link InternalValuesSourceCollector}</li>
* <li>Generates Composite keys from the fetched values for all fields present in the aggregation.</li>
* <li>Adds every composite key to the @{@link BytesKeyedBucketOrds} and Optionally collects them via @{@link BucketsAggregator#collectBucket(LeafBucketCollector, int, long)}</li>
*/
@Override
public List<BytesRef> apply(int doc) throws IOException {
public void apply(int doc, long owningBucketOrd) throws IOException {
// TODO A new list creation can be avoided for every doc.
List<List<TermValue<?>>> collectedValues = new ArrayList<>();
for (InternalValuesSourceCollector collector : collectors) {
collectedValues.add(collector.apply(doc));
}
List<BytesRef> result = new ArrayList<>();
scratch.seek(0);
scratch.writeVInt(collectors.size()); // number of fields per composite key
cartesianProduct(result, scratch, collectedValues, 0);
return result;
generateAndCollectCompositeKeys(collectedValues, 0, owningBucketOrd, doc);
}

/**
* Cartesian product using depth first search.
*
* <p>
* Composite keys are encoded to a {@link BytesRef} in a format compatible with {@link StreamOutput::writeGenericValue},
* but reuses the encoding of the shared prefixes from the previous levels to avoid wasteful work.
* This generates and collects all Composite keys in their buckets by performing a cartesian product <br>
* of all the values in all the fields ( used in agg ) for the given doc recursively.
* @param collectedValues : Values of all fields present in the aggregation for the @doc
* @param index : Points to the field being added to generate the composite key
*/
private void cartesianProduct(
List<BytesRef> compositeKeys,
BytesStreamOutput scratch,
private void generateAndCollectCompositeKeys(
List<List<TermValue<?>>> collectedValues,
int index
int index,
long owningBucketOrd,
int doc
) throws IOException {
if (collectedValues.size() == index) {
compositeKeys.add(BytesRef.deepCopyOf(scratch.bytes().toBytesRef()));
// Avoid performing a deep copy of the composite key by inlining.
long bucketOrd = bucketOrds.add(owningBucketOrd, scratch.bytes().toBytesRef());
if (collectBucketOrds) {
if (bucketOrd < 0) {
bucketOrd = -1 - bucketOrd;
aggregator.collectExistingBucket(sub, doc, bucketOrd);
} else {
aggregator.collectBucket(sub, doc, bucketOrd);
}
}
return;
}

long position = scratch.position();
for (TermValue<?> value : collectedValues.get(index)) {
List<TermValue<?>> values = collectedValues.get(index);
int numIterations = values.size();
// For each loop is not done to reduce the allocations done for Iterator objects
// once for every field in every doc.
for (int i = 0; i < numIterations; i++) {
TermValue<?> value = values.get(i);
value.writeTo(scratch); // encode the value
cartesianProduct(compositeKeys, scratch, collectedValues, index + 1); // dfs
generateAndCollectCompositeKeys(collectedValues, index + 1, owningBucketOrd, doc); // dfs
scratch.seek(position); // backtrack
}
}
Expand Down Expand Up @@ -441,9 +458,14 @@ static InternalValuesSource bytesValuesSource(ValuesSource valuesSource, Include
if (i > 0 && bytes.equals(previous)) {
continue;
}
BytesRef copy = BytesRef.deepCopyOf(bytes);
termValues.add(TermValue.of(copy));
previous = copy;
// Performing a deep copy is not required for field containing only one value.
if (valuesCount > 1) {
BytesRef copy = BytesRef.deepCopyOf(bytes);
termValues.add(TermValue.of(copy));
previous = copy;
} else {
termValues.add(TermValue.of(bytes));
}
}
return termValues;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,19 @@ public class MultiTermsAggregatorTests extends AggregatorTestCase {

private static final Consumer<MultiTermsAggregationBuilder> NONE_DECORATOR = null;

private static final Consumer<InternalMultiTerms> IP_AND_KEYWORD_DESC_ORDER_VERIFY = h -> {
MatcherAssert.assertThat(h.getBuckets(), hasSize(3));
MatcherAssert.assertThat(h.getBuckets().get(0).getKey(), contains(equalTo("a"), equalTo("192.168.0.0")));
MatcherAssert.assertThat(h.getBuckets().get(0).getKeyAsString(), equalTo("a|192.168.0.0"));
MatcherAssert.assertThat(h.getBuckets().get(0).getDocCount(), equalTo(2L));
MatcherAssert.assertThat(h.getBuckets().get(1).getKey(), contains(equalTo("b"), equalTo("192.168.0.1")));
MatcherAssert.assertThat(h.getBuckets().get(1).getKeyAsString(), equalTo("b|192.168.0.1"));
MatcherAssert.assertThat(h.getBuckets().get(1).getDocCount(), equalTo(1L));
MatcherAssert.assertThat(h.getBuckets().get(2).getKey(), contains(equalTo("c"), equalTo("192.168.0.2")));
MatcherAssert.assertThat(h.getBuckets().get(2).getKeyAsString(), equalTo("c|192.168.0.2"));
MatcherAssert.assertThat(h.getBuckets().get(2).getDocCount(), equalTo(1L));
};

@Override
protected List<ValuesSourceType> getSupportedValuesSourceTypes() {
return Collections.unmodifiableList(
Expand Down Expand Up @@ -672,8 +685,48 @@ public void testDatesFieldFormat() throws IOException {
);
}

public void testIpAndKeyword() throws IOException {
testAggregation(new MatchAllDocsQuery(), fieldConfigs(asList(KEYWORD_FIELD, IP_FIELD)), NONE_DECORATOR, iw -> {
public void testIpAndKeywordDefaultDescOrder() throws IOException {
ipAndKeywordTest(NONE_DECORATOR, IP_AND_KEYWORD_DESC_ORDER_VERIFY);
}

public void testIpAndKeywordWithBucketCountSameAsSize() throws IOException {
ipAndKeywordTest(multiTermsAggregationBuilder -> {
multiTermsAggregationBuilder.minDocCount(0);
multiTermsAggregationBuilder.size(3);
multiTermsAggregationBuilder.order(BucketOrder.compound(BucketOrder.count(false)));
}, IP_AND_KEYWORD_DESC_ORDER_VERIFY);
}

public void testIpAndKeywordWithBucketCountGreaterThanSize() throws IOException {
ipAndKeywordTest(multiTermsAggregationBuilder -> {
multiTermsAggregationBuilder.minDocCount(0);
multiTermsAggregationBuilder.size(10);
multiTermsAggregationBuilder.order(BucketOrder.compound(BucketOrder.count(false)));
}, IP_AND_KEYWORD_DESC_ORDER_VERIFY);
}

public void testIpAndKeywordAscOrder() throws IOException {
ipAndKeywordTest(multiTermsAggregationBuilder -> {
multiTermsAggregationBuilder.minDocCount(0);
multiTermsAggregationBuilder.size(3);
multiTermsAggregationBuilder.order(BucketOrder.compound(BucketOrder.count(true)));
}, h -> {
MatcherAssert.assertThat(h.getBuckets(), hasSize(3));
MatcherAssert.assertThat(h.getBuckets().get(0).getKey(), contains(equalTo("b"), equalTo("192.168.0.1")));
MatcherAssert.assertThat(h.getBuckets().get(0).getKeyAsString(), equalTo("b|192.168.0.1"));
MatcherAssert.assertThat(h.getBuckets().get(0).getDocCount(), equalTo(1L));
MatcherAssert.assertThat(h.getBuckets().get(1).getKey(), contains(equalTo("c"), equalTo("192.168.0.2")));
MatcherAssert.assertThat(h.getBuckets().get(1).getKeyAsString(), equalTo("c|192.168.0.2"));
MatcherAssert.assertThat(h.getBuckets().get(1).getDocCount(), equalTo(1L));
MatcherAssert.assertThat(h.getBuckets().get(2).getKey(), contains(equalTo("a"), equalTo("192.168.0.0")));
MatcherAssert.assertThat(h.getBuckets().get(2).getKeyAsString(), equalTo("a|192.168.0.0"));
MatcherAssert.assertThat(h.getBuckets().get(2).getDocCount(), equalTo(2L));
});
}

private void ipAndKeywordTest(Consumer<MultiTermsAggregationBuilder> builderDecorator, Consumer<InternalMultiTerms> verify)
throws IOException {
testAggregation(new MatchAllDocsQuery(), fieldConfigs(asList(KEYWORD_FIELD, IP_FIELD)), builderDecorator, iw -> {
iw.addDocument(
asList(
new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")),
Expand All @@ -698,18 +751,7 @@ public void testIpAndKeyword() throws IOException {
new SortedDocValuesField(IP_FIELD, new BytesRef(InetAddressPoint.encode(InetAddresses.forString("192.168.0.0"))))
)
);
}, h -> {
MatcherAssert.assertThat(h.getBuckets(), hasSize(3));
MatcherAssert.assertThat(h.getBuckets().get(0).getKey(), contains(equalTo("a"), equalTo("192.168.0.0")));
MatcherAssert.assertThat(h.getBuckets().get(0).getKeyAsString(), equalTo("a|192.168.0.0"));
MatcherAssert.assertThat(h.getBuckets().get(0).getDocCount(), equalTo(2L));
MatcherAssert.assertThat(h.getBuckets().get(1).getKey(), contains(equalTo("b"), equalTo("192.168.0.1")));
MatcherAssert.assertThat(h.getBuckets().get(1).getKeyAsString(), equalTo("b|192.168.0.1"));
MatcherAssert.assertThat(h.getBuckets().get(1).getDocCount(), equalTo(1L));
MatcherAssert.assertThat(h.getBuckets().get(2).getKey(), contains(equalTo("c"), equalTo("192.168.0.2")));
MatcherAssert.assertThat(h.getBuckets().get(2).getKeyAsString(), equalTo("c|192.168.0.2"));
MatcherAssert.assertThat(h.getBuckets().get(2).getDocCount(), equalTo(1L));
});
}, verify);
}

public void testEmpty() throws IOException {
Expand Down

0 comments on commit e885aa9

Please sign in to comment.