diff --git a/pyproject.toml b/pyproject.toml index aacbed96..6186bc7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ docs = [ "mkdocs-material==9.1.21", "mkdocstrings[python-legacy]==0.22.0", # for managing tables - "datawrapper>=0.5.3", + "datawrapper>=0.5.3,<0.6.0", # for tutorials "jupyter>=1.0.0", ] diff --git a/src/seb/cache/jinaai__jina-embeddings-v3/LCC.json b/src/seb/cache/jinaai__jina-embeddings-v3/LCC.json index a1514230..7d2bf3af 100644 --- a/src/seb/cache/jinaai__jina-embeddings-v3/LCC.json +++ b/src/seb/cache/jinaai__jina-embeddings-v3/LCC.json @@ -1 +1 @@ -{"task_name":"LCC","task_description":"The leipzig corpora collection, annotated for sentiment","task_version":"1.1.1","time_of_run":"2024-11-13T21:33:38.301520","scores":{"da":{"accuracy":0.5946666666666667,"f1":0.5872722607515735,"accuracy_stderr":0.03222145592958552,"f1_stderr":0.0278698114421992,"main_score":0.5946666666666667}},"main_score":"accuracy"} \ No newline at end of file +{"task_name":"LCC","task_description":"The leipzig corpora collection, annotated for sentiment","task_version":"1.1.1","time_of_run":"2024-11-14T13:57:48.545664","scores":{"da":{"accuracy":0.5999999999999999,"f1":0.5885173430161176,"accuracy_stderr":0.03055050463303893,"f1_stderr":0.02601687642478716,"main_score":0.5999999999999999}},"main_score":"accuracy"} \ No newline at end of file diff --git a/src/seb/registered_models/__init__.py b/src/seb/registered_models/__init__.py index d5dcdc4f..2331cb60 100644 --- a/src/seb/registered_models/__init__.py +++ b/src/seb/registered_models/__init__.py @@ -8,3 +8,4 @@ from .translate_e5_models import * from .voyage_models import * from .bge_models import * +from .jina_models import * diff --git a/src/seb/registered_models/jina_models.py b/src/seb/registered_models/jina_models.py new file mode 100644 index 00000000..e526e4fd --- /dev/null +++ b/src/seb/registered_models/jina_models.py @@ -0,0 +1,108 @@ +import logging +from datetime import date +from functools import partial +from typing import Any, Literal, Optional + +import numpy as np +from sentence_transformers import SentenceTransformer + +from seb.interfaces.model import LazyLoadEncoder, ModelMeta, SebModel +from seb.interfaces.task import Task +from seb.registries import models + +from .normalize_to_ndarray import normalize_to_ndarray +from .sentence_transformer_models import silence_warnings_from_sentence_transformers, wrap_sentence_transformer + + +class Jinav3EncoderWithTaskEncode(SentenceTransformer): + """ + A sentence transformer wrapper that allows for encoding with a task. + """ + + def encode( # type: ignore + self, + sentences: list[str], + *, + batch_size: int = 32, + task: Optional[Task] = None, + encode_type: Literal["query", "passage"] = "passage", + **kwargs: Any, + ) -> np.ndarray: + task_prompt = None + if task is not None: + if task.task_type in ["STS", "BitextMining"]: + task_prompt = "text-matching" + if task.task_type in ["Classification"]: + task_prompt = "classification" + if task.task_type in ["Clustering"]: + task_prompt = "seperation" + if task.task_type in ["Retrieval"] and encode_type == "query": + task_prompt = "retrieval.query" + if task.task_type in ["Retrieval"] and encode_type == "passage": + task_prompt = "retrieval.passage" + + if task_prompt is None: + emb = super().encode(sentences, batch_size=batch_size, **kwargs) + else: + emb = super().encode(sentences, batch_size=batch_size, prompt=task_prompt, **kwargs) + return normalize_to_ndarray(emb) + + def encode_corpus(self, corpus: list[dict[str, str]], **kwargs: Any) -> np.ndarray: + sep = " " + if isinstance(corpus, dict): + sentences = [ + (corpus["title"][i] + sep + corpus["text"][i]).strip() if "title" in corpus else corpus["text"][i].strip() # type: ignore + for i in range(len(corpus["text"])) # type: ignore + ] + else: + sentences = [(doc["title"] + sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus] + return self.encode(sentences, encode_type="passage", **kwargs) + + def encode_queries(self, queries: list[str], **kwargs: Any) -> np.ndarray: + return self.encode(queries, encode_type="query", **kwargs) + + +def wrap_jina_sentence_transformer(model_name: str, max_seq_length: Optional[int] = None, **kwargs: Any) -> Jinav3EncoderWithTaskEncode: + silence_warnings_from_sentence_transformers() + mdl = Jinav3EncoderWithTaskEncode(model_name, **kwargs) + if max_seq_length is not None: + mdl.max_seq_length = max_seq_length + return mdl + + +@models.register("jina-embeddings-v3") +def create_jina_embeddings_v3() -> SebModel: + hf_name = "jinaai/jina-embeddings-v3" + meta = ModelMeta( + name=hf_name.split("/")[-1], + huggingface_name=hf_name, + reference=f"https://huggingface.co/{hf_name}", + languages=[], + open_source=True, + embedding_size=1024, + architecture="XLM-R", + release_date=date(2024, 8, 5), + ) + return SebModel( + encoder=LazyLoadEncoder(partial(wrap_jina_sentence_transformer, model_name=hf_name, trust_remote_code=True)), # type: ignore + meta=meta, + ) + + +@models.register("jina-embedding-b-en-v1") +def create_jina_base() -> SebModel: + hf_name = "jinaai/jina-embedding-b-en-v1" + meta = ModelMeta( + name=hf_name.split("/")[-1], + huggingface_name=hf_name, + reference=f"https://huggingface.co/{hf_name}", + languages=["en"], + open_source=True, + embedding_size=768, + architecture="T5", + release_date=date(2023, 7, 7), + ) + return SebModel( + encoder=LazyLoadEncoder(partial(wrap_sentence_transformer, model_name=hf_name)), # type: ignore + meta=meta, + ) diff --git a/src/seb/registered_models/sentence_transformer_models.py b/src/seb/registered_models/sentence_transformer_models.py index a086b1b5..91e63e77 100644 --- a/src/seb/registered_models/sentence_transformer_models.py +++ b/src/seb/registered_models/sentence_transformer_models.py @@ -87,24 +87,6 @@ def create_all_mini_lm_l6_v2() -> SebModel: ) -@models.register("jina-embeddings-v3") -def create_jina_embeddings_v3() -> SebModel: - hf_name = "jinaai/jina-embeddings-v3" - meta = ModelMeta( - name=hf_name.split("/")[-1], - huggingface_name=hf_name, - reference=f"https://huggingface.co/{hf_name}", - languages=[], - open_source=True, - embedding_size=1024, - architecture="XLM-R", - release_date=date(2024, 8, 5), - ) - return SebModel( - encoder=LazyLoadEncoder(partial(wrap_sentence_transformer, model_name=hf_name, trust_remote_code=True)), # type: ignore - meta=meta, - ) - @models.register("paraphrase-multilingual-MiniLM-L12-v2") def create_multilingual_mini_lm_l12_v2() -> SebModel: diff --git a/uv.lock b/uv.lock index 7bc5b4ea..80f7284d 100644 --- a/uv.lock +++ b/uv.lock @@ -679,7 +679,7 @@ wheels = [ [[package]] name = "datawrapper" -version = "0.6.1" +version = "0.5.6" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "importlib-metadata" }, @@ -688,9 +688,9 @@ dependencies = [ { name = "requests" }, { name = "rich" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/85/6c/28ddff3217c44d2a2eb33337b0e62db632480daaf020ac25fd5d010255a1/datawrapper-0.6.1.tar.gz", hash = "sha256:3926727cb8f7e2873f9791ee6b241fbf0e9525750bd3a914a015821124e5b3a3", size = 176006 } +sdist = { url = "https://files.pythonhosted.org/packages/32/dc/02b2f96c890580eae8b4fe42c8e063c3d1b54280cd88b63eeafe5945b802/datawrapper-0.5.6.tar.gz", hash = "sha256:0a0a1734c5eeee046b5a0235aa4028f29c07cb3bd1827885946e8e5de258176b", size = 161688 } wheels = [ - { url = "https://files.pythonhosted.org/packages/3e/59/7f240469e75459f5e2e94587f4069c026ba6de3a97f264268c6b20c68ce2/datawrapper-0.6.1-py3-none-any.whl", hash = "sha256:2d5705a87cf26609eff18c9b243d1369e2e2cce5b8323b7f65224ee30c0f1fda", size = 14554 }, + { url = "https://files.pythonhosted.org/packages/c0/bd/94016a1189a58629e6267ec7ac3296501622fccab05b0ded97a02ba0c321/datawrapper-0.5.6-py3-none-any.whl", hash = "sha256:d9df761adc8a9f078eacda027a9153cca9fcd9e67f272aa356d29129f180d669", size = 10304 }, ] [[package]] @@ -3963,7 +3963,7 @@ requires-dist = [ { name = "cohere", marker = "extra == 'cohere'", specifier = ">=4.34" }, { name = "cruft", marker = "extra == 'dev'", specifier = ">=2.0.0" }, { name = "datasets", specifier = "<2.20.0" }, - { name = "datawrapper", marker = "extra == 'docs'", specifier = ">=0.5.3" }, + { name = "datawrapper", marker = "extra == 'docs'", specifier = ">=0.5.3,<0.6.0" }, { name = "einops", marker = "extra == 'jina'" }, { name = "fairseq2", marker = "extra == 'sonar'", specifier = ">=0.1.0" }, { name = "fasttext-wheel", marker = "extra == 'fasttext'", specifier = ">=0.9.0" },