-
Notifications
You must be signed in to change notification settings - Fork 104
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enable PECOS HNSW with command line interface
- Loading branch information
1 parent
4cfb9d5
commit 597724e
Showing
5 changed files
with
333 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
# with the License. A copy of the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
# OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
# and limitations under the License. | ||
import argparse | ||
import os | ||
import numpy as np | ||
from pecos.utils import smat_util | ||
from .model import HNSW | ||
|
||
|
||
def parse_arguments(): | ||
"""Parse Inference arguments""" | ||
|
||
parser = argparse.ArgumentParser() | ||
|
||
# Required parameters | ||
parser.add_argument( | ||
"-i", | ||
"--inst-path", | ||
type=str, | ||
required=True, | ||
metavar="PATH", | ||
help="path to the CSR npz or Row-majored npy file of the feature matrix (nr_insts * nr_feats) to be indexed by HNSW", | ||
) | ||
parser.add_argument( | ||
"-m", | ||
"--model-folder", | ||
type=str, | ||
required=True, | ||
metavar="DIR", | ||
help="path to the model folder to load the HNSW index for inference", | ||
) | ||
|
||
# Optional | ||
parser.add_argument( | ||
"-efS", | ||
"--efSearch", | ||
type=int, | ||
default=100, | ||
metavar="INT", | ||
help="size of the priority queue when performing best first search during inference. (Default 100)" | ||
) | ||
parser.add_argument( | ||
"-k", | ||
"--topk", | ||
type=int, | ||
default=10, | ||
metavar="INT", | ||
help="maximum number of candidates (sorted by distances, nearest first) to be returned", | ||
) | ||
parser.add_argument( | ||
"-n", | ||
"--threads", | ||
type=int, | ||
default=-1, | ||
metavar="int", | ||
help= "number of threads to use for inference of hnsw indexer (default -1 to use all)" | ||
) | ||
parser.add_argument( | ||
"-y", | ||
"--label-path", | ||
type=str, | ||
default=None, | ||
metavar="PATH", | ||
help="path to the npz file of the ground truth label matrix (CSR, nr_tst * nr_items)", | ||
) | ||
parser.add_argument( | ||
"-o", | ||
"--save-pred-path", | ||
type=str, | ||
default=None, | ||
metavar="PATH", | ||
help="path to save the predictions (CSR sorted by distances, nr_tst * nr_items)", | ||
) | ||
|
||
return parser | ||
|
||
|
||
def do_predict(args): | ||
"""Predict and Evaluate for HNSW model | ||
Args: | ||
args (argparse.Namespace): Command line arguments parsed by `parser.parse_args()` | ||
""" | ||
|
||
# Load data | ||
Xt = smat_util.load_matrix(args.inst_path).astype(np.float32) | ||
|
||
# Load model | ||
model = HNSW.load(args.model_folder) | ||
|
||
# Setup HNSW Searchers for thread-safe inference | ||
threads = os.cpu_count() if args.threads <= 0 else args.threads | ||
searchers = model.searchers_create(num_searcher=threads) | ||
|
||
# Setup prediction params | ||
# pred_params.threads will be overrided if searchers are provided in model.predict() | ||
pred_params = HNSW.PredParams( | ||
efS=args.efSearch, | ||
topk=args.topk, | ||
threads=threads, | ||
) | ||
|
||
# Model Predicting | ||
Yt_pred = model.predict( | ||
Xt, | ||
pred_params=pred_params, | ||
searchers=searchers, | ||
ret_csr=True, | ||
) | ||
|
||
# Save prediction | ||
if args.save_pred_path: | ||
smat_util.save_matrix(args.save_pred_path, Yt_pred) | ||
|
||
# Evaluate Recallk@k | ||
if args.label_path: | ||
Yt = smat_util.load_matrix(args.label_path) | ||
# assuming ground truth is similarity-based (larger the better) | ||
Yt_topk = smat_util.sorted_csr(Yt, only_topk=args.topk) | ||
# assuming prediction matrix is distance-based, so need 1-dist=similiarty | ||
Yt_pred.data = 1.0 - Yt_pred.data | ||
metric = smat_util.Metrics.generate(Yt_topk, Yt_pred, topk=args.topk) | ||
print("Recall{}@{} {:.6f}%".format(args.topk, args.topk, 100. * metric.recall[-1])) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = parse_arguments() | ||
args = parser.parse_args() | ||
do_predict(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
# with the License. A copy of the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
# OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
# and limitations under the License. | ||
import argparse | ||
import os | ||
import numpy as np | ||
from pecos.utils import smat_util | ||
from .model import HNSW | ||
|
||
|
||
def parse_arguments(): | ||
"""Parse training arguments""" | ||
|
||
parser = argparse.ArgumentParser() | ||
|
||
# Required parameters | ||
parser.add_argument( | ||
"-i", | ||
"--inst-path", | ||
type=str, | ||
required=True, | ||
metavar="PATH", | ||
help="path to the CSR npz or Row-majored npy file of the item matrix (nr_items * nr_feats) to be indexed by HNSW", | ||
) | ||
parser.add_argument( | ||
"-m", | ||
"--model-folder", | ||
type=str, | ||
required=True, | ||
metavar="DIR", | ||
help="path to the model folder that saved the HNSW index", | ||
) | ||
|
||
# Optional | ||
|
||
# HNSW Indexing parameters | ||
parser.add_argument( | ||
"--metric-type", | ||
type=str, | ||
default="ip", | ||
metavar="STR", | ||
help="distance metric type, can be ip (inner product) or l2 (Euclidean distance), default is set to ip", | ||
) | ||
parser.add_argument( | ||
"-maxM", | ||
"--max-edge-per-node", | ||
type=int, | ||
default=32, | ||
metavar="INT", | ||
help="maximum number of edges per node for layer l=1,...,L. For l=0, it becomes 2*M (default 32)" | ||
) | ||
parser.add_argument( | ||
"-efC", | ||
"--efConstruction", | ||
type=int, | ||
default=100, | ||
metavar="INT", | ||
help="size of the priority queue when performing best first search during construction (default 100)" | ||
) | ||
parser.add_argument( | ||
"-n", | ||
"--threads", | ||
type=int, | ||
default=-1, | ||
metavar="int", | ||
help= "number of threads to use for training and inference of hnsw indexer (default -1 to use all)" | ||
) | ||
parser.add_argument( | ||
"-maxL", | ||
"--max-level-upper-bound", | ||
type=int, | ||
default=-1, | ||
metavar="int", | ||
help= "number of maximum layers in the hierarchical graph (default -1 to ignore)", | ||
) | ||
|
||
# HNSW Prediction kwargs | ||
parser.add_argument( | ||
"-efS", | ||
"--efSearch", | ||
type=int, | ||
default=100, | ||
metavar="INT", | ||
help="size of the priority queue when performing best first search during inference (default 100)" | ||
) | ||
parser.add_argument( | ||
"-k", | ||
"--topk", | ||
type=int, | ||
default=10, | ||
metavar="INT", | ||
help="maximum number of candidates (sorted by distances, nearest first) to be returned (default 10)", | ||
) | ||
|
||
return parser | ||
|
||
|
||
def do_train(args): | ||
"""Train and Save HNSW model | ||
Args: | ||
args (argparse.Namespace): Command line arguments parsed by `parser.parse_args()` | ||
""" | ||
|
||
# Create model folder | ||
if not os.path.exists(args.model_folder): | ||
os.makedirs(args.model_folder) | ||
|
||
# Load training inputs | ||
X = smat_util.load_matrix(args.inst_path).astype(np.float32) | ||
|
||
# Setup training and prediction params | ||
# Note that prediction params can be overrided in inference time | ||
train_params = HNSW.TrainParams( | ||
M=args.max_edge_per_node, | ||
efC=args.efConstruction, | ||
metric_type=args.metric_type, | ||
max_level_upper_bound=args.max_level_upper_bound, | ||
threads=args.threads, | ||
) | ||
pred_params = HNSW.PredParams( | ||
efS=args.efSearch, | ||
topk=args.topk, | ||
threads=args.threads, | ||
) | ||
|
||
# train and save HNSW indexer | ||
model = HNSW.train( | ||
X, | ||
train_params=train_params, | ||
pred_params=pred_params, | ||
) | ||
|
||
model.save(args.model_folder) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = parse_arguments() | ||
args = parser.parse_args() | ||
do_train(args) |
Oops, something went wrong.