From 5458301beacbe03a09ed75da8cc74f4ac8f8c537 Mon Sep 17 00:00:00 2001 From: "shengzhe.li" Date: Tue, 18 Jun 2024 05:37:40 +0900 Subject: [PATCH 01/15] Support prefix for retrieval and reranking --- src/jmteb/embedders/base.py | 21 +++++++++--- src/jmteb/embedders/sbert_embedder.py | 37 +++++++++++++++++++-- src/jmteb/evaluators/reranking/evaluator.py | 9 +++++ src/jmteb/evaluators/retrieval/evaluator.py | 10 ++++++ tests/evaluator/test_reranking_evaluator.py | 27 ++++++++++++--- tests/evaluator/test_retrieval_evaluator.py | 32 +++++++++++++++--- 6 files changed, 121 insertions(+), 15 deletions(-) diff --git a/src/jmteb/embedders/base.py b/src/jmteb/embedders/base.py index 145f543..4276d3d 100644 --- a/src/jmteb/embedders/base.py +++ b/src/jmteb/embedders/base.py @@ -14,11 +14,12 @@ class TextEmbedder(ABC): The base class of text embedder. """ - def encode(self, text: str | list[str]) -> np.ndarray: + def encode(self, text: str | list[str], prompt: str | None = None) -> np.ndarray: """Convert a text string or a list of texts to embedding. Args: text (str | list[str]): text string, or a list of texts. + prompt (str, optional): the prompt to use for encoding. Default to None. """ raise NotImplementedError @@ -31,7 +32,12 @@ def get_output_dim(self) -> int: raise NotImplementedError def _batch_encode_and_save_on_disk( - self, text_list: list[str], save_path: str | PathLike[str], batch_size: int = 64, dtype: str = "float32" + self, + text_list: list[str], + save_path: str | PathLike[str], + prompt: str | None = None, + batch_size: int = 64, + dtype: str = "float32", ) -> np.memmap: """ Encode a list of texts and save the embeddings on disk using memmap. @@ -39,6 +45,7 @@ def _batch_encode_and_save_on_disk( Args: text_list (list[str]): list of texts save_path (str): path to save the embeddings + prompt (str, optional): the prompt to use for encoding. Default to None. dtype (str, optional): data type. Defaults to "float32". batch_size (int): batch size. Defaults to 64. """ @@ -50,7 +57,7 @@ def _batch_encode_and_save_on_disk( with tqdm.tqdm(total=num_samples, desc="Encoding") as pbar: for i in range(0, num_samples, batch_size): batch = text_list[i : i + batch_size] - batch_embeddings = self.encode(batch) + batch_embeddings = self.encode(batch, prompt=prompt) batch_embeddings = np.asarray(batch_embeddings, dtype=dtype) embeddings[i : i + batch_size] = batch_embeddings pbar.update(len(batch)) @@ -61,6 +68,7 @@ def _batch_encode_and_save_on_disk( def batch_encode_with_cache( self, text_list: list[str], + prompt: str | None = None, cache_path: str | PathLike[str] | None = None, overwrite_cache: bool = False, batch_size: int = 64, @@ -71,6 +79,7 @@ def batch_encode_with_cache( Args: text_list (list[str]): list of texts + prompt (str, optional): the prompt to use for encoding. Default to None. cache_path (str, optional): path to save the embeddings. Defaults to None. overwrite_cache (bool, optional): whether to overwrite the cache. Defaults to False. batch_size (int): batch size. Defaults to 64. @@ -79,12 +88,14 @@ def batch_encode_with_cache( if cache_path is None: logger.info("Encoding embeddings") - return self.encode(text_list).astype(dtype) + return self.encode(text_list, prompt=prompt).astype(dtype) if Path(cache_path).exists() and not overwrite_cache: logger.info(f"Loading embeddings from {cache_path}") return np.memmap(cache_path, dtype=dtype, mode="r", shape=(len(text_list), self.get_output_dim())) logger.info(f"Encoding and saving embeddings to {cache_path}") - embeddings = self._batch_encode_and_save_on_disk(text_list, cache_path, batch_size=batch_size, dtype=dtype) + embeddings = self._batch_encode_and_save_on_disk( + text_list, cache_path, prompt=prompt, batch_size=batch_size, dtype=dtype + ) return embeddings diff --git a/src/jmteb/embedders/sbert_embedder.py b/src/jmteb/embedders/sbert_embedder.py index 48ab984..6fbc48e 100644 --- a/src/jmteb/embedders/sbert_embedder.py +++ b/src/jmteb/embedders/sbert_embedder.py @@ -15,20 +15,53 @@ def __init__( batch_size: int = 32, device: str | None = None, normalize_embeddings: bool = False, + max_seq_length: int | None = None, + tokenizer_padding_side: str | None = None, + add_eos: bool = False, ) -> None: - self.model = SentenceTransformer(model_name_or_path) + self.model = SentenceTransformer(model_name_or_path, trust_remote_code=True) + if max_seq_length: + self.model.max_seq_length = max_seq_length + if tokenizer_padding_side: + try: + self.model.tokenizer.padding_side = "right" + except AttributeError: + pass + self.batch_size = batch_size self.device = device self.normalize_embeddings = normalize_embeddings + self.max_seq_length = max_seq_length + self.tokenizer_padding_side = tokenizer_padding_side + self.add_eos = add_eos + + if self.max_seq_length: + self.model.max_seq_length = self.max_seq_length + if self.tokenizer_padding_side: + setattr(self.model.tokenizer, "padding_side", self.tokenizer_padding_side) - def encode(self, text: str | list[str]) -> np.ndarray: + def encode(self, text: str | list[str], prompt: str | None = None) -> np.ndarray: + if self.add_eos: + text = self.add_eos_func(text) return self.model.encode( text, + prompt=prompt, convert_to_numpy=True, batch_size=self.batch_size, device=self.device, normalize_embeddings=self.normalize_embeddings, ) + def add_eos_func(self, text: str | list[str]) -> str | list[str]: + try: + eos_token = getattr(self.model.tokenizer, "eos_token") + except AttributeError: + return text + + if isinstance(text, str): + return text + eos_token + elif isinstance(text, list): + return [t + eos_token for t in text] + def get_output_dim(self) -> int: return self.model.get_sentence_embedding_dimension() diff --git a/src/jmteb/evaluators/reranking/evaluator.py b/src/jmteb/evaluators/reranking/evaluator.py index 4b71dfe..0089c0a 100644 --- a/src/jmteb/evaluators/reranking/evaluator.py +++ b/src/jmteb/evaluators/reranking/evaluator.py @@ -28,6 +28,8 @@ class RerankingEvaluator(EmbeddingEvaluator): test_query_dataset (RerankingQueryDataset): test query dataset used for computing the scores doc_dataset (RerankingDocDataset): document dataset ndcg_at_k (list[int] | None): top k documents to consider in NDCG (Normalized Documented Cumulative Gain). + query_prefix (str | None): prefix for queries. Defaults to None. + doc_prefix (str | None): prefix for documents. Defaults to None. """ def __init__( @@ -36,12 +38,16 @@ def __init__( test_query_dataset: RerankingQueryDataset, doc_dataset: RerankingDocDataset, ndcg_at_k: list[int] | None = None, + query_prefix: str | None = None, + doc_prefix: str | None = None, ) -> None: self.test_query_dataset = test_query_dataset self.val_query_dataset = val_query_dataset self.doc_dataset = doc_dataset self.ndcg_at_k = ndcg_at_k or [10, 20, 40] self.main_metric = f"ndcg@{self.ndcg_at_k[0]}" + self.query_prefix = query_prefix + self.doc_prefix = doc_prefix def __call__( self, @@ -54,6 +60,7 @@ def __call__( val_query_embeddings = model.batch_encode_with_cache( text_list=[item.query for item in self.val_query_dataset], + prompt=self.query_prefix, cache_path=Path(cache_dir) / "val_query.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, ) @@ -62,11 +69,13 @@ def __call__( else: test_query_embeddings = model.batch_encode_with_cache( text_list=[item.query for item in self.test_query_dataset], + prompt=self.query_prefix, cache_path=Path(cache_dir) / "test_query.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, ) doc_embeddings = model.batch_encode_with_cache( text_list=[item.text for item in self.doc_dataset], + prompt=self.doc_prefix, cache_path=Path(cache_dir) / "corpus.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, ) diff --git a/src/jmteb/evaluators/retrieval/evaluator.py b/src/jmteb/evaluators/retrieval/evaluator.py index bc97e33..9af7af4 100644 --- a/src/jmteb/evaluators/retrieval/evaluator.py +++ b/src/jmteb/evaluators/retrieval/evaluator.py @@ -31,6 +31,8 @@ class RetrievalEvaluator(EmbeddingEvaluator): doc_chunk_size (int): The maximum size of corpus chunk. Smaller chunk requires less memory but lowers speed. ndcg_at_k (list[int] | None): top k documents to consider in NDCG (Normalized Documented Cumulative Gain). accuracy_at_k (list[int] | None): accuracy in top k hits. + query_prefix (str | None): prefix for queries. Defaults to None. + doc_prefix (str | None): prefix for documents. Defaults to None. """ def __init__( @@ -41,6 +43,8 @@ def __init__( doc_chunk_size: int = 1000000, accuracy_at_k: list[int] | None = None, ndcg_at_k: list[int] | None = None, + query_prefix: str | None = None, + doc_prefix: str | None = None, ) -> None: self.val_query_dataset = val_query_dataset self.test_query_dataset = test_query_dataset @@ -53,6 +57,9 @@ def __init__( self.max_top_k = max(sum([self.accuracy_at_k, self.ndcg_at_k], [])) self.main_metric = f"ndcg@{self.ndcg_at_k[0]}" + self.query_prefix = query_prefix + self.doc_prefix = doc_prefix + def __call__( self, model: TextEmbedder, @@ -64,6 +71,7 @@ def __call__( val_query_embeddings = model.batch_encode_with_cache( text_list=[item.query for item in self.val_query_dataset], + prompt=self.query_prefix, cache_path=Path(cache_dir) / "val_query.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, ) @@ -72,12 +80,14 @@ def __call__( else: test_query_embeddings = model.batch_encode_with_cache( text_list=[item.query for item in self.test_query_dataset], + prompt=self.query_prefix, cache_path=Path(cache_dir) / "test_query.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, ) doc_embeddings = model.batch_encode_with_cache( text_list=[item.text for item in self.doc_dataset], + prompt=self.doc_prefix, cache_path=Path(cache_dir) / "corpus.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, ) diff --git a/tests/evaluator/test_reranking_evaluator.py b/tests/evaluator/test_reranking_evaluator.py index 0d903cb..ef847a9 100644 --- a/tests/evaluator/test_reranking_evaluator.py +++ b/tests/evaluator/test_reranking_evaluator.py @@ -12,11 +12,13 @@ EXPECTED_OUTPUT_DICT_KEYS = {"val_scores", "test_scores", "optimal_distance_metric"} EXPECTED_DIST_FUNC_NAMES = {"cosine_similarity", "euclidean_distance", "dot_score"} +QUERY_PREFIX = "クエリ: " +DOC_PREFIX = "ドキュメント: " class DummyDocDataset(RerankingDocDataset): - def __init__(self): - self._items = [RerankingDoc(id=str(i), text=f"dummy document {i}") for i in range(30)] + def __init__(self, prefix: str = ""): + self._items = [RerankingDoc(id=str(i), text=f"{prefix}dummy document {i}") for i in range(30)] def __len__(self): return len(self._items) @@ -26,9 +28,10 @@ def __getitem__(self, idx): class DummyQueryDataset(RerankingQueryDataset): - def __init__(self): + def __init__(self, prefix: str = ""): self._items = [ - RerankingQuery(query=f"dummy query {i}", retrieved_docs=[str(i)], relevance_scores=[1]) for i in range(10) + RerankingQuery(query=f"{prefix}dummy query {i}", retrieved_docs=[str(i)], relevance_scores=[1]) + for i in range(10) ] def __len__(self): @@ -57,6 +60,22 @@ def test_reranking_evaluator(embedder): assert any(score.startswith(metric) for metric in ["ndcg"]) +def test_reranking_evaluator_with_prefix(embedder): + evaluator_with_prefix = RerankingEvaluator( + val_query_dataset=DummyQueryDataset(), + test_query_dataset=DummyQueryDataset(), + doc_dataset=DummyDocDataset(), + query_prefix=QUERY_PREFIX, + doc_prefix=DOC_PREFIX, + ) + evaluator_with_manual_prefix = RerankingEvaluator( + val_query_dataset=DummyQueryDataset(prefix=QUERY_PREFIX), + test_query_dataset=DummyQueryDataset(prefix=QUERY_PREFIX), + doc_dataset=DummyDocDataset(prefix=DOC_PREFIX), + ) + assert evaluator_with_prefix(model=embedder) == evaluator_with_manual_prefix(model=embedder) + + def test_jsonl_reranking_datasets(): query = JsonlRerankingQueryDataset( filename="tests/test_data/dummy_reranking/val.jsonl", diff --git a/tests/evaluator/test_retrieval_evaluator.py b/tests/evaluator/test_retrieval_evaluator.py index 82b7944..fa52c52 100644 --- a/tests/evaluator/test_retrieval_evaluator.py +++ b/tests/evaluator/test_retrieval_evaluator.py @@ -12,11 +12,13 @@ EXPECTED_OUTPUT_DICT_KEYS = {"val_scores", "test_scores", "optimal_distance_metric"} EXPECTED_DIST_FUNC_NAMES = {"cosine_similarity", "euclidean_distance", "dot_score"} +QUERY_PREFIX = "クエリ: " +DOC_PREFIX = "ドキュメント: " class DummyDocDataset(RetrievalDocDataset): - def __init__(self): - self._items = [RetrievalDoc(id=str(i), text=f"dummy document {i}") for i in range(30)] + def __init__(self, prefix: str = ""): + self._items = [RetrievalDoc(id=str(i), text=f"{prefix}dummy document {i}") for i in range(30)] def __len__(self): return len(self._items) @@ -26,8 +28,8 @@ def __getitem__(self, idx): class DummyQueryDataset(RetrievalQueryDataset): - def __init__(self): - self._items = [RetrievalQuery(f"dummy query {i}", relevant_docs=[str(i)]) for i in range(10)] + def __init__(self, prefix: str = ""): + self._items = [RetrievalQuery(f"{prefix}dummy query {i}", relevant_docs=[str(i)]) for i in range(10)] def __len__(self): return len(self._items) @@ -58,6 +60,28 @@ def test_retrieval_evaluator(embedder): assert any(score.startswith(metric) for metric in ["accuracy", "mrr", "ndcg"]) +def test_retrieval_evaluator_with_prefix(embedder): + evaluator_with_prefix = RetrievalEvaluator( + val_query_dataset=DummyQueryDataset(), + test_query_dataset=DummyQueryDataset(), + doc_dataset=DummyDocDataset(), + query_prefix=QUERY_PREFIX, + doc_prefix=DOC_PREFIX, + accuracy_at_k=[1, 3, 5, 10], + ndcg_at_k=[1, 3, 5], + doc_chunk_size=3, + ) + evaluator_with_manual_prefix = RetrievalEvaluator( + val_query_dataset=DummyQueryDataset(prefix=QUERY_PREFIX), + test_query_dataset=DummyQueryDataset(prefix=QUERY_PREFIX), + doc_dataset=DummyDocDataset(prefix=DOC_PREFIX), + accuracy_at_k=[1, 3, 5, 10], + ndcg_at_k=[1, 3, 5], + doc_chunk_size=3, + ) + assert evaluator_with_prefix(model=embedder) == evaluator_with_manual_prefix(model=embedder) + + def test_if_chunking_does_not_change_result(embedder): evaluator1 = RetrievalEvaluator( val_query_dataset=DummyQueryDataset(), From 32fed50be3452b3eeed0b96e67d0c3f950d29109 Mon Sep 17 00:00:00 2001 From: "shengzhe.li" Date: Tue, 18 Jun 2024 05:50:36 +0900 Subject: [PATCH 02/15] Fix DummyTextEmbedder --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 9a104d9..a34c028 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,7 +25,7 @@ def pytest_collection_modifyitems(config: pytest.Config, items: pytest.Parser): class DummyTextEmbedder(TextEmbedder): - def encode(self, text: str | list[str]) -> np.ndarray: + def encode(self, text: str | list[str], prompt: str | None = None) -> np.ndarray: if isinstance(text, str): batch_size = 1 else: From 6a93870cbc7f51914d0311f3aafb30e8272db381 Mon Sep 17 00:00:00 2001 From: "shengzhe.li" Date: Tue, 18 Jun 2024 12:58:04 +0900 Subject: [PATCH 03/15] Add tokenizer_kwargs in SentenceTransformer init --- src/jmteb/embedders/sbert_embedder.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/jmteb/embedders/sbert_embedder.py b/src/jmteb/embedders/sbert_embedder.py index 6fbc48e..dbed505 100644 --- a/src/jmteb/embedders/sbert_embedder.py +++ b/src/jmteb/embedders/sbert_embedder.py @@ -16,29 +16,21 @@ def __init__( device: str | None = None, normalize_embeddings: bool = False, max_seq_length: int | None = None, - tokenizer_padding_side: str | None = None, add_eos: bool = False, + tokenizer_kwargs: dict | None = None, ) -> None: - self.model = SentenceTransformer(model_name_or_path, trust_remote_code=True) + self.model = SentenceTransformer(model_name_or_path, trust_remote_code=True, tokenizer_kwargs=tokenizer_kwargs) if max_seq_length: self.model.max_seq_length = max_seq_length - if tokenizer_padding_side: - try: - self.model.tokenizer.padding_side = "right" - except AttributeError: - pass self.batch_size = batch_size self.device = device self.normalize_embeddings = normalize_embeddings self.max_seq_length = max_seq_length - self.tokenizer_padding_side = tokenizer_padding_side self.add_eos = add_eos if self.max_seq_length: self.model.max_seq_length = self.max_seq_length - if self.tokenizer_padding_side: - setattr(self.model.tokenizer, "padding_side", self.tokenizer_padding_side) def encode(self, text: str | list[str], prompt: str | None = None) -> np.ndarray: if self.add_eos: From c1e54340589917d9d7a76b0c9fbc90ef8fdf3e69 Mon Sep 17 00:00:00 2001 From: "shengzhe.li" Date: Tue, 18 Jun 2024 13:06:53 +0900 Subject: [PATCH 04/15] Add a test case for tokenizer_kwargs --- tests/embedders/test_sbert.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/embedders/test_sbert.py b/tests/embedders/test_sbert.py index 48d7e4f..77f0585 100644 --- a/tests/embedders/test_sbert.py +++ b/tests/embedders/test_sbert.py @@ -17,3 +17,8 @@ def test_encode(self): def test_get_output_dim(self): assert self.model.get_output_dim() == OUTPUT_DIM + + def test_tokenizer_kwargs(self): + assert self.model.model.tokenizer.sep_token == "[SEP]" + model = SentenceBertEmbedder(MODEL_NAME_OR_PATH, tokenizer_kwargs={"sep_token": ""}) + assert model.model.tokenizer.sep_token == "" From 9a16b8ffcf5bf8d7eaeedf17d73657f4fa79c7cd Mon Sep 17 00:00:00 2001 From: "shengzhe.li" Date: Tue, 18 Jun 2024 16:35:50 +0900 Subject: [PATCH 05/15] Unify argument name prefix/prompt -> prefix --- src/jmteb/embedders/base.py | 18 +++++++++--------- src/jmteb/embedders/sbert_embedder.py | 9 +++------ src/jmteb/evaluators/reranking/evaluator.py | 6 +++--- src/jmteb/evaluators/retrieval/evaluator.py | 6 +++--- tests/conftest.py | 3 ++- 5 files changed, 20 insertions(+), 22 deletions(-) diff --git a/src/jmteb/embedders/base.py b/src/jmteb/embedders/base.py index 4276d3d..c74fddf 100644 --- a/src/jmteb/embedders/base.py +++ b/src/jmteb/embedders/base.py @@ -14,12 +14,12 @@ class TextEmbedder(ABC): The base class of text embedder. """ - def encode(self, text: str | list[str], prompt: str | None = None) -> np.ndarray: + def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray: """Convert a text string or a list of texts to embedding. Args: text (str | list[str]): text string, or a list of texts. - prompt (str, optional): the prompt to use for encoding. Default to None. + prefix (str, optional): the prefix to use for encoding. Default to None. """ raise NotImplementedError @@ -35,7 +35,7 @@ def _batch_encode_and_save_on_disk( self, text_list: list[str], save_path: str | PathLike[str], - prompt: str | None = None, + prefix: str | None = None, batch_size: int = 64, dtype: str = "float32", ) -> np.memmap: @@ -45,7 +45,7 @@ def _batch_encode_and_save_on_disk( Args: text_list (list[str]): list of texts save_path (str): path to save the embeddings - prompt (str, optional): the prompt to use for encoding. Default to None. + prefix (str, optional): the prefix to use for encoding. Default to None. dtype (str, optional): data type. Defaults to "float32". batch_size (int): batch size. Defaults to 64. """ @@ -57,7 +57,7 @@ def _batch_encode_and_save_on_disk( with tqdm.tqdm(total=num_samples, desc="Encoding") as pbar: for i in range(0, num_samples, batch_size): batch = text_list[i : i + batch_size] - batch_embeddings = self.encode(batch, prompt=prompt) + batch_embeddings = self.encode(batch, prefix=prefix) batch_embeddings = np.asarray(batch_embeddings, dtype=dtype) embeddings[i : i + batch_size] = batch_embeddings pbar.update(len(batch)) @@ -68,7 +68,7 @@ def _batch_encode_and_save_on_disk( def batch_encode_with_cache( self, text_list: list[str], - prompt: str | None = None, + prefix: str | None = None, cache_path: str | PathLike[str] | None = None, overwrite_cache: bool = False, batch_size: int = 64, @@ -79,7 +79,7 @@ def batch_encode_with_cache( Args: text_list (list[str]): list of texts - prompt (str, optional): the prompt to use for encoding. Default to None. + prefix (str, optional): the prefix to use for encoding. Default to None. cache_path (str, optional): path to save the embeddings. Defaults to None. overwrite_cache (bool, optional): whether to overwrite the cache. Defaults to False. batch_size (int): batch size. Defaults to 64. @@ -88,7 +88,7 @@ def batch_encode_with_cache( if cache_path is None: logger.info("Encoding embeddings") - return self.encode(text_list, prompt=prompt).astype(dtype) + return self.encode(text_list, prefix=prefix).astype(dtype) if Path(cache_path).exists() and not overwrite_cache: logger.info(f"Loading embeddings from {cache_path}") @@ -96,6 +96,6 @@ def batch_encode_with_cache( logger.info(f"Encoding and saving embeddings to {cache_path}") embeddings = self._batch_encode_and_save_on_disk( - text_list, cache_path, prompt=prompt, batch_size=batch_size, dtype=dtype + text_list, cache_path, prefix=prefix, batch_size=batch_size, dtype=dtype ) return embeddings diff --git a/src/jmteb/embedders/sbert_embedder.py b/src/jmteb/embedders/sbert_embedder.py index dbed505..98f3602 100644 --- a/src/jmteb/embedders/sbert_embedder.py +++ b/src/jmteb/embedders/sbert_embedder.py @@ -26,18 +26,15 @@ def __init__( self.batch_size = batch_size self.device = device self.normalize_embeddings = normalize_embeddings - self.max_seq_length = max_seq_length + self.max_seq_length = self.model.max_seq_length self.add_eos = add_eos - if self.max_seq_length: - self.model.max_seq_length = self.max_seq_length - - def encode(self, text: str | list[str], prompt: str | None = None) -> np.ndarray: + def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray: if self.add_eos: text = self.add_eos_func(text) return self.model.encode( text, - prompt=prompt, + prompt=prefix, convert_to_numpy=True, batch_size=self.batch_size, device=self.device, diff --git a/src/jmteb/evaluators/reranking/evaluator.py b/src/jmteb/evaluators/reranking/evaluator.py index 0089c0a..1029aab 100644 --- a/src/jmteb/evaluators/reranking/evaluator.py +++ b/src/jmteb/evaluators/reranking/evaluator.py @@ -60,7 +60,7 @@ def __call__( val_query_embeddings = model.batch_encode_with_cache( text_list=[item.query for item in self.val_query_dataset], - prompt=self.query_prefix, + prefix=self.query_prefix, cache_path=Path(cache_dir) / "val_query.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, ) @@ -69,13 +69,13 @@ def __call__( else: test_query_embeddings = model.batch_encode_with_cache( text_list=[item.query for item in self.test_query_dataset], - prompt=self.query_prefix, + prefix=self.query_prefix, cache_path=Path(cache_dir) / "test_query.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, ) doc_embeddings = model.batch_encode_with_cache( text_list=[item.text for item in self.doc_dataset], - prompt=self.doc_prefix, + prefix=self.doc_prefix, cache_path=Path(cache_dir) / "corpus.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, ) diff --git a/src/jmteb/evaluators/retrieval/evaluator.py b/src/jmteb/evaluators/retrieval/evaluator.py index 9af7af4..64d48d9 100644 --- a/src/jmteb/evaluators/retrieval/evaluator.py +++ b/src/jmteb/evaluators/retrieval/evaluator.py @@ -71,7 +71,7 @@ def __call__( val_query_embeddings = model.batch_encode_with_cache( text_list=[item.query for item in self.val_query_dataset], - prompt=self.query_prefix, + prefix=self.query_prefix, cache_path=Path(cache_dir) / "val_query.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, ) @@ -80,14 +80,14 @@ def __call__( else: test_query_embeddings = model.batch_encode_with_cache( text_list=[item.query for item in self.test_query_dataset], - prompt=self.query_prefix, + prefix=self.query_prefix, cache_path=Path(cache_dir) / "test_query.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, ) doc_embeddings = model.batch_encode_with_cache( text_list=[item.text for item in self.doc_dataset], - prompt=self.doc_prefix, + prefix=self.doc_prefix, cache_path=Path(cache_dir) / "corpus.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, ) diff --git a/tests/conftest.py b/tests/conftest.py index a34c028..3080cf2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,7 +25,8 @@ def pytest_collection_modifyitems(config: pytest.Config, items: pytest.Parser): class DummyTextEmbedder(TextEmbedder): - def encode(self, text: str | list[str], prompt: str | None = None) -> np.ndarray: + + def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray: if isinstance(text, str): batch_size = 1 else: From fc8f558a25f9ab7b2770422147fbbef9054dc71c Mon Sep 17 00:00:00 2001 From: "shengzhe.li" Date: Tue, 18 Jun 2024 17:59:23 +0900 Subject: [PATCH 06/15] Add prefix for clustering, sts, pair classification and classification --- .../evaluators/classification/evaluator.py | 6 ++++ src/jmteb/evaluators/clustering/evaluator.py | 12 ++++++-- .../pair_classification/evaluator.py | 10 ++++++- src/jmteb/evaluators/sts/evaluator.py | 16 +++++++++-- .../test_classification_evaluator.py | 28 +++++++++++++++++-- tests/evaluator/test_clustering_evaluator.py | 21 ++++++++++++-- .../test_pair_classification_evaluator.py | 23 +++++++++++++-- tests/evaluator/test_sts_evaluator.py | 24 ++++++++++++++-- 8 files changed, 127 insertions(+), 13 deletions(-) diff --git a/src/jmteb/evaluators/classification/evaluator.py b/src/jmteb/evaluators/classification/evaluator.py index 5136e40..dbe2d8e 100644 --- a/src/jmteb/evaluators/classification/evaluator.py +++ b/src/jmteb/evaluators/classification/evaluator.py @@ -27,6 +27,7 @@ class ClassificationEvaluator(EmbeddingEvaluator): and delimited by comma, e.g., `macro, micro`. The first one is specified as the main index. classifiers (dict[str, Classifier]): classifiers to be evaluated. + prefix (str | None): prefix for sentences. Defaults to None. """ def __init__( @@ -36,6 +37,7 @@ def __init__( test_dataset: ClassificationDataset, average: str = "macro", classifiers: dict[str, Classifier] | None = None, + prefix: str | None = None, ) -> None: self.train_dataset = train_dataset self.val_dataset = val_dataset @@ -49,6 +51,7 @@ def __init__( for average_name in average if average_name.strip().lower() in ("micro", "macro", "samples", "weighted", "binary") ] or ["macro"] + self.prefix = prefix self.main_metric = f"{self.average[0]}_f1" def __call__( @@ -60,6 +63,7 @@ def __call__( logger.info("Encoding training and validation sentences...") X_train = model.batch_encode_with_cache( [item.text for item in self.train_dataset], + prefix=self.prefix, cache_path=Path(cache_dir) / "train_embeddings.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, ) @@ -67,6 +71,7 @@ def __call__( X_val = model.batch_encode_with_cache( [item.text for item in self.val_dataset], + prefix=self.prefix, cache_path=Path(cache_dir) / "val_embeddings.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, ) @@ -79,6 +84,7 @@ def __call__( else: X_test = model.batch_encode_with_cache( [item.text for item in self.test_dataset], + prefix=self.prefix, cache_path=Path(cache_dir) / "test_embeddings.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, ) diff --git a/src/jmteb/evaluators/clustering/evaluator.py b/src/jmteb/evaluators/clustering/evaluator.py index e43fd45..d8ef443 100644 --- a/src/jmteb/evaluators/clustering/evaluator.py +++ b/src/jmteb/evaluators/clustering/evaluator.py @@ -30,9 +30,13 @@ def __init__( self, val_dataset: ClusteringDataset, test_dataset: ClusteringDataset, + prefix: str | None = None, + random_seed: int | None = None, ) -> None: self.val_dataset = val_dataset self.test_dataset = test_dataset + self.prefix = prefix + self.random_seed = random_seed self.main_metric = "v_measure_score" def __call__( @@ -44,6 +48,7 @@ def __call__( logger.info("Converting validation data to embeddings...") val_embeddings = model.batch_encode_with_cache( [item.text for item in self.val_dataset], + prefix=self.prefix, cache_path=Path(cache_dir) / "val_embeddings.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, ) @@ -56,6 +61,7 @@ def __call__( else: test_embeddings = model.batch_encode_with_cache( [item.text for item in self.test_dataset], + prefix=self.prefix, cache_path=Path(cache_dir) / "test_embeddings.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, ) @@ -63,9 +69,11 @@ def __call__( n_clusters = len(set(test_labels)) model_constructors: dict[str, Callable[[], ClusterMixin]] = { - "MiniBatchKMeans": lambda: MiniBatchKMeans(n_clusters=n_clusters, n_init="auto"), + "MiniBatchKMeans": lambda: MiniBatchKMeans( + n_clusters=n_clusters, n_init="auto", random_state=self.random_seed + ), "AgglomerativeClustering": lambda: AgglomerativeClustering(n_clusters=n_clusters), - "BisectingKMeans": lambda: BisectingKMeans(n_clusters=n_clusters), + "BisectingKMeans": lambda: BisectingKMeans(n_clusters=n_clusters, random_state=self.random_seed), "Birch": lambda: Birch(n_clusters=n_clusters), } diff --git a/src/jmteb/evaluators/pair_classification/evaluator.py b/src/jmteb/evaluators/pair_classification/evaluator.py index d8ae1fa..6ec30d0 100644 --- a/src/jmteb/evaluators/pair_classification/evaluator.py +++ b/src/jmteb/evaluators/pair_classification/evaluator.py @@ -20,15 +20,21 @@ class PairClassificationEvaluator(EmbeddingEvaluator): Args: val_dataset (PairClassificationDataset): validation dataset test_dataset (PairClassificationDataset): test dataset + sentence1_prefix (str | None): prefix for sentence1. Defaults to None. + sentence2_prefix (str | None): prefix for sentence2. Defaults to None. """ def __init__( self, val_dataset: PairClassificationDataset, test_dataset: PairClassificationDataset, + sentence1_prefix: str | None = None, + sentence2_prefix: str | None = None, ) -> None: self.test_dataset = test_dataset self.val_dataset = val_dataset + self.sentence1_prefix = sentence1_prefix + self.sentence2_prefix = sentence2_prefix self.metrics = [ThresholdAccuracyMetric(), ThresholdF1Metric()] self.main_metric = "binary_f1" @@ -101,8 +107,8 @@ def __call__( }, ) - @staticmethod def _convert_to_embeddings( + self, model: TextEmbedder, dataset: PairClassificationDataset, split: str = "test", @@ -111,11 +117,13 @@ def _convert_to_embeddings( ) -> tuple[np.ndarray, np.ndarray, list[float]]: embeddings1 = model.batch_encode_with_cache( [item.sentence1 for item in dataset], + prefix=self.sentence1_prefix, cache_path=Path(cache_dir) / f"{split}_embeddings1.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, ) embeddings2 = model.batch_encode_with_cache( [item.sentence2 for item in dataset], + prefix=self.sentence2_prefix, cache_path=Path(cache_dir) / f"{split}_embeddings2.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, ) diff --git a/src/jmteb/evaluators/sts/evaluator.py b/src/jmteb/evaluators/sts/evaluator.py index 33f2ffb..8a20b3d 100644 --- a/src/jmteb/evaluators/sts/evaluator.py +++ b/src/jmteb/evaluators/sts/evaluator.py @@ -23,11 +23,21 @@ class STSEvaluator(EmbeddingEvaluator): Args: val_dataset (STSDataset): dev dataset for hyperparameter tuning test_dataset (STSDataset): test dataset + sentence1_prefix (str | None): prefix for sentence1. Defaults to None. + sentence2_prefix (str | None): prefix for sentence2. Defaults to None. """ - def __init__(self, val_dataset: STSDataset, test_dataset: STSDataset) -> None: + def __init__( + self, + val_dataset: STSDataset, + test_dataset: STSDataset, + sentence1_prefix: str | None = None, + sentence2_prefix: str | None = None, + ) -> None: self.val_dataset = val_dataset self.test_dataset = test_dataset + self.sentence1_prefix = sentence1_prefix + self.sentence2_prefix = sentence2_prefix self.main_metric = "spearman" def __call__( @@ -98,8 +108,8 @@ def _compute_similarity( "spearman": spearmanr(golden_scores, test_sim_score)[0], } - @staticmethod def _convert_to_embeddings( + self, model: TextEmbedder, dataset: STSDataset, split: str = "test", @@ -108,11 +118,13 @@ def _convert_to_embeddings( ) -> tuple[Tensor, Tensor, list[float]]: embeddings1 = model.batch_encode_with_cache( [item.sentence1 for item in dataset], + prefix=self.sentence1_prefix, cache_path=Path(cache_dir) / f"{split}_embeddings1.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, ) embeddings2 = model.batch_encode_with_cache( [item.sentence2 for item in dataset], + prefix=self.sentence2_prefix, cache_path=Path(cache_dir) / f"{split}_embeddings2.bin" if cache_dir is not None else None, overwrite_cache=overwrite_cache, ) diff --git a/tests/evaluator/test_classification_evaluator.py b/tests/evaluator/test_classification_evaluator.py index 60e911c..bce9964 100644 --- a/tests/evaluator/test_classification_evaluator.py +++ b/tests/evaluator/test_classification_evaluator.py @@ -8,11 +8,12 @@ from jmteb.evaluators.classification.data import JsonlClassificationDataset EXPECTED_OUTPUT_DICT_KEYS = {"val_scores", "test_scores", "optimal_classifier_name"} +PREFIX = "以下の文を分類する: " class DummyClassificationDataset(ClassificationDataset): - def __init__(self): - self._items = [ClassificationInstance(text=f"dummy text {i}", label=i // 2) for i in range(10)] + def __init__(self, prefix: str = ""): + self._items = [ClassificationInstance(text=f"{prefix}dummy text {i}", label=i // 2) for i in range(10)] def __len__(self): return len(self._items) @@ -43,6 +44,29 @@ def test_classification_evaluator(embedder): assert set(value.keys()) == expected_metrics +def test_classification_evaluator_with_prefix(embedder): + evaluator_with_prefix = ClassificationEvaluator( + train_dataset=DummyClassificationDataset(), + val_dataset=DummyClassificationDataset(), + test_dataset=DummyClassificationDataset(), + prefix=PREFIX, + classifiers={ + "logreg": LogRegClassifier(), + "knn": KnnClassifier(k=2, distance_metric="cosine"), + }, + ) + evaluator_with_manual_prefix = ClassificationEvaluator( + train_dataset=DummyClassificationDataset(prefix=PREFIX), + val_dataset=DummyClassificationDataset(prefix=PREFIX), + test_dataset=DummyClassificationDataset(prefix=PREFIX), + classifiers={ + "logreg": LogRegClassifier(), + "knn": KnnClassifier(k=2, distance_metric="cosine"), + }, + ) + assert evaluator_with_prefix(embedder) == evaluator_with_manual_prefix(embedder) + + def test_classification_jsonl_dataset(): dummy_jsonl_dataset = JsonlClassificationDataset( filename="tests/test_data/dummy_classification/val.jsonl", diff --git a/tests/evaluator/test_clustering_evaluator.py b/tests/evaluator/test_clustering_evaluator.py index 9272cfd..217d850 100644 --- a/tests/evaluator/test_clustering_evaluator.py +++ b/tests/evaluator/test_clustering_evaluator.py @@ -8,11 +8,13 @@ EXPECTED_OUTPUT_DICT_KEYS = {"val_scores", "test_scores", "optimal_clustering_model_name"} EXPECTED_METRIC_NAMES = {"v_measure_score", "completeness_score", "homogeneity_score"} EXPECTED_MODEL_NAMES = {"MiniBatchKMeans", "AgglomerativeClustering", "BisectingKMeans", "Birch"} +PREFIX = "以下の文を主題により分類する: " +RANDOM_SEED = 42 class DummyClusteringDataset(ClusteringDataset): - def __init__(self): - self._items = [ClusteringInstance(text=f"dummy text {i}", label=i // 2) for i in range(10)] + def __init__(self, prefix: str = ""): + self._items = [ClusteringInstance(text=f"{prefix}dummy text {i}", label=i // 2) for i in range(10)] def __len__(self): return len(self._items) @@ -37,6 +39,21 @@ def test_kmeans_clustering(embedder): assert set(results.details[score_splitname][clustering_model].keys()) == expected_metrics +def test_clustering_with_prefix(embedder): + evaluator_with_prefix = ClusteringEvaluator( + val_dataset=DummyClusteringDataset(), + test_dataset=DummyClusteringDataset(), + prefix=PREFIX, + random_seed=RANDOM_SEED, + ) + evaluator_with_manual_prefix = ClusteringEvaluator( + val_dataset=DummyClusteringDataset(prefix=PREFIX), + test_dataset=DummyClusteringDataset(prefix=PREFIX), + random_seed=RANDOM_SEED, + ) + assert evaluator_with_prefix(embedder) == evaluator_with_manual_prefix(embedder) + + def test_clustering_jsonl_dataset(): dataset = JsonlClusteringDataset( filename="tests/test_data/dummy_clustering/val.jsonl", diff --git a/tests/evaluator/test_pair_classification_evaluator.py b/tests/evaluator/test_pair_classification_evaluator.py index 443568d..c2ea8c4 100644 --- a/tests/evaluator/test_pair_classification_evaluator.py +++ b/tests/evaluator/test_pair_classification_evaluator.py @@ -8,12 +8,17 @@ EXPECTED_OUTPUT_DICT_KEYS = {"val_scores", "test_scores", "optimal_distance_metric"} EXPECTED_METRIC_NAMES = {"accuracy", "binary_f1", "accuracy_threshold", "binary_f1_threshold"} EXPECTED_DIST_FUNC_NAMES = {"cosine_distances", "dot_similarities", "manhatten_distances", "euclidean_distances"} +SENT1_PREFIX = "文1: " +SENT2_PREFIX = "文2: " class DummyBinaryDataset(PairClassificationDataset): - def __init__(self): + def __init__(self, sent1_prefix: str = "", sent2_prefix: str = ""): self._items = [ - PairClassificationInstance(f"dummy sentence 1 {i}", f"dummy sentence 2 {i}", i % 2) for i in range(10) + PairClassificationInstance( + f"{sent1_prefix}dummy sentence 1 {i}", f"{sent2_prefix}dummy sentence 2 {i}", i % 2 + ) + for i in range(10) ] def __len__(self): @@ -39,6 +44,20 @@ def test_pair_classification_binary(embedder): assert set(value.keys()) == EXPECTED_METRIC_NAMES +def test_pair_classification_binary_with_prefix(embedder): + evaluator_with_prefix = PairClassificationEvaluator( + val_dataset=DummyBinaryDataset(), + test_dataset=DummyBinaryDataset(), + sentence1_prefix=SENT1_PREFIX, + sentence2_prefix=SENT2_PREFIX, + ) + evaluator_with_manual_prefix = PairClassificationEvaluator( + val_dataset=DummyBinaryDataset(sent1_prefix=SENT1_PREFIX, sent2_prefix=SENT2_PREFIX), + test_dataset=DummyBinaryDataset(sent1_prefix=SENT1_PREFIX, sent2_prefix=SENT2_PREFIX), + ) + assert evaluator_with_prefix(embedder) == evaluator_with_manual_prefix(embedder) + + def test_pair_classification_jsonl_dataset(): dataset = JsonlPairClassificationDataset( filename="tests/test_data/dummy_pair_classification/binary.jsonl", diff --git a/tests/evaluator/test_sts_evaluator.py b/tests/evaluator/test_sts_evaluator.py index a151814..1f750b7 100644 --- a/tests/evaluator/test_sts_evaluator.py +++ b/tests/evaluator/test_sts_evaluator.py @@ -4,11 +4,17 @@ EXPECTED_OUTPUT_DICT_KEYS = {"val_scores", "test_scores", "optimal_similarity_metric"} EXPECTED_SIM_FUNC_NAMES = {"cosine_similarity", "manhatten_distance", "euclidean_distance", "dot_score"} EXPECTED_METRIC_NAMES = {"pearson", "spearman"} +SENT1_PREFIX = "文1: " +SENT2_PREFIX = "文2: " class DummySTSDataset(STSDataset): - def __init__(self): - self._items = [STSInstance("dummy sentence 1", "dummy sentence 2", i * 0.3) for i in range(10)] + + def __init__(self, sent1_prefix: str = "", sent2_prefix: str = ""): + self._items = [ + STSInstance(f"{sent1_prefix}dummy sentence 1", f"{sent2_prefix}dummy sentence 2", i * 0.3) + for i in range(10) + ] def __len__(self): return len(self._items) @@ -32,6 +38,20 @@ def test_sts(embedder): assert set(results.details[score_splitname][dist].keys()) == EXPECTED_METRIC_NAMES +def test_sts_with_prefix(embedder): + evaluator_with_prefix = STSEvaluator( + val_dataset=DummySTSDataset(), + test_dataset=DummySTSDataset(), + sentence1_prefix=SENT1_PREFIX, + sentence2_prefix=SENT2_PREFIX, + ) + evaluator_with_manual_prefix = STSEvaluator( + val_dataset=DummySTSDataset(sent1_prefix=SENT1_PREFIX, sent2_prefix=SENT2_PREFIX), + test_dataset=DummySTSDataset(sent1_prefix=SENT1_PREFIX, sent2_prefix=SENT2_PREFIX), + ) + assert evaluator_with_prefix(embedder) == evaluator_with_manual_prefix(embedder) + + def test_sts_jsonl_dataset(): dataset = JsonlSTSDataset( filename="tests/test_data/dummy_sts/val.jsonl", From 7a9e0cdb402db496d1667dc6b6297dd187a404cf Mon Sep 17 00:00:00 2001 From: "shengzhe.li" Date: Tue, 18 Jun 2024 18:02:14 +0900 Subject: [PATCH 07/15] Fix linting --- tests/evaluator/test_sts_evaluator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/evaluator/test_sts_evaluator.py b/tests/evaluator/test_sts_evaluator.py index 1f750b7..69469cc 100644 --- a/tests/evaluator/test_sts_evaluator.py +++ b/tests/evaluator/test_sts_evaluator.py @@ -9,7 +9,6 @@ class DummySTSDataset(STSDataset): - def __init__(self, sent1_prefix: str = "", sent2_prefix: str = ""): self._items = [ STSInstance(f"{sent1_prefix}dummy sentence 1", f"{sent2_prefix}dummy sentence 2", i * 0.3) From dbb9ba6acbeff6b770a2584afa1591c2ed048a40 Mon Sep 17 00:00:00 2001 From: "shengzhe.li" Date: Tue, 18 Jun 2024 18:05:36 +0900 Subject: [PATCH 08/15] Fix linting --- tests/conftest.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 3080cf2..b6d95ff 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,7 +25,6 @@ def pytest_collection_modifyitems(config: pytest.Config, items: pytest.Parser): class DummyTextEmbedder(TextEmbedder): - def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray: if isinstance(text, str): batch_size = 1 From 3cf3264be052c8d307e3db0aec703362789845eb Mon Sep 17 00:00:00 2001 From: "shengzhe.li" Date: Tue, 18 Jun 2024 18:28:58 +0900 Subject: [PATCH 09/15] Update private func name --- src/jmteb/embedders/sbert_embedder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jmteb/embedders/sbert_embedder.py b/src/jmteb/embedders/sbert_embedder.py index 98f3602..8f48d31 100644 --- a/src/jmteb/embedders/sbert_embedder.py +++ b/src/jmteb/embedders/sbert_embedder.py @@ -31,7 +31,7 @@ def __init__( def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray: if self.add_eos: - text = self.add_eos_func(text) + text = self._add_eos_func(text) return self.model.encode( text, prompt=prefix, @@ -41,7 +41,7 @@ def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray normalize_embeddings=self.normalize_embeddings, ) - def add_eos_func(self, text: str | list[str]) -> str | list[str]: + def _add_eos_func(self, text: str | list[str]) -> str | list[str]: try: eos_token = getattr(self.model.tokenizer, "eos_token") except AttributeError: From 4584167be9aa1ee4fd7ebf40fa252d854fc9e59f Mon Sep 17 00:00:00 2001 From: "shengzhe.li" Date: Tue, 18 Jun 2024 22:23:14 +0900 Subject: [PATCH 10/15] Fix get max_seq_length --- src/jmteb/embedders/sbert_embedder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jmteb/embedders/sbert_embedder.py b/src/jmteb/embedders/sbert_embedder.py index 8f48d31..bab3da7 100644 --- a/src/jmteb/embedders/sbert_embedder.py +++ b/src/jmteb/embedders/sbert_embedder.py @@ -26,7 +26,7 @@ def __init__( self.batch_size = batch_size self.device = device self.normalize_embeddings = normalize_embeddings - self.max_seq_length = self.model.max_seq_length + self.max_seq_length = getattr(self.model, "max_seq_length", None) self.add_eos = add_eos def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray: From 2b5bd922403f9412d75271d7feb277718b4413b3 Mon Sep 17 00:00:00 2001 From: "shengzhe.li" Date: Wed, 19 Jun 2024 11:37:03 +0900 Subject: [PATCH 11/15] Add prefix argument in OpenAIEmbedder --- src/jmteb/embedders/openai_embedder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/jmteb/embedders/openai_embedder.py b/src/jmteb/embedders/openai_embedder.py index d216752..859dcdf 100644 --- a/src/jmteb/embedders/openai_embedder.py +++ b/src/jmteb/embedders/openai_embedder.py @@ -60,7 +60,7 @@ def __init__(self, model: str = "text-embedding-3-small", dim: int | None = None else: self.dim = dim - def encode(self, text: str | list[str]) -> np.ndarray: + def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray: kwargs = {"dimensions": self.dim} if self.model != "text-embedding-ada-002" else {} # specifying `dimensions` is not allowed for "text-embedding-ada-002" if isinstance(text, str): @@ -84,10 +84,10 @@ def encode(self, text: str | list[str]) -> np.ndarray: def get_output_dim(self) -> int: return self.dim - def encode_and_truncate_text(self, text: str) -> list[int]: + def encode_and_truncate_text(self, text: str, prefix: str | None = None) -> list[int]: # Refer to https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken # return a list of token IDs if not text: text = " " logger.warning("Found empty string!") - return self.encoding.encode(text)[: self.max_token_length] + return self.encoding.encode(text, prefix)[: self.max_token_length] From 5bb0e32ebc66d6ee5cb378047d28a4e869b21a1d Mon Sep 17 00:00:00 2001 From: "shengzhe.li" Date: Wed, 19 Jun 2024 11:43:22 +0900 Subject: [PATCH 12/15] Fix prefix in OpenAIEmbedder --- src/jmteb/embedders/openai_embedder.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/jmteb/embedders/openai_embedder.py b/src/jmteb/embedders/openai_embedder.py index 859dcdf..0108c83 100644 --- a/src/jmteb/embedders/openai_embedder.py +++ b/src/jmteb/embedders/openai_embedder.py @@ -64,9 +64,9 @@ def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray kwargs = {"dimensions": self.dim} if self.model != "text-embedding-ada-002" else {} # specifying `dimensions` is not allowed for "text-embedding-ada-002" if isinstance(text, str): - token_ids: list[int] = self.encode_and_truncate_text(text) + token_ids: list[int] = self.encode_and_truncate_text(text, prefix) else: - token_ids: list[list[int]] = [self.encode_and_truncate_text(t) for t in text] + token_ids: list[list[int]] = [self.encode_and_truncate_text(t, prefix) for t in text] result = np.asarray( [ data.embedding @@ -90,4 +90,5 @@ def encode_and_truncate_text(self, text: str, prefix: str | None = None) -> list if not text: text = " " logger.warning("Found empty string!") - return self.encoding.encode(text, prefix)[: self.max_token_length] + # Ignore prefix in OpenAIEmbedder + return self.encoding.encode(text)[: self.max_token_length] From ca8546f190c0e4e9b345038db58df171de7c8fe4 Mon Sep 17 00:00:00 2001 From: "shengzhe.li" Date: Wed, 19 Jun 2024 13:59:57 +0900 Subject: [PATCH 13/15] Don't upgrade to numpy >=2.0 --- poetry.lock | 5 ++--- pyproject.toml | 1 + 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/poetry.lock b/poetry.lock index 2ba271f..8937812 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -2025,7 +2025,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -3384,4 +3383,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.10,<4.0" -content-hash = "c97d2d6a18d934c8566bd48e5a5f1f8d6c0da28d7e529b7f071bf850efea5e1f" +content-hash = "e40f842f52270eeceaf8810710c4709262b2ca68b3c38ef6117f0e50d81de8ff" diff --git a/pyproject.toml b/pyproject.toml index 468adc8..c7adf84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ smart-open = "^7.0.1" openai = "^1.16.2" pytest-mock = "^3.14.0" tiktoken = "^0.6.0" +numpy = "^1.26" [tool.poetry.group.dev.dependencies] black = "^23.11.0" From a57cac38c9c05121c28a1278111fbf602d630547 Mon Sep 17 00:00:00 2001 From: "shengzhe.li" Date: Wed, 19 Jun 2024 16:27:19 +0900 Subject: [PATCH 14/15] Add merge confirmation in PR template --- .github/pull_request_template.md | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index ab121f2..bbaae44 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -20,6 +20,7 @@ base branchを`dev`にするよう、お願いいたします。 ## 動作確認 - [ ] テストが通ることを確認した +- [ ] マージ先がdevブランチであることを確認した - [ ] ...