Skip to content

Commit

Permalink
Refactor Clip and OpenAi query encoders (#2014)
Browse files Browse the repository at this point in the history
Closes #2013
  • Loading branch information
lintool authored Oct 15, 2024
1 parent 1adca72 commit 75aea32
Show file tree
Hide file tree
Showing 27 changed files with 162 additions and 117 deletions.
2 changes: 1 addition & 1 deletion docs/experiments-ance.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,4 +194,4 @@ Top100 accuracy: 0.8522
+ Results reproduced by [@ArthurChen189](https://github.com/ArthurChen189) on 2021-07-06 (commit [`c9f44b`](https://github.com/castorini/pyserini/commit/c9f44b2a24103fff4887cade831f9b7c2472b190))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2022-12-23 (commit [`0c495c`](https://github.com/castorini/pyserini/commit/0c495cf2999dda980eb1f85efa30a4323cef5855))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2023-01-10 (commit [`7dafc4`](https://github.com/castorini/pyserini/commit/7dafc4f918bd44ada3771a5c81692ab19cc2cae9))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2024-10-16 (commit [`3f7609`](https://github.com/castorini/pyserini/commit/3f76099a73820afee12496c0354d52ca6a6175c2))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2024-10-07 (commit [`3f7609`](https://github.com/castorini/pyserini/commit/3f76099a73820afee12496c0354d52ca6a6175c2))
2 changes: 1 addition & 1 deletion docs/experiments-bpr.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,4 @@ Top100 accuracy: 0.8571
+ Results reproduced by [@HAKSOAT](https://github.com/HAKSOAT) on 2022-03-11 (commit [`779668`](https://github.com/castorini/pyserini/commit/77966851755163e36489544fb08f73171e98103f))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2022-12-24 (commit [`0c495c`](https://github.com/castorini/pyserini/commit/0c495cf2999dda980eb1f85efa30a4323cef5855))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2023-01-10 (commit [`7dafc4`](https://github.com/castorini/pyserini/commit/7dafc4f918bd44ada3771a5c81692ab19cc2cae9))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2024-10-16 (commit [`3f7609`](https://github.com/castorini/pyserini/commit/3f76099a73820afee12496c0354d52ca6a6175c2))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2024-10-07 (commit [`3f7609`](https://github.com/castorini/pyserini/commit/3f76099a73820afee12496c0354d52ca6a6175c2))
2 changes: 1 addition & 1 deletion docs/experiments-distilbert_kd.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,4 @@ recall_1000 all 0.9553
+ Results reproduced by [@lintool](https://github.com/lintool) on 2021-04-26 (commit [`854c19`](https://github.com/castorini/pyserini/commit/854c1930ba00819245c0a9fbcf2090ce14db4db0))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2022-12-23 (commit [`0c495c`](https://github.com/castorini/pyserini/commit/0c495cf2999dda980eb1f85efa30a4323cef5855))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2023-01-10 (commit [`7dafc4`](https://github.com/castorini/pyserini/commit/7dafc4f918bd44ada3771a5c81692ab19cc2cae9))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2024-10-16 (commit [`3f7609`](https://github.com/castorini/pyserini/commit/3f76099a73820afee12496c0354d52ca6a6175c2))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2024-10-07 (commit [`3f7609`](https://github.com/castorini/pyserini/commit/3f76099a73820afee12496c0354d52ca6a6175c2))
2 changes: 1 addition & 1 deletion docs/experiments-distilbert_tasb.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,4 @@ recall_1000 all 0.9771
+ Results reproduced by [@lintool](https://github.com/lintool) on 2021-05-28 (commit [`102ed2`](https://github.com/castorini/pyserini/commit/102ed2b2e8770978e4b3e09804913dcffb63c4a7))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2022-12-23 (commit [`0c495c`](https://github.com/castorini/pyserini/commit/0c495cf2999dda980eb1f85efa30a4323cef5855))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2023-01-10 (commit [`7dafc4`](https://github.com/castorini/pyserini/commit/7dafc4f918bd44ada3771a5c81692ab19cc2cae9))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2024-10-16 (commit [`3f7609`](https://github.com/castorini/pyserini/commit/3f76099a73820afee12496c0354d52ca6a6175c2))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2024-10-07 (commit [`3f7609`](https://github.com/castorini/pyserini/commit/3f76099a73820afee12496c0354d52ca6a6175c2))
2 changes: 1 addition & 1 deletion docs/experiments-dkrr.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,4 +145,4 @@ Running hybrid sparse-dense retrieval with DKKR and [GAR-T5](https://github.com/
+ Results reproduced by [@lintool](https://github.com/lintool) on 2021-02-12 (commit [`52a1e7`](https://github.com/castorini/pyserini/commit/52a1e7f241b7b833a3ec1d739e629c08417a324c))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2022-12-23 (commit [`90676b`](https://github.com/castorini/pyserini/commit/90676b351b47585084aa8136265d02a67ced3803))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2023-01-10 (commit [`7dafc4`](https://github.com/castorini/pyserini/commit/7dafc4f918bd44ada3771a5c81692ab19cc2cae9))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2024-10-16 (commit [`3f7609`](https://github.com/castorini/pyserini/commit/3f76099a73820afee12496c0354d52ca6a6175c2))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2024-10-07 (commit [`3f7609`](https://github.com/castorini/pyserini/commit/3f76099a73820afee12496c0354d52ca6a6175c2))
2 changes: 1 addition & 1 deletion docs/experiments-dpr.md
Original file line number Diff line number Diff line change
Expand Up @@ -647,4 +647,4 @@ Top100 accuracy: 0.8837
+ Results reproduced by [@manveertamber](https://github.com/manveertamber) on 2022-01-22 (commit [`ef70c6`](https://github.com/castorini/pyserini/commit/ef70c63efd773e87afd9708338827342f4960540))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2022-12-25 (commit [`0c495c`](https://github.com/castorini/pyserini/commit/0c495cf2999dda980eb1f85efa30a4323cef5855))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2023-01-10 (commit [`7dafc4`](https://github.com/castorini/pyserini/commit/7dafc4f918bd44ada3771a5c81692ab19cc2cae9))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2024-10-16 (commit [`3f7609`](https://github.com/castorini/pyserini/commit/3f76099a73820afee12496c0354d52ca6a6175c2))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2024-10-07 (commit [`3f7609`](https://github.com/castorini/pyserini/commit/3f76099a73820afee12496c0354d52ca6a6175c2))
2 changes: 1 addition & 1 deletion docs/experiments-sbert.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,4 @@ recall_1000 all 0.9659
+ Results reproduced by [@lintool](https://github.com/lintool) on 2021-04-26 (commit [`854c19`](https://github.com/castorini/pyserini/commit/854c1930ba00819245c0a9fbcf2090ce14db4db0))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2022-12-23 (commit [`0c495c`](https://github.com/castorini/pyserini/commit/0c495cf2999dda980eb1f85efa30a4323cef5855))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2023-01-10 (commit [`7dafc4`](https://github.com/castorini/pyserini/commit/7dafc4f918bd44ada3771a5c81692ab19cc2cae9))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2024-10-16 (commit [`3f7609`](https://github.com/castorini/pyserini/commit/3f76099a73820afee12496c0354d52ca6a6175c2))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2024-10-07 (commit [`3f7609`](https://github.com/castorini/pyserini/commit/3f76099a73820afee12496c0354d52ca6a6175c2))
4 changes: 2 additions & 2 deletions docs/experiments-tct_colbert-v2.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Pyserini: TCT-ColBERTv2 for MS MARCO (V1) Collections
# Pyserini: Reproducing TCT-ColBERTv2 for MS MARCO V1

This guide provides instructions to reproduce the family of TCT-ColBERT-V2 dense retrieval models described in the following paper:

Expand Down Expand Up @@ -398,4 +398,4 @@ ndcg_cut_10 all 0.6094
+ Results reproduced by [@lintool](https://github.com/lintool) on 2022-12-25 (commit [`0c495c`](https://github.com/castorini/pyserini/commit/0c495cf2999dda980eb1f85efa30a4323cef5855))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2023-01-10 (commit [`7dafc4`](https://github.com/castorini/pyserini/commit/7dafc4f918bd44ada3771a5c81692ab19cc2cae9))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2023-05-06 (commit [`dcc0ba`](https://github.com/castorini/pyserini/commit/dcc0ba06585a08d7c78cbffac4217b57e170fc3a))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2024-10-16 (commit [`3f7609`](https://github.com/castorini/pyserini/commit/3f76099a73820afee12496c0354d52ca6a6175c2))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2024-10-07 (commit [`3f7609`](https://github.com/castorini/pyserini/commit/3f76099a73820afee12496c0354d52ca6a6175c2))
4 changes: 2 additions & 2 deletions docs/experiments-tct_colbert.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Pyserini: TCT-ColBERT for MS MARCO (V1) Collections
# Pyserini: Reproducing TCT-ColBERT for MS MARCO V1

This guide provides instructions to reproduce the TCT-ColBERT dense retrieval model described in the following paper:

Expand Down Expand Up @@ -415,4 +415,4 @@ recall_100 all 0.9083
+ Results reproduced by [@lintool](https://github.com/lintool) on 2022-12-24 (commit [`0c495c`](https://github.com/castorini/pyserini/commit/0c495cf2999dda980eb1f85efa30a4323cef5855))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2023-01-10 (commit [`7dafc4`](https://github.com/castorini/pyserini/commit/7dafc4f918bd44ada3771a5c81692ab19cc2cae9))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2023-05-06 (commit [`dcc0ba`](https://github.com/castorini/pyserini/commit/dcc0ba06585a08d7c78cbffac4217b57e170fc3a))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2024-10-16 (commit [`3f7609`](https://github.com/castorini/pyserini/commit/3f76099a73820afee12496c0354d52ca6a6175c2))
+ Results reproduced by [@lintool](https://github.com/lintool) on 2024-10-07 (commit [`3f7609`](https://github.com/castorini/pyserini/commit/3f76099a73820afee12496c0354d52ca6a6175c2))
4 changes: 2 additions & 2 deletions pyserini/2cr/msmarco-v1-passage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1256,7 +1256,7 @@ conditions:
display: "OpenAI ada2: Faiss flat, cached queries"
display-html: "OpenAI ada2: Faiss flat, cached queries"
display-row: "[<a href=\"#\" data-mdb-toggle=\"tooltip\" title=\"Lin et al. (2023) Vector Search with OpenAI Embeddings: Lucene Is All You Need.\">11</a>]"
command: python -m pyserini.search.faiss --threads ${sparse_threads} --batch-size ${sparse_batch_size} --index msmarco-v1-passage.openai-ada2 --topics $topics --encoded-queries openai-ada2-$topics --output $output
command: python -m pyserini.search.faiss --threads ${dense_threads} --batch-size ${dense_batch_size} --index msmarco-v1-passage.openai-ada2 --topics $topics --encoded-queries openai-ada2-$topics --output $output
topics:
- topic_key: msmarco-passage-dev-subset
eval_key: msmarco-passage-dev-subset
Expand All @@ -1279,7 +1279,7 @@ conditions:
display: "HyDE-OpenAI ada2: Faiss flat, cached queries"
display-html: "HyDE-OpenAI ada2: Faiss flat, cached queries"
display-row: "[<a href=\"#\" data-mdb-toggle=\"tooltip\" title=\"Gao et al. (ACL&nbsp;2023) Precise Zero-Shot Dense Retrieval without Relevance Labels.\">12</a>]"
command: python -m pyserini.search.faiss --threads ${sparse_threads} --batch-size ${sparse_batch_size} --index msmarco-v1-passage.openai-ada2 --topics $topics --encoded-queries openai-ada2-$topics-hyde --output $output
command: python -m pyserini.search.faiss --threads ${dense_threads} --batch-size ${dense_batch_size} --index msmarco-v1-passage.openai-ada2 --topics $topics --encoded-queries openai-ada2-$topics-hyde --output $output
topics:
- topic_key: dl19-passage
eval_key: dl19-passage
Expand Down
3 changes: 1 addition & 2 deletions pyserini/2cr/msmarco.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,8 +666,7 @@ def run_conditions(args):

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Generate regression matrix for MS MARCO corpora.')
parser.add_argument('--collection', type=str,
help='Collection = {v1-passage, v1-doc, v2-passage, v2-doc}.', required=True)
parser.add_argument('--collection', type=str, help='Collection = {v1-passage, v1-doc, v2-passage, v2-doc}.', required=True)
# To list all conditions
parser.add_argument('--list-conditions', action='store_true', default=False, help='List available conditions.')
# For generating reports
Expand Down
5 changes: 3 additions & 2 deletions pyserini/encode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,17 @@
# Then import these...
from ._aggretriever import AggretrieverDocumentEncoder, AggretrieverQueryEncoder
from ._ance import AnceEncoder, AnceDocumentEncoder, AnceQueryEncoder
from ._arctic import ArcticDocumentEncoder, ArcticQueryEncoder
from ._auto import AutoQueryEncoder, AutoDocumentEncoder
from ._bpr import BprQueryEncoder
from ._cached_data import CachedDataQueryEncoder
from ._clip import ClipDocumentEncoder, ClipTextEncoder, ClipImageEncoder, ClipQueryEncoder
from ._cosdpr import CosDprEncoder, CosDprDocumentEncoder, CosDprQueryEncoder
from ._dkrr import DkrrDprQueryEncoder
from ._dpr import DprDocumentEncoder, DprQueryEncoder
from ._openai import OpenAIDocumentEncoder, OpenAIQueryEncoder, OPENAI_API_RETRY_DELAY
from ._openai import OpenAiDocumentEncoder, OpenAiQueryEncoder, OPENAI_API_RETRY_DELAY
from ._slim import SlimQueryEncoder
from ._splade import SpladeQueryEncoder
from ._tct_colbert import TctColBertDocumentEncoder, TctColBertQueryEncoder
from ._tok_freq import TokFreqQueryEncoder
from ._unicoil import UniCoilEncoder, UniCoilDocumentEncoder, UniCoilQueryEncoder
from ._arctic import ArcticDocumentEncoder, ArcticQueryEncoder
14 changes: 8 additions & 6 deletions pyserini/encode/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
import argparse
import sys

from pyserini.encode import DprDocumentEncoder, TctColBertDocumentEncoder, AnceDocumentEncoder, \
AggretrieverDocumentEncoder, AutoDocumentEncoder, CosDprDocumentEncoder, JsonlRepresentationWriter, \
JsonlCollectionIterator, UniCoilDocumentEncoder, OpenAIDocumentEncoder, ArcticDocumentEncoder, OPENAI_API_RETRY_DELAY
from pyserini.encode._clip import ClipDocumentEncoder
from pyserini.encode._faiss import FaissRepresentationWriter
from pyserini.encode import AutoDocumentEncoder
from pyserini.encode import ArcticDocumentEncoder, AggretrieverDocumentEncoder, AnceDocumentEncoder, \
ClipDocumentEncoder, CosDprDocumentEncoder, DprDocumentEncoder, TctColBertDocumentEncoder, UniCoilDocumentEncoder
from pyserini.encode import OpenAiDocumentEncoder, OPENAI_API_RETRY_DELAY
from pyserini.encode import JsonlRepresentationWriter, JsonlCollectionIterator
from pyserini.encode.optional import FaissRepresentationWriter

encoder_class_map = {
"dpr": DprDocumentEncoder,
Expand All @@ -30,14 +31,15 @@
"ance": AnceDocumentEncoder,
"sentence-transformers": AutoDocumentEncoder,
"unicoil": UniCoilDocumentEncoder,
"openai-api": OpenAIDocumentEncoder,
"openai-api": OpenAiDocumentEncoder,
"cosdpr": CosDprDocumentEncoder,
"auto": AutoDocumentEncoder,
"clip": ClipDocumentEncoder,
"contriever": AutoDocumentEncoder,
"arctic": ArcticDocumentEncoder,
}


def init_encoder(encoder, encoder_class, device, pooling, l2_norm, prefix, multimodal):
_encoder_class = encoder_class

Expand Down
18 changes: 18 additions & 0 deletions pyserini/encode/_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,21 @@ def __init__(self, model_name, device='cuda:0', l2_norm=False, prefix=None, mult

def encode(self, *args, **kwargs):
return self.encoder.encode(*args, **kwargs)


class ClipQueryEncoder(QueryEncoder):
"""Encodes queries using a CLIP model, supporting both images and texts."""

def __init__(self, encoder_dir: str = None, encoded_query_dir: str = None, device: str = 'cuda:0',
l2_norm: bool = False, prefix: str = None, multimodal: bool = False, **kwargs):
super().__init__(encoded_query_dir)
if encoder_dir:
self.device = device
self.encoder = ClipEncoder(encoder_dir, device, l2_norm, prefix, multimodal)
self.has_model = True

if not self.has_model and not self.has_encoded_query:
raise Exception('Neither query encoder model nor encoded queries provided. Please provide at least one')

def encode(self, query: str):
return self.encoder.encode(query).flatten()
47 changes: 38 additions & 9 deletions pyserini/encode/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
client = openai.OpenAI(api_key=api_key, organization=org_key)
OPENAI_API_RETRY_DELAY = 5


def retry_with_delay(func, delay: int = OPENAI_API_RETRY_DELAY, max_retries: int = 10, errors: tuple = (openai.RateLimitError)):
def wrapper(*args, **kwargs):
num_retries = 0
Expand All @@ -45,7 +46,8 @@ def wrapper(*args, **kwargs):
raise e
return wrapper

class OpenAIDocumentEncoder(DocumentEncoder):

class OpenAiDocumentEncoder(DocumentEncoder):
def __init__(self, model_name: str = 'text-embedding-ada-002', tokenizer_name: str = 'cl100k_base', **kwargs):
self.model = model_name
self.tokenizer = tiktoken.get_encoding(tokenizer_name)
Expand All @@ -62,15 +64,42 @@ def encode(self, texts: List[str], titles = None, max_length: int = 512, **kwarg
inputs = [embedding[:max_length] for embedding in inputs]
return self.get_embeddings(inputs)

class OpenAIQueryEncoder(QueryEncoder):
def __init__(self, model_name: str = 'text-embedding-ada-002', tokenizer_name: str = 'cl100k_base', device = None):
self.model = model_name
self.tokenizer = tiktoken.get_encoding(tokenizer_name)
# class OpenAIQueryEncoder(QueryEncoder):
# def __init__(self, model_name: str = 'text-embedding-ada-002', tokenizer_name: str = 'cl100k_base', device = None):
# self.model = model_name
# self.tokenizer = tiktoken.get_encoding(tokenizer_name)
#
# @retry_with_delay
# def get_embedding(self, text: str):
# return np.array(client.embeddings.create(input=text, model=self.model)['data'][0]['embedding'])
#
# def encode(self, text: str, max_length: int = 512, **kwargs):
# inputs = self.tokenizer.encode(text=text)[:max_length]
# return self.get_embedding(inputs)


class OpenAiQueryEncoder(QueryEncoder):
def __init__(self, encoder_dir: str = None, encoded_query_dir: str = None,
tokenizer_name: str = None, max_length: int = 512, **kwargs):
super().__init__(encoded_query_dir)
if encoder_dir:
api_key = '' if os.getenv("OPENAI_API_KEY") is None else os.getenv("OPENAI_API_KEY")
org_key = '' if os.getenv("OPENAI_ORG_KEY") is None else os.getenv("OPENAI_ORG_KEY")
self.client = openai.OpenAI(api_key=api_key, organization=org_key)
self.model = encoder_dir
self.tokenizer = tiktoken.get_encoding(tokenizer_name)
self.max_length = max_length
self.has_model = True
if (not self.has_model) and (not self.has_encoded_query):
raise Exception('Neither query encoder model nor encoded queries provided. Please provide at least one')

@retry_with_delay
def get_embedding(self, text: str):
return np.array(client.embeddings.create(input=text, model=self.model)['data'][0]['embedding'])
return np.array(self.client.embeddings.create(input=text, model=self.model)['data'][0]['embedding'])

def encode(self, text: str, max_length: int = 512, **kwargs):
inputs = self.tokenizer.encode(text=text)[:max_length]
return self.get_embedding(inputs)
def encode(self, query: str, **kwargs):
if self.has_model:
inputs = self.tokenizer.encode(text=query)[:self.max_length]
return self.get_embedding(inputs)
else:
return super().encode(query)
18 changes: 18 additions & 0 deletions pyserini/encode/optional/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#
# Pyserini: Reproducible IR research with sparse and dense representations
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from ._pca import PcaEncoder
from ._faiss import FaissRepresentationWriter
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 75aea32

Please sign in to comment.