Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Nov 14, 2024
2 parents ccb84b1 + a18205c commit 83af9cb
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 24 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ docs = [
"mkdocs-material==9.1.21",
"mkdocstrings[python-legacy]==0.22.0",
# for managing tables
"datawrapper>=0.5.3",
"datawrapper>=0.5.3,<0.6.0",
# for tutorials
"jupyter>=1.0.0",
]
Expand Down
2 changes: 1 addition & 1 deletion src/seb/cache/jinaai__jina-embeddings-v3/LCC.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"task_name":"LCC","task_description":"The leipzig corpora collection, annotated for sentiment","task_version":"1.1.1","time_of_run":"2024-11-13T21:33:38.301520","scores":{"da":{"accuracy":0.5946666666666667,"f1":0.5872722607515735,"accuracy_stderr":0.03222145592958552,"f1_stderr":0.0278698114421992,"main_score":0.5946666666666667}},"main_score":"accuracy"}
{"task_name":"LCC","task_description":"The leipzig corpora collection, annotated for sentiment","task_version":"1.1.1","time_of_run":"2024-11-14T13:57:48.545664","scores":{"da":{"accuracy":0.5999999999999999,"f1":0.5885173430161176,"accuracy_stderr":0.03055050463303893,"f1_stderr":0.02601687642478716,"main_score":0.5999999999999999}},"main_score":"accuracy"}
1 change: 1 addition & 0 deletions src/seb/registered_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .translate_e5_models import *
from .voyage_models import *
from .bge_models import *
from .jina_models import *
108 changes: 108 additions & 0 deletions src/seb/registered_models/jina_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import logging
from datetime import date
from functools import partial
from typing import Any, Literal, Optional

import numpy as np
from sentence_transformers import SentenceTransformer

from seb.interfaces.model import LazyLoadEncoder, ModelMeta, SebModel
from seb.interfaces.task import Task
from seb.registries import models

from .normalize_to_ndarray import normalize_to_ndarray
from .sentence_transformer_models import silence_warnings_from_sentence_transformers, wrap_sentence_transformer


class Jinav3EncoderWithTaskEncode(SentenceTransformer):
"""
A sentence transformer wrapper that allows for encoding with a task.
"""

def encode( # type: ignore
self,
sentences: list[str],
*,
batch_size: int = 32,
task: Optional[Task] = None,
encode_type: Literal["query", "passage"] = "passage",
**kwargs: Any,
) -> np.ndarray:
task_prompt = None
if task is not None:
if task.task_type in ["STS", "BitextMining"]:
task_prompt = "text-matching"
if task.task_type in ["Classification"]:
task_prompt = "classification"
if task.task_type in ["Clustering"]:
task_prompt = "seperation"
if task.task_type in ["Retrieval"] and encode_type == "query":
task_prompt = "retrieval.query"
if task.task_type in ["Retrieval"] and encode_type == "passage":
task_prompt = "retrieval.passage"

if task_prompt is None:
emb = super().encode(sentences, batch_size=batch_size, **kwargs)
else:
emb = super().encode(sentences, batch_size=batch_size, prompt=task_prompt, **kwargs)
return normalize_to_ndarray(emb)

def encode_corpus(self, corpus: list[dict[str, str]], **kwargs: Any) -> np.ndarray:
sep = " "
if isinstance(corpus, dict):
sentences = [
(corpus["title"][i] + sep + corpus["text"][i]).strip() if "title" in corpus else corpus["text"][i].strip() # type: ignore
for i in range(len(corpus["text"])) # type: ignore
]
else:
sentences = [(doc["title"] + sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus]
return self.encode(sentences, encode_type="passage", **kwargs)

def encode_queries(self, queries: list[str], **kwargs: Any) -> np.ndarray:
return self.encode(queries, encode_type="query", **kwargs)


def wrap_jina_sentence_transformer(model_name: str, max_seq_length: Optional[int] = None, **kwargs: Any) -> Jinav3EncoderWithTaskEncode:
silence_warnings_from_sentence_transformers()
mdl = Jinav3EncoderWithTaskEncode(model_name, **kwargs)
if max_seq_length is not None:
mdl.max_seq_length = max_seq_length
return mdl


@models.register("jina-embeddings-v3")
def create_jina_embeddings_v3() -> SebModel:
hf_name = "jinaai/jina-embeddings-v3"
meta = ModelMeta(
name=hf_name.split("/")[-1],
huggingface_name=hf_name,
reference=f"https://huggingface.co/{hf_name}",
languages=[],
open_source=True,
embedding_size=1024,
architecture="XLM-R",
release_date=date(2024, 8, 5),
)
return SebModel(
encoder=LazyLoadEncoder(partial(wrap_jina_sentence_transformer, model_name=hf_name, trust_remote_code=True)), # type: ignore
meta=meta,
)


@models.register("jina-embedding-b-en-v1")
def create_jina_base() -> SebModel:
hf_name = "jinaai/jina-embedding-b-en-v1"
meta = ModelMeta(
name=hf_name.split("/")[-1],
huggingface_name=hf_name,
reference=f"https://huggingface.co/{hf_name}",
languages=["en"],
open_source=True,
embedding_size=768,
architecture="T5",
release_date=date(2023, 7, 7),
)
return SebModel(
encoder=LazyLoadEncoder(partial(wrap_sentence_transformer, model_name=hf_name)), # type: ignore
meta=meta,
)
18 changes: 0 additions & 18 deletions src/seb/registered_models/sentence_transformer_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,24 +87,6 @@ def create_all_mini_lm_l6_v2() -> SebModel:
)


@models.register("jina-embeddings-v3")
def create_jina_embeddings_v3() -> SebModel:
hf_name = "jinaai/jina-embeddings-v3"
meta = ModelMeta(
name=hf_name.split("/")[-1],
huggingface_name=hf_name,
reference=f"https://huggingface.co/{hf_name}",
languages=[],
open_source=True,
embedding_size=1024,
architecture="XLM-R",
release_date=date(2024, 8, 5),
)
return SebModel(
encoder=LazyLoadEncoder(partial(wrap_sentence_transformer, model_name=hf_name, trust_remote_code=True)), # type: ignore
meta=meta,
)


@models.register("paraphrase-multilingual-MiniLM-L12-v2")
def create_multilingual_mini_lm_l12_v2() -> SebModel:
Expand Down
8 changes: 4 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 83af9cb

Please sign in to comment.