Skip to content

Commit

Permalink
Added encode_queries and encode_documents to EmbeddingModel, made tas…
Browse files Browse the repository at this point in the history
…k optional
  • Loading branch information
x-tabdeveloping committed Jan 26, 2024
1 parent 6742fdd commit ecee037
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 14 deletions.
34 changes: 31 additions & 3 deletions src/seb/interfaces/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, runtime_checkable
from typing import (TYPE_CHECKING, Any, Callable, Optional, Protocol,
runtime_checkable)

from numpy.typing import ArrayLike
from pydantic import BaseModel
Expand All @@ -21,7 +22,7 @@ def encode(
self,
sentences: list[str],
*,
task: "Task",
task: Optional["Task"] = None,
batch_size: int = 32,
**kwargs: Any,
) -> ArrayLike:
Expand Down Expand Up @@ -110,7 +111,7 @@ def encode(
self,
sentences: list[str],
*,
task: "Task",
task: Optional["Task"] = None,
batch_size: int = 32,
**kwargs: Any,
) -> ArrayLike:
Expand All @@ -127,3 +128,30 @@ def encode(
Embeddings for the given documents
"""
return self.model.encode(sentences, batch_size=batch_size, task=task, **kwargs)

def encode_queries(self, queries: list[str], batch_size: int, **kwargs):
try:
return self.model.encode_queries(queries, batch_size=batch_size, **kwargs)
except AttributeError:
return self.encode(queries, task=None, batch_size=batch_size, **kwargs)

def encode_corpus(self, corpus: list[dict[str, str]], batch_size: int, **kwargs):
try:
return self.model.encode_corpus(corpus, batch_size=batch_size, **kwargs)
except AttributeError:
sep = " "
if type(corpus) is dict:
sentences = [
(corpus["title"][i] + sep + corpus["text"][i]).strip()
if "title" in corpus
else corpus["text"][i].strip()
for i in range(len(corpus["text"]))
]
else:
sentences = [
(doc["title"] + sep + doc["text"]).strip()
if "title" in doc
else doc["text"].strip()
for doc in corpus
]
return self.encode(sentences, task=None, batch_size=batch_size, **kwargs)
4 changes: 2 additions & 2 deletions src/seb/registered_models/cohere_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import logging
from functools import partial
from typing import Any
from typing import Any, Optional

import torch

Expand Down Expand Up @@ -41,7 +41,7 @@ def encode(
sentences: list[str],
batch_size: int = 32, # noqa: ARG002
*,
task: Task,
task: Optional[Task] = None,
**kwargs: Any, # noqa: ARG002
) -> torch.Tensor:
if task.task_type == "Classification":
Expand Down
5 changes: 4 additions & 1 deletion src/seb/registered_models/e5_mistral.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Iterable, Sequence
from itertools import islice
from typing import Any, TypeVar
from typing import Any, Optional, TypeVar

import torch
import torch.nn.functional as F
Expand All @@ -10,6 +10,7 @@

from seb import models
from seb.interfaces.model import EmbeddingModel, Encoder, ModelMeta
from seb.interfaces.task import Task

T = TypeVar("T")

Expand Down Expand Up @@ -77,6 +78,8 @@ def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tenso
def encode(
self,
sentences: list[str],
*,
task: Optional[Task] = None,
batch_size: int = 32,
**kwargs: Any, # noqa
) -> ArrayLike:
Expand Down
2 changes: 2 additions & 0 deletions src/seb/registered_models/e5_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ def encode(
return self.encode_queries(sentences, batch_size=batch_size, **kwargs)

def encode_queries(self, queries: list[str], batch_size: int, **kwargs):
print("ENCODING QUERYYYYY!!!!")
sentences = ["query: " + sentence for sentence in queries]
return self.mdl.encode(sentences, batch_size=batch_size, **kwargs)

def encode_corpus(self, corpus: list[dict[str, str]], batch_size: int, **kwargs):
print("ENCODING CORPUS!!!!")
if type(corpus) is dict:
sentences = [
(corpus["title"][i] + self.sep + corpus["text"][i]).strip()
Expand Down
19 changes: 14 additions & 5 deletions src/seb/registered_models/fairseq_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch

from seb.interfaces.model import EmbeddingModel, Encoder, ModelMeta
from seb.interfaces.task import Task
from seb.registries import models


Expand Down Expand Up @@ -35,9 +36,7 @@ def __init__(
Norwegian Nynorsk, and Norwegian Bokmål, respectively.
"""
from sonar.models.sonar_text import ( # type: ignore
load_sonar_text_encoder_model,
load_sonar_tokenizer,
)
load_sonar_text_encoder_model, load_sonar_tokenizer)

super().__init__()

Expand All @@ -60,6 +59,8 @@ def __init__(
def encode(
self,
input: Union[Path, Sequence[str]], # noqa: A002
*,
task: Optional[Task] = None,
batch_size: int,
**kwargs: dict, # noqa: ARG002
) -> torch.Tensor:
Expand All @@ -72,7 +73,11 @@ def encode(
tokenizer_encoder = self.tokenizer.create_encoder(lang=self.source_lang) # type: ignore

pipeline = (
(read_text(input) if isinstance(input, (str, Path)) else read_sequence(input))
(
read_text(input)
if isinstance(input, (str, Path))
else read_sequence(input)
)
.map(tokenizer_encoder)
.bucket(batch_size)
.map(Collater(self.tokenizer.vocab_info.pad_idx)) # type: ignore
Expand All @@ -96,7 +101,11 @@ def get_sonar_model(source_lang: str) -> SonarTextToEmbeddingModelPipeline:
source_lang=source_lang,
)
except ImportError:
msg = "Could not fetch Sonar Models. Make sure you have" + "fairseq2 installed. This is currently only supported for " + "Linux."
msg = (
"Could not fetch Sonar Models. Make sure you have"
+ "fairseq2 installed. This is currently only supported for "
+ "Linux."
)
raise ImportError(msg) # noqa B904


Expand Down
2 changes: 1 addition & 1 deletion src/seb/registered_models/hf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def encode(
sentences: list[str],
*,
batch_size: int,
task: Task, # noqa: ARG002
task: Optional[Task] = None, # noqa: ARG002
**kwargs: Any,
) -> ArrayLike:
return super().encode(sentences, batch_size=batch_size, **kwargs) # type: ignore
Expand Down
4 changes: 2 additions & 2 deletions src/seb/registered_models/translate_e5_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Sequence
from functools import partial
from typing import Any
from typing import Any, Optional

import torch
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
Expand Down Expand Up @@ -34,7 +34,7 @@ def encode(
self,
sentences: list[str],
*,
task: seb.Task, # noqa: ARG002
task: Optional[seb.Task] = None, # noqa: ARG002
batch_size: int = 32,
**kwargs: Any,
) -> torch.Tensor:
Expand Down

0 comments on commit ecee037

Please sign in to comment.