Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing model_args to ST #2578

Merged
merged 17 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import shutil
from collections import OrderedDict
import warnings
from typing import List, Dict, Literal, Tuple, Iterable, Type, Union, Callable, Optional, TYPE_CHECKING
from typing import List, Dict, Literal, Tuple, Iterable, Type, Union, Callable, Optional, TYPE_CHECKING, Any
import numpy as np
from numpy import ndarray
import transformers
Expand Down Expand Up @@ -73,6 +73,24 @@ class SentenceTransformer(nn.Sequential):
:param token: Hugging Face authentication token to download private models.
:param truncate_dim: The dimension to truncate sentence embeddings to. `None` does no truncation. Truncation is
only applicable during inference when `.encode` is called.
:param torch_dtype: Override the default `torch.dtype` and load the model under a specific `dtype`.
The different options are:
1. `torch.float16` or `torch.bfloat16` or `torch.float`: load in a specified
`dtype`, ignoring the model's `config.torch_dtype` if one exists. If not specified - the model will
get loaded in `torch.float` (fp32).

2. `"auto"` - A `torch_dtype` entry in the `config.json` file of the model will be
attempted to be used. If this entry isn't found then next check the `dtype` of the first weight in
the checkpoint that's of a floating point type and use that as `dtype`. This will load the model
using the `dtype` it was saved in at the end of the training. It can't be used as an indicator of how
the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32.
:param attn_implementation: The attention implementation to use in the model (if relevant). Can be any of
`"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`]
(https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)),
or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)).
By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"`
implementation.
:param model_kwargs: Additional model configuration parameters, to be passed to the Huggingface Transformers model.
"""

def __init__(
Expand All @@ -89,6 +107,9 @@ def __init__(
token: Optional[Union[bool, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
truncate_dim: Optional[int] = None,
torch_dtype: Optional[Union[str, torch.dtype]] = None,
attn_implementation: Optional[str] = None,
**model_kwargs,
):
# Note: self._load_sbert_model can also update `self.prompts` and `self.default_prompt_name`
self.prompts = prompts or {}
Expand Down Expand Up @@ -209,6 +230,11 @@ def __init__(
revision=revision,
trust_remote_code=trust_remote_code,
local_files_only=local_files_only,
model_args={
"torch_dtype": torch_dtype,
"attn_implementation": attn_implementation,
**model_kwargs,
}
)
else:
modules = self._load_auto_model(
Expand All @@ -218,6 +244,11 @@ def __init__(
revision=revision,
trust_remote_code=trust_remote_code,
local_files_only=local_files_only,
model_args={
"torch_dtype": torch_dtype,
"attn_implementation": attn_implementation,
**model_kwargs,
}
)

if modules is not None and not isinstance(modules, OrderedDict):
Expand Down Expand Up @@ -1196,6 +1227,7 @@ def _load_auto_model(
revision: Optional[str] = None,
trust_remote_code: bool = False,
local_files_only: bool = False,
model_args: Optional[Dict[str, Any]] = None,
):
"""
Creates a simple Transformer + Mean Pooling model and returns the modules
Expand All @@ -1213,7 +1245,7 @@ def _load_auto_model(
"trust_remote_code": trust_remote_code,
"revision": revision,
"local_files_only": local_files_only,
},
} | (model_args or {}),
satyamk7054 marked this conversation as resolved.
Show resolved Hide resolved
tokenizer_args={
"token": token,
"trust_remote_code": trust_remote_code,
Expand All @@ -1232,6 +1264,7 @@ def _load_sbert_model(
revision: Optional[str] = None,
trust_remote_code: bool = False,
local_files_only: bool = False,
model_args: Optional[Dict[str, Any]] = None,
):
"""
Loads a full sentence-transformers model
Expand Down Expand Up @@ -1332,6 +1365,8 @@ def _load_sbert_model(
kwargs["model_args"].update(hub_kwargs)
else:
kwargs["model_args"] = hub_kwargs
kwargs["model_args"].update(model_args or {})
satyamk7054 marked this conversation as resolved.
Show resolved Hide resolved

if "tokenizer_args" in kwargs:
kwargs["tokenizer_args"].update(hub_kwargs)
else:
Expand Down
22 changes: 10 additions & 12 deletions sentence_transformers/models/Transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ def __init__(
self.config_keys = ["max_seq_length", "do_lower_case"]
self.do_lower_case = do_lower_case

config = AutoConfig.from_pretrained(model_name_or_path, **model_args, cache_dir=cache_dir)
self._load_model(model_name_or_path, config, cache_dir, **model_args)
satyamk7054 marked this conversation as resolved.
Show resolved Hide resolved
self._load_model(model_name_or_path, cache_dir, **model_args)

self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name_or_path if tokenizer_name_or_path is not None else model_name_or_path,
Expand All @@ -55,33 +54,32 @@ def __init__(
if tokenizer_name_or_path is not None:
self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__

def _load_model(self, model_name_or_path, config, cache_dir, **model_args):
def _load_model(self, model_name_or_path, cache_dir, **model_args):
"""Loads the transformer model"""
config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir)
if isinstance(config, T5Config):
self._load_t5_model(model_name_or_path, config, cache_dir, **model_args)
self._load_t5_model(model_name_or_path, cache_dir, **model_args)
elif isinstance(config, MT5Config):
self._load_mt5_model(model_name_or_path, config, cache_dir, **model_args)
self._load_mt5_model(model_name_or_path, cache_dir, **model_args)
else:
self.auto_model = AutoModel.from_pretrained(
model_name_or_path, config=config, cache_dir=cache_dir, **model_args
)
self.auto_model = AutoModel.from_pretrained(model_name_or_path, cache_dir=cache_dir, **model_args)

def _load_t5_model(self, model_name_or_path, config, cache_dir, **model_args):
def _load_t5_model(self, model_name_or_path, cache_dir, **model_args):
"""Loads the encoder model from T5"""
from transformers import T5EncoderModel

T5EncoderModel._keys_to_ignore_on_load_unexpected = ["decoder.*"]
self.auto_model = T5EncoderModel.from_pretrained(
model_name_or_path, config=config, cache_dir=cache_dir, **model_args
model_name_or_path, cache_dir=cache_dir, **model_args
)

def _load_mt5_model(self, model_name_or_path, config, cache_dir, **model_args):
def _load_mt5_model(self, model_name_or_path, cache_dir, **model_args):
"""Loads the encoder model from T5"""
from transformers import MT5EncoderModel

MT5EncoderModel._keys_to_ignore_on_load_unexpected = ["decoder.*"]
self.auto_model = MT5EncoderModel.from_pretrained(
model_name_or_path, config=config, cache_dir=cache_dir, **model_args
model_name_or_path, cache_dir=cache_dir, **model_args
)

def __repr__(self):
Expand Down
40 changes: 40 additions & 0 deletions tests/test_sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,46 @@ def test_save_load_prompts() -> None:
assert fresh_model.default_prompt_name == "query"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test float16 support.")
def test_load_with_torch_dtype() -> None:
model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors")

assert model.encode(["Hello there!"], convert_to_tensor=True).dtype == torch.float32

with tempfile.TemporaryDirectory() as tmp_folder:
fp16_model_dir = Path(tmp_folder) / "fp16_model"
model.half()
model.save(str(fp16_model_dir))
del model

fp16_model = SentenceTransformer(str(fp16_model_dir), torch_dtype="auto", )
assert fp16_model.encode(["Hello there!"], convert_to_tensor=True).dtype == torch.float16


def test_load_with_model_kwargs(monkeypatch: pytest.MonkeyPatch) -> None:
transformer_kwargs = {}
original_transformer_init = Transformer.__init__

def transformers_init(*args, **kwargs):
nonlocal transformer_kwargs
nonlocal original_transformer_init
transformer_kwargs = kwargs
return original_transformer_init(*args, **kwargs)

monkeypatch.setattr(Transformer, "__init__", transformers_init)

SentenceTransformer(
"sentence-transformers-testing/stsb-bert-tiny-safetensors",
attn_implementation="eager",
low_cpu_mem_usage=False
)

assert "low_cpu_mem_usage" in transformer_kwargs["model_args"]
assert transformer_kwargs["model_args"]["low_cpu_mem_usage"] is False
assert "attn_implementation" in transformer_kwargs["model_args"]
assert transformer_kwargs["model_args"]["attn_implementation"] == "eager"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test float16 support.")
def test_encode_fp16() -> None:
tiny_model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors")
Expand Down
Loading