Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add model and tokenizer kwargs to TransformersSimilarityRanker, SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder #8145

Merged
merged 8 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Dict, List, Optional, cast
from typing import Any, Dict, List, Optional, cast

import numpy as np

Expand All @@ -27,6 +27,8 @@ def get_embedding_backend(
auth_token: Optional[Secret] = None,
trust_remote_code: bool = False,
truncate_dim: Optional[int] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
):
embedding_backend_id = f"{model}{device}{auth_token}{truncate_dim}"

Expand All @@ -38,6 +40,8 @@ def get_embedding_backend(
auth_token=auth_token,
trust_remote_code=trust_remote_code,
truncate_dim=truncate_dim,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
)
_SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend
return embedding_backend
Expand All @@ -55,6 +59,8 @@ def __init__(
auth_token: Optional[Secret] = None,
trust_remote_code: bool = False,
truncate_dim: Optional[int] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
):
sentence_transformers_import.check()
self.model = SentenceTransformer(
Expand All @@ -63,6 +69,8 @@ def __init__(
use_auth_token=auth_token.resolve_value() if auth_token else None,
trust_remote_code=trust_remote_code,
truncate_dim=truncate_dim,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
)

def embed(self, data: List[str], **kwargs) -> List[List[float]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
_SentenceTransformersEmbeddingBackendFactory,
)
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs


@component
Expand Down Expand Up @@ -37,7 +38,7 @@ class SentenceTransformersDocumentEmbedder:
```
"""

def __init__(
def __init__( # noqa: PLR0913
self,
model: str = "sentence-transformers/all-mpnet-base-v2",
device: Optional[ComponentDevice] = None,
Expand All @@ -51,6 +52,8 @@ def __init__(
embedding_separator: str = "\n",
trust_remote_code: bool = False,
truncate_dim: Optional[int] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Creates a SentenceTransformersDocumentEmbedder component.
Expand Down Expand Up @@ -86,6 +89,12 @@ def __init__(
The dimension to truncate sentence embeddings to. `None` does no truncation.
If the model wasn't trained with Matryoshka Representation Learning,
truncating embeddings can significantly affect performance.
:param model_kwargs:
Additional keyword arguments for `AutoModelForSequenceClassification.from_pretrained`
when loading the model. Refer to specific model documentation for available kwargs.
:param tokenizer_kwargs:
Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer.
Refer to specific model documentation for available kwargs.
"""

self.model = model
Expand All @@ -100,6 +109,9 @@ def __init__(
self.embedding_separator = embedding_separator
self.trust_remote_code = trust_remote_code
self.truncate_dim = truncate_dim
self.model_kwargs = model_kwargs
self.tokenizer_kwargs = tokenizer_kwargs
self.embedding_backend = None

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand All @@ -114,7 +126,7 @@ def to_dict(self) -> Dict[str, Any]:
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
serialization_dict = default_to_dict(
self,
model=self.model,
device=self.device.to_dict(),
Expand All @@ -128,7 +140,12 @@ def to_dict(self) -> Dict[str, Any]:
embedding_separator=self.embedding_separator,
trust_remote_code=self.trust_remote_code,
truncate_dim=self.truncate_dim,
model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs,
)
if serialization_dict["init_parameters"].get("model_kwargs") is not None:
sjrl marked this conversation as resolved.
Show resolved Hide resolved
serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
return serialization_dict

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDocumentEmbedder":
Expand All @@ -144,19 +161,23 @@ def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDocumentEmbedde
if init_params.get("device") is not None:
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
deserialize_secrets_inplace(init_params, keys=["token"])
if init_params.get("model_kwargs") is not None:
deserialize_hf_model_kwargs(init_params["model_kwargs"])
return default_from_dict(cls, data)

def warm_up(self):
"""
Initializes the component.
"""
if not hasattr(self, "embedding_backend"):
if self.embedding_backend is None:
self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model=self.model,
device=self.device.to_torch_str(),
auth_token=self.token,
trust_remote_code=self.trust_remote_code,
truncate_dim=self.truncate_dim,
model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs,
)

@component.output_types(documents=List[Document])
Expand All @@ -176,11 +197,9 @@ def run(self, documents: List[Document]):
"SentenceTransformersDocumentEmbedder expects a list of Documents as input."
"In case you want to embed a list of strings, please use the SentenceTransformersTextEmbedder."
)
if not hasattr(self, "embedding_backend"):
if self.embedding_backend is None:
raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.")

# TODO: once non textual Documents are properly supported, we should also prepare them for embedding here

texts_to_embed = []
for doc in documents:
meta_values_to_embed = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
_SentenceTransformersEmbeddingBackendFactory,
)
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs


@component
Expand Down Expand Up @@ -45,6 +46,8 @@ def __init__(
normalize_embeddings: bool = False,
trust_remote_code: bool = False,
truncate_dim: Optional[int] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Create a SentenceTransformersTextEmbedder component.
Expand Down Expand Up @@ -76,6 +79,12 @@ def __init__(
The dimension to truncate sentence embeddings to. `None` does no truncation.
If the model has not been trained with Matryoshka Representation Learning,
truncation of embeddings can significantly affect performance.
:param model_kwargs:
Additional keyword arguments for `AutoModelForSequenceClassification.from_pretrained`
when loading the model. Refer to specific model documentation for available kwargs.
:param tokenizer_kwargs:
Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer.
Refer to specific model documentation for available kwargs.
"""

self.model = model
Expand All @@ -88,6 +97,9 @@ def __init__(
self.normalize_embeddings = normalize_embeddings
self.trust_remote_code = trust_remote_code
self.truncate_dim = truncate_dim
self.model_kwargs = model_kwargs
self.tokenizer_kwargs = tokenizer_kwargs
self.embedding_backend = None

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand All @@ -102,7 +114,7 @@ def to_dict(self) -> Dict[str, Any]:
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
serialization_dict = default_to_dict(
self,
model=self.model,
device=self.device.to_dict(),
Expand All @@ -114,7 +126,12 @@ def to_dict(self) -> Dict[str, Any]:
normalize_embeddings=self.normalize_embeddings,
trust_remote_code=self.trust_remote_code,
truncate_dim=self.truncate_dim,
model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs,
)
if serialization_dict["init_parameters"].get("model_kwargs") is not None:
serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
return serialization_dict

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersTextEmbedder":
Expand All @@ -130,19 +147,23 @@ def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersTextEmbedder":
if init_params.get("device") is not None:
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
deserialize_secrets_inplace(init_params, keys=["token"])
if init_params.get("model_kwargs") is not None:
deserialize_hf_model_kwargs(init_params["model_kwargs"])
return default_from_dict(cls, data)

def warm_up(self):
"""
Initializes the component.
"""
if not hasattr(self, "embedding_backend"):
if self.embedding_backend is None:
self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model=self.model,
device=self.device.to_torch_str(),
auth_token=self.token,
trust_remote_code=self.trust_remote_code,
truncate_dim=self.truncate_dim,
model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs,
)

@component.output_types(embedding=List[float])
Expand All @@ -162,7 +183,7 @@ def run(self, text: str):
"SentenceTransformersTextEmbedder expects a string as input."
"In case you want to embed a list of Documents, please use the SentenceTransformersDocumentEmbedder."
)
if not hasattr(self, "embedding_backend"):
if self.embedding_backend is None:
raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.")

text_to_embed = self.prefix + text + self.suffix
Expand Down
10 changes: 9 additions & 1 deletion haystack/components/rankers/transformers_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
calibration_factor: Optional[float] = 1.0,
score_threshold: Optional[float] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Creates an instance of TransformersSimilarityRanker.
Expand Down Expand Up @@ -89,6 +90,9 @@ def __init__(
:param model_kwargs:
Additional keyword arguments for `AutoModelForSequenceClassification.from_pretrained`
when loading the model. Refer to specific model documentation for available kwargs.
:param tokenizer_kwargs:
Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer.
Refer to specific model documentation for available kwargs.

:raises ValueError:
If `top_k` is not > 0.
Expand All @@ -112,6 +116,7 @@ def __init__(

model_kwargs = resolve_hf_device_map(device=device, model_kwargs=model_kwargs)
self.model_kwargs = model_kwargs
self.tokenizer_kwargs = tokenizer_kwargs or {}

# Parameter validation
if self.scale_score and self.calibration_factor is None:
Expand All @@ -137,7 +142,9 @@ def warm_up(self):
self.model_name_or_path, token=self.token.resolve_value() if self.token else None, **self.model_kwargs
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name_or_path, token=self.token.resolve_value() if self.token else None
self.model_name_or_path,
token=self.token.resolve_value() if self.token else None,
**self.tokenizer_kwargs,
)
self.device = ComponentDevice.from_multiple(device_map=DeviceMap.from_hf(self.model.hf_device_map))

Expand All @@ -162,6 +169,7 @@ def to_dict(self) -> Dict[str, Any]:
calibration_factor=self.calibration_factor,
score_threshold=self.score_threshold,
model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs,
)

serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
enhancements:
- |
Adds model_kwargs and tokenizer_kwargs to the components TransformersSimilarityRanker, SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder.
This allows passing things like model_max_length or torch_dtype for better management of model inference.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import pytest
import torch

from haystack import Document
from haystack.components.embedders.sentence_transformers_document_embedder import SentenceTransformersDocumentEmbedder
Expand Down Expand Up @@ -73,6 +74,8 @@ def test_to_dict(self):
"meta_fields_to_embed": [],
"trust_remote_code": False,
"truncate_dim": None,
"model_kwargs": None,
"tokenizer_kwargs": None,
},
}

Expand All @@ -90,6 +93,8 @@ def test_to_dict_with_custom_init_parameters(self):
embedding_separator=" - ",
trust_remote_code=True,
truncate_dim=256,
model_kwargs={"torch_dtype": torch.float32},
tokenizer_kwargs={"model_max_length": 512},
)
data = component.to_dict()

Expand All @@ -108,6 +113,8 @@ def test_to_dict_with_custom_init_parameters(self):
"trust_remote_code": True,
"meta_fields_to_embed": ["meta_field"],
"truncate_dim": 256,
"model_kwargs": {"torch_dtype": "torch.float32"},
"tokenizer_kwargs": {"model_max_length": 512},
},
}

Expand Down Expand Up @@ -209,7 +216,13 @@ def test_warmup(self, mocked_factory):
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(
model="model", device="cpu", auth_token=None, trust_remote_code=False, truncate_dim=None
model="model",
device="cpu",
auth_token=None,
trust_remote_code=False,
truncate_dim=None,
model_kwargs=None,
tokenizer_kwargs=None,
)

@patch(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def test_model_initialization(mock_sentence_transformer):
use_auth_token="fake-api-token",
trust_remote_code=True,
truncate_dim=256,
model_kwargs=None,
tokenizer_kwargs=None,
)


Expand Down
Loading