diff --git a/haystack/components/embedders/backends/sentence_transformers_backend.py b/haystack/components/embedders/backends/sentence_transformers_backend.py index 7e57f4c43b..a7547d5967 100644 --- a/haystack/components/embedders/backends/sentence_transformers_backend.py +++ b/haystack/components/embedders/backends/sentence_transformers_backend.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, List, Optional, cast +from typing import Any, Dict, List, Optional, cast import numpy as np @@ -27,6 +27,8 @@ def get_embedding_backend( auth_token: Optional[Secret] = None, trust_remote_code: bool = False, truncate_dim: Optional[int] = None, + model_kwargs: Optional[Dict[str, Any]] = None, + tokenizer_kwargs: Optional[Dict[str, Any]] = None, ): embedding_backend_id = f"{model}{device}{auth_token}{truncate_dim}" @@ -38,6 +40,8 @@ def get_embedding_backend( auth_token=auth_token, trust_remote_code=trust_remote_code, truncate_dim=truncate_dim, + model_kwargs=model_kwargs, + tokenizer_kwargs=tokenizer_kwargs, ) _SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend return embedding_backend @@ -55,6 +59,8 @@ def __init__( auth_token: Optional[Secret] = None, trust_remote_code: bool = False, truncate_dim: Optional[int] = None, + model_kwargs: Optional[Dict[str, Any]] = None, + tokenizer_kwargs: Optional[Dict[str, Any]] = None, ): sentence_transformers_import.check() self.model = SentenceTransformer( @@ -63,6 +69,8 @@ def __init__( use_auth_token=auth_token.resolve_value() if auth_token else None, trust_remote_code=trust_remote_code, truncate_dim=truncate_dim, + model_kwargs=model_kwargs, + tokenizer_kwargs=tokenizer_kwargs, ) def embed(self, data: List[str], **kwargs) -> List[List[float]]: diff --git a/haystack/components/embedders/sentence_transformers_document_embedder.py b/haystack/components/embedders/sentence_transformers_document_embedder.py index 9cf5e9323e..1ea708736f 100644 --- a/haystack/components/embedders/sentence_transformers_document_embedder.py +++ b/haystack/components/embedders/sentence_transformers_document_embedder.py @@ -9,6 +9,7 @@ _SentenceTransformersEmbeddingBackendFactory, ) from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace +from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs @component @@ -37,7 +38,7 @@ class SentenceTransformersDocumentEmbedder: ``` """ - def __init__( + def __init__( # noqa: PLR0913 self, model: str = "sentence-transformers/all-mpnet-base-v2", device: Optional[ComponentDevice] = None, @@ -51,6 +52,8 @@ def __init__( embedding_separator: str = "\n", trust_remote_code: bool = False, truncate_dim: Optional[int] = None, + model_kwargs: Optional[Dict[str, Any]] = None, + tokenizer_kwargs: Optional[Dict[str, Any]] = None, ): """ Creates a SentenceTransformersDocumentEmbedder component. @@ -86,6 +89,12 @@ def __init__( The dimension to truncate sentence embeddings to. `None` does no truncation. If the model wasn't trained with Matryoshka Representation Learning, truncating embeddings can significantly affect performance. + :param model_kwargs: + Additional keyword arguments for `AutoModelForSequenceClassification.from_pretrained` + when loading the model. Refer to specific model documentation for available kwargs. + :param tokenizer_kwargs: + Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer. + Refer to specific model documentation for available kwargs. """ self.model = model @@ -100,6 +109,9 @@ def __init__( self.embedding_separator = embedding_separator self.trust_remote_code = trust_remote_code self.truncate_dim = truncate_dim + self.model_kwargs = model_kwargs + self.tokenizer_kwargs = tokenizer_kwargs + self.embedding_backend = None def _get_telemetry_data(self) -> Dict[str, Any]: """ @@ -114,7 +126,7 @@ def to_dict(self) -> Dict[str, Any]: :returns: Dictionary with serialized data. """ - return default_to_dict( + serialization_dict = default_to_dict( self, model=self.model, device=self.device.to_dict(), @@ -128,7 +140,12 @@ def to_dict(self) -> Dict[str, Any]: embedding_separator=self.embedding_separator, trust_remote_code=self.trust_remote_code, truncate_dim=self.truncate_dim, + model_kwargs=self.model_kwargs, + tokenizer_kwargs=self.tokenizer_kwargs, ) + if serialization_dict["init_parameters"].get("model_kwargs") is not None: + serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"]) + return serialization_dict @classmethod def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDocumentEmbedder": @@ -144,19 +161,23 @@ def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDocumentEmbedde if init_params.get("device") is not None: init_params["device"] = ComponentDevice.from_dict(init_params["device"]) deserialize_secrets_inplace(init_params, keys=["token"]) + if init_params.get("model_kwargs") is not None: + deserialize_hf_model_kwargs(init_params["model_kwargs"]) return default_from_dict(cls, data) def warm_up(self): """ Initializes the component. """ - if not hasattr(self, "embedding_backend"): + if self.embedding_backend is None: self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( model=self.model, device=self.device.to_torch_str(), auth_token=self.token, trust_remote_code=self.trust_remote_code, truncate_dim=self.truncate_dim, + model_kwargs=self.model_kwargs, + tokenizer_kwargs=self.tokenizer_kwargs, ) @component.output_types(documents=List[Document]) @@ -176,11 +197,9 @@ def run(self, documents: List[Document]): "SentenceTransformersDocumentEmbedder expects a list of Documents as input." "In case you want to embed a list of strings, please use the SentenceTransformersTextEmbedder." ) - if not hasattr(self, "embedding_backend"): + if self.embedding_backend is None: raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.") - # TODO: once non textual Documents are properly supported, we should also prepare them for embedding here - texts_to_embed = [] for doc in documents: meta_values_to_embed = [ diff --git a/haystack/components/embedders/sentence_transformers_text_embedder.py b/haystack/components/embedders/sentence_transformers_text_embedder.py index 22c449a8ce..0a60fc8988 100644 --- a/haystack/components/embedders/sentence_transformers_text_embedder.py +++ b/haystack/components/embedders/sentence_transformers_text_embedder.py @@ -9,6 +9,7 @@ _SentenceTransformersEmbeddingBackendFactory, ) from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace +from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs @component @@ -45,6 +46,8 @@ def __init__( normalize_embeddings: bool = False, trust_remote_code: bool = False, truncate_dim: Optional[int] = None, + model_kwargs: Optional[Dict[str, Any]] = None, + tokenizer_kwargs: Optional[Dict[str, Any]] = None, ): """ Create a SentenceTransformersTextEmbedder component. @@ -76,6 +79,12 @@ def __init__( The dimension to truncate sentence embeddings to. `None` does no truncation. If the model has not been trained with Matryoshka Representation Learning, truncation of embeddings can significantly affect performance. + :param model_kwargs: + Additional keyword arguments for `AutoModelForSequenceClassification.from_pretrained` + when loading the model. Refer to specific model documentation for available kwargs. + :param tokenizer_kwargs: + Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer. + Refer to specific model documentation for available kwargs. """ self.model = model @@ -88,6 +97,9 @@ def __init__( self.normalize_embeddings = normalize_embeddings self.trust_remote_code = trust_remote_code self.truncate_dim = truncate_dim + self.model_kwargs = model_kwargs + self.tokenizer_kwargs = tokenizer_kwargs + self.embedding_backend = None def _get_telemetry_data(self) -> Dict[str, Any]: """ @@ -102,7 +114,7 @@ def to_dict(self) -> Dict[str, Any]: :returns: Dictionary with serialized data. """ - return default_to_dict( + serialization_dict = default_to_dict( self, model=self.model, device=self.device.to_dict(), @@ -114,7 +126,12 @@ def to_dict(self) -> Dict[str, Any]: normalize_embeddings=self.normalize_embeddings, trust_remote_code=self.trust_remote_code, truncate_dim=self.truncate_dim, + model_kwargs=self.model_kwargs, + tokenizer_kwargs=self.tokenizer_kwargs, ) + if serialization_dict["init_parameters"].get("model_kwargs") is not None: + serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"]) + return serialization_dict @classmethod def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersTextEmbedder": @@ -130,19 +147,23 @@ def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersTextEmbedder": if init_params.get("device") is not None: init_params["device"] = ComponentDevice.from_dict(init_params["device"]) deserialize_secrets_inplace(init_params, keys=["token"]) + if init_params.get("model_kwargs") is not None: + deserialize_hf_model_kwargs(init_params["model_kwargs"]) return default_from_dict(cls, data) def warm_up(self): """ Initializes the component. """ - if not hasattr(self, "embedding_backend"): + if self.embedding_backend is None: self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( model=self.model, device=self.device.to_torch_str(), auth_token=self.token, trust_remote_code=self.trust_remote_code, truncate_dim=self.truncate_dim, + model_kwargs=self.model_kwargs, + tokenizer_kwargs=self.tokenizer_kwargs, ) @component.output_types(embedding=List[float]) @@ -162,7 +183,7 @@ def run(self, text: str): "SentenceTransformersTextEmbedder expects a string as input." "In case you want to embed a list of Documents, please use the SentenceTransformersDocumentEmbedder." ) - if not hasattr(self, "embedding_backend"): + if self.embedding_backend is None: raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.") text_to_embed = self.prefix + text + self.suffix diff --git a/haystack/components/rankers/transformers_similarity.py b/haystack/components/rankers/transformers_similarity.py index 5086c14d1e..ca9cd25191 100644 --- a/haystack/components/rankers/transformers_similarity.py +++ b/haystack/components/rankers/transformers_similarity.py @@ -56,6 +56,7 @@ def __init__( calibration_factor: Optional[float] = 1.0, score_threshold: Optional[float] = None, model_kwargs: Optional[Dict[str, Any]] = None, + tokenizer_kwargs: Optional[Dict[str, Any]] = None, ): """ Creates an instance of TransformersSimilarityRanker. @@ -89,6 +90,9 @@ def __init__( :param model_kwargs: Additional keyword arguments for `AutoModelForSequenceClassification.from_pretrained` when loading the model. Refer to specific model documentation for available kwargs. + :param tokenizer_kwargs: + Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer. + Refer to specific model documentation for available kwargs. :raises ValueError: If `top_k` is not > 0. @@ -112,6 +116,7 @@ def __init__( model_kwargs = resolve_hf_device_map(device=device, model_kwargs=model_kwargs) self.model_kwargs = model_kwargs + self.tokenizer_kwargs = tokenizer_kwargs or {} # Parameter validation if self.scale_score and self.calibration_factor is None: @@ -137,7 +142,9 @@ def warm_up(self): self.model_name_or_path, token=self.token.resolve_value() if self.token else None, **self.model_kwargs ) self.tokenizer = AutoTokenizer.from_pretrained( - self.model_name_or_path, token=self.token.resolve_value() if self.token else None + self.model_name_or_path, + token=self.token.resolve_value() if self.token else None, + **self.tokenizer_kwargs, ) self.device = ComponentDevice.from_multiple(device_map=DeviceMap.from_hf(self.model.hf_device_map)) @@ -162,6 +169,7 @@ def to_dict(self) -> Dict[str, Any]: calibration_factor=self.calibration_factor, score_threshold=self.score_threshold, model_kwargs=self.model_kwargs, + tokenizer_kwargs=self.tokenizer_kwargs, ) serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"]) diff --git a/releasenotes/notes/add-model-and-tokenizer-kwargs-4b7618806665f8ba.yaml b/releasenotes/notes/add-model-and-tokenizer-kwargs-4b7618806665f8ba.yaml new file mode 100644 index 0000000000..5f0395d94e --- /dev/null +++ b/releasenotes/notes/add-model-and-tokenizer-kwargs-4b7618806665f8ba.yaml @@ -0,0 +1,5 @@ +--- +enhancements: + - | + Adds model_kwargs and tokenizer_kwargs to the components TransformersSimilarityRanker, SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder. + This allows passing things like model_max_length or torch_dtype for better management of model inference. diff --git a/test/components/embedders/test_sentence_transformers_document_embedder.py b/test/components/embedders/test_sentence_transformers_document_embedder.py index 8587c745b5..ecc0b8ef8d 100644 --- a/test/components/embedders/test_sentence_transformers_document_embedder.py +++ b/test/components/embedders/test_sentence_transformers_document_embedder.py @@ -5,6 +5,7 @@ import numpy as np import pytest +import torch from haystack import Document from haystack.components.embedders.sentence_transformers_document_embedder import SentenceTransformersDocumentEmbedder @@ -73,6 +74,8 @@ def test_to_dict(self): "meta_fields_to_embed": [], "trust_remote_code": False, "truncate_dim": None, + "model_kwargs": None, + "tokenizer_kwargs": None, }, } @@ -90,6 +93,8 @@ def test_to_dict_with_custom_init_parameters(self): embedding_separator=" - ", trust_remote_code=True, truncate_dim=256, + model_kwargs={"torch_dtype": torch.float32}, + tokenizer_kwargs={"model_max_length": 512}, ) data = component.to_dict() @@ -108,6 +113,8 @@ def test_to_dict_with_custom_init_parameters(self): "trust_remote_code": True, "meta_fields_to_embed": ["meta_field"], "truncate_dim": 256, + "model_kwargs": {"torch_dtype": "torch.float32"}, + "tokenizer_kwargs": {"model_max_length": 512}, }, } @@ -125,6 +132,8 @@ def test_from_dict(self): "meta_fields_to_embed": ["meta_field"], "trust_remote_code": True, "truncate_dim": 256, + "model_kwargs": {"torch_dtype": "torch.float32"}, + "tokenizer_kwargs": {"model_max_length": 512}, } component = SentenceTransformersDocumentEmbedder.from_dict( { @@ -144,6 +153,8 @@ def test_from_dict(self): assert component.trust_remote_code assert component.meta_fields_to_embed == ["meta_field"] assert component.truncate_dim == 256 + assert component.model_kwargs == {"torch_dtype": torch.float32} + assert component.tokenizer_kwargs == {"model_max_length": 512} def test_from_dict_no_default_parameters(self): component = SentenceTransformersDocumentEmbedder.from_dict( @@ -209,7 +220,13 @@ def test_warmup(self, mocked_factory): mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() mocked_factory.get_embedding_backend.assert_called_once_with( - model="model", device="cpu", auth_token=None, trust_remote_code=False, truncate_dim=None + model="model", + device="cpu", + auth_token=None, + trust_remote_code=False, + truncate_dim=None, + model_kwargs=None, + tokenizer_kwargs=None, ) @patch( diff --git a/test/components/embedders/test_sentence_transformers_embedding_backend.py b/test/components/embedders/test_sentence_transformers_embedding_backend.py index 9b7d2b1a11..7ca42aab91 100644 --- a/test/components/embedders/test_sentence_transformers_embedding_backend.py +++ b/test/components/embedders/test_sentence_transformers_embedding_backend.py @@ -40,6 +40,8 @@ def test_model_initialization(mock_sentence_transformer): use_auth_token="fake-api-token", trust_remote_code=True, truncate_dim=256, + model_kwargs=None, + tokenizer_kwargs=None, ) diff --git a/test/components/embedders/test_sentence_transformers_text_embedder.py b/test/components/embedders/test_sentence_transformers_text_embedder.py index 7352a8809e..2845c38f26 100644 --- a/test/components/embedders/test_sentence_transformers_text_embedder.py +++ b/test/components/embedders/test_sentence_transformers_text_embedder.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 from unittest.mock import MagicMock, patch +import torch import numpy as np import pytest @@ -64,6 +65,8 @@ def test_to_dict(self): "normalize_embeddings": False, "trust_remote_code": False, "truncate_dim": None, + "model_kwargs": None, + "tokenizer_kwargs": None, }, } @@ -79,6 +82,8 @@ def test_to_dict_with_custom_init_parameters(self): normalize_embeddings=True, trust_remote_code=True, truncate_dim=256, + model_kwargs={"torch_dtype": torch.float32}, + tokenizer_kwargs={"model_max_length": 512}, ) data = component.to_dict() assert data == { @@ -94,6 +99,8 @@ def test_to_dict_with_custom_init_parameters(self): "normalize_embeddings": True, "trust_remote_code": True, "truncate_dim": 256, + "model_kwargs": {"torch_dtype": "torch.float32"}, + "tokenizer_kwargs": {"model_max_length": 512}, }, } @@ -116,6 +123,8 @@ def test_from_dict(self): "normalize_embeddings": False, "trust_remote_code": False, "truncate_dim": None, + "model_kwargs": {"torch_dtype": "torch.float32"}, + "tokenizer_kwargs": {"model_max_length": 512}, }, } component = SentenceTransformersTextEmbedder.from_dict(data) @@ -129,6 +138,8 @@ def test_from_dict(self): assert component.normalize_embeddings is False assert component.trust_remote_code is False assert component.truncate_dim is None + assert component.model_kwargs == {"torch_dtype": torch.float32} + assert component.tokenizer_kwargs == {"model_max_length": 512} def test_from_dict_no_default_parameters(self): data = { @@ -183,7 +194,13 @@ def test_warmup(self, mocked_factory): mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() mocked_factory.get_embedding_backend.assert_called_once_with( - model="model", device="cpu", auth_token=None, trust_remote_code=False, truncate_dim=None + model="model", + device="cpu", + auth_token=None, + trust_remote_code=False, + truncate_dim=None, + model_kwargs=None, + tokenizer_kwargs=None, ) @patch( diff --git a/test/components/rankers/test_transformers_similarity.py b/test/components/rankers/test_transformers_similarity.py index 9e083ffa68..45e564b0f5 100644 --- a/test/components/rankers/test_transformers_similarity.py +++ b/test/components/rankers/test_transformers_similarity.py @@ -33,6 +33,7 @@ def test_to_dict(self): "calibration_factor": 1.0, "score_threshold": None, "model_kwargs": {"device_map": ComponentDevice.resolve_device(None).to_hf()}, + "tokenizer_kwargs": {}, }, } @@ -48,6 +49,7 @@ def test_to_dict_with_custom_init_parameters(self): calibration_factor=None, score_threshold=0.01, model_kwargs={"torch_dtype": torch.float16}, + tokenizer_kwargs={"model_max_length": 512}, ) data = component.to_dict() assert data == { @@ -68,6 +70,7 @@ def test_to_dict_with_custom_init_parameters(self): "torch_dtype": "torch.float16", "device_map": ComponentDevice.from_str("cuda:0").to_hf(), }, # torch_dtype is correctly serialized + "tokenizer_kwargs": {"model_max_length": 512}, }, } @@ -102,6 +105,7 @@ def test_to_dict_with_quantization_options(self): "bnb_4bit_compute_dtype": "torch.bfloat16", "device_map": ComponentDevice.resolve_device(None).to_hf(), }, + "tokenizer_kwargs": {}, }, } @@ -132,6 +136,7 @@ def test_to_dict_device_map(self, device_map, expected): "calibration_factor": 1.0, "score_threshold": None, "model_kwargs": {"device_map": expected}, + "tokenizer_kwargs": {}, }, } @@ -151,6 +156,7 @@ def test_from_dict(self): "calibration_factor": None, "score_threshold": 0.01, "model_kwargs": {"torch_dtype": "torch.float16"}, + "tokenizer_kwargs": {"model_max_length": 512}, }, } @@ -171,6 +177,7 @@ def test_from_dict(self): "torch_dtype": torch.float16, "device_map": ComponentDevice.resolve_device(None).to_hf(), } + assert component.tokenizer_kwargs == {"model_max_length": 512} def test_from_dict_no_default_parameters(self): data = {