Skip to content

Commit 26a05a2

Browse files
kaivalnpKaival Parikh
andauthored
Add option for index-time filtering to knnPerfTest.py (#468)
* Add option for index-time filtering to knnPerfTest.py * Improve error message for incorrect "-filterStrategy" * Fix error message for missing "-filterStrategy" value --------- Co-authored-by: Kaival Parikh <kaivalp2000@gmail.com>
1 parent 13a50bf commit 26a05a2

File tree

5 files changed

+161
-60
lines changed

5 files changed

+161
-60
lines changed

src/main/knn/IndexerThread.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.apache.lucene.document.Field;
3232
import org.apache.lucene.index.IndexWriter;
3333
import org.apache.lucene.index.VectorEncoding;
34+
import org.apache.lucene.util.BitSet;
3435

3536
class IndexerThread extends Thread {
3637
private final IndexWriter iw;
@@ -41,14 +42,16 @@ class IndexerThread extends Thread {
4142
private final VectorEncoding vectorEncoding;
4243
private final byte[] byteVectorBuffer;
4344
private final float[] floatVectorBuffer;
45+
private final BitSet filtered;
4446

45-
public IndexerThread(IndexWriter iw, int dims, VectorReader vectorReader, VectorEncoding vectorEncoding, FieldType fieldType, AtomicInteger numDocsIndexed, int numDocsToIndex) {
47+
public IndexerThread(IndexWriter iw, int dims, VectorReader vectorReader, VectorEncoding vectorEncoding, FieldType fieldType, AtomicInteger numDocsIndexed, int numDocsToIndex, BitSet filtered) {
4648
this.iw = iw;
4749
this.vectorReader = vectorReader;
4850
this.vectorEncoding = vectorEncoding;
4951
this.fieldType = fieldType;
5052
this.numDocsIndexed = numDocsIndexed;
5153
this.numDocsToIndex = numDocsToIndex;
54+
this.filtered = filtered;
5255
switch (vectorEncoding) {
5356
case BYTE -> {
5457
byteVectorBuffer = new byte[dims];
@@ -88,11 +91,17 @@ private void _run() throws IOException {
8891
byte[] bytes = ((VectorReaderByte) vectorReader).nextBytes();
8992
System.arraycopy(bytes, 0, byteVectorBuffer, 0, bytes.length);
9093
doc.add(new KnnByteVectorField(KnnGraphTester.KNN_FIELD, byteVectorBuffer, fieldType));
94+
if (filtered != null && filtered.get(id)) {
95+
doc.add(new KnnByteVectorField(KnnGraphTester.KNN_FIELD_FILTERED, byteVectorBuffer, fieldType));
96+
}
9197
}
9298
case FLOAT32 -> {
9399
float[] floats = vectorReader.next();
94100
System.arraycopy(floats, 0, floatVectorBuffer, 0, floats.length);
95101
doc.add(new KnnFloatVectorField(KnnGraphTester.KNN_FIELD, floatVectorBuffer, fieldType));
102+
if (filtered != null && filtered.get(id)) {
103+
doc.add(new KnnFloatVectorField(KnnGraphTester.KNN_FIELD_FILTERED, floatVectorBuffer, fieldType));
104+
}
96105
}
97106
}
98107

src/main/knn/KnnGraphTester.java

Lines changed: 121 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919

2020
import java.io.IOException;
2121
import java.io.OutputStream;
22-
import java.lang.management.ManagementFactory;
23-
import java.lang.management.ThreadMXBean;
2422
import java.nio.ByteBuffer;
2523
import java.nio.ByteOrder;
2624
import java.nio.IntBuffer;
@@ -133,7 +131,14 @@ enum IndexType {
133131
FLAT
134132
}
135133

134+
enum FilterStrategy {
135+
QUERY_TIME_PRE_FILTER,
136+
QUERY_TIME_POST_FILTER,
137+
INDEX_TIME_FILTER
138+
}
139+
136140
public static final String KNN_FIELD = "knn";
141+
public static final String KNN_FIELD_FILTERED = "knn-filtered";
137142
public static final String ID_FIELD = "id";
138143
private static final String INDEX_DIR = "knnIndices";
139144
public static final String DOCTYPE_FIELD = "docType";
@@ -174,8 +179,8 @@ enum IndexType {
174179
private VectorSimilarityFunction similarityFunction;
175180
private VectorEncoding vectorEncoding;
176181
private Query filterQuery;
177-
private float selectivity;
178-
private boolean prefilter;
182+
private FilterStrategy filterStrategy;
183+
private Float filterSelectivity;
179184
private boolean randomCommits;
180185
private boolean parentJoin;
181186
private Path parentJoinMetaFile;
@@ -200,8 +205,8 @@ private KnnGraphTester() {
200205
fanout = topK;
201206
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
202207
vectorEncoding = VectorEncoding.FLOAT32;
203-
selectivity = 1f;
204-
prefilter = false;
208+
filterStrategy = null;
209+
filterSelectivity = null;
205210
quantize = false;
206211
randomCommits = false;
207212
quantizeBits = 7;
@@ -412,18 +417,27 @@ private void run(String... args) throws Exception {
412417
case "-forceMerge":
413418
forceMerge = true;
414419
break;
415-
case "-prefilter":
416-
prefilter = true;
417-
break;
418420
case "-randomCommits":
419421
randomCommits = true;
420422
break;
423+
case "-filterStrategy":
424+
if (iarg == args.length - 1) {
425+
throw new IllegalArgumentException("-filterStrategy requires a following string, one of (case-insensitive) {'query-time-pre-filter', 'query-time-post-filter', 'index-time-filter'}");
426+
}
427+
String filterStrategyVal = args[++iarg].toLowerCase().trim();
428+
filterStrategy = switch (filterStrategyVal) {
429+
case "query-time-pre-filter" -> FilterStrategy.QUERY_TIME_PRE_FILTER;
430+
case "query-time-post-filter" -> FilterStrategy.QUERY_TIME_POST_FILTER;
431+
case "index-time-filter" -> FilterStrategy.INDEX_TIME_FILTER;
432+
default -> throw new IllegalArgumentException("-filterStrategy must be one of (case-insensitive) {'query-time-pre-filter', 'query-time-post-filter', 'index-time-filter'}, found: " + filterStrategyVal);
433+
};
434+
break;
421435
case "-filterSelectivity":
422436
if (iarg == args.length - 1) {
423437
throw new IllegalArgumentException("-filterSelectivity requires a following float");
424438
}
425-
selectivity = Float.parseFloat(args[++iarg]);
426-
if (selectivity <= 0 || selectivity >= 1) {
439+
filterSelectivity = Float.parseFloat(args[++iarg]);
440+
if (filterSelectivity <= 0 || filterSelectivity >= 1) {
427441
throw new IllegalArgumentException("-filterSelectivity must be between 0 and 1");
428442
}
429443
break;
@@ -484,8 +498,10 @@ private void run(String... args) throws Exception {
484498
if (operation == null && reindex == false) {
485499
usage();
486500
}
487-
if (prefilter && selectivity == 1f) {
488-
throw new IllegalArgumentException("-prefilter requires filterSelectivity between 0 and 1");
501+
502+
BitSet filtered = null;
503+
if (filterStrategy != null && filterSelectivity == null || filterStrategy == null && filterSelectivity != null) {
504+
throw new IllegalArgumentException("Either both or none of -filterStrategy or -filterSelectivity should be specified");
489505
}
490506
if (indexPath == null) {
491507
indexPath = Paths.get(formatIndexPath(docVectorsPath, numDocs)); // derive index path
@@ -501,10 +517,19 @@ private void run(String... args) throws Exception {
501517
log("Seed = %d\n", randomSeed);
502518
random = new Random(randomSeed);
503519

520+
if (filterSelectivity != null) {
521+
filtered = selectRandomDocs(random, numDocs, filterSelectivity);
522+
}
504523
if (reindex || Files.exists(indexPath) == false) {
505524
if (docVectorsPath == null) {
506525
throw new IllegalArgumentException("-docs argument is required when indexing");
507526
}
527+
528+
BitSet indexTimeFilter = null;
529+
if (filterStrategy == FilterStrategy.INDEX_TIME_FILTER) {
530+
indexTimeFilter = filtered;
531+
}
532+
508533
reindexTimeMsec = new KnnIndexer(
509534
docVectorsPath,
510535
indexPath,
@@ -518,7 +543,8 @@ private void run(String... args) throws Exception {
518543
quiet,
519544
parentJoin,
520545
parentJoinMetaFile,
521-
useBp
546+
useBp,
547+
indexTimeFilter
522548
).createIndex();
523549
log("reindex takes %.2f sec\n", msToSec(reindexTimeMsec));
524550
}
@@ -533,7 +559,9 @@ private void run(String... args) throws Exception {
533559
if (docVectorsPath == null) {
534560
throw new IllegalArgumentException("missing -docs arg");
535561
}
536-
filterQuery = selectivity == 1f ? new MatchAllDocsQuery() : generateRandomQuery(random, indexPath, numDocs, selectivity);
562+
if (filterSelectivity != null) {
563+
filterQuery = createFilterQuery(indexPath, filtered);
564+
}
537565
if (outputPath != null) {
538566
testSearch(indexPath, queryPath, queryStartIndex, outputPath, null);
539567
} else {
@@ -655,7 +683,7 @@ private void printIndexStatistics(Path indexPath) throws IOException {
655683
}
656684
}
657685

658-
private static Query generateRandomQuery(Random random, Path indexPath, int size, float selectivity) throws IOException {
686+
private static BitSet selectRandomDocs(Random random, int size, float selectivity) {
659687
FixedBitSet bitSet = new FixedBitSet(size);
660688
for (int i = 0; i < size; i++) {
661689
if (random.nextFloat() < selectivity) {
@@ -664,7 +692,10 @@ private static Query generateRandomQuery(Random random, Path indexPath, int size
664692
bitSet.clear(i);
665693
}
666694
}
695+
return bitSet;
696+
}
667697

698+
private static Query createFilterQuery(Path indexPath, BitSet bitSet) throws IOException {
668699
try (Directory dir = FSDirectory.open(indexPath);
669700
DirectoryReader reader = DirectoryReader.open(dir)) {
670701
BitSet[] segmentDocs = new BitSet[reader.leaves().size()];
@@ -706,6 +737,14 @@ private String formatIndexPath(Path docsPath, int numDocs) {
706737
}
707738
// make sure we reindex if numDocs has changed:
708739
suffix.add(Integer.toString(numDocs));
740+
741+
// make sure we reindex if index-time filter is used
742+
if (filterStrategy == FilterStrategy.INDEX_TIME_FILTER) {
743+
suffix.add(filterStrategy.toString());
744+
suffix.add(filterSelectivity.toString());
745+
suffix.add(String.valueOf(randomSeed));
746+
}
747+
709748
return INDEX_DIR + "/" + docsPath.getFileName() + "-" + String.join("-", suffix) + ".index";
710749
}
711750

@@ -878,10 +917,10 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat
878917
for (int i = 0; i < numQueryVectors; i++) {
879918
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
880919
byte[] target = targetReaderByte.nextBytes();
881-
doKnnByteVectorQuery(searcher, KNN_FIELD, target, topK, fanout, prefilter, filterQuery);
920+
doKnnByteVectorQuery(searcher, target, topK, fanout, filterStrategy, filterQuery);
882921
} else {
883922
float[] target = targetReader.next();
884-
doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, prefilter, filterQuery, parentJoin);
923+
doKnnVectorQuery(searcher, target, topK, fanout, filterStrategy, filterQuery, parentJoin);
885924
}
886925
}
887926
targetReader.reset();
@@ -890,10 +929,10 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat
890929
for (int i = 0; i < numQueryVectors; i++) {
891930
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
892931
byte[] target = targetReaderByte.nextBytes();
893-
results[i] = doKnnByteVectorQuery(searcher, KNN_FIELD, target, topK, fanout, prefilter, filterQuery);
932+
results[i] = doKnnByteVectorQuery(searcher, target, topK, fanout, filterStrategy, filterQuery);
894933
} else {
895934
float[] target = targetReader.next();
896-
results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, prefilter, filterQuery, parentJoin);
935+
results[i] = doKnnVectorQuery(searcher, target, topK, fanout, filterStrategy, filterQuery, parentJoin);
897936
}
898937
}
899938
ThreadDetails endThreadDetails = new ThreadDetails();
@@ -956,7 +995,7 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat
956995
double reindexSec = reindexTimeMsec / 1000.0;
957996
System.out.printf(
958997
Locale.ROOT,
959-
"SUMMARY: %5.3f\t%5.3f\t%5.3f\t%5.3f\t%d\t%d\t%d\t%d\t%d\t%s\t%d\t%.2f\t%.2f\t%.2f\t%d\t%.2f\t%.2f\t%s\t%5.3f\t%5.3f\t%5.3f\t%s\n",
998+
"SUMMARY: %5.3f\t%5.3f\t%5.3f\t%5.3f\t%d\t%d\t%d\t%d\t%d\t%s\t%d\t%.2f\t%.2f\t%.2f\t%d\t%.2f\t%s\t%.2f\t%5.3f\t%5.3f\t%5.3f\t%s\n",
960999
recall,
9611000
elapsedMS / (float) numQueryVectors,
9621001
totalCpuTimeMS / (float) numQueryVectors,
@@ -973,8 +1012,8 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat
9731012
forceMergeTimeSec,
9741013
indexNumSegments,
9751014
indexSizeOnDiskMB,
976-
selectivity,
977-
prefilter ? "pre-filter" : "post-filter",
1015+
filterStrategy.toString().toLowerCase().replace('_', '-'),
1016+
filterSelectivity,
9781017
overSample,
9791018
vectorDiskSizeBytes / 1024. / 1024.,
9801019
vectorRAMSizeBytes / 1024. / 1024.,
@@ -992,31 +1031,62 @@ private static double msToSec(long ms) {
9921031
}
9931032

9941033
private static Result doKnnByteVectorQuery(
995-
IndexSearcher searcher, String field, byte[] vector, int k, int fanout, boolean prefilter, Query filter)
1034+
IndexSearcher searcher, byte[] vector, int k, int fanout, FilterStrategy filterStrategy, Query filter)
9961035
throws IOException {
997-
ProfiledKnnByteVectorQuery profiledQuery = new ProfiledKnnByteVectorQuery(field, vector, k, fanout, prefilter ? filter : null);
998-
Query query = prefilter ? profiledQuery : new BooleanQuery.Builder()
999-
.add(profiledQuery, BooleanClause.Occur.MUST)
1000-
.add(filter, BooleanClause.Occur.FILTER)
1001-
.build();
1036+
1037+
Query queryTimeFilter = null;
1038+
if (filterStrategy == FilterStrategy.QUERY_TIME_PRE_FILTER) {
1039+
queryTimeFilter = filter;
1040+
}
1041+
1042+
String knnField = KNN_FIELD;
1043+
if (filterStrategy == FilterStrategy.INDEX_TIME_FILTER) {
1044+
knnField = KNN_FIELD_FILTERED;
1045+
}
1046+
1047+
ProfiledKnnByteVectorQuery profiledQuery = new ProfiledKnnByteVectorQuery(knnField, vector, k, fanout, queryTimeFilter);
1048+
1049+
Query query = profiledQuery;
1050+
if (filterStrategy == FilterStrategy.QUERY_TIME_POST_FILTER) {
1051+
query = new BooleanQuery.Builder()
1052+
.add(profiledQuery, BooleanClause.Occur.MUST)
1053+
.add(filter, BooleanClause.Occur.FILTER)
1054+
.build();
1055+
}
10021056
TopDocs docs = searcher.search(query, k);
10031057
return new Result(docs, profiledQuery.totalVectorCount(), 0);
10041058
}
10051059

10061060
private static Result doKnnVectorQuery(
1007-
IndexSearcher searcher, String field, float[] vector, int k, int fanout, boolean prefilter, Query filter, boolean isParentJoinQuery)
1061+
IndexSearcher searcher, float[] vector, int k, int fanout, FilterStrategy filterStrategy, Query filter, boolean isParentJoinQuery)
10081062
throws IOException {
1063+
1064+
Query queryTimeFilter = null;
1065+
if (filterStrategy == FilterStrategy.QUERY_TIME_PRE_FILTER) {
1066+
queryTimeFilter = filter;
1067+
}
1068+
1069+
String knnField = KNN_FIELD;
1070+
if (filterStrategy == FilterStrategy.INDEX_TIME_FILTER) {
1071+
knnField = KNN_FIELD_FILTERED;
1072+
}
1073+
10091074
if (isParentJoinQuery) {
1010-
var topChildVectors = new DiversifyingChildrenFloatKnnVectorQuery(KNN_FIELD, vector, null, k + fanout, parentsFilter);
1075+
var topChildVectors = new DiversifyingChildrenFloatKnnVectorQuery(knnField, vector, null, k + fanout, parentsFilter);
10111076
var query = new ToParentBlockJoinQuery(topChildVectors, parentsFilter, org.apache.lucene.search.join.ScoreMode.Max);
10121077
TopDocs topDocs = searcher.search(query, k);
10131078
return new Result(topDocs, 0, 0);
10141079
}
1015-
ProfiledKnnFloatVectorQuery profiledQuery = new ProfiledKnnFloatVectorQuery(field, vector, k, fanout, prefilter ? filter : null);
1016-
Query query = prefilter ? profiledQuery : new BooleanQuery.Builder()
1017-
.add(profiledQuery, BooleanClause.Occur.MUST)
1018-
.add(filter, BooleanClause.Occur.FILTER)
1019-
.build();
1080+
1081+
ProfiledKnnFloatVectorQuery profiledQuery = new ProfiledKnnFloatVectorQuery(knnField, vector, k, fanout, queryTimeFilter);
1082+
1083+
Query query = profiledQuery;
1084+
if (filterStrategy == FilterStrategy.QUERY_TIME_POST_FILTER) {
1085+
query = new BooleanQuery.Builder()
1086+
.add(profiledQuery, BooleanClause.Occur.MUST)
1087+
.add(filter, BooleanClause.Occur.FILTER)
1088+
.build();
1089+
}
10201090
TopDocs docs = searcher.search(query, k);
10211091
return new Result(docs, profiledQuery.totalVectorCount(), 0);
10221092
}
@@ -1060,7 +1130,7 @@ private int compareNN(int[] expected, int[] results) {
10601130
*/
10611131
private int[][] getExactNN(Path docPath, Path indexPath, Path queryPath, int queryStartIndex) throws IOException, InterruptedException {
10621132
// look in working directory for cached nn file
1063-
String hash = Integer.toString(Objects.hash(docPath, indexPath, queryPath, numDocs, numQueryVectors, topK, similarityFunction.ordinal(), parentJoin, queryStartIndex, prefilter ? selectivity : 1f, prefilter ? randomSeed : 0f), 36);
1133+
String hash = Integer.toString(Objects.hash(docPath, indexPath, queryPath, numDocs, numQueryVectors, topK, similarityFunction.ordinal(), parentJoin, queryStartIndex, filterSelectivity == null ? 0 : Objects.hash(filterSelectivity, randomSeed)), 36);
10641134
String nnFileName = "nn-" + hash + ".bin";
10651135
Path nnPath = Paths.get(nnFileName);
10661136
if (Files.exists(nnPath) && isNewer(nnPath, docPath, queryPath)) {
@@ -1164,10 +1234,13 @@ public Void call() {
11641234
try {
11651235
var queryVector = new ConstKnnByteVectorValueSource(query);
11661236
var docVectors = new ByteKnnVectorFieldSource(KNN_FIELD);
1167-
var query = new BooleanQuery.Builder()
1168-
.add(new FunctionQuery(new ByteVectorSimilarityFunction(similarityFunction, queryVector, docVectors)), BooleanClause.Occur.SHOULD)
1169-
.add(filterQuery, BooleanClause.Occur.FILTER)
1170-
.build();
1237+
Query query = new FunctionQuery(new ByteVectorSimilarityFunction(similarityFunction, queryVector, docVectors));
1238+
if (filterQuery != null) {
1239+
query = new BooleanQuery.Builder()
1240+
.add(query, BooleanClause.Occur.SHOULD)
1241+
.add(filterQuery, BooleanClause.Occur.FILTER)
1242+
.build();
1243+
}
11711244
var topDocs = searcher.search(query, topK);
11721245
result[queryOrd] = knn.KnnTesterUtils.getResultIds(topDocs, reader.storedFields());
11731246
if ((queryOrd + 1) % 10 == 0) {
@@ -1238,10 +1311,13 @@ public Void call() {
12381311
try {
12391312
var queryVector = new ConstKnnFloatValueSource(query);
12401313
var docVectors = new FloatKnnVectorFieldSource(KNN_FIELD);
1241-
var query = new BooleanQuery.Builder()
1242-
.add(new FunctionQuery(new FloatVectorSimilarityFunction(similarityFunction, queryVector, docVectors)), BooleanClause.Occur.SHOULD)
1243-
.add(filterQuery, BooleanClause.Occur.FILTER)
1244-
.build();
1314+
Query query = new FunctionQuery(new FloatVectorSimilarityFunction(similarityFunction, queryVector, docVectors));
1315+
if (filterQuery != null) {
1316+
query = new BooleanQuery.Builder()
1317+
.add(query, BooleanClause.Occur.SHOULD)
1318+
.add(filterQuery, BooleanClause.Occur.FILTER)
1319+
.build();
1320+
}
12451321
var topDocs = searcher.search(query, topK);
12461322
result[queryOrd] = knn.KnnTesterUtils.getResultIds(topDocs, reader.storedFields());
12471323
if ((queryOrd + 1) % 10 == 0) {

0 commit comments

Comments
 (0)