Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing model_args to ST #2612

Closed
wants to merge 8 commits into from
13 changes: 11 additions & 2 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import shutil
from collections import OrderedDict
import warnings
from typing import List, Dict, Literal, Tuple, Iterable, Type, Union, Callable, Optional, TYPE_CHECKING
from typing import List, Dict, Literal, Tuple, Iterable, Type, Union, Callable, Optional, TYPE_CHECKING, Any
import numpy as np
from numpy import ndarray
import transformers
Expand Down Expand Up @@ -73,6 +73,7 @@ class SentenceTransformer(nn.Sequential):
:param token: Hugging Face authentication token to download private models.
:param truncate_dim: The dimension to truncate sentence embeddings to. `None` does no truncation. Truncation is
only applicable during inference when `.encode` is called.
:param model_args: Arguments (key, value pairs) passed to the Huggingface Transformers model.
"""

def __init__(
Expand All @@ -89,6 +90,7 @@ def __init__(
token: Optional[Union[bool, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
truncate_dim: Optional[int] = None,
model_args: Optional[Dict[str, Any]] = None
):
# Note: self._load_sbert_model can also update `self.prompts` and `self.default_prompt_name`
self.prompts = prompts or {}
Expand Down Expand Up @@ -208,6 +210,7 @@ def __init__(
cache_folder=cache_folder,
revision=revision,
trust_remote_code=trust_remote_code,
model_args=model_args,
local_files_only=local_files_only,
)
else:
Expand All @@ -217,6 +220,7 @@ def __init__(
cache_folder=cache_folder,
revision=revision,
trust_remote_code=trust_remote_code,
model_args=model_args,
local_files_only=local_files_only,
)

Expand Down Expand Up @@ -1195,6 +1199,7 @@ def _load_auto_model(
cache_folder: Optional[str],
revision: Optional[str] = None,
trust_remote_code: bool = False,
model_args: Optional[Dict[str, Any]] = None,
local_files_only: bool = False,
):
"""
Expand All @@ -1213,7 +1218,7 @@ def _load_auto_model(
"trust_remote_code": trust_remote_code,
"revision": revision,
"local_files_only": local_files_only,
},
} | (model_args or {}),
tokenizer_args={
"token": token,
"trust_remote_code": trust_remote_code,
Expand All @@ -1231,6 +1236,7 @@ def _load_sbert_model(
cache_folder: Optional[str],
revision: Optional[str] = None,
trust_remote_code: bool = False,
model_args: Optional[Dict[str, Any]] = None,
local_files_only: bool = False,
):
"""
Expand Down Expand Up @@ -1332,6 +1338,9 @@ def _load_sbert_model(
kwargs["model_args"].update(hub_kwargs)
else:
kwargs["model_args"] = hub_kwargs
if model_args is not None:
kwargs["model_args"].update(model_args)

if "tokenizer_args" in kwargs:
kwargs["tokenizer_args"].update(hub_kwargs)
else:
Expand Down
22 changes: 10 additions & 12 deletions sentence_transformers/models/Transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ def __init__(
self.config_keys = ["max_seq_length", "do_lower_case"]
self.do_lower_case = do_lower_case

config = AutoConfig.from_pretrained(model_name_or_path, **model_args, cache_dir=cache_dir)
self._load_model(model_name_or_path, config, cache_dir, **model_args)
self._load_model(model_name_or_path, cache_dir, **model_args)

self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name_or_path if tokenizer_name_or_path is not None else model_name_or_path,
Expand All @@ -55,33 +54,32 @@ def __init__(
if tokenizer_name_or_path is not None:
self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__

def _load_model(self, model_name_or_path, config, cache_dir, **model_args):
def _load_model(self, model_name_or_path, cache_dir, **model_args):
"""Loads the transformer model"""
config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir)
if isinstance(config, T5Config):
self._load_t5_model(model_name_or_path, config, cache_dir, **model_args)
self._load_t5_model(model_name_or_path, cache_dir, **model_args)
elif isinstance(config, MT5Config):
self._load_mt5_model(model_name_or_path, config, cache_dir, **model_args)
self._load_mt5_model(model_name_or_path, cache_dir, **model_args)
else:
self.auto_model = AutoModel.from_pretrained(
model_name_or_path, config=config, cache_dir=cache_dir, **model_args
)
self.auto_model = AutoModel.from_pretrained(model_name_or_path, cache_dir=cache_dir, **model_args)

def _load_t5_model(self, model_name_or_path, config, cache_dir, **model_args):
def _load_t5_model(self, model_name_or_path, cache_dir, **model_args):
"""Loads the encoder model from T5"""
from transformers import T5EncoderModel

T5EncoderModel._keys_to_ignore_on_load_unexpected = ["decoder.*"]
self.auto_model = T5EncoderModel.from_pretrained(
model_name_or_path, config=config, cache_dir=cache_dir, **model_args
model_name_or_path, cache_dir=cache_dir, **model_args
)

def _load_mt5_model(self, model_name_or_path, config, cache_dir, **model_args):
def _load_mt5_model(self, model_name_or_path, cache_dir, **model_args):
"""Loads the encoder model from T5"""
from transformers import MT5EncoderModel

MT5EncoderModel._keys_to_ignore_on_load_unexpected = ["decoder.*"]
self.auto_model = MT5EncoderModel.from_pretrained(
model_name_or_path, config=config, cache_dir=cache_dir, **model_args
model_name_or_path, cache_dir=cache_dir, **model_args
)

def __repr__(self):
Expand Down
16 changes: 16 additions & 0 deletions tests/test_sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,22 @@ def test_save_load_prompts() -> None:
assert fresh_model.default_prompt_name == "query"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test float16 support.")
def test_load_with_model_args() -> None:
model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors")

assert model.encode(["Hello there!"], convert_to_tensor=True).dtype == torch.float32

with tempfile.TemporaryDirectory() as tmp_folder:
fp16_model_dir = Path(tmp_folder) / "fp16_model"
model.half()
model.save(str(fp16_model_dir))
del model

fp16_model = SentenceTransformer(str(fp16_model_dir), model_args={"torch_dtype": "auto"})
assert fp16_model.encode(["Hello there!"], convert_to_tensor=True).dtype == torch.float16


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test float16 support.")
def test_encode_fp16() -> None:
tiny_model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors")
Expand Down