From ecee037d73448de50320d58063d08917fcd9d77b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1rton=20Kardos?= Date: Fri, 26 Jan 2024 11:20:13 +0100 Subject: [PATCH] Added encode_queries and encode_documents to EmbeddingModel, made task optional --- src/seb/interfaces/model.py | 34 +++++++++++++++++-- src/seb/registered_models/cohere_models.py | 4 +-- src/seb/registered_models/e5_mistral.py | 5 ++- src/seb/registered_models/e5_models.py | 2 ++ src/seb/registered_models/fairseq_models.py | 19 ++++++++--- src/seb/registered_models/hf_models.py | 2 +- .../registered_models/translate_e5_models.py | 4 +-- 7 files changed, 56 insertions(+), 14 deletions(-) diff --git a/src/seb/interfaces/model.py b/src/seb/interfaces/model.py index 5ad07a28..18f624cb 100644 --- a/src/seb/interfaces/model.py +++ b/src/seb/interfaces/model.py @@ -1,6 +1,7 @@ import json from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, runtime_checkable +from typing import (TYPE_CHECKING, Any, Callable, Optional, Protocol, + runtime_checkable) from numpy.typing import ArrayLike from pydantic import BaseModel @@ -21,7 +22,7 @@ def encode( self, sentences: list[str], *, - task: "Task", + task: Optional["Task"] = None, batch_size: int = 32, **kwargs: Any, ) -> ArrayLike: @@ -110,7 +111,7 @@ def encode( self, sentences: list[str], *, - task: "Task", + task: Optional["Task"] = None, batch_size: int = 32, **kwargs: Any, ) -> ArrayLike: @@ -127,3 +128,30 @@ def encode( Embeddings for the given documents """ return self.model.encode(sentences, batch_size=batch_size, task=task, **kwargs) + + def encode_queries(self, queries: list[str], batch_size: int, **kwargs): + try: + return self.model.encode_queries(queries, batch_size=batch_size, **kwargs) + except AttributeError: + return self.encode(queries, task=None, batch_size=batch_size, **kwargs) + + def encode_corpus(self, corpus: list[dict[str, str]], batch_size: int, **kwargs): + try: + return self.model.encode_corpus(corpus, batch_size=batch_size, **kwargs) + except AttributeError: + sep = " " + if type(corpus) is dict: + sentences = [ + (corpus["title"][i] + sep + corpus["text"][i]).strip() + if "title" in corpus + else corpus["text"][i].strip() + for i in range(len(corpus["text"])) + ] + else: + sentences = [ + (doc["title"] + sep + doc["text"]).strip() + if "title" in doc + else doc["text"].strip() + for doc in corpus + ] + return self.encode(sentences, task=None, batch_size=batch_size, **kwargs) diff --git a/src/seb/registered_models/cohere_models.py b/src/seb/registered_models/cohere_models.py index b8384fa2..d11830db 100644 --- a/src/seb/registered_models/cohere_models.py +++ b/src/seb/registered_models/cohere_models.py @@ -5,7 +5,7 @@ import logging from functools import partial -from typing import Any +from typing import Any, Optional import torch @@ -41,7 +41,7 @@ def encode( sentences: list[str], batch_size: int = 32, # noqa: ARG002 *, - task: Task, + task: Optional[Task] = None, **kwargs: Any, # noqa: ARG002 ) -> torch.Tensor: if task.task_type == "Classification": diff --git a/src/seb/registered_models/e5_mistral.py b/src/seb/registered_models/e5_mistral.py index 4dafe5b7..6a2e5916 100644 --- a/src/seb/registered_models/e5_mistral.py +++ b/src/seb/registered_models/e5_mistral.py @@ -1,6 +1,6 @@ from collections.abc import Iterable, Sequence from itertools import islice -from typing import Any, TypeVar +from typing import Any, Optional, TypeVar import torch import torch.nn.functional as F @@ -10,6 +10,7 @@ from seb import models from seb.interfaces.model import EmbeddingModel, Encoder, ModelMeta +from seb.interfaces.task import Task T = TypeVar("T") @@ -77,6 +78,8 @@ def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tenso def encode( self, sentences: list[str], + *, + task: Optional[Task] = None, batch_size: int = 32, **kwargs: Any, # noqa ) -> ArrayLike: diff --git a/src/seb/registered_models/e5_models.py b/src/seb/registered_models/e5_models.py index 278a1e36..cf9205e5 100644 --- a/src/seb/registered_models/e5_models.py +++ b/src/seb/registered_models/e5_models.py @@ -27,10 +27,12 @@ def encode( return self.encode_queries(sentences, batch_size=batch_size, **kwargs) def encode_queries(self, queries: list[str], batch_size: int, **kwargs): + print("ENCODING QUERYYYYY!!!!") sentences = ["query: " + sentence for sentence in queries] return self.mdl.encode(sentences, batch_size=batch_size, **kwargs) def encode_corpus(self, corpus: list[dict[str, str]], batch_size: int, **kwargs): + print("ENCODING CORPUS!!!!") if type(corpus) is dict: sentences = [ (corpus["title"][i] + self.sep + corpus["text"][i]).strip() diff --git a/src/seb/registered_models/fairseq_models.py b/src/seb/registered_models/fairseq_models.py index 8d76f82f..55b7137e 100644 --- a/src/seb/registered_models/fairseq_models.py +++ b/src/seb/registered_models/fairseq_models.py @@ -6,6 +6,7 @@ import torch from seb.interfaces.model import EmbeddingModel, Encoder, ModelMeta +from seb.interfaces.task import Task from seb.registries import models @@ -35,9 +36,7 @@ def __init__( Norwegian Nynorsk, and Norwegian Bokmål, respectively. """ from sonar.models.sonar_text import ( # type: ignore - load_sonar_text_encoder_model, - load_sonar_tokenizer, - ) + load_sonar_text_encoder_model, load_sonar_tokenizer) super().__init__() @@ -60,6 +59,8 @@ def __init__( def encode( self, input: Union[Path, Sequence[str]], # noqa: A002 + *, + task: Optional[Task] = None, batch_size: int, **kwargs: dict, # noqa: ARG002 ) -> torch.Tensor: @@ -72,7 +73,11 @@ def encode( tokenizer_encoder = self.tokenizer.create_encoder(lang=self.source_lang) # type: ignore pipeline = ( - (read_text(input) if isinstance(input, (str, Path)) else read_sequence(input)) + ( + read_text(input) + if isinstance(input, (str, Path)) + else read_sequence(input) + ) .map(tokenizer_encoder) .bucket(batch_size) .map(Collater(self.tokenizer.vocab_info.pad_idx)) # type: ignore @@ -96,7 +101,11 @@ def get_sonar_model(source_lang: str) -> SonarTextToEmbeddingModelPipeline: source_lang=source_lang, ) except ImportError: - msg = "Could not fetch Sonar Models. Make sure you have" + "fairseq2 installed. This is currently only supported for " + "Linux." + msg = ( + "Could not fetch Sonar Models. Make sure you have" + + "fairseq2 installed. This is currently only supported for " + + "Linux." + ) raise ImportError(msg) # noqa B904 diff --git a/src/seb/registered_models/hf_models.py b/src/seb/registered_models/hf_models.py index 2e91d7c0..57a2a12f 100644 --- a/src/seb/registered_models/hf_models.py +++ b/src/seb/registered_models/hf_models.py @@ -29,7 +29,7 @@ def encode( sentences: list[str], *, batch_size: int, - task: Task, # noqa: ARG002 + task: Optional[Task] = None, # noqa: ARG002 **kwargs: Any, ) -> ArrayLike: return super().encode(sentences, batch_size=batch_size, **kwargs) # type: ignore diff --git a/src/seb/registered_models/translate_e5_models.py b/src/seb/registered_models/translate_e5_models.py index ea687588..165e712a 100644 --- a/src/seb/registered_models/translate_e5_models.py +++ b/src/seb/registered_models/translate_e5_models.py @@ -1,6 +1,6 @@ from collections.abc import Sequence from functools import partial -from typing import Any +from typing import Any, Optional import torch from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer @@ -34,7 +34,7 @@ def encode( self, sentences: list[str], *, - task: seb.Task, # noqa: ARG002 + task: Optional[seb.Task] = None, # noqa: ARG002 batch_size: int = 32, **kwargs: Any, ) -> torch.Tensor: