1919
2020import java .io .IOException ;
2121import java .io .OutputStream ;
22- import java .lang .management .ManagementFactory ;
23- import java .lang .management .ThreadMXBean ;
2422import java .nio .ByteBuffer ;
2523import java .nio .ByteOrder ;
2624import java .nio .IntBuffer ;
@@ -133,7 +131,14 @@ enum IndexType {
133131 FLAT
134132 }
135133
134+ enum FilterStrategy {
135+ QUERY_TIME_PRE_FILTER ,
136+ QUERY_TIME_POST_FILTER ,
137+ INDEX_TIME_FILTER
138+ }
139+
136140 public static final String KNN_FIELD = "knn" ;
141+ public static final String KNN_FIELD_FILTERED = "knn-filtered" ;
137142 public static final String ID_FIELD = "id" ;
138143 private static final String INDEX_DIR = "knnIndices" ;
139144 public static final String DOCTYPE_FIELD = "docType" ;
@@ -174,8 +179,8 @@ enum IndexType {
174179 private VectorSimilarityFunction similarityFunction ;
175180 private VectorEncoding vectorEncoding ;
176181 private Query filterQuery ;
177- private float selectivity ;
178- private boolean prefilter ;
182+ private FilterStrategy filterStrategy ;
183+ private Float filterSelectivity ;
179184 private boolean randomCommits ;
180185 private boolean parentJoin ;
181186 private Path parentJoinMetaFile ;
@@ -200,8 +205,8 @@ private KnnGraphTester() {
200205 fanout = topK ;
201206 similarityFunction = VectorSimilarityFunction .DOT_PRODUCT ;
202207 vectorEncoding = VectorEncoding .FLOAT32 ;
203- selectivity = 1f ;
204- prefilter = false ;
208+ filterStrategy = null ;
209+ filterSelectivity = null ;
205210 quantize = false ;
206211 randomCommits = false ;
207212 quantizeBits = 7 ;
@@ -412,18 +417,27 @@ private void run(String... args) throws Exception {
412417 case "-forceMerge" :
413418 forceMerge = true ;
414419 break ;
415- case "-prefilter" :
416- prefilter = true ;
417- break ;
418420 case "-randomCommits" :
419421 randomCommits = true ;
420422 break ;
423+ case "-filterStrategy" :
424+ if (iarg == args .length - 1 ) {
425+ throw new IllegalArgumentException ("-filterStrategy requires a following string, one of (case-insensitive) {'query-time-pre-filter', 'query-time-post-filter', 'index-time-filter'}" );
426+ }
427+ String filterStrategyVal = args [++iarg ].toLowerCase ().trim ();
428+ filterStrategy = switch (filterStrategyVal ) {
429+ case "query-time-pre-filter" -> FilterStrategy .QUERY_TIME_PRE_FILTER ;
430+ case "query-time-post-filter" -> FilterStrategy .QUERY_TIME_POST_FILTER ;
431+ case "index-time-filter" -> FilterStrategy .INDEX_TIME_FILTER ;
432+ default -> throw new IllegalArgumentException ("-filterStrategy must be one of (case-insensitive) {'query-time-pre-filter', 'query-time-post-filter', 'index-time-filter'}, found: " + filterStrategyVal );
433+ };
434+ break ;
421435 case "-filterSelectivity" :
422436 if (iarg == args .length - 1 ) {
423437 throw new IllegalArgumentException ("-filterSelectivity requires a following float" );
424438 }
425- selectivity = Float .parseFloat (args [++iarg ]);
426- if (selectivity <= 0 || selectivity >= 1 ) {
439+ filterSelectivity = Float .parseFloat (args [++iarg ]);
440+ if (filterSelectivity <= 0 || filterSelectivity >= 1 ) {
427441 throw new IllegalArgumentException ("-filterSelectivity must be between 0 and 1" );
428442 }
429443 break ;
@@ -484,8 +498,10 @@ private void run(String... args) throws Exception {
484498 if (operation == null && reindex == false ) {
485499 usage ();
486500 }
487- if (prefilter && selectivity == 1f ) {
488- throw new IllegalArgumentException ("-prefilter requires filterSelectivity between 0 and 1" );
501+
502+ BitSet filtered = null ;
503+ if (filterStrategy != null && filterSelectivity == null || filterStrategy == null && filterSelectivity != null ) {
504+ throw new IllegalArgumentException ("Either both or none of -filterStrategy or -filterSelectivity should be specified" );
489505 }
490506 if (indexPath == null ) {
491507 indexPath = Paths .get (formatIndexPath (docVectorsPath , numDocs )); // derive index path
@@ -501,10 +517,19 @@ private void run(String... args) throws Exception {
501517 log ("Seed = %d\n " , randomSeed );
502518 random = new Random (randomSeed );
503519
520+ if (filterSelectivity != null ) {
521+ filtered = selectRandomDocs (random , numDocs , filterSelectivity );
522+ }
504523 if (reindex || Files .exists (indexPath ) == false ) {
505524 if (docVectorsPath == null ) {
506525 throw new IllegalArgumentException ("-docs argument is required when indexing" );
507526 }
527+
528+ BitSet indexTimeFilter = null ;
529+ if (filterStrategy == FilterStrategy .INDEX_TIME_FILTER ) {
530+ indexTimeFilter = filtered ;
531+ }
532+
508533 reindexTimeMsec = new KnnIndexer (
509534 docVectorsPath ,
510535 indexPath ,
@@ -518,7 +543,8 @@ private void run(String... args) throws Exception {
518543 quiet ,
519544 parentJoin ,
520545 parentJoinMetaFile ,
521- useBp
546+ useBp ,
547+ indexTimeFilter
522548 ).createIndex ();
523549 log ("reindex takes %.2f sec\n " , msToSec (reindexTimeMsec ));
524550 }
@@ -533,7 +559,9 @@ private void run(String... args) throws Exception {
533559 if (docVectorsPath == null ) {
534560 throw new IllegalArgumentException ("missing -docs arg" );
535561 }
536- filterQuery = selectivity == 1f ? new MatchAllDocsQuery () : generateRandomQuery (random , indexPath , numDocs , selectivity );
562+ if (filterSelectivity != null ) {
563+ filterQuery = createFilterQuery (indexPath , filtered );
564+ }
537565 if (outputPath != null ) {
538566 testSearch (indexPath , queryPath , queryStartIndex , outputPath , null );
539567 } else {
@@ -655,7 +683,7 @@ private void printIndexStatistics(Path indexPath) throws IOException {
655683 }
656684 }
657685
658- private static Query generateRandomQuery (Random random , Path indexPath , int size , float selectivity ) throws IOException {
686+ private static BitSet selectRandomDocs (Random random , int size , float selectivity ) {
659687 FixedBitSet bitSet = new FixedBitSet (size );
660688 for (int i = 0 ; i < size ; i ++) {
661689 if (random .nextFloat () < selectivity ) {
@@ -664,7 +692,10 @@ private static Query generateRandomQuery(Random random, Path indexPath, int size
664692 bitSet .clear (i );
665693 }
666694 }
695+ return bitSet ;
696+ }
667697
698+ private static Query createFilterQuery (Path indexPath , BitSet bitSet ) throws IOException {
668699 try (Directory dir = FSDirectory .open (indexPath );
669700 DirectoryReader reader = DirectoryReader .open (dir )) {
670701 BitSet [] segmentDocs = new BitSet [reader .leaves ().size ()];
@@ -706,6 +737,14 @@ private String formatIndexPath(Path docsPath, int numDocs) {
706737 }
707738 // make sure we reindex if numDocs has changed:
708739 suffix .add (Integer .toString (numDocs ));
740+
741+ // make sure we reindex if index-time filter is used
742+ if (filterStrategy == FilterStrategy .INDEX_TIME_FILTER ) {
743+ suffix .add (filterStrategy .toString ());
744+ suffix .add (filterSelectivity .toString ());
745+ suffix .add (String .valueOf (randomSeed ));
746+ }
747+
709748 return INDEX_DIR + "/" + docsPath .getFileName () + "-" + String .join ("-" , suffix ) + ".index" ;
710749 }
711750
@@ -878,10 +917,10 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat
878917 for (int i = 0 ; i < numQueryVectors ; i ++) {
879918 if (vectorEncoding .equals (VectorEncoding .BYTE )) {
880919 byte [] target = targetReaderByte .nextBytes ();
881- doKnnByteVectorQuery (searcher , KNN_FIELD , target , topK , fanout , prefilter , filterQuery );
920+ doKnnByteVectorQuery (searcher , target , topK , fanout , filterStrategy , filterQuery );
882921 } else {
883922 float [] target = targetReader .next ();
884- doKnnVectorQuery (searcher , KNN_FIELD , target , topK , fanout , prefilter , filterQuery , parentJoin );
923+ doKnnVectorQuery (searcher , target , topK , fanout , filterStrategy , filterQuery , parentJoin );
885924 }
886925 }
887926 targetReader .reset ();
@@ -890,10 +929,10 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat
890929 for (int i = 0 ; i < numQueryVectors ; i ++) {
891930 if (vectorEncoding .equals (VectorEncoding .BYTE )) {
892931 byte [] target = targetReaderByte .nextBytes ();
893- results [i ] = doKnnByteVectorQuery (searcher , KNN_FIELD , target , topK , fanout , prefilter , filterQuery );
932+ results [i ] = doKnnByteVectorQuery (searcher , target , topK , fanout , filterStrategy , filterQuery );
894933 } else {
895934 float [] target = targetReader .next ();
896- results [i ] = doKnnVectorQuery (searcher , KNN_FIELD , target , topK , fanout , prefilter , filterQuery , parentJoin );
935+ results [i ] = doKnnVectorQuery (searcher , target , topK , fanout , filterStrategy , filterQuery , parentJoin );
897936 }
898937 }
899938 ThreadDetails endThreadDetails = new ThreadDetails ();
@@ -956,7 +995,7 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat
956995 double reindexSec = reindexTimeMsec / 1000.0 ;
957996 System .out .printf (
958997 Locale .ROOT ,
959- "SUMMARY: %5.3f\t %5.3f\t %5.3f\t %5.3f\t %d\t %d\t %d\t %d\t %d\t %s\t %d\t %.2f\t %.2f\t %.2f\t %d\t %.2f\t %.2f \t %s \t %5.3f\t %5.3f\t %5.3f\t %s\n " ,
998+ "SUMMARY: %5.3f\t %5.3f\t %5.3f\t %5.3f\t %d\t %d\t %d\t %d\t %d\t %s\t %d\t %.2f\t %.2f\t %.2f\t %d\t %.2f\t %s \t %.2f \t %5.3f\t %5.3f\t %5.3f\t %s\n " ,
960999 recall ,
9611000 elapsedMS / (float ) numQueryVectors ,
9621001 totalCpuTimeMS / (float ) numQueryVectors ,
@@ -973,8 +1012,8 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat
9731012 forceMergeTimeSec ,
9741013 indexNumSegments ,
9751014 indexSizeOnDiskMB ,
976- selectivity ,
977- prefilter ? "pre-filter" : "post-filter" ,
1015+ filterStrategy . toString (). toLowerCase (). replace ( '_' , '-' ) ,
1016+ filterSelectivity ,
9781017 overSample ,
9791018 vectorDiskSizeBytes / 1024. / 1024. ,
9801019 vectorRAMSizeBytes / 1024. / 1024. ,
@@ -992,31 +1031,62 @@ private static double msToSec(long ms) {
9921031 }
9931032
9941033 private static Result doKnnByteVectorQuery (
995- IndexSearcher searcher , String field , byte [] vector , int k , int fanout , boolean prefilter , Query filter )
1034+ IndexSearcher searcher , byte [] vector , int k , int fanout , FilterStrategy filterStrategy , Query filter )
9961035 throws IOException {
997- ProfiledKnnByteVectorQuery profiledQuery = new ProfiledKnnByteVectorQuery (field , vector , k , fanout , prefilter ? filter : null );
998- Query query = prefilter ? profiledQuery : new BooleanQuery .Builder ()
999- .add (profiledQuery , BooleanClause .Occur .MUST )
1000- .add (filter , BooleanClause .Occur .FILTER )
1001- .build ();
1036+
1037+ Query queryTimeFilter = null ;
1038+ if (filterStrategy == FilterStrategy .QUERY_TIME_PRE_FILTER ) {
1039+ queryTimeFilter = filter ;
1040+ }
1041+
1042+ String knnField = KNN_FIELD ;
1043+ if (filterStrategy == FilterStrategy .INDEX_TIME_FILTER ) {
1044+ knnField = KNN_FIELD_FILTERED ;
1045+ }
1046+
1047+ ProfiledKnnByteVectorQuery profiledQuery = new ProfiledKnnByteVectorQuery (knnField , vector , k , fanout , queryTimeFilter );
1048+
1049+ Query query = profiledQuery ;
1050+ if (filterStrategy == FilterStrategy .QUERY_TIME_POST_FILTER ) {
1051+ query = new BooleanQuery .Builder ()
1052+ .add (profiledQuery , BooleanClause .Occur .MUST )
1053+ .add (filter , BooleanClause .Occur .FILTER )
1054+ .build ();
1055+ }
10021056 TopDocs docs = searcher .search (query , k );
10031057 return new Result (docs , profiledQuery .totalVectorCount (), 0 );
10041058 }
10051059
10061060 private static Result doKnnVectorQuery (
1007- IndexSearcher searcher , String field , float [] vector , int k , int fanout , boolean prefilter , Query filter , boolean isParentJoinQuery )
1061+ IndexSearcher searcher , float [] vector , int k , int fanout , FilterStrategy filterStrategy , Query filter , boolean isParentJoinQuery )
10081062 throws IOException {
1063+
1064+ Query queryTimeFilter = null ;
1065+ if (filterStrategy == FilterStrategy .QUERY_TIME_PRE_FILTER ) {
1066+ queryTimeFilter = filter ;
1067+ }
1068+
1069+ String knnField = KNN_FIELD ;
1070+ if (filterStrategy == FilterStrategy .INDEX_TIME_FILTER ) {
1071+ knnField = KNN_FIELD_FILTERED ;
1072+ }
1073+
10091074 if (isParentJoinQuery ) {
1010- var topChildVectors = new DiversifyingChildrenFloatKnnVectorQuery (KNN_FIELD , vector , null , k + fanout , parentsFilter );
1075+ var topChildVectors = new DiversifyingChildrenFloatKnnVectorQuery (knnField , vector , null , k + fanout , parentsFilter );
10111076 var query = new ToParentBlockJoinQuery (topChildVectors , parentsFilter , org .apache .lucene .search .join .ScoreMode .Max );
10121077 TopDocs topDocs = searcher .search (query , k );
10131078 return new Result (topDocs , 0 , 0 );
10141079 }
1015- ProfiledKnnFloatVectorQuery profiledQuery = new ProfiledKnnFloatVectorQuery (field , vector , k , fanout , prefilter ? filter : null );
1016- Query query = prefilter ? profiledQuery : new BooleanQuery .Builder ()
1017- .add (profiledQuery , BooleanClause .Occur .MUST )
1018- .add (filter , BooleanClause .Occur .FILTER )
1019- .build ();
1080+
1081+ ProfiledKnnFloatVectorQuery profiledQuery = new ProfiledKnnFloatVectorQuery (knnField , vector , k , fanout , queryTimeFilter );
1082+
1083+ Query query = profiledQuery ;
1084+ if (filterStrategy == FilterStrategy .QUERY_TIME_POST_FILTER ) {
1085+ query = new BooleanQuery .Builder ()
1086+ .add (profiledQuery , BooleanClause .Occur .MUST )
1087+ .add (filter , BooleanClause .Occur .FILTER )
1088+ .build ();
1089+ }
10201090 TopDocs docs = searcher .search (query , k );
10211091 return new Result (docs , profiledQuery .totalVectorCount (), 0 );
10221092 }
@@ -1060,7 +1130,7 @@ private int compareNN(int[] expected, int[] results) {
10601130 */
10611131 private int [][] getExactNN (Path docPath , Path indexPath , Path queryPath , int queryStartIndex ) throws IOException , InterruptedException {
10621132 // look in working directory for cached nn file
1063- String hash = Integer .toString (Objects .hash (docPath , indexPath , queryPath , numDocs , numQueryVectors , topK , similarityFunction .ordinal (), parentJoin , queryStartIndex , prefilter ? selectivity : 1f , prefilter ? randomSeed : 0f ), 36 );
1133+ String hash = Integer .toString (Objects .hash (docPath , indexPath , queryPath , numDocs , numQueryVectors , topK , similarityFunction .ordinal (), parentJoin , queryStartIndex , filterSelectivity == null ? 0 : Objects . hash ( filterSelectivity , randomSeed ) ), 36 );
10641134 String nnFileName = "nn-" + hash + ".bin" ;
10651135 Path nnPath = Paths .get (nnFileName );
10661136 if (Files .exists (nnPath ) && isNewer (nnPath , docPath , queryPath )) {
@@ -1164,10 +1234,13 @@ public Void call() {
11641234 try {
11651235 var queryVector = new ConstKnnByteVectorValueSource (query );
11661236 var docVectors = new ByteKnnVectorFieldSource (KNN_FIELD );
1167- var query = new BooleanQuery .Builder ()
1168- .add (new FunctionQuery (new ByteVectorSimilarityFunction (similarityFunction , queryVector , docVectors )), BooleanClause .Occur .SHOULD )
1169- .add (filterQuery , BooleanClause .Occur .FILTER )
1170- .build ();
1237+ Query query = new FunctionQuery (new ByteVectorSimilarityFunction (similarityFunction , queryVector , docVectors ));
1238+ if (filterQuery != null ) {
1239+ query = new BooleanQuery .Builder ()
1240+ .add (query , BooleanClause .Occur .SHOULD )
1241+ .add (filterQuery , BooleanClause .Occur .FILTER )
1242+ .build ();
1243+ }
11711244 var topDocs = searcher .search (query , topK );
11721245 result [queryOrd ] = knn .KnnTesterUtils .getResultIds (topDocs , reader .storedFields ());
11731246 if ((queryOrd + 1 ) % 10 == 0 ) {
@@ -1238,10 +1311,13 @@ public Void call() {
12381311 try {
12391312 var queryVector = new ConstKnnFloatValueSource (query );
12401313 var docVectors = new FloatKnnVectorFieldSource (KNN_FIELD );
1241- var query = new BooleanQuery .Builder ()
1242- .add (new FunctionQuery (new FloatVectorSimilarityFunction (similarityFunction , queryVector , docVectors )), BooleanClause .Occur .SHOULD )
1243- .add (filterQuery , BooleanClause .Occur .FILTER )
1244- .build ();
1314+ Query query = new FunctionQuery (new FloatVectorSimilarityFunction (similarityFunction , queryVector , docVectors ));
1315+ if (filterQuery != null ) {
1316+ query = new BooleanQuery .Builder ()
1317+ .add (query , BooleanClause .Occur .SHOULD )
1318+ .add (filterQuery , BooleanClause .Occur .FILTER )
1319+ .build ();
1320+ }
12451321 var topDocs = searcher .search (query , topK );
12461322 result [queryOrd ] = knn .KnnTesterUtils .getResultIds (topDocs , reader .storedFields ());
12471323 if ((queryOrd + 1 ) % 10 == 0 ) {
0 commit comments