Skip to content

Commit

Permalink
Merge pull request #92 from KennethEnevoldsen/custom_embeddings
Browse files Browse the repository at this point in the history
Custom embeddings for E5 and Cohere + Interface changes to accomodate this
  • Loading branch information
x-tabdeveloping authored Jan 26, 2024
2 parents 5f9d47b + 2414337 commit aeb32ce
Show file tree
Hide file tree
Showing 9 changed files with 156 additions and 75 deletions.
38 changes: 34 additions & 4 deletions src/seb/interfaces/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, runtime_checkable
from typing import (TYPE_CHECKING, Any, Callable, Optional, Protocol,
runtime_checkable)

from numpy.typing import ArrayLike
from pydantic import BaseModel
Expand All @@ -21,7 +23,7 @@ def encode(
self,
sentences: list[str],
*,
task: "Task",
task: Optional["Task"] = None,
batch_size: int = 32,
**kwargs: Any,
) -> ArrayLike:
Expand Down Expand Up @@ -78,7 +80,8 @@ def from_disk(cls, path: Path) -> "ModelMeta":
return cls(**model_meta)


class EmbeddingModel(BaseModel):
@dataclass
class EmbeddingModel:
"""
An embedding model as implemented in SEB. It notably dynamically loads models (such that models are not loaded when a cache is hit)
and includes metadata pertaining to the specific model.
Expand Down Expand Up @@ -110,7 +113,7 @@ def encode(
self,
sentences: list[str],
*,
task: "Task",
task: Optional["Task"] = None,
batch_size: int = 32,
**kwargs: Any,
) -> ArrayLike:
Expand All @@ -127,3 +130,30 @@ def encode(
Embeddings for the given documents
"""
return self.model.encode(sentences, batch_size=batch_size, task=task, **kwargs)

def encode_queries(self, queries: list[str], batch_size: int, **kwargs):
try:
return self.model.encode_queries(queries, batch_size=batch_size, **kwargs)
except AttributeError:
return self.encode(queries, task=None, batch_size=batch_size, **kwargs)

def encode_corpus(self, corpus: list[dict[str, str]], batch_size: int, **kwargs):
try:
return self.model.encode_corpus(corpus, batch_size=batch_size, **kwargs)
except AttributeError:
sep = " "
if type(corpus) is dict:
sentences = [
(corpus["title"][i] + sep + corpus["text"][i]).strip()
if "title" in corpus
else corpus["text"][i].strip()
for i in range(len(corpus["text"]))
]
else:
sentences = [
(doc["title"] + sep + doc["text"]).strip()
if "title" in doc
else doc["text"].strip()
for doc in corpus
]
return self.encode(sentences, task=None, batch_size=batch_size, **kwargs)
22 changes: 11 additions & 11 deletions src/seb/interfaces/mteb_task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime
from functools import partial
from typing import Any, Union

import numpy as np
Expand All @@ -12,15 +13,6 @@
from .task import DescriptiveDatasetStats, Task


class MTEBTaskModel(Encoder):
def __init__(self, mteb_model: Encoder, task: Task) -> None:
self.mteb_model = mteb_model
self.task = task

def encode(self, texts: list[str], **kwargs: Any) -> ArrayLike: # type: ignore
return self.mteb_model.encode(texts, task=self.task, **kwargs)


class MTEBTask(Task):
def __init__(self, mteb_task: AbsTask) -> None:
self.mteb_task = mteb_task
Expand Down Expand Up @@ -76,8 +68,16 @@ def get_descriptive_stats(self) -> DescriptiveDatasetStats:

def evaluate(self, model: Encoder) -> TaskResult:
split = self.mteb_task.description["eval_splits"][0]
task_model = MTEBTaskModel(model, self)
scores = self.mteb_task.evaluate(task_model, split=split)
# Infusing task into encode()
original_encode = model.encode
try:
model.encode = partial(model.encode, task=self)
scores = self.mteb_task.evaluate(model, split=split)
except Exception as e:
raise e
finally:
# Resetting encode to original
model.encode = original_encode
if scores is None:
raise ValueError("MTEBTask evaluation failed.")

Expand Down
68 changes: 44 additions & 24 deletions src/seb/registered_models/cohere_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,54 +4,74 @@


import logging
from collections.abc import Sequence
from functools import partial
from typing import Any
from typing import Any, Optional

import torch

import seb
from seb.interfaces.task import Task
from seb.registries import models

logger = logging.getLogger(__name__)


class CohereTextEmbeddingModel(seb.Encoder):
def __init__(self, model_name: str) -> None:
def __init__(self, model_name: str, sep: str = " ") -> None:
self.model_name = model_name

@staticmethod
def create_sentence_blocks(
sentences: Sequence[str],
block_size: int,
) -> list[Sequence[str]]:
sent_blocks: list[Sequence[str]] = []
for i in range(0, len(sentences), block_size):
sent_blocks.append(sentences[i : i + block_size])
return sent_blocks
self.sep = sep

def get_embedding_dim(self) -> int:
v = self.encode(["get emb dim"])
v = self._embed(["get emb dim"], input_type="classification")
return v.shape[1]

def encode(
self,
sentences: Sequence[str],
batch_size: int = 32, # noqa: ARG002
embed_type: str = "classification",
**kwargs: Any, # noqa: ARG002
) -> torch.Tensor:
import cohere # type: ignore
def _embed(self, sentences: list[str], input_type: str) -> torch.Tensor:
import cohere

client = cohere.Client()
response = client.embed(
texts=list(sentences),
model=self.model_name,
input_type=embed_type,
input_type=input_type,
)

return torch.tensor(response.embeddings)

def encode(
self,
sentences: list[str],
batch_size: int = 32, # noqa: ARG002
*,
task: Optional[Task] = None,
**kwargs: Any, # noqa: ARG002
) -> torch.Tensor:
if task.task_type == "Classification":
input_type = "classification"
elif task.task_type == "Clustering":
input_type = "clustering"
else:
input_type = "search_document"
return self._embed(sentences, input_type=input_type)

def encode_queries(self, queries: list[str], batch_size: int, **kwargs):
return self._embed(queries, input_type="search_query")

def encode_corpus(self, corpus: list[dict[str, str]], batch_size: int, **kwargs):
if type(corpus) is dict:
sentences = [
(corpus["title"][i] + self.sep + corpus["text"][i]).strip()
if "title" in corpus
else corpus["text"][i].strip()
for i in range(len(corpus["text"]))
]
else:
sentences = [
(doc["title"] + self.sep + doc["text"]).strip()
if "title" in doc
else doc["text"].strip()
for doc in corpus
]
return self._embed(sentences, input_type="search_document")


@models.register("embed-multilingual-v3.0")
def create_embed_multilingual_v3() -> seb.EmbeddingModel:
Expand Down
5 changes: 4 additions & 1 deletion src/seb/registered_models/e5_mistral.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Iterable, Sequence
from itertools import islice
from typing import Any, TypeVar
from typing import Any, Optional, TypeVar

import torch
import torch.nn.functional as F
Expand All @@ -10,6 +10,7 @@

from seb import models
from seb.interfaces.model import EmbeddingModel, Encoder, ModelMeta
from seb.interfaces.task import Task

T = TypeVar("T")

Expand Down Expand Up @@ -73,6 +74,8 @@ def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tenso
def encode(
self,
sentences: list[str],
*,
task: Optional[Task] = None,
batch_size: int = 32,
**kwargs: Any, # noqa
) -> ArrayLike:
Expand Down
42 changes: 28 additions & 14 deletions src/seb/registered_models/e5_models.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,20 @@
from functools import partial
from typing import Any
from typing import Any, Optional

from numpy.typing import ArrayLike
from sentence_transformers import SentenceTransformer

from seb import models

from ..interfaces.model import EmbeddingModel, Encoder, ModelMeta
from ..interfaces.task import Task
from .hf_models import get_sentence_transformer


class E5Wrapper(Encoder):
def __init__(self, model_name: str):
def __init__(self, model_name: str, sep: str = " "):
self.model_name = model_name
self.mdl = get_sentence_transformer(model_name)

@staticmethod
def preprocess(sentences: list[str]) -> list[str]:
# following the documentation it is better to generally do this:
return ["query: " + sentence for sentence in sentences]

# but it does not work slightly better than this:
# return sentences # noqa
self.mdl = SentenceTransformer(model_name)
self.sep = sep

def encode(
self,
Expand All @@ -31,8 +24,29 @@ def encode(
batch_size: int = 32,
**kwargs: Any,
) -> ArrayLike:
sentences = self.preprocess(sentences)
return self.mdl.encode(sentences, batch_size=batch_size, task=task, **kwargs) # type: ignore
return self.encode_queries(sentences, batch_size=batch_size, **kwargs)

def encode_queries(self, queries: list[str], batch_size: int, **kwargs):
sentences = ["query: " + sentence for sentence in queries]
return self.mdl.encode(sentences, batch_size=batch_size, **kwargs)

def encode_corpus(self, corpus: list[dict[str, str]], batch_size: int, **kwargs):
if type(corpus) is dict:
sentences = [
(corpus["title"][i] + self.sep + corpus["text"][i]).strip()
if "title" in corpus
else corpus["text"][i].strip()
for i in range(len(corpus["text"]))
]
else:
sentences = [
(doc["title"] + self.sep + doc["text"]).strip()
if "title" in doc
else doc["text"].strip()
for doc in corpus
]
sentences = ["passage: " + sentence for sentence in sentences]
return self.mdl.encode(sentences, batch_size=batch_size, **kwargs)


# English
Expand Down
18 changes: 13 additions & 5 deletions src/seb/registered_models/fairseq_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch

from seb.interfaces.model import EmbeddingModel, Encoder, ModelMeta
from seb.interfaces.task import Task
from seb.registries import models


Expand Down Expand Up @@ -35,9 +36,7 @@ def __init__(
Norwegian Nynorsk, and Norwegian Bokmål, respectively.
"""
from sonar.models.sonar_text import ( # type: ignore
load_sonar_text_encoder_model,
load_sonar_tokenizer,
)
load_sonar_text_encoder_model, load_sonar_tokenizer)

super().__init__()

Expand All @@ -61,6 +60,7 @@ def encode( # type: ignore
self,
input: Union[Path, Sequence[str]], # noqa: A002
*,
task: Optional[Task] = None,
batch_size: int,
**kwargs: Any, # noqa: ARG002
) -> torch.Tensor:
Expand All @@ -73,7 +73,11 @@ def encode( # type: ignore
tokenizer_encoder = self.tokenizer.create_encoder(lang=self.source_lang) # type: ignore

pipeline = (
(read_text(input) if isinstance(input, (str, Path)) else read_sequence(input))
(
read_text(input)
if isinstance(input, (str, Path))
else read_sequence(input)
)
.map(tokenizer_encoder)
.bucket(batch_size)
.map(Collater(self.tokenizer.vocab_info.pad_idx)) # type: ignore
Expand All @@ -97,7 +101,11 @@ def get_sonar_model(source_lang: str) -> SonarTextToEmbeddingModelPipeline:
source_lang=source_lang,
)
except ImportError:
msg = "Could not fetch Sonar Models. Make sure you have" + "fairseq2 installed. This is currently only supported for " + "Linux."
msg = (
"Could not fetch Sonar Models. Make sure you have"
+ "fairseq2 installed. This is currently only supported for "
+ "Linux."
)
raise ImportError(msg) # noqa B904


Expand Down
2 changes: 1 addition & 1 deletion src/seb/registered_models/hf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def encode( # type: ignore
sentences: list[str],
*,
batch_size: int,
task: Task, # noqa: ARG002
task: Optional[Task] = None, # noqa: ARG002
**kwargs: Any,
) -> ArrayLike:
return super().encode(sentences, batch_size=batch_size, **kwargs) # type: ignore
Expand Down
4 changes: 2 additions & 2 deletions src/seb/registered_models/translate_e5_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Sequence
from functools import partial
from typing import Any
from typing import Any, Optional

import torch
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
Expand Down Expand Up @@ -30,7 +30,7 @@ def encode(
self,
sentences: list[str],
*,
task: seb.Task, # noqa: ARG002
task: Optional[seb.Task] = None, # noqa: ARG002
batch_size: int = 32,
**kwargs: Any,
) -> torch.Tensor:
Expand Down
Loading

0 comments on commit aeb32ce

Please sign in to comment.