-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Offline IVF powered by faiss big batch search (#3175)
Summary: This PR introduces the offline IVF (OIVF) framework which contains some tooling to run search using IVFPQ indexes (plus OPQ pretransforms) for large batches of queries using [big_batch_search](https://github.com/mlomeli1/faiss/blob/main/contrib/big_batch_search.py) and GPU faiss. See the [README](https://github.com/mlomeli1/faiss/blob/oivf/demos/offline_ivf/README.md) for details about using this framework. This PR includes the following unit tests, which can be run with the unittest library as so: ```` ~/faiss/demos/offline_ivf$ python3 -m unittest tests/test_iterate_input.py -k test_iterate_back ```` In test_offline_ivf: ```` test_consistency_check test_train_index test_index_shard_equal_file_sizes test_index_shard_unequal_file_sizes test_search test_evaluate_without_margin test_evaluate_without_margin_OPQ test_evaluate_with_margin test_split_batch_size_bigger_than_file_sizes test_split_batch_size_smaller_than_file_sizes test_split_files_with_corrupted_input_file ```` In test_iterate_input: ```` test_iterate_input_file_larger_than_batch test_get_vs_iterate test_iterate_back ```` Pull Request resolved: #3175 Reviewed By: algoriddle Differential Revision: D52218447 Pulled By: mlomeli1 fbshipit-source-id: 78b12457c79b02eb2c9ae993560f2e295798e7e5
- Loading branch information
1 parent
be12427
commit 9a8b34e
Showing
34 changed files
with
2,647 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
|
||
# Offline IVF | ||
|
||
This folder contains the code for the offline ivf algorithm powered by faiss big batch search. | ||
|
||
Create a conda env: | ||
|
||
`conda create --name oivf python=3.10` | ||
|
||
`conda activate oivf` | ||
|
||
`conda install -c pytorch/label/nightly -c nvidia faiss-gpu=1.7.4` | ||
|
||
`conda install tqdm` | ||
|
||
`conda install pyyaml` | ||
|
||
`conda install -c conda-forge submitit` | ||
|
||
|
||
## Run book | ||
|
||
1. Optionally shard your dataset (see create_sharded_dataset.py) and create the corresponding yaml file `config_ssnpp.yaml`. You can use `generate_config.py` by specifying the root directory of your dataset and the files with the data shards | ||
|
||
`python generate_config` | ||
|
||
2. Run the train index command | ||
|
||
`python run.py --command train_index --config config_ssnpp.yaml --xb ssnpp_1B` | ||
|
||
|
||
3. Run the index-shard command so it produces sharded indexes, required for the search step | ||
|
||
`python run.py --command index_shard --config config_ssnpp.yaml --xb ssnpp_1B` | ||
|
||
|
||
6. Send jobs to the cluster to run search | ||
|
||
`python run.py --command search --config config_ssnpp.yaml --xb ssnpp_1B --cluster_run --partition <PARTITION-NAME>` | ||
|
||
|
||
Remarks about the `search` command: it is assumed that the database vectors are the query vectors when performing the search step. | ||
a. If the query vectors are different than the database vectors, it should be passed in the xq argument | ||
b. A new dataset needs to be prepared (step 1) before passing it to the query vectors argument `–xq` | ||
|
||
`python run.py --command search --config config_ssnpp.yaml --xb ssnpp_1B --xq <QUERIES_DATASET_NAME>` | ||
|
||
|
||
6. We can always run the consistency-check for sanity checks! | ||
|
||
`python run.py --command consistency_check--config config_ssnpp.yaml --xb ssnpp_1B` | ||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
d: 256 | ||
output: /checkpoint/marialomeli/offline_faiss/ssnpp | ||
index: | ||
prod: | ||
- 'IVF8192,PQ128' | ||
non-prod: | ||
- 'IVF16384,PQ128' | ||
- 'IVF32768,PQ128' | ||
nprobe: | ||
prod: | ||
- 512 | ||
non-prod: | ||
- 256 | ||
- 128 | ||
- 1024 | ||
- 2048 | ||
- 4096 | ||
- 8192 | ||
|
||
k: 50 | ||
index_shard_size: 50000000 | ||
query_batch_size: 50000000 | ||
evaluation_sample: 10000 | ||
training_sample: 1572864 | ||
datasets: | ||
ssnpp_1B: | ||
root: /checkpoint/marialomeli/ssnpp_data | ||
size: 1000000000 | ||
files: | ||
- dtype: uint8 | ||
format: npy | ||
name: ssnpp_0000000000.npy | ||
size: 50000000 | ||
- dtype: uint8 | ||
format: npy | ||
name: ssnpp_0000000001.npy | ||
size: 50000000 | ||
- dtype: uint8 | ||
format: npy | ||
name: ssnpp_0000000002.npy | ||
size: 50000000 | ||
- dtype: uint8 | ||
format: npy | ||
name: ssnpp_0000000003.npy | ||
size: 50000000 | ||
- dtype: uint8 | ||
format: npy | ||
name: ssnpp_0000000004.npy | ||
size: 50000000 | ||
- dtype: uint8 | ||
format: npy | ||
name: ssnpp_0000000005.npy | ||
size: 50000000 | ||
- dtype: uint8 | ||
format: npy | ||
name: ssnpp_0000000006.npy | ||
size: 50000000 | ||
- dtype: uint8 | ||
format: npy | ||
name: ssnpp_0000000007.npy | ||
size: 50000000 | ||
- dtype: uint8 | ||
format: npy | ||
name: ssnpp_0000000008.npy | ||
size: 50000000 | ||
- dtype: uint8 | ||
format: npy | ||
name: ssnpp_0000000009.npy | ||
size: 50000000 | ||
- dtype: uint8 | ||
format: npy | ||
name: ssnpp_0000000010.npy | ||
size: 50000000 | ||
- dtype: uint8 | ||
format: npy | ||
name: ssnpp_0000000011.npy | ||
size: 50000000 | ||
- dtype: uint8 | ||
format: npy | ||
name: ssnpp_0000000012.npy | ||
size: 50000000 | ||
- dtype: uint8 | ||
format: npy | ||
name: ssnpp_0000000013.npy | ||
size: 50000000 | ||
- dtype: uint8 | ||
format: npy | ||
name: ssnpp_0000000014.npy | ||
size: 50000000 | ||
- dtype: uint8 | ||
format: npy | ||
name: ssnpp_0000000015.npy | ||
size: 50000000 | ||
- dtype: uint8 | ||
format: npy | ||
name: ssnpp_0000000016.npy | ||
size: 50000000 | ||
- dtype: uint8 | ||
format: npy | ||
name: ssnpp_0000000017.npy | ||
size: 50000000 | ||
- dtype: uint8 | ||
format: npy | ||
name: ssnpp_0000000018.npy | ||
size: 50000000 | ||
- dtype: uint8 | ||
format: npy | ||
name: ssnpp_0000000019.npy | ||
size: 50000000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import numpy as np | ||
import argparse | ||
import os | ||
|
||
|
||
def xbin_mmap(fname, dtype, maxn=-1): | ||
""" | ||
Code from | ||
https://github.com/harsha-simhadri/big-ann-benchmarks/blob/main/benchmark/dataset_io.py#L94 | ||
mmap the competition file format for a given type of items | ||
""" | ||
n, d = map(int, np.fromfile(fname, dtype="uint32", count=2)) | ||
assert os.stat(fname).st_size == 8 + n * d * np.dtype(dtype).itemsize | ||
if maxn > 0: | ||
n = min(n, maxn) | ||
return np.memmap(fname, dtype=dtype, mode="r", offset=8, shape=(n, d)) | ||
|
||
|
||
def main(args: argparse.Namespace): | ||
ssnpp_data = xbin_mmap(fname=args.filepath, dtype="uint8") | ||
num_batches = ssnpp_data.shape[0] // args.data_batch | ||
assert ( | ||
ssnpp_data.shape[0] % args.data_batch == 0 | ||
), "num of embeddings per file should divide total num of embeddings" | ||
for i in range(num_batches): | ||
xb_batch = ssnpp_data[ | ||
i * args.data_batch : (i + 1) * args.data_batch, : | ||
] | ||
filename = args.output_dir + f"/ssnpp_{(i):010}.npy" | ||
np.save(filename, xb_batch) | ||
print(f"File {filename} is saved!") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--data_batch", | ||
dest="data_batch", | ||
type=int, | ||
default=50000000, | ||
help="Number of embeddings per file, should be a divisor of 1B", | ||
) | ||
parser.add_argument( | ||
"--filepath", | ||
dest="filepath", | ||
type=str, | ||
default="/datasets01/big-ann-challenge-data/FB_ssnpp/FB_ssnpp_database.u8bin", | ||
help="path of 1B ssnpp database vectors' original file", | ||
) | ||
parser.add_argument( | ||
"--filepath", | ||
dest="output_dir", | ||
type=str, | ||
default="/checkpoint/marialomeli/ssnpp_data", | ||
help="path to put sharded files", | ||
) | ||
|
||
args = parser.parse_args() | ||
main(args) |
Oops, something went wrong.