Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing model_args to ST #2578

Merged
merged 17 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 98 additions & 24 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import shutil
from collections import OrderedDict
import warnings
from typing import List, Dict, Literal, Tuple, Iterable, Type, Union, Callable, Optional, TYPE_CHECKING
from typing import List, Dict, Literal, Tuple, Iterable, Type, Union, Callable, Optional, TYPE_CHECKING, Any
import numpy as np
from numpy import ndarray
import transformers
Expand Down Expand Up @@ -72,7 +72,40 @@ class SentenceTransformer(nn.Sequential):
:param local_files_only: If `True`, avoid downloading the model.
:param token: Hugging Face authentication token to download private models.
:param truncate_dim: The dimension to truncate sentence embeddings to. `None` does no truncation. Truncation is
only applicable during inference when `.encode` is called.
only applicable during inference when :meth:`SentenceTransformer.encode` is called.
:param model_kwargs: Additional model configuration parameters to be passed to the Huggingface Transformers model.
Particularly useful options are:

- ``torch_dtype``: Override the default `torch.dtype` and load the model under a specific `dtype`.
The different options are:

1. ``torch.float16``, ``torch.bfloat16`` or ``torch.float``: load in a specified
``dtype``, ignoring the model's ``config.torch_dtype`` if one exists. If not specified - the model will
get loaded in ``torch.float`` (fp32).

2. ``"auto"`` - A ``torch_dtype`` entry in the ``config.json`` file of the model will be
attempted to be used. If this entry isn't found then next check the ``dtype`` of the first weight in
the checkpoint that's of a floating point type and use that as ``dtype``. This will load the model
using the ``dtype`` it was saved in at the end of the training. It can't be used as an indicator of how
the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32.
- ``attn_implementation``: The attention implementation to use in the model (if relevant). Can be any of
`"eager"` (manual implementation of the attention), `"sdpa"` (using `F.scaled_dot_product_attention
<https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html>`_),
or `"flash_attention_2"` (using `Dao-AILab/flash-attention <https://github.com/Dao-AILab/flash-attention>`_).
By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"`
implementation.

See the `PreTrainedModel.from_pretrained
<https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.from_pretrained>`_
documentation for more details.
:param tokenizer_kwargs: Additional tokenizer configuration parameters to be passed to the Huggingface Transformers tokenizer.
See the `AutoTokenizer.from_pretrained
<https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoTokenizer.from_pretrained>`_
documentation for more details.
:param config_kwargs: Additional model configuration parameters to be passed to the Huggingface Transformers config.
See the `AutoConfig.from_pretrained
<https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoConfig.from_pretrained>`_
documentation for more details.
"""

def __init__(
Expand All @@ -89,6 +122,9 @@ def __init__(
token: Optional[Union[bool, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
truncate_dim: Optional[int] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
):
# Note: self._load_sbert_model can also update `self.prompts` and `self.default_prompt_name`
self.prompts = prompts or {}
Expand Down Expand Up @@ -209,6 +245,9 @@ def __init__(
revision=revision,
trust_remote_code=trust_remote_code,
local_files_only=local_files_only,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
config_kwargs=config_kwargs,
)
else:
modules = self._load_auto_model(
Expand All @@ -218,6 +257,9 @@ def __init__(
revision=revision,
trust_remote_code=trust_remote_code,
local_files_only=local_files_only,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
config_kwargs=config_kwargs,
)

if modules is not None and not isinstance(modules, OrderedDict):
Expand Down Expand Up @@ -423,7 +465,10 @@ def encode(
all_embeddings = torch.Tensor()
elif convert_to_numpy:
if not isinstance(all_embeddings, np.ndarray):
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
if all_embeddings[0].dtype == torch.bfloat16:
all_embeddings = np.asarray([emb.float().numpy() for emb in all_embeddings])
else:
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
elif isinstance(all_embeddings, np.ndarray):
all_embeddings = [torch.from_numpy(embedding) for embedding in all_embeddings]

Expand Down Expand Up @@ -1196,30 +1241,35 @@ def _load_auto_model(
revision: Optional[str] = None,
trust_remote_code: bool = False,
local_files_only: bool = False,
model_kwargs: Optional[Dict[str, Any]] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Creates a simple Transformer + Mean Pooling model and returns the modules
"""
logger.warning(
"No sentence-transformers model found with name {}. Creating a new one with MEAN pooling.".format(
"No sentence-transformers model found with name {}. Creating a new one with mean pooling.".format(
model_name_or_path
)
)

shared_kwargs = {
"token": token,
"trust_remote_code": trust_remote_code,
"revision": revision,
"local_files_only": local_files_only,
}
model_kwargs = shared_kwargs if model_kwargs is None else {**shared_kwargs, **model_kwargs}
tokenizer_kwargs = shared_kwargs if tokenizer_kwargs is None else {**shared_kwargs, **tokenizer_kwargs}
config_kwargs = shared_kwargs if config_kwargs is None else {**shared_kwargs, **config_kwargs}

transformer_model = Transformer(
model_name_or_path,
cache_dir=cache_folder,
model_args={
"token": token,
"trust_remote_code": trust_remote_code,
"revision": revision,
"local_files_only": local_files_only,
},
tokenizer_args={
"token": token,
"trust_remote_code": trust_remote_code,
"revision": revision,
"local_files_only": local_files_only,
},
model_args=model_kwargs,
tokenizer_args=tokenizer_kwargs,
config_args=config_kwargs,
)
pooling_model = Pooling(transformer_model.get_word_embedding_dimension(), "mean")
return [transformer_model, pooling_model]
Expand All @@ -1232,6 +1282,9 @@ def _load_sbert_model(
revision: Optional[str] = None,
trust_remote_code: bool = False,
local_files_only: bool = False,
model_kwargs: Optional[Dict[str, Any]] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Loads a full sentence-transformers model
Expand Down Expand Up @@ -1321,21 +1374,42 @@ def _load_sbert_model(
if config_path is not None:
with open(config_path) as fIn:
kwargs = json.load(fIn)
# Don't allow configs to set trust_remote_code
if "model_args" in kwargs and "trust_remote_code" in kwargs["model_args"]:
kwargs["model_args"].pop("trust_remote_code")
if "tokenizer_args" in kwargs and "trust_remote_code" in kwargs["tokenizer_args"]:
kwargs["tokenizer_args"].pop("trust_remote_code")
if "config_args" in kwargs and "trust_remote_code" in kwargs["config_args"]:
kwargs["config_args"].pop("trust_remote_code")
break

hub_kwargs = {
"token": token,
"trust_remote_code": trust_remote_code,
"revision": revision,
"local_files_only": local_files_only,
}
if "model_args" in kwargs:
kwargs["model_args"].update(hub_kwargs)
else:
kwargs["model_args"] = hub_kwargs
if "tokenizer_args" in kwargs:
kwargs["tokenizer_args"].update(hub_kwargs)
else:
kwargs["tokenizer_args"] = hub_kwargs
# 3rd priority: config file
if "model_args" not in kwargs:
kwargs["model_args"] = {}
if "tokenizer_args" not in kwargs:
kwargs["tokenizer_args"] = {}
if "config_args" not in kwargs:
kwargs["config_args"] = {}

# 2nd priority: hub_kwargs
kwargs["model_args"].update(hub_kwargs)
kwargs["tokenizer_args"].update(hub_kwargs)
kwargs["config_args"].update(hub_kwargs)

# 1st priority: kwargs passed to SentenceTransformer
if model_kwargs:
kwargs["model_args"].update(model_kwargs)
if tokenizer_kwargs:
kwargs["tokenizer_args"].update(tokenizer_kwargs)
if config_kwargs:
kwargs["config_args"].update(config_kwargs)

module = Transformer(model_name_or_path, cache_dir=cache_folder, **kwargs)
else:
# Normalize does not require any files to be loaded
Expand Down
28 changes: 20 additions & 8 deletions sentence_transformers/models/Transformer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from torch import nn
from transformers import AutoModel, AutoTokenizer, AutoConfig, T5Config, MT5Config
import json
from typing import List, Dict, Optional, Union, Tuple
from typing import Any, List, Dict, Optional, Union, Tuple
import os


Expand All @@ -11,9 +11,10 @@ class Transformer(nn.Module):

:param model_name_or_path: Huggingface models name (https://huggingface.co/models)
:param max_seq_length: Truncate any inputs longer than max_seq_length
:param model_args: Arguments (key, value pairs) passed to the Huggingface Transformers model
:param model_args: Keyword arguments passed to the Huggingface Transformers model
:param tokenizer_args: Keyword arguments passed to the Huggingface Transformers tokenizer
:param config_args: Keyword arguments passed to the Huggingface Transformers config
:param cache_dir: Cache dir for Huggingface Transformers to store/load models
:param tokenizer_args: Arguments (key, value pairs) passed to the Huggingface Tokenizer model
:param do_lower_case: If true, lowercases the input (independent if the model is cased or not)
:param tokenizer_name_or_path: Name or path of the tokenizer. When None, then model_name_or_path is used
"""
Expand All @@ -22,17 +23,24 @@ def __init__(
self,
model_name_or_path: str,
max_seq_length: Optional[int] = None,
model_args: Dict = {},
model_args: Optional[Dict[str, Any]] = None,
tokenizer_args: Optional[Dict[str, Any]] = None,
config_args: Optional[Dict[str, Any]] = None,
cache_dir: Optional[str] = None,
tokenizer_args: Dict = {},
do_lower_case: bool = False,
tokenizer_name_or_path: str = None,
):
super(Transformer, self).__init__()
self.config_keys = ["max_seq_length", "do_lower_case"]
self.do_lower_case = do_lower_case

config = AutoConfig.from_pretrained(model_name_or_path, **model_args, cache_dir=cache_dir)
if model_args is None:
model_args = {}
if tokenizer_args is None:
tokenizer_args = {}
if config_args is None:
config_args = {}

config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
self._load_model(model_name_or_path, config, cache_dir, **model_args)

self.tokenizer = AutoTokenizer.from_pretrained(
Expand Down Expand Up @@ -182,6 +190,10 @@ def load(input_path: str):
with open(sbert_config_path) as fIn:
config = json.load(fIn)
# Don't allow configs to set trust_remote_code
if "model_args" in config:
if "model_args" in config and "trust_remote_code" in config["model_args"]:
config["model_args"].pop("trust_remote_code")
if "tokenizer_args" in config and "trust_remote_code" in config["tokenizer_args"]:
config["tokenizer_args"].pop("trust_remote_code")
if "config_args" in config and "trust_remote_code" in config["config_args"]:
config["config_args"].pop("trust_remote_code")
return Transformer(model_name_or_path=input_path, **config)
42 changes: 42 additions & 0 deletions tests/test_sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,48 @@ def test_save_load_prompts() -> None:
assert fresh_model.default_prompt_name == "query"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test float16 support.")
def test_load_with_torch_dtype() -> None:
model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors")

assert model.encode(["Hello there!"], convert_to_tensor=True).dtype == torch.float32

with tempfile.TemporaryDirectory() as tmp_folder:
fp16_model_dir = Path(tmp_folder) / "fp16_model"
model.half()
model.save(str(fp16_model_dir))
del model

fp16_model = SentenceTransformer(
str(fp16_model_dir),
model_kwargs={"torch_dtype": "auto"},
)
assert fp16_model.encode(["Hello there!"], convert_to_tensor=True).dtype == torch.float16


def test_load_with_model_kwargs(monkeypatch: pytest.MonkeyPatch) -> None:
transformer_kwargs = {}
original_transformer_init = Transformer.__init__

def transformers_init(*args, **kwargs):
nonlocal transformer_kwargs
nonlocal original_transformer_init
transformer_kwargs = kwargs
return original_transformer_init(*args, **kwargs)

monkeypatch.setattr(Transformer, "__init__", transformers_init)

SentenceTransformer(
"sentence-transformers-testing/stsb-bert-tiny-safetensors",
model_kwargs={"attn_implementation": "eager", "low_cpu_mem_usage": False},
)

assert "low_cpu_mem_usage" in transformer_kwargs["model_args"]
assert transformer_kwargs["model_args"]["low_cpu_mem_usage"] is False
assert "attn_implementation" in transformer_kwargs["model_args"]
assert transformer_kwargs["model_args"]["attn_implementation"] == "eager"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test float16 support.")
def test_encode_fp16() -> None:
tiny_model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors")
Expand Down
Loading