-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'add-jina' of https://github.com/KennethEnevoldsen/scand…
…inavian-embedding-benchmark into add-jina
- Loading branch information
Showing
6 changed files
with
115 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
{"task_name":"LCC","task_description":"The leipzig corpora collection, annotated for sentiment","task_version":"1.1.1","time_of_run":"2024-11-13T21:33:38.301520","scores":{"da":{"accuracy":0.5946666666666667,"f1":0.5872722607515735,"accuracy_stderr":0.03222145592958552,"f1_stderr":0.0278698114421992,"main_score":0.5946666666666667}},"main_score":"accuracy"} | ||
{"task_name":"LCC","task_description":"The leipzig corpora collection, annotated for sentiment","task_version":"1.1.1","time_of_run":"2024-11-14T13:57:48.545664","scores":{"da":{"accuracy":0.5999999999999999,"f1":0.5885173430161176,"accuracy_stderr":0.03055050463303893,"f1_stderr":0.02601687642478716,"main_score":0.5999999999999999}},"main_score":"accuracy"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import logging | ||
from datetime import date | ||
from functools import partial | ||
from typing import Any, Literal, Optional | ||
|
||
import numpy as np | ||
from sentence_transformers import SentenceTransformer | ||
|
||
from seb.interfaces.model import LazyLoadEncoder, ModelMeta, SebModel | ||
from seb.interfaces.task import Task | ||
from seb.registries import models | ||
|
||
from .normalize_to_ndarray import normalize_to_ndarray | ||
from .sentence_transformer_models import silence_warnings_from_sentence_transformers, wrap_sentence_transformer | ||
|
||
|
||
class Jinav3EncoderWithTaskEncode(SentenceTransformer): | ||
""" | ||
A sentence transformer wrapper that allows for encoding with a task. | ||
""" | ||
|
||
def encode( # type: ignore | ||
self, | ||
sentences: list[str], | ||
*, | ||
batch_size: int = 32, | ||
task: Optional[Task] = None, | ||
encode_type: Literal["query", "passage"] = "passage", | ||
**kwargs: Any, | ||
) -> np.ndarray: | ||
task_prompt = None | ||
if task is not None: | ||
if task.task_type in ["STS", "BitextMining"]: | ||
task_prompt = "text-matching" | ||
if task.task_type in ["Classification"]: | ||
task_prompt = "classification" | ||
if task.task_type in ["Clustering"]: | ||
task_prompt = "seperation" | ||
if task.task_type in ["Retrieval"] and encode_type == "query": | ||
task_prompt = "retrieval.query" | ||
if task.task_type in ["Retrieval"] and encode_type == "passage": | ||
task_prompt = "retrieval.passage" | ||
|
||
if task_prompt is None: | ||
emb = super().encode(sentences, batch_size=batch_size, **kwargs) | ||
else: | ||
emb = super().encode(sentences, batch_size=batch_size, prompt=task_prompt, **kwargs) | ||
return normalize_to_ndarray(emb) | ||
|
||
def encode_corpus(self, corpus: list[dict[str, str]], **kwargs: Any) -> np.ndarray: | ||
sep = " " | ||
if isinstance(corpus, dict): | ||
sentences = [ | ||
(corpus["title"][i] + sep + corpus["text"][i]).strip() if "title" in corpus else corpus["text"][i].strip() # type: ignore | ||
for i in range(len(corpus["text"])) # type: ignore | ||
] | ||
else: | ||
sentences = [(doc["title"] + sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus] | ||
return self.encode(sentences, encode_type="passage", **kwargs) | ||
|
||
def encode_queries(self, queries: list[str], **kwargs: Any) -> np.ndarray: | ||
return self.encode(queries, encode_type="query", **kwargs) | ||
|
||
|
||
def wrap_jina_sentence_transformer(model_name: str, max_seq_length: Optional[int] = None, **kwargs: Any) -> Jinav3EncoderWithTaskEncode: | ||
silence_warnings_from_sentence_transformers() | ||
mdl = Jinav3EncoderWithTaskEncode(model_name, **kwargs) | ||
if max_seq_length is not None: | ||
mdl.max_seq_length = max_seq_length | ||
return mdl | ||
|
||
|
||
@models.register("jina-embeddings-v3") | ||
def create_jina_embeddings_v3() -> SebModel: | ||
hf_name = "jinaai/jina-embeddings-v3" | ||
meta = ModelMeta( | ||
name=hf_name.split("/")[-1], | ||
huggingface_name=hf_name, | ||
reference=f"https://huggingface.co/{hf_name}", | ||
languages=[], | ||
open_source=True, | ||
embedding_size=1024, | ||
architecture="XLM-R", | ||
release_date=date(2024, 8, 5), | ||
) | ||
return SebModel( | ||
encoder=LazyLoadEncoder(partial(wrap_jina_sentence_transformer, model_name=hf_name, trust_remote_code=True)), # type: ignore | ||
meta=meta, | ||
) | ||
|
||
|
||
@models.register("jina-embedding-b-en-v1") | ||
def create_jina_base() -> SebModel: | ||
hf_name = "jinaai/jina-embedding-b-en-v1" | ||
meta = ModelMeta( | ||
name=hf_name.split("/")[-1], | ||
huggingface_name=hf_name, | ||
reference=f"https://huggingface.co/{hf_name}", | ||
languages=["en"], | ||
open_source=True, | ||
embedding_size=768, | ||
architecture="T5", | ||
release_date=date(2023, 7, 7), | ||
) | ||
return SebModel( | ||
encoder=LazyLoadEncoder(partial(wrap_sentence_transformer, model_name=hf_name)), # type: ignore | ||
meta=meta, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.