Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[KNN] Add comment and remove duplicate code #13594

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.lucene.search;

import static org.apache.lucene.search.AnnQueryUtils.createBitSet;
import static org.apache.lucene.search.AnnQueryUtils.createFilterWeight;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;

import java.io.IOException;
Expand Down Expand Up @@ -53,8 +55,13 @@ abstract class AbstractKnnVectorQuery extends Query {

private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;

/** the KNN vector field to search */
protected final String field;

/** the number of documents to find */
protected final int k;

/** the filter to be executed. when the filter is applied is up to the underlying knn index */
protected final Query filter;

public AbstractKnnVectorQuery(String field, int k, Query filter) {
Expand All @@ -68,20 +75,12 @@ public AbstractKnnVectorQuery(String field, int k, Query filter) {

@Override
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
// we need to perform search inside rewrite() because we need to get top-k
// matches across all segments

IndexReader reader = indexSearcher.getIndexReader();

final Weight filterWeight;
if (filter != null) {
BooleanQuery booleanQuery =
new BooleanQuery.Builder()
.add(filter, BooleanClause.Occur.FILTER)
.add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER)
.build();
Query rewritten = indexSearcher.rewrite(booleanQuery);
filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
} else {
filterWeight = null;
}
final Weight filterWeight = createFilterWeight(indexSearcher, filter, field);

TimeLimitingKnnCollectorManager knnCollectorManager =
new TimeLimitingKnnCollectorManager(
Expand Down Expand Up @@ -116,6 +115,7 @@ private TopDocs searchLeaf(
return results;
}

// Perform kNN search for the provided LeafReaderContext applying filterWeight as necessary
private TopDocs getLeafResults(
LeafReaderContext ctx,
Weight filterWeight,
Expand Down Expand Up @@ -156,24 +156,6 @@ private TopDocs getLeafResults(
}
}

private BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc)
throws IOException {
if (liveDocs == null && iterator instanceof BitSetIterator bitSetIterator) {
// If we already have a BitSet and no deletions, reuse the BitSet
return bitSetIterator.getBitSet();
} else {
// Create a new BitSet from matching and live docs
FilteredDocIdSetIterator filterIterator =
new FilteredDocIdSetIterator(iterator) {
@Override
protected boolean match(int doc) {
return liveDocs == null || liveDocs.get(doc);
}
};
return BitSet.of(filterIterator, maxDoc);
}
}

protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
return new TopKnnCollectorManager(k, searcher);
}
Expand All @@ -188,6 +170,8 @@ protected abstract TopDocs approximateSearch(
abstract VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi)
throws IOException;

// Perform a brute-force search by computing the vector score for each accepted doc and try to
// take the top k docs.
// We allow this to be overridden so that tests can check what search strategy is used
protected TopDocs exactSearch(
LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout)
Expand Down Expand Up @@ -255,6 +239,8 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
return TopDocs.merge(k, perLeafResults);
}

// At this point we already collected top k matching docs, thus we only wrap the cached docs with
// their scores here.
private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
int len = topK.scoreDocs.length;

Expand All @@ -272,6 +258,8 @@ private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
return new DocAndScoreQuery(docs, scores, maxScore, segmentStarts, reader.getContext().id());
}

// For each segment, find the first index in <code>docs</code> belong to that segment.
// This method essentially partitions <code>docs</code> by segments
static int[] findSegmentStarts(List<LeafReaderContext> leaves, int[] docs) {
int[] starts = new int[leaves.size() + 1];
starts[starts.length - 1] = docs.length;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
*/
package org.apache.lucene.search;

import static org.apache.lucene.search.AnnQueryUtils.createBitSet;
import static org.apache.lucene.search.AnnQueryUtils.createFilterWeight;

import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
Expand Down Expand Up @@ -78,10 +81,7 @@ protected abstract TopDocs approximateSearch(
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
throws IOException {
return new Weight(this) {
final Weight filterWeight =
filter == null
? null
: searcher.createWeight(searcher.rewrite(filter), ScoreMode.COMPLETE_NO_SCORES, 1);
final Weight filterWeight = createFilterWeight(searcher, filter, field);

final QueryTimeout queryTimeout = searcher.getTimeout();
final TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager =
Expand Down Expand Up @@ -133,21 +133,7 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti
return null;
}

BitSet acceptDocs;
if (liveDocs == null && scorer.iterator() instanceof BitSetIterator bitSetIterator) {
// If there are no deletions, and matching docs are already cached
acceptDocs = bitSetIterator.getBitSet();
} else {
// Else collect all matching docs
FilteredDocIdSetIterator filtered =
new FilteredDocIdSetIterator(scorer.iterator()) {
@Override
protected boolean match(int doc) {
return liveDocs == null || liveDocs.get(doc);
}
};
acceptDocs = BitSet.of(filtered, leafReader.maxDoc());
}
BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, leafReader.maxDoc());

int cardinality = acceptDocs.cardinality();
if (cardinality == 0) {
Expand Down
81 changes: 81 additions & 0 deletions lucene/core/src/java/org/apache/lucene/search/AnnQueryUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;

import java.io.IOException;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;

/** Common utilities for ANN queries. */
final class AnnQueryUtils {

/** private constructor */
private AnnQueryUtils() {}

/**
* Create a bit set for a set of matching docs which are also not deleted.
*
* <p>If there is no deleted doc, it will use the matching docs bit set. Otherwise, it will return
* the bit set from matching docs which are also not deleted.
*
* @param iterator the matching doc iterator
* @param liveDocs the segment live (non-deleted) doc
* @param maxDoc the maximum number of docs to return
* @return a bit set over the matching docs
*/
static BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc)
throws IOException {
if (liveDocs == null && iterator instanceof BitSetIterator bitSetIterator) {
// If we already have a BitSet and no deletions, reuse the BitSet
return bitSetIterator.getBitSet();
} else {
// Create a new BitSet from matching and live docs
FilteredDocIdSetIterator filterIterator =
new FilteredDocIdSetIterator(iterator) {
@Override
protected boolean match(int doc) {
return liveDocs == null || liveDocs.get(doc);
}
};
return BitSet.of(filterIterator, maxDoc);
}
}

/**
* Create a Weight for the filtered query. The filter will also be enhanced to only match
* documents with value in the vector field.
*
* @param indexSearcher the index searcher to rewrite and create weight
* @param filter the filter query
* @param field the KNN vector field to check
* @return Weight for the filter query
*/
static Weight createFilterWeight(IndexSearcher indexSearcher, Query filter, String field)
throws IOException {
if (filter == null) {
return null;
}
BooleanQuery booleanQuery =
new BooleanQuery.Builder()
.add(filter, BooleanClause.Occur.FILTER)
.add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER)
.build();
Query rewritten = indexSearcher.rewrite(booleanQuery);
return indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
}
}
Loading