forked from facebookresearch/faiss
-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
benchmark view results (facebookresearch#3144)
Summary: Pull Request resolved: facebookresearch#3144 Visualize results of running the benchmark with Pareto optima filtering: 1. per index or across indices 2. for space, time or space & time 3. knn or range search, the latter @ specific precision Reviewed By: mdouze Differential Revision: D51552775 fbshipit-source-id: d4f29e3d46ef044e71b54439b3972548c86af5a7
- Loading branch information
1 parent
9519a19
commit 4c83965
Showing
1 changed file
with
289 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,289 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "be081589-e1b2-4569-acb7-44203e273899", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import matplotlib.pyplot as plt\n", | ||
"import itertools\n", | ||
"from faiss.contrib.evaluation import OperatingPoints\n", | ||
"from enum import Enum\n", | ||
"from bench_fw.benchmark_io import BenchmarkIO as BIO" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "a6492e95-24c7-4425-bf0a-27e10e879ca6", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"root = \"/checkpoint\"\n", | ||
"results = BIO(root).read_json(\"result.json\")\n", | ||
"results.keys()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "0875d269-aef4-426d-83dd-866970f43777", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"results['indices']" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "a7ff7078-29c7-407c-a079-201877b764ad", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class Cost:\n", | ||
" def __init__(self, values):\n", | ||
" self.values = values\n", | ||
"\n", | ||
" def __le__(self, other):\n", | ||
" return all(v1 <= v2 for v1, v2 in zip(self.values, other.values, strict=True))\n", | ||
"\n", | ||
" def __lt__(self, other):\n", | ||
" return all(v1 < v2 for v1, v2 in zip(self.values, other.values, strict=True))\n", | ||
"\n", | ||
"class ParetoMode(Enum):\n", | ||
" DISABLE = 1 # no Pareto filtering\n", | ||
" INDEX = 2 # index-local optima\n", | ||
" GLOBAL = 3 # global optima\n", | ||
"\n", | ||
"\n", | ||
"class ParetoMetric(Enum):\n", | ||
" TIME = 0 # time vs accuracy\n", | ||
" SPACE = 1 # space vs accuracy\n", | ||
" TIME_SPACE = 2 # (time, space) vs accuracy\n", | ||
"\n", | ||
"def range_search_recall_at_precision(experiment, precision):\n", | ||
" return round(max(r for r, p in zip(experiment['range_search_pr']['recall'], experiment['range_search_pr']['precision']) if p > precision), 6)\n", | ||
"\n", | ||
"def filter_results(\n", | ||
" results,\n", | ||
" evaluation,\n", | ||
" accuracy_metric, # str or func\n", | ||
" time_metric=None, # func or None -> use default\n", | ||
" space_metric=None, # func or None -> use default\n", | ||
" min_accuracy=0,\n", | ||
" max_space=0,\n", | ||
" max_time=0,\n", | ||
" scaling_factor=1.0,\n", | ||
" \n", | ||
" pareto_mode=ParetoMode.DISABLE,\n", | ||
" pareto_metric=ParetoMetric.TIME,\n", | ||
"):\n", | ||
" if isinstance(accuracy_metric, str):\n", | ||
" accuracy_key = accuracy_metric\n", | ||
" accuracy_metric = lambda v: v[accuracy_key]\n", | ||
"\n", | ||
" if time_metric is None:\n", | ||
" time_metric = lambda v: v['time'] * scaling_factor + (v['quantizer']['time'] if 'quantizer' in v else 0)\n", | ||
"\n", | ||
" if space_metric is None:\n", | ||
" space_metric = lambda v: results['indices'][v['codec']]['code_size']\n", | ||
" \n", | ||
" fe = []\n", | ||
" ops = {}\n", | ||
" if pareto_mode == ParetoMode.GLOBAL:\n", | ||
" op = OperatingPoints()\n", | ||
" ops[\"global\"] = op\n", | ||
" for k, v in results['experiments'].items():\n", | ||
" if f\".{evaluation}\" in k:\n", | ||
" accuracy = accuracy_metric(v)\n", | ||
" if min_accuracy > 0 and accuracy < min_accuracy:\n", | ||
" continue\n", | ||
" space = space_metric(v)\n", | ||
" if max_space > 0 and space > max_space:\n", | ||
" continue\n", | ||
" time = time_metric(v)\n", | ||
" if max_time > 0 and time > max_time:\n", | ||
" continue\n", | ||
" idx_name = v['index']\n", | ||
" experiment = (accuracy, space, time, k, v)\n", | ||
" if pareto_mode == ParetoMode.DISABLE:\n", | ||
" fe.append(experiment)\n", | ||
" continue\n", | ||
" if pareto_mode == ParetoMode.INDEX:\n", | ||
" if idx_name not in ops:\n", | ||
" ops[idx_name] = OperatingPoints()\n", | ||
" op = ops[idx_name]\n", | ||
" if pareto_metric == ParetoMetric.TIME:\n", | ||
" op.add_operating_point(experiment, accuracy, time)\n", | ||
" elif pareto_metric == ParetoMetric.SPACE:\n", | ||
" op.add_operating_point(experiment, accuracy, space)\n", | ||
" else:\n", | ||
" op.add_operating_point(experiment, accuracy, Cost([time, space]))\n", | ||
"\n", | ||
" if ops:\n", | ||
" for op in ops.values():\n", | ||
" for v, _, _ in op.operating_points:\n", | ||
" fe.append(v)\n", | ||
"\n", | ||
" fe.sort()\n", | ||
" return fe" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f080a6e2-1565-418b-8732-4adeff03a099", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def plot_metric(experiments, accuracy_title, cost_title, plot_space=False):\n", | ||
" x = {}\n", | ||
" y = {}\n", | ||
" for accuracy, space, time, k, v in experiments:\n", | ||
" idx_name = v['index']\n", | ||
" if idx_name not in x:\n", | ||
" x[idx_name] = []\n", | ||
" y[idx_name] = []\n", | ||
" x[idx_name].append(accuracy)\n", | ||
" if plot_space:\n", | ||
" y[idx_name].append(space)\n", | ||
" else:\n", | ||
" y[idx_name].append(time)\n", | ||
"\n", | ||
" #plt.figure(figsize=(10,6))\n", | ||
" plt.yscale(\"log\")\n", | ||
" plt.title(accuracy_title)\n", | ||
" plt.xlabel(accuracy_title)\n", | ||
" plt.ylabel(cost_title)\n", | ||
" marker = itertools.cycle((\"o\", \"v\", \"^\", \"<\", \">\", \"s\", \"p\", \"P\", \"*\", \"h\", \"X\", \"D\")) \n", | ||
" for index in x.keys():\n", | ||
" plt.plot(x[index], y[index], marker=next(marker), label=index)\n", | ||
" plt.legend(bbox_to_anchor=(1, 1), loc='upper left')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "61007155-5edc-449e-835e-c141a01a2ae5", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"accuracy_metric = \"knn_intersection\"\n", | ||
"fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", | ||
"plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"time (seconds, 16 cores)\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "36e82084-18f6-4546-a717-163eb0224ee8", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"precision = 0.8\n", | ||
"accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n", | ||
"fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", | ||
"plot_metric(fr, accuracy_title=f\"range recall @ precision {precision}\", cost_title=\"time (seconds, 16 cores)\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "aff79376-39f7-47c0-8b83-1efe5192bb7e", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# index local optima\n", | ||
"precision = 0.2\n", | ||
"accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n", | ||
"fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", | ||
"plot_metric(fr, accuracy_title=f\"range recall @ precision {precision}\", cost_title=\"time (seconds, 16 cores)\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "b4834f1f-bbbe-4cae-9aa0-a459b0c842d1", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# global optima\n", | ||
"precision = 0.8\n", | ||
"accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n", | ||
"fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", | ||
"plot_metric(fr, accuracy_title=f\"range recall @ precision {precision}\", cost_title=\"time (seconds, 16 cores)\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "9aead830-6209-4956-b7ea-4a5e0029d616", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def plot_range_search_pr_curves(experiments):\n", | ||
" x = {}\n", | ||
" y = {}\n", | ||
" show = {\n", | ||
" 'Flat': None,\n", | ||
" }\n", | ||
" for _, _, _, k, v in fr:\n", | ||
" if \".weighted\" in k: # and v['index'] in show:\n", | ||
" x[k] = v['range_search_pr']['recall']\n", | ||
" y[k] = v['range_search_pr']['precision']\n", | ||
" \n", | ||
" plt.title(\"range search recall\")\n", | ||
" plt.xlabel(\"recall\")\n", | ||
" plt.ylabel(\"precision\")\n", | ||
" for index in x.keys():\n", | ||
" plt.plot(x[index], y[index], '.', label=index)\n", | ||
" plt.legend(bbox_to_anchor=(1.0, 1.0), loc='upper left')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "92e45502-7a31-4a15-90df-fa3032d7d350", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"precision = 0.8\n", | ||
"accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n", | ||
"fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME_SPACE, scaling_factor=1)\n", | ||
"plot_range_search_pr_curves(fr)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "fdf8148a-0da6-4c5e-8d60-f8f85314574c", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python [conda env:faiss_cpu_from_source] *", | ||
"language": "python", | ||
"name": "conda-env-faiss_cpu_from_source-py" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |