diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java index a37452490d1d5..0f3d975960364 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java @@ -328,6 +328,11 @@ private abstract static class Collector extends LeafBucketCollector implements R } + /** + * This collector enhance the delegate collector with pruning ability on term field + * The iterators of term field values are wrapped into a priority queue, and able to + * pop/prune the values after being collected + */ private static class PruningCollector extends Collector { private final Collector delegate; @@ -348,8 +353,8 @@ private static class PruningCollector extends Collector { } this.queue = new DisiPriorityQueue(postingMap.size()); - for (Map.Entry entry : postingMap.entrySet()) { - queue.add(new DisiWrapper(entry.getValue())); + for (Scorer scorer : postingMap.values()) { + queue.add(new DisiWrapper(scorer)); } competitiveIterator = new DisjunctionDISI(queue); diff --git a/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java b/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java index 1de553a960414..b5dd27e37c332 100644 --- a/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java +++ b/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java @@ -43,6 +43,7 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.NoMergePolicy; import org.apache.lucene.index.Term; import org.apache.lucene.search.DocValuesFieldExistsQuery; import org.apache.lucene.search.IndexSearcher; @@ -73,6 +74,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.HashSet; +import java.util.List; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; @@ -224,6 +226,39 @@ private void testAggregation( testCase(aggregationBuilder, query, buildIndex, verify, fieldType); } + public void testDynamicPruningDisabledWhenExceedingThreshold() throws IOException { + final String fieldName = "testField"; + final String filterFieldName = "filterField"; + + MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType(fieldName); + final CardinalityAggregationBuilder aggregationBuilder = new CardinalityAggregationBuilder("_name").field(fieldName); + + int randomCardinality = randomIntBetween(20, 100); + AtomicInteger counter = new AtomicInteger(); + + testDynamicPruning(aggregationBuilder, new TermQuery(new Term(filterFieldName, "foo")), iw -> { + for (int i = 0; i < randomCardinality; i++) { + String filterValue = "foo"; + if (randomBoolean()) { + filterValue = "bar"; + counter.getAndIncrement(); + } + iw.addDocument( + asList( + new KeywordField(filterFieldName, filterValue, Field.Store.NO), + new KeywordField(fieldName, String.valueOf(i), Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef(String.valueOf(i))) + ) + ); + } + }, + card -> { assertEquals(randomCardinality - counter.get(), card.getValue(), 0); }, + fieldType, + 10, + (collectCount) -> assertEquals(randomCardinality - counter.get(), (int) collectCount) + ); + } + public void testDynamicPruningFixedValues() throws IOException { final String fieldName = "testField"; final String filterFieldName = "filterField"; @@ -356,15 +391,29 @@ public void testDynamicPruningFieldMissingInSegment() throws IOException { MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType(fieldName); final CardinalityAggregationBuilder aggregationBuilder = new CardinalityAggregationBuilder("_name").field(fieldName); + int randomNumSegments = randomIntBetween(1, 50); + logger.info("Indexing [{}] segments", randomNumSegments); + testDynamicPruning(aggregationBuilder, new MatchAllDocsQuery(), iw -> { - iw.addDocument(asList(new KeywordField(fieldName, "1", Field.Store.NO), new KeywordField(fieldName, "2", Field.Store.NO))); - iw.addDocument(asList(new KeywordField(fieldName, "1", Field.Store.NO), new KeywordField(fieldName, "3", Field.Store.NO))); - iw.addDocument(asList(new KeywordField(fieldName, "2", Field.Store.NO), new KeywordField(fieldName, "3", Field.Store.NO))); + for (int i = 0; i < randomNumSegments; i++) { + iw.addDocument( + asList( + new KeywordField(fieldName, String.valueOf(i), Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef(String.valueOf(i))) + ) + ); + iw.commit(); + } + iw.addDocument(List.of(new KeywordField(fieldName2, "100", Field.Store.NO))); + iw.addDocument(List.of(new KeywordField(fieldName2, "101", Field.Store.NO))); + iw.addDocument(List.of(new KeywordField(fieldName2, "102", Field.Store.NO))); iw.commit(); - iw.addDocument(asList(new KeywordField(fieldName2, "100", Field.Store.NO))); - iw.addDocument(asList(new KeywordField(fieldName2, "101", Field.Store.NO))); - iw.addDocument(asList(new KeywordField(fieldName2, "102", Field.Store.NO))); - }, card -> { assertEquals(3, card.getValue(), 0); }, fieldType, 100, (collectCount) -> assertEquals(3, (int) collectCount)); + }, + card -> { assertEquals(randomNumSegments, card.getValue(), 0); }, + fieldType, + 100, + (collectCount) -> assertEquals(3, (int) collectCount) + ); } private void testDynamicPruning( @@ -377,7 +426,13 @@ private void testDynamicPruning( Consumer verifyCollectCount ) throws IOException { try (Directory directory = newDirectory()) { - try (IndexWriter indexWriter = new IndexWriter(directory, new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec()))) { + try ( + IndexWriter indexWriter = new IndexWriter( + directory, + new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec()).setMergePolicy(NoMergePolicy.INSTANCE) + ) + ) { + // disable merge so segment number is same as commit times buildIndex.accept(indexWriter); } @@ -411,6 +466,7 @@ private void testDynamicPruning( verify.accept(card); + logger.info("aggregator collect count {}", aggregator.getCollectCount().get()); verifyCollectCount.accept(aggregator.getCollectCount().get()); } }