From 9519a19f4291d5750c4b7c6036ef080522c7c299 Mon Sep 17 00:00:00 2001 From: Gergely Szilvasy Date: Mon, 4 Dec 2023 05:53:17 -0800 Subject: [PATCH] benchmark refactor Summary: 1. Support for index construction parameters outside of the factory string (arbitrary depth of quantizers). 2. Refactor that provides an index wrapper which is a prereq for the optimizer, which will generate indices from pre-optimized components (particularly quantizers) Reviewed By: mdouze Differential Revision: D51427452 fbshipit-source-id: 014d05dd798d856360f2546963e7cad64c2fcaeb --- benchs/bench_fw/benchmark.py | 500 ++++++-------------- benchs/bench_fw/benchmark_io.py | 217 ++++----- benchs/bench_fw/descriptors.py | 46 +- benchs/bench_fw/index.py | 785 ++++++++++++++++++++++++++++++++ benchs/bench_fw_ivf_flat.py | 37 ++ benchs/bench_fw_test.py | 61 +++ contrib/factory_tools.py | 39 +- 7 files changed, 1202 insertions(+), 483 deletions(-) create mode 100644 benchs/bench_fw/index.py create mode 100644 benchs/bench_fw_ivf_flat.py create mode 100644 benchs/bench_fw_test.py diff --git a/benchs/bench_fw/benchmark.py b/benchs/bench_fw/benchmark.py index 0d7f1d8b0c..8ee53103e5 100644 --- a/benchs/bench_fw/benchmark.py +++ b/benchs/bench_fw/benchmark.py @@ -1,23 +1,20 @@ -# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. -from contextlib import contextmanager -import json import logging from dataclasses import dataclass -from multiprocessing.pool import ThreadPool from operator import itemgetter from statistics import median, mean -from time import perf_counter -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional + +from .index import Index, IndexFromCodec, IndexFromFactory from .descriptors import DatasetDescriptor, IndexDescriptor import faiss # @manual=//faiss/python:pyfaiss_gpu from faiss.contrib.evaluation import ( # @manual=//faiss/contrib:faiss_contrib_gpu knn_intersection_measure, - OperatingPointsWithRanges, -) -from faiss.contrib.ivf_tools import ( # @manual=//faiss/contrib:faiss_contrib_gpu - add_preassigned, ) import numpy as np @@ -27,56 +24,21 @@ logger = logging.getLogger(__name__) -@contextmanager -def timer(name) -> float: - logger.info(f"Measuring {name}") - t1 = t2 = perf_counter() - yield lambda: t2 - t1 - t2 = perf_counter() - logger.info(f"Time for {name}: {t2 - t1:.3f} seconds") - - -def refine_distances_knn( - D: np.ndarray, I: np.ndarray, xq: np.ndarray, xb: np.ndarray, metric -): - return np.where( - I >= 0, - np.square(np.linalg.norm(xq[:, None] - xb[I], axis=2)) - if metric == faiss.METRIC_L2 - else np.einsum("qd,qkd->qk", xq, xb[I]), - D, - ) - - -def refine_distances_range( - lims: np.ndarray, - D: np.ndarray, - I: np.ndarray, - xq: np.ndarray, - xb: np.ndarray, - metric, -): - with ThreadPool(32) as pool: - R = pool.map( - lambda i: ( - np.sum(np.square(xq[i] - xb[I[lims[i]:lims[i + 1]]]), axis=1) - if metric == faiss.METRIC_L2 - else np.tensordot( - xq[i], xb[I[lims[i]:lims[i + 1]]], axes=(0, 1) - ) - ) - if lims[i + 1] > lims[i] - else [], - range(len(lims) - 1), - ) - return np.hstack(R) - - def range_search_pr_curve( dist_ann: np.ndarray, metric_score: np.ndarray, gt_rsm: float ): assert dist_ann.shape == metric_score.shape assert dist_ann.ndim == 1 + l = len(dist_ann) + if l == 0: + return { + "dist_ann": [], + "metric_score_sample": [], + "cum_score": [], + "precision": [], + "recall": [], + "unique_key": [], + } sort_by_dist_ann = dist_ann.argsort() dist_ann = dist_ann[sort_by_dist_ann] metric_score = metric_score[sort_by_dist_ann] @@ -87,7 +49,7 @@ def range_search_pr_curve( tbl = np.vstack( [dist_ann, metric_score, cum_score, precision, recall, unique_key] ) - group_by_dist_max_cum_score = np.empty(len(dist_ann), bool) + group_by_dist_max_cum_score = np.empty(l, bool) group_by_dist_max_cum_score[-1] = True group_by_dist_max_cum_score[:-1] = dist_ann[1:] != dist_ann[:-1] tbl = tbl[:, group_by_dist_max_cum_score] @@ -105,39 +67,7 @@ def range_search_pr_curve( } -def set_index_parameter(index, name, val): - index = faiss.downcast_index(index) - - if isinstance(index, faiss.IndexPreTransform): - set_index_parameter(index.index, name, val) - elif name.startswith("quantizer_"): - index_ivf = faiss.extract_index_ivf(index) - set_index_parameter( - index_ivf.quantizer, name[name.find("_") + 1:], val - ) - elif name == "efSearch": - index.hnsw.efSearch - index.hnsw.efSearch = int(val) - elif name == "nprobe": - index_ivf = faiss.extract_index_ivf(index) - index_ivf.nprobe - index_ivf.nprobe = int(val) - elif name == "noop": - pass - else: - raise RuntimeError(f"could not set param {name} on {index}") - - -def optimizer(codec, search, cost_metric, perf_metric): - op = OperatingPointsWithRanges() - op.add_range("noop", [0]) - codec_ivf = faiss.try_extract_index_ivf(codec) - if codec_ivf is not None: - op.add_range( - "nprobe", - [2**i for i in range(12) if 2**i < codec_ivf.nlist * 0.1], - ) - +def optimizer(op, search, cost_metric, perf_metric): totex = op.num_experiments() rs = np.random.RandomState(123) if totex > 1: @@ -243,7 +173,7 @@ def sigmoid(x, a, b, c): cutoff, lambda x: np.where(x < cutoff, sigmoid(x, *popt), 0), popt.tolist(), - list(zip(aradius, ascore, aradius_from, aradius_to, strict=True)) + list(zip(aradius, ascore, aradius_from, aradius_to, strict=True)), ) else: # Assuming that the range_metric is a float, @@ -265,21 +195,20 @@ def sigmoid(x, a, b, c): @dataclass class Benchmark: training_vectors: Optional[DatasetDescriptor] = None - db_vectors: Optional[DatasetDescriptor] = None + database_vectors: Optional[DatasetDescriptor] = None query_vectors: Optional[DatasetDescriptor] = None index_descs: Optional[List[IndexDescriptor]] = None range_ref_index_desc: Optional[str] = None k: Optional[int] = None - distance_metric: str = "METRIC_L2" + distance_metric: str = "L2" def __post_init__(self): - if self.distance_metric == "METRIC_INNER_PRODUCT": + if self.distance_metric == "IP": self.distance_metric_type = faiss.METRIC_INNER_PRODUCT - elif self.distance_metric == "METRIC_L2": + elif self.distance_metric == "L2": self.distance_metric_type = faiss.METRIC_L2 else: raise ValueError - self.cached_index_key = None def set_io(self, benchmark_io): self.io = benchmark_io @@ -292,54 +221,24 @@ def get_index_desc(self, factory: str) -> Optional[IndexDescriptor]: return desc return None - def get_index(self, index_desc: IndexDescriptor): - if self.cached_index_key != index_desc.factory: - xb = self.io.get_dataset(self.db_vectors) - index = faiss.clone_index( - self.io.get_codec(index_desc, xb.shape[1]) - ) - assert index.ntotal == 0 - logger.info("Adding vectors to index") - index_ivf = faiss.try_extract_index_ivf(index) - if index_ivf is not None: - QD, QI, _, QP = self.knn_search( - index_desc, - parameters=None, - db_vectors=None, - query_vectors=self.db_vectors, - k=1, - index=index_ivf.quantizer, - level=1, - ) - print(f"{QI.ravel().shape=}") - add_preassigned(index_ivf, xb, QI.ravel()) - else: - index.add(xb) - assert index.ntotal == xb.shape[0] - logger.info("Added vectors to index") - self.cached_index_key = index_desc.factory - self.cached_index = index - return self.cached_index - - def range_search_reference(self, index_desc, range_metric): + def range_search_reference(self, index, parameters, range_metric): logger.info("range_search_reference: begin") if isinstance(range_metric, list): assert len(range_metric) > 0 - ri = len(range_metric[0]) - 1 m_radius = ( - max(rm[ri] for rm in range_metric) + max(rm[-2] for rm in range_metric) if self.distance_metric_type == faiss.METRIC_L2 - else min(rm[ri] for rm in range_metric) + else min(rm[-2] for rm in range_metric) ) else: m_radius = range_metric lims, D, I, R, P = self.range_search( - index_desc, - index_desc.parameters, + index, + parameters, radius=m_radius, ) - flat = index_desc.factory == "Flat" + flat = index.factory == "Flat" ( gt_radius, range_search_metric_function, @@ -351,111 +250,61 @@ def range_search_reference(self, index_desc, range_metric): R if not flat else None, ) logger.info("range_search_reference: end") - return gt_radius, range_search_metric_function, coefficients, coefficients_training_data + return ( + gt_radius, + range_search_metric_function, + coefficients, + coefficients_training_data, + ) - def estimate_range(self, index_desc, parameters, range_scoring_radius): - D, I, R, P = self.knn_search( - index_desc, parameters, self.db_vectors, self.query_vectors + def estimate_range(self, index, parameters, range_scoring_radius): + D, I, R, P = index.knn_search( + parameters, + self.query_vectors, + self.k, ) samples = [] for i, j in np.argwhere(R < range_scoring_radius): samples.append((R[i, j].item(), D[i, j].item())) - samples.sort(key=itemgetter(0)) - return median(r for _, r in samples[-3:]) + if len(samples) > 0: # estimate range + samples.sort(key=itemgetter(0)) + return median(r for _, r in samples[-3:]) + else: # ensure at least one result + i, j = np.argwhere(R.min() == R)[0] + return D[i, j].item() def range_search( self, - index_desc: IndexDescriptor, - parameters: Optional[dict[str, int]], + index: Index, + search_parameters: Optional[Dict[str, int]], radius: Optional[float] = None, gt_radius: Optional[float] = None, ): logger.info("range_search: begin") - flat = index_desc.factory == "Flat" if radius is None: assert gt_radius is not None radius = ( gt_radius - if flat - else self.estimate_range(index_desc, parameters, gt_radius) + if index.is_flat() + else self.estimate_range( + index, + search_parameters, + gt_radius, + ) ) logger.info(f"Radius={radius}") - filename = self.io.get_filename_range_search( - factory=index_desc.factory, - parameters=parameters, - level=0, - db_vectors=self.db_vectors, + return index.range_search( + search_parameters=search_parameters, query_vectors=self.query_vectors, - r=radius, + radius=radius, ) - if self.io.file_exist(filename): - logger.info(f"Using cached results for {index_desc.factory}") - lims, D, I, R, P = self.io.read_file( - filename, ["lims", "D", "I", "R", "P"] - ) - else: - xq = self.io.get_dataset(self.query_vectors) - index = self.get_index(index_desc) - if parameters: - for name, val in parameters.items(): - set_index_parameter(index, name, val) - - index_ivf = faiss.try_extract_index_ivf(index) - if index_ivf is not None: - QD, QI, _, QP = self.knn_search( - index_desc, - parameters=None, - db_vectors=None, - query_vectors=self.query_vectors, - k=index.nprobe, - index=index_ivf.quantizer, - level=1, - ) - # QD = QD[:, :index.nprobe] - # QI = QI[:, :index.nprobe] - faiss.cvar.indexIVF_stats.reset() - with timer("range_search_preassigned") as t: - lims, D, I = index.range_search_preassigned(xq, radius, QI, QD) - else: - with timer("range_search") as t: - lims, D, I = index.range_search(xq, radius) - if flat: - R = D - else: - xb = self.io.get_dataset(self.db_vectors) - R = refine_distances_range( - lims, D, I, xq, xb, self.distance_metric_type - ) - P = { - "time": t(), - "radius": radius, - "count": lims[-1].item(), - "parameters": parameters, - "index": index_desc.factory, - } - if index_ivf is not None: - stats = faiss.cvar.indexIVF_stats - P |= { - "quantizer": QP, - "nq": stats.nq, - "nlist": stats.nlist, - "ndis": stats.ndis, - "nheap_updates": stats.nheap_updates, - "quantization_time": stats.quantization_time, - "search_time": stats.search_time, - } - self.io.write_file( - filename, ["lims", "D", "I", "R", "P"], [lims, D, I, R, P] - ) - logger.info("range_seach: end") - return lims, D, I, R, P def range_ground_truth(self, gt_radius, range_search_metric_function): logger.info("range_ground_truth: begin") flat_desc = self.get_index_desc("Flat") lims, D, I, R, P = self.range_search( - flat_desc, - flat_desc.parameters, + flat_desc.index, + search_parameters=None, radius=gt_radius, ) gt_rsm = np.sum(range_search_metric_function(R)).item() @@ -464,37 +313,32 @@ def range_ground_truth(self, gt_radius, range_search_metric_function): def range_search_benchmark( self, - results: dict[str, Any], - index_desc: IndexDescriptor, + results: Dict[str, Any], + index: Index, metric_key: str, + radius: float, gt_radius: float, range_search_metric_function, gt_rsm: float, ): - logger.info(f"range_search_benchmark: begin {index_desc.factory=}") - xq = self.io.get_dataset(self.query_vectors) - (nq, d) = xq.shape - logger.info( - f"Searching {index_desc.factory} with {nq} vectors of dimension {d}" - ) - codec = self.io.get_codec(index_desc, d) - faiss.omp_set_num_threads(16) + logger.info(f"range_search_benchmark: begin {index.get_index_name()}") def experiment(parameters, cost_metric, perf_metric): nonlocal results - key = self.io.get_filename_evaluation_name( - factory=index_desc.factory, - parameters=parameters, - level=0, - db_vectors=self.db_vectors, + key = index.get_range_search_name( + search_parameters=parameters, query_vectors=self.query_vectors, - evaluation_name=metric_key, + radius=radius, ) + key += metric_key if key in results["experiments"]: metrics = results["experiments"][key] else: lims, D, I, R, P = self.range_search( - index_desc, parameters, gt_radius=gt_radius + index, + parameters, + radius=radius, + gt_radius=gt_radius, ) range_search_metric = range_search_metric_function(R) range_search_pr = range_search_pr_curve( @@ -511,8 +355,9 @@ def experiment(parameters, cost_metric, perf_metric): for cost_metric in ["time"]: for perf_metric in ["range_score_max_recall"]: + op = index.get_operating_points() optimizer( - codec, + op, experiment, cost_metric, perf_metric, @@ -520,134 +365,33 @@ def experiment(parameters, cost_metric, perf_metric): logger.info("range_search_benchmark: end") return results - def knn_search( - self, - index_desc: IndexDescriptor, - parameters: Optional[dict[str, int]], - db_vectors: Optional[DatasetDescriptor], - query_vectors: DatasetDescriptor, - k: Optional[int] = None, - index: Optional[faiss.Index] = None, - level: int = 0, - ): - assert level >= 0 - if level == 0: - assert index is None - assert db_vectors is not None - else: - assert index is not None # quantizer - assert db_vectors is None - logger.info("knn_seach: begin") - k = k if k is not None else self.k - flat = index_desc.factory == "Flat" - filename = self.io.get_filename_knn_search( - factory=index_desc.factory, - parameters=parameters, - level=level, - db_vectors=db_vectors, - query_vectors=query_vectors, - k=k, - ) - if self.io.file_exist(filename): - logger.info(f"Using cached results for {index_desc.factory}") - D, I, R, P = self.io.read_file(filename, ["D", "I", "R", "P"]) - else: - xq = self.io.get_dataset(query_vectors) - if index is None: - index = self.get_index(index_desc) - if parameters: - for name, val in parameters.items(): - set_index_parameter(index, name, val) - - index_ivf = faiss.try_extract_index_ivf(index) - if index_ivf is not None: - QD, QI, _, QP = self.knn_search( - index_desc, - parameters=None, - db_vectors=None, - query_vectors=query_vectors, - k=index.nprobe, - index=index_ivf.quantizer, - level=level + 1, - ) - # QD = QD[:, :index.nprobe] - # QI = QI[:, :index.nprobe] - faiss.cvar.indexIVF_stats.reset() - with timer("knn search_preassigned") as t: - D, I = index.search_preassigned(xq, k, QI, QD) - else: - with timer("knn search") as t: - D, I = index.search(xq, k) - if flat or level > 0: - R = D - else: - xb = self.io.get_dataset(db_vectors) - R = refine_distances_knn( - D, I, xq, xb, self.distance_metric_type - ) - P = { - "time": t(), - "parameters": parameters, - "index": index_desc.factory, - "level": level, - } - if index_ivf is not None: - stats = faiss.cvar.indexIVF_stats - P |= { - "quantizer": QP, - "nq": stats.nq, - "nlist": stats.nlist, - "ndis": stats.ndis, - "nheap_updates": stats.nheap_updates, - "quantization_time": stats.quantization_time, - "search_time": stats.search_time, - } - self.io.write_file(filename, ["D", "I", "R", "P"], [D, I, R, P]) - logger.info("knn_seach: end") - return D, I, R, P - def knn_ground_truth(self): logger.info("knn_ground_truth: begin") flat_desc = self.get_index_desc("Flat") - self.gt_knn_D, self.gt_knn_I, _, _ = self.knn_search( - flat_desc, - flat_desc.parameters, - self.db_vectors, - self.query_vectors, + self.gt_knn_D, self.gt_knn_I, _, _ = flat_desc.index.knn_search( + search_parameters=None, + query_vectors=self.query_vectors, + k=self.k, ) logger.info("knn_ground_truth: end") - def knn_search_benchmark( - self, results: dict[str, Any], index_desc: IndexDescriptor - ): - logger.info(f"knn_search_benchmark: begin {index_desc.factory=}") - xq = self.io.get_dataset(self.query_vectors) - (nq, d) = xq.shape - logger.info( - f"Searching {index_desc.factory} with {nq} vectors of dimension {d}" - ) - codec = self.io.get_codec(index_desc, d) - codec_ivf = faiss.try_extract_index_ivf(codec) - if codec_ivf is not None: - results["indices"][index_desc.factory] = {"nlist": codec_ivf.nlist} - - faiss.omp_set_num_threads(16) + def knn_search_benchmark(self, results: Dict[str, Any], index: Index): + index_name = index.get_index_name() + logger.info(f"knn_search_benchmark: begin {index_name}") def experiment(parameters, cost_metric, perf_metric): nonlocal results - key = self.io.get_filename_evaluation_name( - factory=index_desc.factory, - parameters=parameters, - level=0, - db_vectors=self.db_vectors, - query_vectors=self.query_vectors, - evaluation_name="knn", + key = index.get_knn_search_name( + parameters, + self.query_vectors, + self.k, ) + key += "knn" if key in results["experiments"]: metrics = results["experiments"][key] else: - D, I, R, P = self.knn_search( - index_desc, parameters, self.db_vectors, self.query_vectors + D, I, R, P = index.knn_search( + parameters, self.query_vectors, self.k ) metrics = P | { "knn_intersection": knn_intersection_measure( @@ -662,8 +406,9 @@ def experiment(parameters, cost_metric, perf_metric): for cost_metric in ["time"]: for perf_metric in ["knn_intersection", "distance_ratio"]: + op = index.get_operating_points() optimizer( - codec, + op, experiment, cost_metric, perf_metric, @@ -671,18 +416,61 @@ def experiment(parameters, cost_metric, perf_metric): logger.info("knn_search_benchmark: end") return results - def benchmark(self) -> str: - logger.info("begin evaluate") - results = {"indices": {}, "experiments": {}} + def train(self, results): + xq = self.io.get_dataset(self.query_vectors) + self.d = xq.shape[1] if self.get_index_desc("Flat") is None: self.index_descs.append(IndexDescriptor(factory="Flat")) + for index_desc in self.index_descs: + if index_desc.factory is not None: + index = IndexFromFactory( + d=self.d, + metric=self.distance_metric, + database_vectors=self.database_vectors, + search_params=index_desc.search_params, + construction_params=index_desc.construction_params, + factory=index_desc.factory, + training_vectors=self.training_vectors, + ) + index.set_io(self.io) + index.train() + index_desc.index = index + results["indices"][index.get_codec_name()] = { + "code_size": index.get_code_size() + } + else: + index = IndexFromCodec( + d=self.d, + metric=self.distance_metric, + database_vectors=self.database_vectors, + search_params=index_desc.search_params, + construction_params=index_desc.construction_params, + path=index_desc.path, + bucket=index_desc.bucket, + ) + index.set_io(self.io) + index_desc.index = index + results["indices"][index.get_codec_name()] = { + "code_size": index.get_code_size() + } + return results + + def benchmark(self, result_file=None): + logger.info("begin evaluate") + + faiss.omp_set_num_threads(24) + results = {"indices": {}, "experiments": {}} + results = self.train(results) + + # knn search self.knn_ground_truth() for index_desc in self.index_descs: results = self.knn_search_benchmark( results=results, - index_desc=index_desc, + index=index_desc.index, ) + # range search if self.range_ref_index_desc is not None: index_desc = self.get_index_desc(self.range_ref_index_desc) if index_desc is None: @@ -700,7 +488,9 @@ def benchmark(self) -> str: range_search_metric_function, coefficients, coefficients_training_data, - ) = self.range_search_reference(index_desc, range_metric) + ) = self.range_search_reference( + index_desc.index, index_desc.search_params, range_metric + ) results["metrics"][metric_key] = { "coefficients": coefficients, "training_data": coefficients_training_data, @@ -709,14 +499,18 @@ def benchmark(self) -> str: gt_radius, range_search_metric_function ) for index_desc in self.index_descs: + if not index_desc.index.supports_range_search(): + continue results = self.range_search_benchmark( results=results, - index_desc=index_desc, + index=index_desc.index, metric_key=metric_key, + radius=index_desc.radius, gt_radius=gt_radius, range_search_metric_function=range_search_metric_function, gt_rsm=gt_rsm, ) - self.io.write_json(results, "result.json", overwrite=True) + if result_file is not None: + self.io.write_json(results, result_file, overwrite=True) logger.info("end evaluate") - return json.dumps(results) + return results diff --git a/benchs/bench_fw/benchmark_io.py b/benchs/bench_fw/benchmark_io.py index 30fda9c726..370efffce5 100644 --- a/benchs/bench_fw/benchmark_io.py +++ b/benchs/bench_fw/benchmark_io.py @@ -1,7 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import hashlib import io import json import logging import os +import pickle from dataclasses import dataclass from typing import Any, List, Optional from zipfile import ZipFile @@ -9,115 +16,45 @@ import faiss # @manual=//faiss/python:pyfaiss_gpu import numpy as np - -from .descriptors import DatasetDescriptor, IndexDescriptor +from faiss.contrib.datasets import ( # @manual=//faiss/contrib:faiss_contrib_gpu + dataset_from_name, +) logger = logging.getLogger(__name__) +# merge RCQ coarse quantizer and ITQ encoder to one Faiss index +def merge_rcq_itq( + # pyre-ignore[11]: `faiss.ResidualCoarseQuantizer` is not defined as a type + rcq_coarse_quantizer: faiss.ResidualCoarseQuantizer, + itq_encoder: faiss.IndexPreTransform, + # pyre-ignore[11]: `faiss.IndexIVFSpectralHash` is not defined as a type. +) -> faiss.IndexIVFSpectralHash: + # pyre-ignore[16]: `faiss` has no attribute `IndexIVFSpectralHash`. + index = faiss.IndexIVFSpectralHash( + rcq_coarse_quantizer, + rcq_coarse_quantizer.d, + rcq_coarse_quantizer.ntotal, + itq_encoder.sa_code_size() * 8, + 1000000, # larger than the magnitude of the vectors + ) + index.replace_vt(itq_encoder) + return index + + @dataclass class BenchmarkIO: path: str def __post_init__(self): self.cached_ds = {} - self.cached_codec_key = None - - def get_filename_search( - self, - factory: str, - parameters: Optional[dict[str, int]], - level: int, - db_vectors: DatasetDescriptor, - query_vectors: DatasetDescriptor, - k: Optional[int] = None, - r: Optional[float] = None, - evaluation_name: Optional[str] = None, - ): - assert factory is not None - assert level is not None - assert self.distance_metric is not None - assert query_vectors is not None - assert self.distance_metric is not None - filename = f"{factory.lower().replace(',', '_')}." - if level > 0: - filename += f"l_{level}." - if db_vectors is not None: - filename += db_vectors.get_filename("d") - filename += query_vectors.get_filename("q") - filename += self.distance_metric.upper() + "." - if k is not None: - filename += f"k_{k}." - if r is not None: - filename += f"r_{int(r * 1000)}." - if parameters is not None: - for name, val in parameters.items(): - if name != "noop": - filename += f"{name}_{val}." - if evaluation_name is None: - filename += "zip" - else: - filename += evaluation_name - return filename - - def get_filename_knn_search( - self, - factory: str, - parameters: Optional[dict[str, int]], - level: int, - db_vectors: DatasetDescriptor, - query_vectors: DatasetDescriptor, - k: int, - ): - assert k is not None - return self.get_filename_search( - factory=factory, - parameters=parameters, - level=level, - db_vectors=db_vectors, - query_vectors=query_vectors, - k=k, - ) - - def get_filename_range_search( - self, - factory: str, - parameters: Optional[dict[str, int]], - level: int, - db_vectors: DatasetDescriptor, - query_vectors: DatasetDescriptor, - r: float, - ): - assert r is not None - return self.get_filename_search( - factory=factory, - parameters=parameters, - level=level, - db_vectors=db_vectors, - query_vectors=query_vectors, - r=r, - ) - - def get_filename_evaluation_name( - self, - factory: str, - parameters: Optional[dict[str, int]], - level: int, - db_vectors: DatasetDescriptor, - query_vectors: DatasetDescriptor, - evaluation_name: str, - ): - assert evaluation_name is not None - return self.get_filename_search( - factory=factory, - parameters=parameters, - level=level, - db_vectors=db_vectors, - query_vectors=query_vectors, - evaluation_name=evaluation_name, - ) def get_local_filename(self, filename): + if len(filename) > 184: + fn, ext = os.path.splitext(filename) + filename = ( + fn[:184] + hashlib.sha256(filename.encode()).hexdigest() + ext + ) return os.path.join(self.path, filename) def download_file_from_blobstore( @@ -143,22 +80,6 @@ def file_exist(self, filename: str): logger.info(f"{filename} {exists=}") return exists - def get_codec(self, index_desc: IndexDescriptor, d: int): - if index_desc.factory == "Flat": - return faiss.IndexFlat(d, self.distance_metric_type) - else: - if self.cached_codec_key != index_desc.factory: - codec = faiss.read_index( - self.get_local_filename(index_desc.path) - ) - assert ( - codec.metric_type == self.distance_metric_type - ), f"{codec.metric_type=} != {self.distance_metric_type=}" - logger.info(f"Loaded codec from {index_desc.path}") - self.cached_codec_key = index_desc.factory - self.cached_codec = codec - return self.cached_codec - def read_file(self, filename: str, keys: List[str]): fn = self.download_file_from_blobstore(filename) logger.info(f"Loading file {fn}") @@ -196,19 +117,50 @@ def write_file( self.upload_file_to_blobstore(filename, overwrite=overwrite) def get_dataset(self, dataset): - if dataset not in self.cached_ds: - self.cached_ds[dataset] = self.read_nparray( - os.path.join(self.path, dataset.tablename) - ) + if dataset.namespace is not None and dataset.namespace[:4] == "std_": + if dataset.tablename not in self.cached_ds: + self.cached_ds[dataset.tablename] = dataset_from_name( + dataset.tablename, + ) + p = dataset.namespace[4] + if p == "t": + return self.cached_ds[dataset.tablename].get_train() + elif p == "d": + return self.cached_ds[dataset.tablename].get_database() + elif p == "q": + return self.cached_ds[dataset.tablename].get_queries() + else: + raise ValueError + elif dataset not in self.cached_ds: + if dataset.namespace == "syn": + d, seed = dataset.tablename.split("_") + d = int(d) + seed = int(seed) + n = dataset.num_vectors + # based on faiss.contrib.datasets.SyntheticDataset + d1 = 10 + rs = np.random.RandomState(seed) + x = rs.normal(size=(n, d1)) + x = np.dot(x, rs.rand(d1, d)) + x = x * (rs.rand(d) * 4 + 0.1) + x = np.sin(x) + x = x.astype(np.float32) + self.cached_ds[dataset] = x + else: + self.cached_ds[dataset] = self.read_nparray( + os.path.join(self.path, dataset.tablename), + mmap_mode="r", + )[: dataset.num_vectors].copy() return self.cached_ds[dataset] def read_nparray( self, filename: str, + mmap_mode: Optional[str] = None, ): fn = self.download_file_from_blobstore(filename) logger.info(f"Loading nparray from {fn}") - nparray = np.load(fn) + nparray = np.load(fn, mmap_mode=mmap_mode) logger.info(f"Loaded nparray {nparray.shape} from {fn}") return nparray @@ -244,3 +196,32 @@ def write_json( with open(fn, "w") as fp: json.dump(json_dict, fp) self.upload_file_to_blobstore(filename, overwrite=overwrite) + + def read_index( + self, + filename: str, + bucket: Optional[str] = None, + path: Optional[str] = None, + ): + fn = self.download_file_from_blobstore(filename, bucket, path) + logger.info(f"Loading index {fn}") + ext = os.path.splitext(fn)[1] + if ext in [".faiss", ".codec"]: + index = faiss.read_index(fn) + elif ext == ".pkl": + with open(fn, "rb") as model_file: + model = pickle.load(model_file) + rcq_coarse_quantizer, itq_encoder = model["model"] + index = merge_rcq_itq(rcq_coarse_quantizer, itq_encoder) + logger.info(f"Loaded index from {fn}") + return index + + def write_index( + self, + index: faiss.Index, + filename: str, + ): + fn = self.get_local_filename(filename) + logger.info(f"Saving index to {fn}") + faiss.write_index(index, fn) + self.upload_file_to_blobstore(filename) diff --git a/benchs/bench_fw/descriptors.py b/benchs/bench_fw/descriptors.py index 0268ec328c..15e5b9330b 100644 --- a/benchs/bench_fw/descriptors.py +++ b/benchs/bench_fw/descriptors.py @@ -1,15 +1,21 @@ -# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional @dataclass class IndexDescriptor: - factory: str bucket: Optional[str] = None + # either path or factory should be set, + # but not both at the same time. path: Optional[str] = None - parameters: Optional[dict[str, int]] = None + factory: Optional[str] = None + construction_params: Optional[List[Dict[str, int]]] = None + search_params: Optional[Dict[str, int]] = None # range metric definitions # key: name # value: one of the following: @@ -25,14 +31,38 @@ class IndexDescriptor: # [[radius1_from, radius1_to, score1], ...] # [radius1_from, radius1_to) -> score1, # [radius2_from, radius2_to) -> score2 - range_metrics: Optional[dict[str, Any]] = None + range_metrics: Optional[Dict[str, Any]] = None + radius: Optional[float] = None @dataclass class DatasetDescriptor: + # namespace possible values: + # 1. a hive namespace + # 2. 'std_t', 'std_d', 'std_q' for the standard datasets + # via faiss.contrib.datasets.dataset_from_name() + # t - training, d - database, q - queries + # eg. "std_t" + # 3. 'syn' for synthetic data + # 4. None for local files namespace: Optional[str] = None + + # tablename possible values, corresponding to the + # namespace value above: + # 1. a hive table name + # 2. name of the standard dataset as recognized + # by faiss.contrib.datasets.dataset_from_name() + # eg. "bigann1M" + # 3. d_seed, eg. 128_1234 for 128 dimensional vectors + # with seed 1234 + # 4. a local file name (relative to benchmark_io.path) tablename: Optional[str] = None + + # partition names and values for hive + # eg. ["ds=2021-09-01"] partitions: Optional[List[str]] = None + + # number of vectors to load from the dataset num_vectors: Optional[int] = None def __hash__(self): @@ -40,9 +70,11 @@ def __hash__(self): def get_filename( self, - prefix: str = "v", + prefix: str = None, ) -> str: - filename = prefix + "_" + filename = "" + if prefix is not None: + filename += prefix + "_" if self.namespace is not None: filename += self.namespace + "_" assert self.tablename is not None diff --git a/benchs/bench_fw/index.py b/benchs/bench_fw/index.py new file mode 100644 index 0000000000..3405f59561 --- /dev/null +++ b/benchs/bench_fw/index.py @@ -0,0 +1,785 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +import os +from dataclasses import dataclass +from multiprocessing.pool import ThreadPool +from time import perf_counter +from typing import ClassVar, Dict, List, Optional + +import faiss # @manual=//faiss/python:pyfaiss_gpu + +import numpy as np +from faiss.contrib.evaluation import ( # @manual=//faiss/contrib:faiss_contrib_gpu + OperatingPointsWithRanges, +) + +from faiss.contrib.factory_tools import ( # @manual=//faiss/contrib:faiss_contrib_gpu + reverse_index_factory, +) +from faiss.contrib.ivf_tools import ( # @manual=//faiss/contrib:faiss_contrib_gpu + add_preassigned, + replace_ivf_quantizer, +) + +from .descriptors import DatasetDescriptor + +logger = logging.getLogger(__name__) + + +def timer(name, func, once=False) -> float: + logger.info(f"Measuring {name}") + t1 = perf_counter() + res = func() + t2 = perf_counter() + t = t2 - t1 + repeat = 1 + if not once and t < 1.0: + repeat = int(2.0 // t) + logger.info( + f"Time for {name}: {t:.3f} seconds, repeating {repeat} times" + ) + t1 = perf_counter() + for _ in range(repeat): + res = func() + t2 = perf_counter() + t = (t2 - t1) / repeat + logger.info(f"Time for {name}: {t:.3f} seconds") + return res, t, repeat + + +def refine_distances_knn( + D: np.ndarray, I: np.ndarray, xq: np.ndarray, xb: np.ndarray, metric +): + return np.where( + I >= 0, + np.square(np.linalg.norm(xq[:, None] - xb[I], axis=2)) + if metric == faiss.METRIC_L2 + else np.einsum("qd,qkd->qk", xq, xb[I]), + D, + ) + + +def refine_distances_range( + lims: np.ndarray, + D: np.ndarray, + I: np.ndarray, + xq: np.ndarray, + xb: np.ndarray, + metric, +): + with ThreadPool(32) as pool: + R = pool.map( + lambda i: ( + np.sum(np.square(xq[i] - xb[I[lims[i]:lims[i + 1]]]), axis=1) + if metric == faiss.METRIC_L2 + else np.tensordot( + xq[i], xb[I[lims[i]:lims[i + 1]]], axes=(0, 1) + ) + ) + if lims[i + 1] > lims[i] + else [], + range(len(lims) - 1), + ) + return np.hstack(R) + + +# The classes below are wrappers around Faiss indices, with different +# implementations for the case when we start with an already trained +# index (IndexFromCodec) vs factory strings (IndexFromFactory). +# In both cases the classes have operations for adding to an index +# and searching it, and outputs are cached on disk. +# IndexFromFactory also decomposes the index (pretransform and quantizer) +# and trains sub-indices independently. +class IndexBase: + def set_io(self, benchmark_io): + self.io = benchmark_io + + @staticmethod + def param_dict_list_to_name(param_dict_list): + if not param_dict_list: + return "" + l = 0 + n = "" + for param_dict in param_dict_list: + n += IndexBase.param_dict_to_name(param_dict, f"cp{l}") + return n + + @staticmethod + def param_dict_to_name(param_dict, prefix="sp"): + if not param_dict: + return "" + n = prefix + for name, val in param_dict.items(): + if name != "noop": + n += f"_{name}_{val}" + if n == prefix: + return "" + n += "." + return n + + @staticmethod + def set_index_param_dict_list(index, param_dict_list): + if not param_dict_list: + return + index = faiss.downcast_index(index) + for param_dict in param_dict_list: + assert index is not None + IndexBase.set_index_param_dict(index, param_dict) + index = faiss.try_extract_index_ivf(index) + + @staticmethod + def set_index_param_dict(index, param_dict): + if not param_dict: + return + for name, val in param_dict.items(): + IndexBase.set_index_param(index, name, val) + + @staticmethod + def set_index_param(index, name, val): + index = faiss.downcast_index(index) + + if isinstance(index, faiss.IndexPreTransform): + Index.set_index_param(index.index, name, val) + elif name == "efSearch": + index.hnsw.efSearch + index.hnsw.efSearch = int(val) + elif name == "efConstruction": + index.hnsw.efConstruction + index.hnsw.efConstruction = int(val) + elif name == "nprobe": + index_ivf = faiss.extract_index_ivf(index) + index_ivf.nprobe + index_ivf.nprobe = int(val) + elif name == "k_factor": + index.k_factor + index.k_factor = int(val) + elif name == "parallel_mode": + index_ivf = faiss.extract_index_ivf(index) + index_ivf.parallel_mode + index_ivf.parallel_mode = int(val) + elif name == "noop": + pass + else: + raise RuntimeError(f"could not set param {name} on {index}") + + def is_flat(self): + codec = faiss.downcast_index(self.get_model()) + return isinstance(codec, faiss.IndexFlat) + + def is_ivf(self): + codec = self.get_model() + return faiss.try_extract_index_ivf(codec) is not None + + def is_pretransform(self): + codec = self.get_model() + if isinstance(codec, faiss.IndexRefine): + codec = faiss.downcast_index(codec.base_index) + return isinstance(codec, faiss.IndexPreTransform) + + # index is a codec + database vectors + # in other words: a trained Faiss index + # that contains database vectors + def get_index_name(self): + raise NotImplementedError + + def get_index(self): + raise NotImplementedError + + # codec is a trained model + # in other words: a trained Faiss index + # without any database vectors + def get_codec_name(self): + raise NotImplementedError + + def get_codec(self): + raise NotImplementedError + + # model is an untrained Faiss index + # it can be used for training (see codec) + # or to inspect its structure + def get_model_name(self): + raise NotImplementedError + + def get_model(self): + raise NotImplementedError + + def transform(self, vectors): + transformed_vectors = DatasetDescriptor( + tablename=f"{vectors.get_filename()}{self.get_codec_name()}transform.npy" + ) + if not self.io.file_exist(transformed_vectors.tablename): + codec = self.fetch_codec() + assert isinstance(codec, faiss.IndexPreTransform) + transform = faiss.downcast_VectorTransform(codec.chain.at(0)) + x = self.io.get_dataset(vectors) + xt = transform.apply(x) + self.io.write_nparray(xt, transformed_vectors.tablename) + return transformed_vectors + + def knn_search_quantizer(self, index, query_vectors, k): + if self.is_pretransform(): + pretransform = self.get_pretransform() + quantizer_query_vectors = pretransform.transform(query_vectors) + else: + pretransform = None + quantizer_query_vectors = query_vectors + + QD, QI, _, QP = self.get_quantizer(pretransform).knn_search( + search_parameters=None, + query_vectors=quantizer_query_vectors, + k=k, + ) + xqt = self.io.get_dataset(quantizer_query_vectors) + return xqt, QD, QI, QP + + def get_knn_search_name( + self, + search_parameters: Optional[Dict[str, int]], + query_vectors: DatasetDescriptor, + k: int, + ): + name = self.get_index_name() + name += Index.param_dict_to_name(search_parameters) + name += query_vectors.get_filename("q") + name += f"k_{k}." + return name + + def knn_search( + self, + search_parameters: Optional[Dict[str, int]], + query_vectors: DatasetDescriptor, + k: int, + ): + logger.info("knn_seach: begin") + filename = ( + self.get_knn_search_name(search_parameters, query_vectors, k) + + "zip" + ) + if self.io.file_exist(filename): + logger.info(f"Using cached results for {filename}") + D, I, R, P = self.io.read_file(filename, ["D", "I", "R", "P"]) + else: + xq = self.io.get_dataset(query_vectors) + index = self.get_index() + Index.set_index_param_dict(index, search_parameters) + + if self.is_ivf(): + xqt, QD, QI, QP = self.knn_search_quantizer( + index, query_vectors, search_parameters["nprobe"] + ) + index_ivf = faiss.extract_index_ivf(index) + if index_ivf.parallel_mode != 2: + logger.info("Setting IVF parallel mode") + index_ivf.parallel_mode = 2 + + (D, I), t, repeat = timer( + "knn_search_preassigned", + lambda: index_ivf.search_preassigned(xqt, k, QI, QD), + ) + else: + (D, I), t, _ = timer("knn_search", lambda: index.search(xq, k)) + if self.is_flat() or not hasattr(self, "database_vectors"): # TODO + R = D + else: + xb = self.io.get_dataset(self.database_vectors) + R = refine_distances_knn(D, I, xq, xb, self.metric_type) + P = { + "time": t, + "index": self.get_index_name(), + "codec": self.get_codec_name(), + "factory": self.factory if hasattr(self, "factory") else "", + "search_params": search_parameters, + "k": k, + } + if self.is_ivf(): + stats = faiss.cvar.indexIVF_stats + P |= { + "quantizer": QP, + "nq": int(stats.nq // repeat), + "nlist": int(stats.nlist // repeat), + "ndis": int(stats.ndis // repeat), + "nheap_updates": int(stats.nheap_updates // repeat), + "quantization_time": int( + stats.quantization_time // repeat + ), + "search_time": int(stats.search_time // repeat), + } + self.io.write_file(filename, ["D", "I", "R", "P"], [D, I, R, P]) + logger.info("knn_seach: end") + return D, I, R, P + + def range_search( + self, + search_parameters: Optional[Dict[str, int]], + query_vectors: DatasetDescriptor, + radius: Optional[float] = None, + ): + logger.info("range_search: begin") + filename = ( + self.get_range_search_name( + search_parameters, query_vectors, radius + ) + + "zip" + ) + if self.io.file_exist(filename): + logger.info(f"Using cached results for {filename}") + lims, D, I, R, P = self.io.read_file( + filename, ["lims", "D", "I", "R", "P"] + ) + else: + xq = self.io.get_dataset(query_vectors) + index = self.get_index() + Index.set_index_param_dict(index, search_parameters) + + if self.is_ivf(): + xqt, QD, QI, QP = self.knn_search_quantizer( + index, query_vectors, search_parameters["nprobe"] + ) + index_ivf = faiss.extract_index_ivf(index) + if index_ivf.parallel_mode != 2: + logger.info("Setting IVF parallel mode") + index_ivf.parallel_mode = 2 + + (lims, D, I), t, repeat = timer( + "range_search_preassigned", + lambda: index_ivf.range_search_preassigned( + xqt, radius, QI, QD + ), + ) + else: + (lims, D, I), t, _ = timer( + "range_search", lambda: index.range_search(xq, radius) + ) + if self.is_flat(): + R = D + else: + xb = self.io.get_dataset(self.database_vectors) + R = refine_distances_range( + lims, D, I, xq, xb, self.metric_type + ) + P = { + "time": t, + "index": self.get_codec_name(), + "codec": self.get_codec_name(), + "search_params": search_parameters, + "radius": radius, + "count": len(I), + } + if self.is_ivf(): + stats = faiss.cvar.indexIVF_stats + P |= { + "quantizer": QP, + "nq": int(stats.nq // repeat), + "nlist": int(stats.nlist // repeat), + "ndis": int(stats.ndis // repeat), + "nheap_updates": int(stats.nheap_updates // repeat), + "quantization_time": int( + stats.quantization_time // repeat + ), + "search_time": int(stats.search_time // repeat), + } + self.io.write_file( + filename, ["lims", "D", "I", "R", "P"], [lims, D, I, R, P] + ) + logger.info("range_seach: end") + return lims, D, I, R, P + + +# Common base for IndexFromCodec and IndexFromFactory, +# but not for the sub-indices of codec-based indices +# IndexFromQuantizer and IndexFromPreTransform, because +# they share the configuration of their parent IndexFromCodec +@dataclass +class Index(IndexBase): + d: int + metric: str + database_vectors: DatasetDescriptor + construction_params: List[Dict[str, int]] + search_params: Dict[str, int] + + cached_codec_name: ClassVar[str] = None + cached_codec: ClassVar[faiss.Index] = None + cached_index_name: ClassVar[str] = None + cached_index: ClassVar[faiss.Index] = None + + def __post_init__(self): + if isinstance(self.metric, str): + if self.metric == "IP": + self.metric_type = faiss.METRIC_INNER_PRODUCT + elif self.metric == "L2": + self.metric_type = faiss.METRIC_L2 + else: + raise ValueError + elif isinstance(self.metric, int): + self.metric_type = self.metric + if self.metric_type == faiss.METRIC_INNER_PRODUCT: + self.metric = "IP" + elif self.metric_type == faiss.METRIC_L2: + self.metric = "L2" + else: + raise ValueError + else: + raise ValueError + + def supports_range_search(self): + codec = self.get_codec() + return not type(codec) in [ + faiss.IndexHNSWFlat, + faiss.IndexIVFFastScan, + faiss.IndexRefine, + faiss.IndexPQ, + ] + + def fetch_codec(self): + raise NotImplementedError + + def train(self): + # get triggers a train, if necessary + self.get_codec() + + def get_codec(self): + codec_name = self.get_codec_name() + if Index.cached_codec_name != codec_name: + Index.cached_codec = self.fetch_codec() + Index.cached_codec_name = codec_name + return Index.cached_codec + + def get_index_name(self): + name = self.get_codec_name() + assert self.database_vectors is not None + name += self.database_vectors.get_filename("xb") + return name + + def fetch_index(self): + index = faiss.clone_index(self.get_codec()) + assert index.ntotal == 0 + logger.info("Adding vectors to index") + xb = self.io.get_dataset(self.database_vectors) + + if self.is_ivf(): + xbt, QD, QI, QP = self.knn_search_quantizer( + index, self.database_vectors, 1 + ) + index_ivf = faiss.extract_index_ivf(index) + if index_ivf.parallel_mode != 2: + logger.info("Setting IVF parallel mode") + index_ivf.parallel_mode = 2 + + _, t, _ = timer( + "add_preassigned", + lambda: add_preassigned(index_ivf, xbt, QI.ravel()), + once=True, + ) + else: + _, t, _ = timer( + "add", + lambda: index.add(xb), + once=True, + ) + assert index.ntotal == xb.shape[0] or index_ivf.ntotal == xb.shape[0] + logger.info("Added vectors to index") + return index + + def get_index(self): + index_name = self.get_index_name() + if Index.cached_index_name != index_name: + Index.cached_index = self.fetch_index() + Index.cached_index_name = index_name + return Index.cached_index + + def get_code_size(self): + def get_index_code_size(index): + index = faiss.downcast_index(index) + if isinstance(index, faiss.IndexPreTransform): + return get_index_code_size(index.index) + elif isinstance(index, faiss.IndexHNSWFlat): + return index.d * 4 # TODO + elif type(index) in [faiss.IndexRefine, faiss.IndexRefineFlat]: + return get_index_code_size( + index.base_index + ) + get_index_code_size(index.refine_index) + else: + return index.code_size + + codec = self.get_codec() + return get_index_code_size(codec) + + def get_operating_points(self): + op = OperatingPointsWithRanges() + + def add_range_or_val(name, range): + op.add_range( + name, + [self.search_params[name]] + if self.search_params and name in self.search_params + else range, + ) + + op.add_range("noop", [0]) + codec = faiss.downcast_index(self.get_codec()) + codec_ivf = faiss.try_extract_index_ivf(codec) + if codec_ivf is not None: + add_range_or_val( + "nprobe", + [ + 2**i + for i in range(12) + if 2**i <= codec_ivf.nlist * 0.25 + ], + ) + if isinstance(codec, faiss.IndexRefine): + add_range_or_val( + "k_factor", + [2**i for i in range(11)], + ) + if isinstance(codec, faiss.IndexHNSWFlat): + add_range_or_val( + "efSearch", + [2**i for i in range(3, 11)], + ) + return op + + def get_range_search_name( + self, + search_parameters: Optional[Dict[str, int]], + query_vectors: DatasetDescriptor, + radius: Optional[float] = None, + ): + name = self.get_index_name() + name += Index.param_dict_to_name(search_parameters) + name += query_vectors.get_filename("q") + if radius is not None: + name += f"r_{int(radius * 1000)}." + else: + name += "r_auto." + return name + + +# IndexFromCodec, IndexFromQuantizer and IndexFromPreTransform +# are used to wrap pre-trained Faiss indices (codecs) +@dataclass +class IndexFromCodec(Index): + path: str + bucket: Optional[str] = None + + def get_quantizer(self): + if not self.is_ivf(): + raise ValueError("Not an IVF index") + quantizer = IndexFromQuantizer(self) + quantizer.set_io(self.io) + return quantizer + + def get_pretransform(self): + if not self.is_ivf(): + raise ValueError("Not an IVF index") + quantizer = IndexFromPreTransform(self) + quantizer.set_io(self.io) + return quantizer + + def get_codec_name(self): + assert self.path is not None + name = os.path.basename(self.path) + name += Index.param_dict_list_to_name(self.construction_params) + return name + + def fetch_codec(self): + codec = self.io.read_index( + os.path.basename(self.path), + self.bucket, + os.path.dirname(self.path), + ) + assert self.d == codec.d + assert self.metric_type == codec.metric_type + Index.set_index_param_dict_list(codec, self.construction_params) + return codec + + def get_model(self): + return self.get_codec() + + +class IndexFromQuantizer(IndexBase): + ivf_index: Index + + def __init__(self, ivf_index: Index): + self.ivf_index = ivf_index + super().__init__() + + def get_codec_name(self): + return self.get_index_name() + + def get_codec(self): + return self.get_index() + + def get_index_name(self): + ivf_codec_name = self.ivf_index.get_codec_name() + return f"{ivf_codec_name}quantizer." + + def get_index(self): + ivf_codec = faiss.extract_index_ivf(self.ivf_index.get_codec()) + return ivf_codec.quantizer + + +class IndexFromPreTransform(IndexBase): + pre_transform_index: Index + + def __init__(self, pre_transform_index: Index): + self.pre_transform_index = pre_transform_index + super().__init__() + + def get_codec_name(self): + pre_transform_codec_name = self.pre_transform_index.get_codec_name() + return f"{pre_transform_codec_name}pretransform." + + def get_codec(self): + return self.get_codec() + + +# IndexFromFactory is for creating and training indices from scratch +@dataclass +class IndexFromFactory(Index): + factory: str + training_vectors: DatasetDescriptor + + def get_codec_name(self): + assert self.factory is not None + name = f"{self.factory.replace(',', '_')}." + assert self.d is not None + assert self.metric is not None + name += f"d_{self.d}.{self.metric.upper()}." + if self.factory != "Flat": + assert self.training_vectors is not None + name += self.training_vectors.get_filename("xt") + name += Index.param_dict_list_to_name(self.construction_params) + return name + + def fetch_codec(self): + codec_filename = self.get_codec_name() + "codec" + if self.io.file_exist(codec_filename): + codec = self.io.read_index(codec_filename) + assert self.d == codec.d + assert self.metric_type == codec.metric_type + else: + codec = self.assemble() + if self.factory != "Flat": + self.io.write_index(codec, codec_filename) + return codec + + def get_model(self): + model = faiss.index_factory(self.d, self.factory, self.metric_type) + Index.set_index_param_dict_list(model, self.construction_params) + return model + + def get_pretransform(self): + model = faiss.index_factory(self.d, self.factory, self.metric_type) + assert isinstance(model, faiss.IndexPreTransform) + sub_index = faiss.downcast_index(model.index) + if isinstance(sub_index, faiss.IndexFlat): + return self + # replace the sub-index with Flat + codec = faiss.clone_index(model) + codec.index = faiss.IndexFlat(codec.index.d, codec.index.metric_type) + pretransform = IndexFromFactory( + d=codec.d, + metric=codec.metric_type, + database_vectors=self.database_vectors, + construction_params=self.construction_params, + search_params=self.search_params, + factory=reverse_index_factory(codec), + training_vectors=self.training_vectors, + ) + pretransform.set_io(self.io) + return pretransform + + def get_quantizer(self, pretransform=None): + model = self.get_model() + model_ivf = faiss.extract_index_ivf(model) + assert isinstance(model_ivf, faiss.IndexIVF) + assert ord(model_ivf.quantizer_trains_alone) in [0, 2] + if pretransform is None: + training_vectors = self.training_vectors + else: + training_vectors = pretransform.transform(self.training_vectors) + centroids = self.k_means(training_vectors, model_ivf.nlist) + quantizer = IndexFromFactory( + d=model_ivf.quantizer.d, + metric=model_ivf.quantizer.metric_type, + database_vectors=centroids, + construction_params=None, # self.construction_params[1:], + search_params=None, # self.construction_params[0], # TODO: verify + factory=reverse_index_factory(model_ivf.quantizer), + training_vectors=centroids, + ) + quantizer.set_io(self.io) + return quantizer + + def k_means(self, vectors, k): + kmeans_vectors = DatasetDescriptor( + tablename=f"{vectors.get_filename()}kmeans_{k}.npy" + ) + if not self.io.file_exist(kmeans_vectors.tablename): + x = self.io.get_dataset(vectors) + kmeans = faiss.Kmeans(d=x.shape[1], k=k, gpu=True) + kmeans.train(x) + self.io.write_nparray(kmeans.centroids, kmeans_vectors.tablename) + return kmeans_vectors + + def assemble(self): + model = self.get_model() + codec = faiss.clone_index(model) + if isinstance(model, faiss.IndexPreTransform): + sub_index = faiss.downcast_index(model.index) + if not isinstance(sub_index, faiss.IndexFlat): + # replace the sub-index with Flat and fetch pre-trained + pretransform = self.get_pretransform() + codec = pretransform.fetch_codec() + assert codec.is_trained + transformed_training_vectors = pretransform.transform( + self.training_vectors + ) + transformed_database_vectors = pretransform.transform( + self.database_vectors + ) + # replace the Flat index with the required sub-index + wrapper = IndexFromFactory( + d=sub_index.d, + metric=sub_index.metric_type, + database_vectors=transformed_database_vectors, + construction_params=self.construction_params, + search_params=self.search_params, + factory=reverse_index_factory(sub_index), + training_vectors=transformed_training_vectors, + ) + wrapper.set_io(self.io) + codec.index = wrapper.fetch_codec() + assert codec.index.is_trained + elif isinstance(model, faiss.IndexIVF): + # replace the quantizer + quantizer = self.get_quantizer() + replace_ivf_quantizer(codec, quantizer.fetch_index()) + assert codec.quantizer.is_trained + assert codec.nlist == codec.quantizer.ntotal + elif isinstance(model, faiss.IndexRefine) or isinstance( + model, faiss.IndexRefineFlat + ): + # replace base_index + wrapper = IndexFromFactory( + d=model.base_index.d, + metric=model.base_index.metric_type, + database_vectors=self.database_vectors, + construction_params=self.construction_params, + search_params=self.search_params, + factory=reverse_index_factory(model.base_index), + training_vectors=self.training_vectors, + ) + wrapper.set_io(self.io) + codec.base_index = wrapper.fetch_codec() + assert codec.base_index.is_trained + + xt = self.io.get_dataset(self.training_vectors) + codec.train(xt) + return codec diff --git a/benchs/bench_fw_ivf_flat.py b/benchs/bench_fw_ivf_flat.py new file mode 100644 index 0000000000..37b4bd7862 --- /dev/null +++ b/benchs/bench_fw_ivf_flat.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +from bench_fw.benchmark import Benchmark +from bench_fw.benchmark_io import BenchmarkIO +from bench_fw.descriptors import DatasetDescriptor, IndexDescriptor + +logging.basicConfig(level=logging.INFO) + +benchmark = Benchmark( + training_vectors=DatasetDescriptor( + namespace="std_d", tablename="sift1M" + ), + database_vectors=DatasetDescriptor( + namespace="std_d", tablename="sift1M" + ), + query_vectors=DatasetDescriptor( + namespace="std_q", tablename="sift1M" + ), + index_descs=[ + IndexDescriptor( + factory=f"IVF{2 ** nlist},Flat", + ) + for nlist in range(8, 15) + ], + k=1, + distance_metric="L2", +) +io = BenchmarkIO( + path="/checkpoint", +) +benchmark.set_io(io) +print(benchmark.benchmark("result.json")) diff --git a/benchs/bench_fw_test.py b/benchs/bench_fw_test.py new file mode 100644 index 0000000000..55b9e16e65 --- /dev/null +++ b/benchs/bench_fw_test.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +from bench_fw.benchmark import Benchmark +from bench_fw.benchmark_io import BenchmarkIO +from bench_fw.descriptors import DatasetDescriptor, IndexDescriptor + +logging.basicConfig(level=logging.INFO) + +benchmark = Benchmark( + training_vectors=DatasetDescriptor( + tablename="training.npy", num_vectors=200000 + ), + database_vectors=DatasetDescriptor( + tablename="database.npy", num_vectors=200000 + ), + query_vectors=DatasetDescriptor(tablename="query.npy", num_vectors=2000), + index_descs=[ + IndexDescriptor( + factory="Flat", + range_metrics={ + "weighted": [ + [0.1, 0.928], + [0.2, 0.865], + [0.3, 0.788], + [0.4, 0.689], + [0.5, 0.49], + [0.6, 0.308], + [0.7, 0.193], + [0.8, 0.0], + ] + }, + ), + IndexDescriptor( + factory="OPQ32_128,IVF512,PQ32", + ), + IndexDescriptor( + factory="OPQ32_256,IVF512,PQ32", + ), + IndexDescriptor( + factory="HNSW32", + construction_params=[ + { + "efConstruction": 64, + } + ], + ), + ], + k=10, + distance_metric="L2", + range_ref_index_desc="Flat", +) +io = BenchmarkIO( + path="/checkpoint", +) +benchmark.set_io(io) +print(benchmark.benchmark("result.json")) diff --git a/contrib/factory_tools.py b/contrib/factory_tools.py index 9623ad55f4..da90e986f8 100644 --- a/contrib/factory_tools.py +++ b/contrib/factory_tools.py @@ -72,6 +72,9 @@ def get_code_size(d, indexkey): raise RuntimeError("cannot parse " + indexkey) +def get_hnsw_M(index): + return index.hnsw.cum_nneighbor_per_level.at(1) // 2 + def reverse_index_factory(index): """ @@ -80,21 +83,47 @@ def reverse_index_factory(index): index = faiss.downcast_index(index) if isinstance(index, faiss.IndexFlat): return "Flat" - if isinstance(index, faiss.IndexIVF): + elif isinstance(index, faiss.IndexIVF): quantizer = faiss.downcast_index(index.quantizer) if isinstance(quantizer, faiss.IndexFlat): - prefix = "IVF%d" % index.nlist + prefix = f"IVF{index.nlist}" elif isinstance(quantizer, faiss.MultiIndexQuantizer): - prefix = "IMI%dx%d" % (quantizer.pq.M, quantizer.pq.nbit) + prefix = f"IMI{quantizer.pq.M}x{quantizer.pq.nbits}" elif isinstance(quantizer, faiss.IndexHNSW): - prefix = "IVF%d_HNSW%d" % (index.nlist, quantizer.hnsw.M) + prefix = f"IVF{index.nlist}_HNSW{get_hnsw_M(quantizer)}" else: - prefix = "IVF%d(%s)" % (index.nlist, reverse_index_factory(quantizer)) + prefix = f"IVF{index.nlist}({reverse_index_factory(quantizer)})" if isinstance(index, faiss.IndexIVFFlat): return prefix + ",Flat" if isinstance(index, faiss.IndexIVFScalarQuantizer): return prefix + ",SQ8" + if isinstance(index, faiss.IndexIVFPQ): + return prefix + f",PQ{index.pq.M}x{index.pq.nbits}" + + elif isinstance(index, faiss.IndexPreTransform): + assert index.chain.size() == 1 + vt = faiss.downcast_VectorTransform(index.chain.at(0)) + if isinstance(vt, faiss.OPQMatrix): + return f"OPQ{vt.M}_{vt.d_out},{reverse_index_factory(index.index)}" + + elif isinstance(index, faiss.IndexHNSW): + return f"HNSW{get_hnsw_M(index)}" + + elif isinstance(index, faiss.IndexRefine): + return f"{reverse_index_factory(index.base_index)},Refine({reverse_index_factory(index.refine_index)})" + + elif isinstance(index, faiss.IndexPQFastScan): + return f"PQ{index.pq.M}x{index.pq.nbits}fs" + + elif isinstance(index, faiss.IndexScalarQuantizer): + sqtypes = { + faiss.ScalarQuantizer.QT_8bit: "8", + faiss.ScalarQuantizer.QT_4bit: "4", + faiss.ScalarQuantizer.QT_6bit: "6", + faiss.ScalarQuantizer.QT_fp16: "fp16", + } + return f"SQ{sqtypes[index.sq.qtype]}" raise NotImplementedError()