Skip to content

Commit

Permalink
Allow search Index without Gt (#3827)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3827

There are few fixes in this diff which allows us to execute search on an existing index without needing to compare it with ground truth. This has been currently added to only knn search (not range)

Reviewed By: satymish

Differential Revision: D61825365

fbshipit-source-id: ee1e39260ed3480ed32aeeb8d7232e975f56bbfa
  • Loading branch information
kuarora authored and facebook-github-bot committed Sep 5, 2024
1 parent a4ebcb1 commit 202a204
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 34 deletions.
64 changes: 48 additions & 16 deletions benchs/bench_fw/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,8 @@ def build_index_wrapper(self, knn_desc: KnnDescriptor):
if hasattr(knn_desc, "index"):
return

if knn_desc.index_desc.index is not None:
assert knn_desc.index_desc is not None
if hasattr(knn_desc.index_desc, "index"):
knn_desc.index = knn_desc.index_desc.index
knn_desc.index.knn_name = knn_desc.get_name()
knn_desc.index.search_params = knn_desc.search_params
Expand All @@ -359,6 +360,7 @@ def build_index_wrapper(self, knn_desc: KnnDescriptor):
metric=self.distance_metric,
bucket=knn_desc.index_desc.bucket,
index_path=knn_desc.index_desc.path,
index_name=knn_desc.index_desc.get_name(),
# knn_name=knn_desc.get_name(),
search_params=knn_desc.search_params,
)
Expand Down Expand Up @@ -544,15 +546,40 @@ def experiment(parameters, cost_metric, perf_metric):
def knn_search_benchmark(
self, dry_run, results: Dict[str, Any], knn_desc: KnnDescriptor
):
gt_knn_D = None
gt_knn_I = None
if hasattr(self, "gt_knn_D"):
gt_knn_D = self.gt_knn_D
gt_knn_I = self.gt_knn_I

assert hasattr(knn_desc, "index")
if not knn_desc.index.is_flat_index() and gt_knn_I is None:
key = knn_desc.index.get_knn_search_name(
search_parameters=knn_desc.search_params,
query_vectors=knn_desc.query_dataset,
k=knn_desc.k,
reconstruct=False,
)
metrics, requires = knn_desc.index.knn_search(
dry_run,
knn_desc.search_params,
knn_desc.query_dataset,
knn_desc.k,
)[3:]
if requires is not None:
return results, requires
results["experiments"][key] = metrics
return results, requires

return self.search_benchmark(
name="knn_search",
search_func=lambda parameters: knn_desc.index.knn_search(
dry_run,
parameters,
knn_desc.query_dataset,
knn_desc.k,
self.gt_knn_I,
self.gt_knn_D,
gt_knn_I,
gt_knn_D,
)[3:],
key_func=lambda parameters: knn_desc.index.get_knn_search_name(
search_parameters=parameters,
Expand Down Expand Up @@ -634,6 +661,7 @@ class ExecutionOperator:
train_op: Optional[TrainOperator] = None
build_op: Optional[BuildOperator] = None
search_op: Optional[SearchOperator] = None
compute_gt: bool = True

def __post_init__(self):
if self.distance_metric == "IP":
Expand Down Expand Up @@ -698,16 +726,11 @@ def search_one(
faiss.omp_set_num_threads(self.num_threads)
assert self.search_op is not None

if not dry_run:
if not dry_run and self.compute_gt:
self.create_gt_knn(knn_desc)
self.create_range_ref_knn(knn_desc)

self.search_op.build_index_wrapper(knn_desc)
meta, requires = knn_desc.index.fetch_meta(dry_run=dry_run)
if requires is not None:
# return results, (requires if train else None)
return results, requires
results["indices"][knn_desc.index.get_codec_name()] = meta

# results, requires = self.reconstruct_benchmark(
# dry_run=True,
Expand Down Expand Up @@ -766,9 +789,11 @@ def search_one(
ref_index_desc.search_params,
range_metric,
)
gt_rsm = self.search_op.range_ground_truth(
gt_radius, range_search_metric_function
)
gt_rsm = None
if self.compute_gt:
gt_rsm = self.search_op.range_ground_truth(
gt_radius, range_search_metric_function
)
results, requires = self.search_op.range_search_benchmark(
dry_run=True,
results=results,
Expand Down Expand Up @@ -847,9 +872,13 @@ def create_gt_knn(self, knn_desc, search=True) -> Optional[KnnDescriptor]:
if self.search_op:
gt_knn_desc = self.search_op.get_flat_desc(knn_desc.flat_name())
if gt_knn_desc is None:
gt_index_desc = self.build_op.get_flat_desc(
knn_desc.index_desc.flat_name()
)
if knn_desc.index_desc is not None:
gt_index_desc = knn_desc.gt_index_desc
else:
gt_index_desc = self.build_op.get_flat_desc(
knn_desc.index_desc.flat_name()
)
knn_desc.gt_index_desc = gt_index_desc
assert gt_index_desc is not None
gt_knn_desc = KnnDescriptor(
d=knn_desc.d,
Expand Down Expand Up @@ -933,7 +962,10 @@ def execute(self, results: Dict[str, Any], dry_run: False):
if self.search_op is not None:
for desc in self.search_op.knn_descs:
results, requires = self.search_one(
knn_desc=desc, results=results, dry_run=dry_run, range=self.search_op.range
knn_desc=desc,
results=results,
dry_run=dry_run,
range=self.search_op.range,
)
if dry_run:
if requires is None:
Expand Down
64 changes: 46 additions & 18 deletions benchs/bench_fw/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,10 @@ def knn_search(
D_gt=None,
):
logger.info("knn_search: begin")
if search_parameters is not None and search_parameters["snap"] == 1:
if (
search_parameters is not None and
search_parameters.get("snap", 0) == 1
):
query_vectors = self.snap(query_vectors)
filename = (
self.get_knn_search_name(search_parameters, query_vectors, k)
Expand Down Expand Up @@ -322,7 +325,11 @@ def knn_search(
else:
xq = self.io.get_dataset(query_vectors)
(D, I), t, _ = timer("knn_search", lambda: index.search(xq, k))
if self.is_flat() or not hasattr(self, "database_vectors"): # TODO
if (
self.is_flat() or
not hasattr(self, "database_vectors") or
(self.database_vectors is None)
): # TODO
R = D
else:
xq = self.io.get_dataset(query_vectors)
Expand Down Expand Up @@ -352,20 +359,24 @@ def knn_search(
"factory": self.get_model_name(),
"construction_params": self.get_construction_params(),
"search_params": search_parameters,
"knn_intersection": knn_intersection_measure(
I,
I_gt,
)
if I_gt is not None
else None,
"distance_ratio": distance_ratio_measure(
I,
R,
D_gt,
self.metric_type,
)
if D_gt is not None
else None,
"knn_intersection": (
knn_intersection_measure(
I,
I_gt,
)
if I_gt is not None
else None
),
"distance_ratio": (
distance_ratio_measure(
I,
R,
D_gt,
self.metric_type,
)
if D_gt is not None
else None
),
}
logger.info("knn_search: end")
return D, I, R, P, None
Expand Down Expand Up @@ -467,7 +478,10 @@ def range_search(
radius: Optional[float] = None,
):
logger.info("range_search: begin")
if search_parameters is not None and search_parameters.get("snap") == 1:
if (
search_parameters is not None and
search_parameters.get("snap", 0) == 1
):
query_vectors = self.snap(query_vectors)
filename = (
self.get_range_search_name(
Expand Down Expand Up @@ -607,6 +621,12 @@ def get_codec(self):
Index.cached_codec.popitem(last=False)
return Index.cached_codec[codec_name]

def get_model(self):
return self.get_index()

def get_model_name(self):
return self.get_index_name()

def get_codec_name(self) -> Optional[str]:
return self.codec_name

Expand Down Expand Up @@ -709,6 +729,11 @@ def get_operating_points(self):
def add_range_or_val(name, range):
op.add_range(
name,
(
[self.search_params[name]]
if self.search_params and name in self.search_params
else range
),
[self.search_params[name]]
if self.search_params and name in self.search_params
else range,
Expand Down Expand Up @@ -808,7 +833,10 @@ def get_pretransform(self):
return quantizer

def get_model_name(self):
return os.path.basename(self.path)
if self.path is not None:
return os.path.basename(self.path)
else:
return self.get_codec_name()

def fetch_meta(self, dry_run=False):
return None, None
Expand Down

0 comments on commit 202a204

Please sign in to comment.