From 961318c3a51b1f1714b5272107c4c5d4d1684126 Mon Sep 17 00:00:00 2001 From: Kyung Jae Lee Date: Sun, 19 Nov 2023 18:24:12 -0500 Subject: [PATCH] Update AToMiC demo page to support dense search --- pyserini/demo/atomic.py | 199 ++++++++++++++++++++++------ pyserini/demo/templates/atomic.html | 147 ++++++++++++++++---- 2 files changed, 278 insertions(+), 68 deletions(-) diff --git a/pyserini/demo/atomic.py b/pyserini/demo/atomic.py index b9cb53488..fafce8b3c 100644 --- a/pyserini/demo/atomic.py +++ b/pyserini/demo/atomic.py @@ -29,41 +29,134 @@ from typing import Callable, Optional, Tuple, Union from flask import Flask, render_template, request, flash, jsonify -from pyserini.search import LuceneSearcher, FaissSearcher +from pyserini.search import LuceneSearcher, FaissSearcher, QueryEncoder + + +RETRIEVER_TO_INDEXES = { + 'BM25': [ + 'atomic_image_v0.2_small_validation', + 'atomic_image_v0.2_base', + 'atomic_image_v0.2_large', + 'atomic_text_v0.2.1_small_validation', + 'atomic_text_v0.2.1_base', + 'atomic_text_v0.2.1_large', + ], + 'ViT-L-14.laion2b_s32b_b82k': [ + 'atomic-v0.2.ViT-L-14.laion2b_s32b_b82k.image.validation', + 'atomic-v0.2.ViT-L-14.laion2b_s32b_b82k.image.base', + 'atomic-v0.2.ViT-L-14.laion2b_s32b_b82k.image.large', + 'atomic-v0.2.1.ViT-L-14.laion2b_s32b_b82k.text.validation', + 'atomic-v0.2.1.ViT-L-14.laion2b_s32b_b82k.text.base', + 'atomic-v0.2.1.ViT-L-14.laion2b_s32b_b82k.text.large', + ], + 'ViT-H-14.laion2b_s32b_b79k': [ + 'atomic-v0.2.ViT-H-14.laion2b_s32b_b79k.image.large', + 'atomic-v0.2.1.ViT-H-14.laion2b_s32b_b79k.text.large', + ], + 'ViT-bigG-14.laion2b_s39b_b160k': [ + 'atomic-v0.2.ViT-bigG-14.laion2b_s39b_b160k.image.large', + 'atomic-v0.2.1.ViT-bigG-14.laion2b_s39b_b160k.text.large', + ], + 'ViT-B-32.laion2b_e16': [ + 'atomic-v0.2.ViT-B-32.laion2b_e16.image.large', + 'atomic-v0.2.1.ViT-B-32.laion2b_e16.text.large', + ], + 'ViT-B-32.laion400m_e32': [ + 'atomic-v0.2.ViT-B-32.laion400m_e32.image.large', + 'atomic-v0.2.1.ViT-B-32.laion400m_e32.text.large', + ], + 'openai.clip-vit-base-patch32': [ + 'atomic-v0.2.openai.clip-vit-base-patch32.image.large', + 'atomic-v0.2.1.openai.clip-vit-base-patch32.text.large', + ], + 'openai.clip-vit-large-patch14': [ + 'atomic-v0.2.openai.clip-vit-large-patch14.image.large', + 'atomic-v0.2.1.openai.clip-vit-large-patch14.text.large', + ], + 'Salesforce.blip-itm-base-coco': [ + 'atomic-v0.2.Salesforce.blip-itm-base-coco.image.large', + 'atomic-v0.2.1.Salesforce.blip-itm-base-coco.text.large', + ], + 'Salesforce.blip-itm-large-coco': [ + 'atomic-v0.2.Salesforce.blip-itm-large-coco.image.large', + 'atomic-v0.2.1.Salesforce.blip-itm-large-coco.text.large', + ], + 'facebook.flava-full': [ + 'atomic-v0.2.facebook.flava-full.image.large', + 'atomic-v0.2.1.facebook.flava-full.text.large', + ], +} + +INDEX_TO_ENCODED_QUERIES = { + # 'ViT-L-14.laion2b_s32b_b82k' + 'atomic-v0.2.1.ViT-L-14.laion2b_s32b_b82k.text.validation': 'atomic-v0.2-image-ViT-L-14.laion2b_s32b_b82k-validation', + 'atomic-v0.2.ViT-L-14.laion2b_s32b_b82k.image.validation': 'atomic-v0.2.1-text-ViT-L-14.laion2b_s32b_b82k-validation', + 'atomic-v0.2.1.ViT-L-14.laion2b_s32b_b82k.text.base': 'atomic-v0.2-image-ViT-L-14.laion2b_s32b_b82k-validation', + 'atomic-v0.2.ViT-L-14.laion2b_s32b_b82k.image.base': 'atomic-v0.2.1-text-ViT-L-14.laion2b_s32b_b82k-validation', + 'atomic-v0.2.1.ViT-L-14.laion2b_s32b_b82k.text.large': 'atomic-v0.2-image-ViT-L-14.laion2b_s32b_b82k-validation', + 'atomic-v0.2.ViT-L-14.laion2b_s32b_b82k.image.large': 'atomic-v0.2.1-text-ViT-L-14.laion2b_s32b_b82k-validation', + # ViT-H-14.laion2b_s32b_b79k + 'atomic-v0.2.ViT-H-14.laion2b_s32b_b79k.image.large': 'atomic-v0.2.1-text-ViT-H-14.laion2b_s32b_b79k-validation', + 'atomic-v0.2.1.ViT-H-14.laion2b_s32b_b79k.text.large': 'atomic-v0.2-image-ViT-H-14.laion2b_s32b_b79k-validation', + # ViT-bigG-14.laion2b_s39b_b160k + 'atomic-v0.2.ViT-bigG-14.laion2b_s39b_b160k.image.large': 'atomic-v0.2.1-text-ViT-bigG-14.laion2b_s39b_b160k-validation', + 'atomic-v0.2.1.ViT-bigG-14.laion2b_s39b_b160k.text.large': 'atomic-v0.2-image-ViT-bigG-14.laion2b_s39b_b160k-validation', + # ViT-B-32.laion2b_e16 + 'atomic-v0.2.ViT-B-32.laion2b_e16.image.large': 'atomic-v0.2.1-text-ViT-B-32.laion2b_e16-validation', + 'atomic-v0.2.1.ViT-B-32.laion2b_e16.text.large': 'atomic-v0.2-image-ViT-B-32.laion2b_e16-validation', + # ViT-B-32.laion400m_e32 + 'atomic-v0.2.ViT-B-32.laion400m_e32.image.large': 'atomic-v0.2.1-text-ViT-B-32.laion400m_e32-validation', + 'atomic-v0.2.1.ViT-B-32.laion400m_e32.text.large': 'atomic-v0.2-image-ViT-B-32.laion400m_e32-validation', + # openai.clip-vit-base-patch32 + 'atomic-v0.2.openai.clip-vit-base-patch32.image.large': 'atomic-v0.2.1-text-openai.clip-vit-base-patch32-validation', + 'atomic-v0.2.1.openai.clip-vit-base-patch32.text.large': 'atomic-v0.2-image-openai.clip-vit-base-patch32-validation', + # openai.clip-vit-large-patch14 + 'atomic-v0.2.openai.clip-vit-large-patch14.image.large': 'atomic-v0.2.1-text-openai.clip-vit-large-patch14-validation', + 'atomic-v0.2.1.openai.clip-vit-large-patch14.text.large': 'atomic-v0.2-image-openai.clip-vit-large-patch14-validation', + # Salesforce.blip-itm-base-coco + 'atomic-v0.2.Salesforce.blip-itm-base-coco.image.large': 'atomic-v0.2.1-text-Salesforce.blip-itm-base-coco-validation', + 'atomic-v0.2.1.Salesforce.blip-itm-base-coco.text.large': 'atomic-v0.2-image-Salesforce.blip-itm-base-coco-validation', + # Salesforce.blip-itm-large-coco + 'atomic-v0.2.Salesforce.blip-itm-large-coco.image.large': 'atomic-v0.2.1-text-Salesforce.blip-itm-large-coco-validation', + 'atomic-v0.2.1.Salesforce.blip-itm-large-coco.text.large': 'atomic-v0.2-image-Salesforce.blip-itm-large-coco-validation', + # facebook.flava-full + 'atomic-v0.2.facebook.flava-full.image.large': 'atomic-v0.2.1-text-facebook.flava-full-validation', + 'atomic-v0.2.1.facebook.flava-full.text.large': 'atomic-v0.2-image-facebook.flava-full-validation', +} - -INDEX_NAMES = ( - 'atomic_image_v0.2_small_validation', - 'atomic_image_v0.2_base', - 'atomic_image_v0.2_large', - 'atomic_text_v0.2.1_small_validation', - 'atomic_text_v0.2.1_base', - 'atomic_text_v0.2.1_large', -) Searcher = Union[FaissSearcher, LuceneSearcher] -def create_app(k: int, load_searcher_fn: Callable[[str], Tuple[Searcher, str]]): +def create_app(k: int, load_searcher_fn: Callable[[str], Searcher]): app = Flask(__name__) - index_name = INDEX_NAMES[0] - searcher, retriever = load_searcher_fn(index_name=index_name) + # Use BM25 as default retriever upon page load + retriever = "BM25" + index_name = RETRIEVER_TO_INDEXES[retriever][0] + searcher = load_searcher_fn(index_name=index_name) + query_options = [] # for dense search only @app.route('/') def index(): - nonlocal searcher, retriever - return render_template('atomic.html', index_name=index_name, retriever=retriever) + return render_template( + 'atomic.html', index_name=index_name, retriever=retriever, retriever_to_indexes=RETRIEVER_TO_INDEXES + ) @app.route('/search', methods=['GET', 'POST']) def search(): - nonlocal searcher, retriever query = request.form['q'] + if retriever != "BM25": + query = query_options[int(query)] if not query: search_results = [] flash('Question is required') # NOTE: this throws an exception unless we set a secret session key else: - hits = searcher.search(query, k=k) + try: + hits = searcher.search(query, k=k) + except KeyError: + hits = [] + flash('Invalid query given') docs = [json.loads(searcher.doc(hit.docid).raw()) for hit in hits] search_results = [ { @@ -76,56 +169,74 @@ def search(): for r, hit in enumerate(hits) ] return render_template( - 'atomic.html', index_name=index_name, search_results=search_results, query=query, retriever=retriever + 'atomic.html', index_name=index_name, retriever=retriever, + retriever_to_indexes=RETRIEVER_TO_INDEXES, search_results=search_results, query=query, ) + def _change_index(new_index_name): + nonlocal index_name, searcher, query_options + index_name = new_index_name + searcher = load_searcher_fn(index_name=index_name) + if retriever != "BM25": + query_options = {i: option for i, option in enumerate(searcher.query_encoder.embedding.keys())} + + @app.route('/retriever', methods=['GET']) + def change_retriever(): + nonlocal retriever + new_retriever = request.args.get('new_retriever_name', '', type=str) + if not new_retriever or new_retriever not in list(RETRIEVER_TO_INDEXES.keys()): + return + + retriever = new_retriever + _change_index(new_index_name=RETRIEVER_TO_INDEXES[retriever][0]) + return jsonify(index_list=RETRIEVER_TO_INDEXES[retriever]) + @app.route('/index', methods=['GET']) def change_index_name(): - nonlocal index_name, searcher, retriever new_index_name = request.args.get('new_index_name', '', type=str) - if not new_index_name or new_index_name not in INDEX_NAMES: + if not new_index_name or new_index_name not in RETRIEVER_TO_INDEXES[retriever]: return - - index_name = new_index_name - searcher, retriever = load_searcher_fn(index_name=index_name) + _change_index(new_index_name) return jsonify(index_name=index_name) + @app.route('/search_options', methods=['GET']) + def search_options(): + query = request.args.get('query', '') + + matching_options = { + i: option + for i, option in query_options.items() + if option.lower().startswith(query.lower()) + } + return jsonify(matching_options) + return app -def _load_sparse_searcher(index_name, language: str, k1: Optional[float]=None, b: Optional[float]=None) -> (Searcher, str): - searcher = LuceneSearcher.from_prebuilt_index(index_name) - if k1 is not None and b is not None: - searcher.set_bm25(k1, b) - retriever_name = f'BM25 (k1={k1}, b={b})' +def _load_searcher(index_name: str, language: str, k1: Optional[float]=None, b: Optional[float]=None): + if index_name in RETRIEVER_TO_INDEXES['BM25']: + searcher = LuceneSearcher.from_prebuilt_index(index_name) + if k1 is not None and b is not None: + searcher.set_bm25(k1, b) else: - retriever_name = 'BM25' - - return searcher, retriever_name + query_encoder = QueryEncoder.load_encoded_queries(INDEX_TO_ENCODED_QUERIES[index_name]) + searcher = FaissSearcher.from_prebuilt_index( + index_name, query_encoder + ) + return searcher def main(): parser = ArgumentParser() - parser.add_argument('--k1', type=float, help='BM25 k1 parameter.') parser.add_argument('--b', type=float, help='BM25 b parameter.') parser.add_argument('--hits', type=int, default=10, help='Number of hits returned by the retriever') parser.add_argument( - '--device', - type=str, - default='cpu', - help='Device to run query encoder, cpu or [cuda:0, cuda:1, ...] (used only when index is based on FAISS)', - ) - parser.add_argument( - '--port', - default=8080, - type=int, - help='Web server port', + '--port', default=8080, type=int, help='Web server port', ) - args = parser.parse_args() - load_fn = partial(_load_sparse_searcher, language='en', k1=args.k1, b=args.b) + load_fn = partial(_load_searcher, language='en', k1=args.k1, b=args.b) app = create_app(args.hits, load_fn) app.run(host='0.0.0.0', port=args.port) diff --git a/pyserini/demo/templates/atomic.html b/pyserini/demo/templates/atomic.html index ff916be02..f278fdf60 100644 --- a/pyserini/demo/templates/atomic.html +++ b/pyserini/demo/templates/atomic.html @@ -15,12 +15,91 @@