diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index 656103966..93bc360c0 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -5,7 +5,7 @@ import shutil from collections import OrderedDict import warnings -from typing import List, Dict, Literal, Tuple, Iterable, Type, Union, Callable, Optional, TYPE_CHECKING +from typing import List, Dict, Literal, Tuple, Iterable, Type, Union, Callable, Optional, TYPE_CHECKING, Any import numpy as np from numpy import ndarray import transformers @@ -72,7 +72,40 @@ class SentenceTransformer(nn.Sequential): :param local_files_only: If `True`, avoid downloading the model. :param token: Hugging Face authentication token to download private models. :param truncate_dim: The dimension to truncate sentence embeddings to. `None` does no truncation. Truncation is - only applicable during inference when `.encode` is called. + only applicable during inference when :meth:`SentenceTransformer.encode` is called. + :param model_kwargs: Additional model configuration parameters to be passed to the Huggingface Transformers model. + Particularly useful options are: + + - ``torch_dtype``: Override the default `torch.dtype` and load the model under a specific `dtype`. + The different options are: + + 1. ``torch.float16``, ``torch.bfloat16`` or ``torch.float``: load in a specified + ``dtype``, ignoring the model's ``config.torch_dtype`` if one exists. If not specified - the model will + get loaded in ``torch.float`` (fp32). + + 2. ``"auto"`` - A ``torch_dtype`` entry in the ``config.json`` file of the model will be + attempted to be used. If this entry isn't found then next check the ``dtype`` of the first weight in + the checkpoint that's of a floating point type and use that as ``dtype``. This will load the model + using the ``dtype`` it was saved in at the end of the training. It can't be used as an indicator of how + the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32. + - ``attn_implementation``: The attention implementation to use in the model (if relevant). Can be any of + `"eager"` (manual implementation of the attention), `"sdpa"` (using `F.scaled_dot_product_attention + `_), + or `"flash_attention_2"` (using `Dao-AILab/flash-attention `_). + By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` + implementation. + + See the `PreTrainedModel.from_pretrained + `_ + documentation for more details. + :param tokenizer_kwargs: Additional tokenizer configuration parameters to be passed to the Huggingface Transformers tokenizer. + See the `AutoTokenizer.from_pretrained + `_ + documentation for more details. + :param config_kwargs: Additional model configuration parameters to be passed to the Huggingface Transformers config. + See the `AutoConfig.from_pretrained + `_ + documentation for more details. """ def __init__( @@ -89,6 +122,9 @@ def __init__( token: Optional[Union[bool, str]] = None, use_auth_token: Optional[Union[bool, str]] = None, truncate_dim: Optional[int] = None, + model_kwargs: Optional[Dict[str, Any]] = None, + tokenizer_kwargs: Optional[Dict[str, Any]] = None, + config_kwargs: Optional[Dict[str, Any]] = None, ): # Note: self._load_sbert_model can also update `self.prompts` and `self.default_prompt_name` self.prompts = prompts or {} @@ -209,6 +245,9 @@ def __init__( revision=revision, trust_remote_code=trust_remote_code, local_files_only=local_files_only, + model_kwargs=model_kwargs, + tokenizer_kwargs=tokenizer_kwargs, + config_kwargs=config_kwargs, ) else: modules = self._load_auto_model( @@ -218,6 +257,9 @@ def __init__( revision=revision, trust_remote_code=trust_remote_code, local_files_only=local_files_only, + model_kwargs=model_kwargs, + tokenizer_kwargs=tokenizer_kwargs, + config_kwargs=config_kwargs, ) if modules is not None and not isinstance(modules, OrderedDict): @@ -423,7 +465,10 @@ def encode( all_embeddings = torch.Tensor() elif convert_to_numpy: if not isinstance(all_embeddings, np.ndarray): - all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) + if all_embeddings[0].dtype == torch.bfloat16: + all_embeddings = np.asarray([emb.float().numpy() for emb in all_embeddings]) + else: + all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) elif isinstance(all_embeddings, np.ndarray): all_embeddings = [torch.from_numpy(embedding) for embedding in all_embeddings] @@ -1196,30 +1241,35 @@ def _load_auto_model( revision: Optional[str] = None, trust_remote_code: bool = False, local_files_only: bool = False, + model_kwargs: Optional[Dict[str, Any]] = None, + tokenizer_kwargs: Optional[Dict[str, Any]] = None, + config_kwargs: Optional[Dict[str, Any]] = None, ): """ Creates a simple Transformer + Mean Pooling model and returns the modules """ logger.warning( - "No sentence-transformers model found with name {}. Creating a new one with MEAN pooling.".format( + "No sentence-transformers model found with name {}. Creating a new one with mean pooling.".format( model_name_or_path ) ) + + shared_kwargs = { + "token": token, + "trust_remote_code": trust_remote_code, + "revision": revision, + "local_files_only": local_files_only, + } + model_kwargs = shared_kwargs if model_kwargs is None else {**shared_kwargs, **model_kwargs} + tokenizer_kwargs = shared_kwargs if tokenizer_kwargs is None else {**shared_kwargs, **tokenizer_kwargs} + config_kwargs = shared_kwargs if config_kwargs is None else {**shared_kwargs, **config_kwargs} + transformer_model = Transformer( model_name_or_path, cache_dir=cache_folder, - model_args={ - "token": token, - "trust_remote_code": trust_remote_code, - "revision": revision, - "local_files_only": local_files_only, - }, - tokenizer_args={ - "token": token, - "trust_remote_code": trust_remote_code, - "revision": revision, - "local_files_only": local_files_only, - }, + model_args=model_kwargs, + tokenizer_args=tokenizer_kwargs, + config_args=config_kwargs, ) pooling_model = Pooling(transformer_model.get_word_embedding_dimension(), "mean") return [transformer_model, pooling_model] @@ -1232,6 +1282,9 @@ def _load_sbert_model( revision: Optional[str] = None, trust_remote_code: bool = False, local_files_only: bool = False, + model_kwargs: Optional[Dict[str, Any]] = None, + tokenizer_kwargs: Optional[Dict[str, Any]] = None, + config_kwargs: Optional[Dict[str, Any]] = None, ): """ Loads a full sentence-transformers model @@ -1321,21 +1374,42 @@ def _load_sbert_model( if config_path is not None: with open(config_path) as fIn: kwargs = json.load(fIn) + # Don't allow configs to set trust_remote_code + if "model_args" in kwargs and "trust_remote_code" in kwargs["model_args"]: + kwargs["model_args"].pop("trust_remote_code") + if "tokenizer_args" in kwargs and "trust_remote_code" in kwargs["tokenizer_args"]: + kwargs["tokenizer_args"].pop("trust_remote_code") + if "config_args" in kwargs and "trust_remote_code" in kwargs["config_args"]: + kwargs["config_args"].pop("trust_remote_code") break + hub_kwargs = { "token": token, "trust_remote_code": trust_remote_code, "revision": revision, "local_files_only": local_files_only, } - if "model_args" in kwargs: - kwargs["model_args"].update(hub_kwargs) - else: - kwargs["model_args"] = hub_kwargs - if "tokenizer_args" in kwargs: - kwargs["tokenizer_args"].update(hub_kwargs) - else: - kwargs["tokenizer_args"] = hub_kwargs + # 3rd priority: config file + if "model_args" not in kwargs: + kwargs["model_args"] = {} + if "tokenizer_args" not in kwargs: + kwargs["tokenizer_args"] = {} + if "config_args" not in kwargs: + kwargs["config_args"] = {} + + # 2nd priority: hub_kwargs + kwargs["model_args"].update(hub_kwargs) + kwargs["tokenizer_args"].update(hub_kwargs) + kwargs["config_args"].update(hub_kwargs) + + # 1st priority: kwargs passed to SentenceTransformer + if model_kwargs: + kwargs["model_args"].update(model_kwargs) + if tokenizer_kwargs: + kwargs["tokenizer_args"].update(tokenizer_kwargs) + if config_kwargs: + kwargs["config_args"].update(config_kwargs) + module = Transformer(model_name_or_path, cache_dir=cache_folder, **kwargs) else: # Normalize does not require any files to be loaded diff --git a/sentence_transformers/models/Transformer.py b/sentence_transformers/models/Transformer.py index 25727ab1e..5a50c778b 100644 --- a/sentence_transformers/models/Transformer.py +++ b/sentence_transformers/models/Transformer.py @@ -1,7 +1,7 @@ from torch import nn from transformers import AutoModel, AutoTokenizer, AutoConfig, T5Config, MT5Config import json -from typing import List, Dict, Optional, Union, Tuple +from typing import Any, List, Dict, Optional, Union, Tuple import os @@ -11,9 +11,10 @@ class Transformer(nn.Module): :param model_name_or_path: Huggingface models name (https://huggingface.co/models) :param max_seq_length: Truncate any inputs longer than max_seq_length - :param model_args: Arguments (key, value pairs) passed to the Huggingface Transformers model + :param model_args: Keyword arguments passed to the Huggingface Transformers model + :param tokenizer_args: Keyword arguments passed to the Huggingface Transformers tokenizer + :param config_args: Keyword arguments passed to the Huggingface Transformers config :param cache_dir: Cache dir for Huggingface Transformers to store/load models - :param tokenizer_args: Arguments (key, value pairs) passed to the Huggingface Tokenizer model :param do_lower_case: If true, lowercases the input (independent if the model is cased or not) :param tokenizer_name_or_path: Name or path of the tokenizer. When None, then model_name_or_path is used """ @@ -22,17 +23,24 @@ def __init__( self, model_name_or_path: str, max_seq_length: Optional[int] = None, - model_args: Dict = {}, + model_args: Optional[Dict[str, Any]] = None, + tokenizer_args: Optional[Dict[str, Any]] = None, + config_args: Optional[Dict[str, Any]] = None, cache_dir: Optional[str] = None, - tokenizer_args: Dict = {}, do_lower_case: bool = False, tokenizer_name_or_path: str = None, ): super(Transformer, self).__init__() self.config_keys = ["max_seq_length", "do_lower_case"] self.do_lower_case = do_lower_case - - config = AutoConfig.from_pretrained(model_name_or_path, **model_args, cache_dir=cache_dir) + if model_args is None: + model_args = {} + if tokenizer_args is None: + tokenizer_args = {} + if config_args is None: + config_args = {} + + config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir) self._load_model(model_name_or_path, config, cache_dir, **model_args) self.tokenizer = AutoTokenizer.from_pretrained( @@ -182,6 +190,10 @@ def load(input_path: str): with open(sbert_config_path) as fIn: config = json.load(fIn) # Don't allow configs to set trust_remote_code - if "model_args" in config: + if "model_args" in config and "trust_remote_code" in config["model_args"]: config["model_args"].pop("trust_remote_code") + if "tokenizer_args" in config and "trust_remote_code" in config["tokenizer_args"]: + config["tokenizer_args"].pop("trust_remote_code") + if "config_args" in config and "trust_remote_code" in config["config_args"]: + config["config_args"].pop("trust_remote_code") return Transformer(model_name_or_path=input_path, **config) diff --git a/tests/test_sentence_transformer.py b/tests/test_sentence_transformer.py index 770a737e5..0ac5b5089 100644 --- a/tests/test_sentence_transformer.py +++ b/tests/test_sentence_transformer.py @@ -339,6 +339,48 @@ def test_save_load_prompts() -> None: assert fresh_model.default_prompt_name == "query" +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test float16 support.") +def test_load_with_torch_dtype() -> None: + model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors") + + assert model.encode(["Hello there!"], convert_to_tensor=True).dtype == torch.float32 + + with tempfile.TemporaryDirectory() as tmp_folder: + fp16_model_dir = Path(tmp_folder) / "fp16_model" + model.half() + model.save(str(fp16_model_dir)) + del model + + fp16_model = SentenceTransformer( + str(fp16_model_dir), + model_kwargs={"torch_dtype": "auto"}, + ) + assert fp16_model.encode(["Hello there!"], convert_to_tensor=True).dtype == torch.float16 + + +def test_load_with_model_kwargs(monkeypatch: pytest.MonkeyPatch) -> None: + transformer_kwargs = {} + original_transformer_init = Transformer.__init__ + + def transformers_init(*args, **kwargs): + nonlocal transformer_kwargs + nonlocal original_transformer_init + transformer_kwargs = kwargs + return original_transformer_init(*args, **kwargs) + + monkeypatch.setattr(Transformer, "__init__", transformers_init) + + SentenceTransformer( + "sentence-transformers-testing/stsb-bert-tiny-safetensors", + model_kwargs={"attn_implementation": "eager", "low_cpu_mem_usage": False}, + ) + + assert "low_cpu_mem_usage" in transformer_kwargs["model_args"] + assert transformer_kwargs["model_args"]["low_cpu_mem_usage"] is False + assert "attn_implementation" in transformer_kwargs["model_args"] + assert transformer_kwargs["model_args"]["attn_implementation"] == "eager" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test float16 support.") def test_encode_fp16() -> None: tiny_model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors")