Skip to content

Commit

Permalink
Add dynamic pruning test with score mode set to Max
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikep86 committed Sep 6, 2024
1 parent ca0cd07 commit 51cf4fd
Showing 1 changed file with 145 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.carrotsearch.randomizedtesting.generators.RandomPicks;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
Expand Down Expand Up @@ -55,7 +56,8 @@ public class TestBlockJoinBulkScorer extends LuceneTestCase {
private enum MatchValue {
MATCH_A("A", 1),
MATCH_B("B", 2),
MATCH_C("C", 3);
MATCH_C("C", 3),
MATCH_D("D", 4);

private static final List<MatchValue> VALUES = List.of(values());

Expand Down Expand Up @@ -92,7 +94,7 @@ public ChildDocMatch(int docId, List<MatchValue> matches) {
}
}

private static Map<Integer, List<ChildDocMatch>> populateIndex(
private static Map<Integer, List<ChildDocMatch>> populateRandomIndex(
RandomIndexWriter writer, int maxParentDocCount, int maxChildDocCount, int maxChildDocMatches)
throws IOException {
Map<Integer, List<ChildDocMatch>> expectedMatches = new HashMap<>();
Expand Down Expand Up @@ -139,6 +141,40 @@ private static Map<Integer, List<ChildDocMatch>> populateIndex(
return expectedMatches;
}

private static void populateStaticIndex(RandomIndexWriter writer) throws IOException {
// Use these vars to improve readability when defining the docs
final String A = MatchValue.MATCH_A.getText();
final String B = MatchValue.MATCH_B.getText();
final String C = MatchValue.MATCH_C.getText();
final String D = MatchValue.MATCH_D.getText();

for (String[][] values :
Arrays.asList(
new String[][] {{A, B}, {A, B, C}},
new String[][] {{A}, {B}},
new String[][] {{}},
new String[][] {{A, B, C}, {A, B, C, D}},
new String[][] {{B}},
new String[][] {{B, C}, {A, B}, {A, C}})) {

List<Document> docs = new ArrayList<>();
for (String[] value : values) {
Document childDoc = new Document();
childDoc.add(newStringField(TYPE_FIELD_NAME, CHILD_FILTER_VALUE, Field.Store.NO));
for (String v : value) {
childDoc.add(newStringField(VALUE_FIELD_NAME, v, Field.Store.NO));
}
docs.add(childDoc);
}

Document parentDoc = new Document();
parentDoc.add(newStringField(TYPE_FIELD_NAME, PARENT_FILTER_VALUE, Field.Store.NO));
docs.add(parentDoc);

writer.addDocuments(docs);
}
}

private static Map<Integer, Float> computeExpectedScores(
Map<Integer, List<ChildDocMatch>> expectedMatches,
ScoreMode joinScoreMode,
Expand Down Expand Up @@ -201,6 +237,51 @@ private static float computeExpectedScore(ChildDocMatch childDocMatch) {
return expectedScore;
}

private static ToParentBlockJoinQuery buildQuery(ScoreMode scoreMode) {
BooleanQuery.Builder childQueryBuilder = new BooleanQuery.Builder();
for (MatchValue matchValue : MatchValue.VALUES) {
childQueryBuilder.add(
new BoostQuery(
new ConstantScoreQuery(
new TermQuery(new Term(VALUE_FIELD_NAME, matchValue.getText()))),
matchValue.getScore()),
BooleanClause.Occur.SHOULD);
}
BitSetProducer parentsFilter =
new QueryBitSetProducer(new TermQuery(new Term(TYPE_FIELD_NAME, PARENT_FILTER_VALUE)));
return new ToParentBlockJoinQuery(childQueryBuilder.build(), parentsFilter, scoreMode);
}

private static void assertScores(
BulkScorer bulkScorer,
org.apache.lucene.search.ScoreMode scoreMode,
Float minScore,
Map<Integer, Float> expectedScores)
throws IOException {
Map<Integer, Float> actualScores = new HashMap<>();
bulkScorer.score(
new LeafCollector() {
private Scorable scorer;

@Override
public void setScorer(Scorable scorer) throws IOException {
assertNotNull(scorer);
this.scorer = scorer;
if (minScore != null) {
this.scorer.setMinCompetitiveScore(minScore);
}
}

@Override
public void collect(int doc) throws IOException {
assertNotNull(scorer);
actualScores.put(doc, scoreMode.needsScores() ? scorer.score() : 0);
}
},
null);
assertEquals(expectedScores, actualScores);
}

public void testScoreRandomIndices() throws IOException {
for (int i = 0; i < 200 * RANDOM_MULTIPLIER; i++) {
try (Directory dir = newDirectory()) {
Expand All @@ -215,7 +296,7 @@ public void testScoreRandomIndices() throws IOException {
newLogMergePolicy(random().nextBoolean())))) {

expectedMatches =
populateIndex(
populateRandomIndex(
w,
TestUtil.nextInt(random(), 10 * RANDOM_MULTIPLIER, 30 * RANDOM_MULTIPLIER),
20,
Expand All @@ -233,49 +314,74 @@ public void testScoreRandomIndices() throws IOException {
final Map<Integer, Float> expectedScores =
computeExpectedScores(expectedMatches, joinScoreMode, searchScoreMode);

BooleanQuery.Builder childQueryBuilder = new BooleanQuery.Builder();
for (MatchValue matchValue : MatchValue.VALUES) {
childQueryBuilder.add(
new BoostQuery(
new ConstantScoreQuery(
new TermQuery(new Term(VALUE_FIELD_NAME, matchValue.getText()))),
matchValue.getScore()),
BooleanClause.Occur.SHOULD);
}
BitSetProducer parentsFilter =
new QueryBitSetProducer(
new TermQuery(new Term(TYPE_FIELD_NAME, PARENT_FILTER_VALUE)));
ToParentBlockJoinQuery parentQuery =
new ToParentBlockJoinQuery(childQueryBuilder.build(), parentsFilter, joinScoreMode);

Weight weight = searcher.createWeight(searcher.rewrite(parentQuery), searchScoreMode, 1);
ToParentBlockJoinQuery query = buildQuery(joinScoreMode);
Weight weight = searcher.createWeight(searcher.rewrite(query), searchScoreMode, 1);
ScorerSupplier ss = weight.scorerSupplier(searcher.getIndexReader().leaves().get(0));
if (ss == null) {
// Score supplier will be null when there are no matches
assertTrue(expectedScores.isEmpty());
continue;
}

Map<Integer, Float> actualScores = new HashMap<>();
BulkScorer bulkScorer = ss.bulkScorer();
bulkScorer.score(
new LeafCollector() {
private Scorable scorer;

@Override
public void setScorer(Scorable scorer) {
assertNotNull(scorer);
this.scorer = scorer;
}

@Override
public void collect(int doc) throws IOException {
assertNotNull(scorer);
actualScores.put(doc, searchScoreMode.needsScores() ? scorer.score() : 0);
}
},
null);
assertEquals(expectedScores, actualScores);
assertScores(ss.bulkScorer(), searchScoreMode, null, expectedScores);
}
}
}
}

public void testSetMinCompetitiveScoreWithScoreModeMax() throws IOException {
try (Directory dir = newDirectory()) {
try (RandomIndexWriter w =
new RandomIndexWriter(
random(),
dir,
newIndexWriterConfig()
.setMergePolicy(
// retain doc id order
newLogMergePolicy(random().nextBoolean())))) {

populateStaticIndex(w);
w.forceMerge(1);
}

try (IndexReader reader = DirectoryReader.open(dir)) {
final IndexSearcher searcher = newSearcher(reader);
final ToParentBlockJoinQuery query = buildQuery(ScoreMode.Max);
final org.apache.lucene.search.ScoreMode scoreMode =
org.apache.lucene.search.ScoreMode.TOP_SCORES;
final Weight weight = searcher.createWeight(searcher.rewrite(query), scoreMode, 1);

{
Map<Integer, Float> expectedScores =
Map.of(
2, 6.0f,
5, 2.0f,
10, 10.0f,
12, 2.0f,
16, 5.0f);

ScorerSupplier ss = weight.scorerSupplier(searcher.getIndexReader().leaves().get(0));
ss.setTopLevelScoringClause();
assertScores(ss.bulkScorer(), scoreMode, null, expectedScores);
}

{
Map<Integer, Float> expectedScores =
Map.of(
2, 6.0f,
10, 10.0f);

ScorerSupplier ss = weight.scorerSupplier(searcher.getIndexReader().leaves().get(0));
ss.setTopLevelScoringClause();
assertScores(ss.bulkScorer(), scoreMode, 6.0f, expectedScores);
}

{
Map<Integer, Float> expectedScores = Map.of();

ScorerSupplier ss = weight.scorerSupplier(searcher.getIndexReader().leaves().get(0));
ss.setTopLevelScoringClause();
assertScores(ss.bulkScorer(), scoreMode, 11.0f, expectedScores);
}
}
}
Expand Down

0 comments on commit 51cf4fd

Please sign in to comment.