diff --git a/pecos/ann/hnsw/README.md b/pecos/ann/hnsw/README.md index 79339cd5..38f542dd 100644 --- a/pecos/ann/hnsw/README.md +++ b/pecos/ann/hnsw/README.md @@ -8,6 +8,23 @@ * Supports **thread-safe** graph construction in parallel on multi-core shared memory machines * Supports **thread-safe** Searchers to do inference in parallel, which reduces inference overhead +## Command Line Usage +Basic training (building HNSW index) and predicting (HNSW inference): +```bash +python3 -m pecos.ann.hnsw.train -x ${X_path} -m ${model_folder} +python3 -m pecos.ann.hnsw.predict -x ${Xt_path} -m ${model_folder} -o ${Yp_path} +``` +where +* `X_path` and `Xt_path` are the paths to the CSR npz or Row-majored npy files of the training/test feature matrices with shape `(N,d)` and `(Nt,d)` +* `model_folder` is the path to the model folder where the trained model will be saved to, will be created if not exist +* `Yp_path` is the path to save the prediction label matrix with shape `(Nt, N)` + +For detailed usage, please refer to +```bash +python3 -m pecos.ann.hnsw.train --help +python3 -m pecos.ann.hnsw.predict --help +``` + ## Python API examples #### Prepare data @@ -20,11 +37,15 @@ X_tst = np.random.randn(1000, 100).astype(np.float32) Note that the data type needed to be `np.float32`. #### HNSW Training -Train the HNSW model (i.e., building the graph-based indexing data structure) with maximum number of threads available on your machine (`threads=0`): +Train the HNSW model (i.e., building the graph-based indexing data structure) with maximum number of threads available on your machine (`threads=-1`): ```python from pecos.ann.hnsw import HNSW train_params = HNSW.TrainParams(M=32, efC=300, metric_type="ip", threads=-1) -model = HNSW.train(X_trn, train_params=train_params) +model = HNSW.train(X_trn, train_params=train_params, pred_params=None) +``` +Users are also welcome to train the default parameters via +```python +model = HNSW.train(X_trn) ``` #### HNSW Save and Load @@ -47,11 +68,13 @@ searchers = model.searchers_create(num_searcher=4) Finally, we conduct ANN inference by inputing searchers to the HNSW model. ```python pred_params = HNSW.PredParams(efS=100, topk=10) -indices, distances = model.predict(X_tst, pred_params=pred_params, searchers=searchers, ret_csr=False) +Yt_pred = model.predict(X_tst, pred_params=pred_params, searchers=searchers) ``` +where `Yt_pred` is a `scipy.sparse.csr_matrix` whose column indices for each row are sorted by its distances ascendingly. + Alternatively, it is also feasible to do inference without pre-allocating searchers, which may have larger overhead since it will **re-allocate** intermediate graph-searhing variables for each query matrix `X_tst`. ```python pred_params.threads = 2 indices, distances = model.predict(X_tst, pred_params=pred_params, ret_csr=False) ``` -When `ret_csr=True`, the prediction function will return a single csr matrix that combines the indices and distances numpy array. +When `ret_csr=False`, the prediction function will return the indices and distances numpy array. diff --git a/pecos/ann/hnsw/model.py b/pecos/ann/hnsw/model.py index 932b8c98..de3584da 100644 --- a/pecos/ann/hnsw/model.py +++ b/pecos/ann/hnsw/model.py @@ -34,14 +34,14 @@ class TrainParams(pecos.BaseParams): """Training Parameters of HNSW class Attributes: - M (int): maximum number of edges per node for layer l=1,...,L. For layer l=0, its 2*M. + M (int): maximum number of edges per node for layer l=1,...,L. For layer l=0, its 2*M. Default 32 efC (int): size of the priority queue when performing best first search during construction. Default 100 threads (int): number of threads to use for training HNSW indexer. Default -1 to use all max_level_upper_bound (int): number of maximum layers in the hierarchical graph. Default -1 to ignore metric_type (str): distance metric type, can be "ip" for inner product or "l2" for Euclidean distance """ - M: int = 24 + M: int = 32 efC: int = 100 threads: int = -1 max_level_upper_bound: int = -1 @@ -213,13 +213,13 @@ def get_pred_params(self): """ return copy.deepcopy(self.pred_params) - def predict(self, X, pred_params=None, searchers=None, ret_csr=False): + def predict(self, X, pred_params=None, searchers=None, ret_csr=True): """predict with multi-thread. If searchers are provided, less overhead for online inference. Args: X (nd.array/ScipyDrmF32, scipy.sparse.csr_matrix/ScipyCsrF32): query matrix to be predicted. (num_query x feat_dim). pred_params (HNSW.PredParams, optional): instance of pecos.ann.hnsw.HNSW.PredParams searchers (c_void_p): pointer to C/C++ std::vector. It's an object returned by self.create_searcher(). - ret_csr (bool): if true, the returns will be csr matrix. if false, return indices/distances np.array + ret_csr (bool): if true, the returns will be csr matrix. if false, return indices/distances np.array (default true) Returns: indices (np.array): returned indices array, sorted by smallest-to-largest distances. (num_query x pred_params.topk) distances (np.array): returned dinstances array, sorted by smallest-to-largest distances (num_query x pred_params.topk) diff --git a/pecos/ann/hnsw/predict.py b/pecos/ann/hnsw/predict.py new file mode 100644 index 00000000..875d1f78 --- /dev/null +++ b/pecos/ann/hnsw/predict.py @@ -0,0 +1,139 @@ +# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions +# and limitations under the License. +import argparse +import os +import numpy as np +from pecos.utils import smat_util +from .model import HNSW + + +def parse_arguments(): + """Parse Inference arguments""" + + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument( + "-x", + "--inst-path", + type=str, + required=True, + metavar="PATH", + help="path to the CSR npz or Row-majored npy file of the feature matrix (nr_insts * nr_feats) to be indexed by HNSW", + ) + parser.add_argument( + "-m", + "--model-folder", + type=str, + required=True, + metavar="DIR", + help="path to the model folder to load the HNSW index for inference", + ) + + # Optional + parser.add_argument( + "-efS", + "--efSearch", + type=int, + default=100, + metavar="INT", + help="size of the priority queue when performing best first search during inference. (Default 100)", + ) + parser.add_argument( + "-k", + "--only-topk", + type=int, + default=10, + metavar="INT", + help="maximum number of candidates (sorted by distances, nearest first) to be returned", + ) + parser.add_argument( + "-n", + "--threads", + type=int, + default=-1, + metavar="int", + help="number of threads to use for inference of hnsw indexer (default -1 to use all)", + ) + parser.add_argument( + "-y", + "--label-path", + type=str, + default=None, + metavar="PATH", + help="path to the npz file of the ground truth label matrix (CSR, nr_tst * nr_items)", + ) + parser.add_argument( + "-o", + "--save-pred-path", + type=str, + default=None, + metavar="PATH", + help="path to save the predictions (CSR sorted by distances, nr_tst * nr_items)", + ) + + return parser + + +def do_predict(args): + """Predict and Evaluate for HNSW model + + Args: + args (argparse.Namespace): Command line arguments parsed by `parser.parse_args()` + """ + + # Load data + Xt = smat_util.load_matrix(args.inst_path).astype(np.float32) + + # Load model + model = HNSW.load(args.model_folder) + + # Setup HNSW Searchers for thread-safe inference + threads = os.cpu_count() if args.threads <= 0 else args.threads + searchers = model.searchers_create(num_searcher=threads) + + # Setup prediction params + # pred_params.threads will be overrided if searchers are provided in model.predict() + pred_params = HNSW.PredParams( + efS=args.efSearch, + topk=args.only_topk, + threads=threads, + ) + + # Model Predicting + Yt_pred = model.predict( + Xt, + pred_params=pred_params, + searchers=searchers, + ret_csr=True, + ) + + # Save prediction + if args.save_pred_path: + smat_util.save_matrix(args.save_pred_path, Yt_pred) + + # Evaluate Recallk@k + if args.label_path: + Yt = smat_util.load_matrix(args.label_path) + # assuming ground truth is similarity-based (larger the better) + Yt_topk = smat_util.sorted_csr(Yt, only_topk=args.only_topk) + # assuming prediction matrix is distance-based, so need 1-dist=similiarty + Yt_pred.data = 1.0 - Yt_pred.data + metric = smat_util.Metrics.generate(Yt_topk, Yt_pred, topk=args.only_topk) + print( + "Recall{}@{} {:.6f}%".format(args.only_topk, args.only_topk, 100.0 * metric.recall[-1]) + ) + + +if __name__ == "__main__": + parser = parse_arguments() + args = parser.parse_args() + do_predict(args) diff --git a/pecos/ann/hnsw/train.py b/pecos/ann/hnsw/train.py new file mode 100644 index 00000000..67e756da --- /dev/null +++ b/pecos/ann/hnsw/train.py @@ -0,0 +1,147 @@ +# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions +# and limitations under the License. +import argparse +import os +import numpy as np +from pecos.utils import smat_util +from .model import HNSW + + +def parse_arguments(): + """Parse training arguments""" + + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument( + "-x", + "--inst-path", + type=str, + required=True, + metavar="PATH", + help="path to the CSR npz or Row-majored npy file of the item matrix (nr_items * nr_feats) to be indexed by HNSW", + ) + parser.add_argument( + "-m", + "--model-folder", + type=str, + required=True, + metavar="DIR", + help="path to the model folder that saved the HNSW index", + ) + + # Optional + + # HNSW Indexing parameters + parser.add_argument( + "--metric-type", + type=str, + default="ip", + metavar="STR", + help="distance metric type, can be ip (inner product) or l2 (Euclidean distance), default is set to ip", + ) + parser.add_argument( + "-M", + "--max-edge-per-node", + type=int, + default=32, + metavar="INT", + help="maximum number of edges per node for layer l=1,...,L. For l=0, it becomes 2*M (default 32)", + ) + parser.add_argument( + "-efC", + "--efConstruction", + type=int, + default=100, + metavar="INT", + help="size of the priority queue when performing best first search during construction (default 100)", + ) + parser.add_argument( + "-n", + "--threads", + type=int, + default=-1, + metavar="int", + help="number of threads to use for training and inference of hnsw indexer (default -1 to use all)", + ) + parser.add_argument( + "-L", + "--max-level-upper-bound", + type=int, + default=-1, + metavar="int", + help="number of maximum layers in the hierarchical graph (default -1 to ignore)", + ) + + # HNSW Prediction kwargs + parser.add_argument( + "-efS", + "--efSearch", + type=int, + default=100, + metavar="INT", + help="size of the priority queue when performing best first search during inference (default 100)", + ) + parser.add_argument( + "-k", + "--only-topk", + type=int, + default=10, + metavar="INT", + help="maximum number of candidates (sorted by distances, nearest first) to be returned (default 10)", + ) + + return parser + + +def do_train(args): + """Train and Save HNSW model + + Args: + args (argparse.Namespace): Command line arguments parsed by `parser.parse_args()` + """ + + # Create model folder + if not os.path.exists(args.model_folder): + os.makedirs(args.model_folder) + + # Load training inputs + X = smat_util.load_matrix(args.inst_path).astype(np.float32) + + # Setup training and prediction params + # Note that prediction params can be overrided in inference time + train_params = HNSW.TrainParams( + M=args.max_edge_per_node, + efC=args.efConstruction, + metric_type=args.metric_type, + max_level_upper_bound=args.max_level_upper_bound, + threads=args.threads, + ) + pred_params = HNSW.PredParams( + efS=args.efSearch, + topk=args.only_topk, + threads=args.threads, + ) + + # train and save HNSW indexer + model = HNSW.train( + X, + train_params=train_params, + pred_params=pred_params, + ) + + model.save(args.model_folder) + + +if __name__ == "__main__": + parser = parse_arguments() + args = parser.parse_args() + do_train(args) diff --git a/test/pecos/ann/test_hnsw.py b/test/pecos/ann/test_hnsw.py index 4deeb06b..03d3f156 100644 --- a/test/pecos/ann/test_hnsw.py +++ b/test/pecos/ann/test_hnsw.py @@ -36,12 +36,12 @@ def test_save_and_load(tmpdir): train_params=train_params, pred_params=pred_params, ) - Yp_from_mem, _ = model.predict(X_tst) + Yp_from_mem, _ = model.predict(X_tst, ret_csr=False) model.save(model_folder) del model model = HNSW.load(model_folder) - Yp_from_file, _ = model.predict(X_tst, pred_params=pred_params) + Yp_from_file, _ = model.predict(X_tst, pred_params=pred_params, ret_csr=False) assert Yp_from_mem == approx( Yp_from_file, abs=0.0 ), f"save and load failed: Yp_from_mem != Yp_from_file" @@ -112,3 +112,34 @@ def calc_recall(Y_true, Y_pred): 1.0, abs=1e-2 ), f"hnsw inference failed: data_type=csr, efS={efS}, recall={recall}" del searchers, model + + +def test_cli(tmpdir): + import subprocess + import shlex + + x_trn_path = "test/tst-data/ann/X.trn.l2-normalized.npy" + x_tst_path = "test/tst-data/ann/X.tst.l2-normalized.npy" + model_folder = str(tmpdir.join("hnsw_save_model")) + y_pred_path = str(tmpdir.join("Yt_pred.npz")) + + # train + cmd = [] + cmd += ["python3 -m pecos.ann.hnsw.train"] + cmd += ["-x {}".format(x_trn_path)] + cmd += ["-m {}".format(model_folder)] + process = subprocess.run( + shlex.split(" ".join(cmd)), stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + assert process.returncode == 0, " ".join(cmd) + + # predict + cmd = [] + cmd += ["python3 -m pecos.ann.hnsw.predict"] + cmd += ["-x {}".format(x_tst_path)] + cmd += ["-m {}".format(model_folder)] + cmd += ["-o {}".format(y_pred_path)] + process = subprocess.run( + shlex.split(" ".join(cmd)), stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + assert process.returncode == 0, " ".join(cmd)