From 7de0696cb7caa7279eee0074a09fa2b72396ccc3 Mon Sep 17 00:00:00 2001 From: satyamk7054 <43010011+satyamk7054@users.noreply.github.com> Date: Sat, 6 Apr 2024 07:15:48 +0000 Subject: [PATCH 1/6] Use torch_dtype='auto' when loading auto-class --- sentence_transformers/models/Transformer.py | 2 +- tests/test_sentence_transformer.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/sentence_transformers/models/Transformer.py b/sentence_transformers/models/Transformer.py index 3e72c865c..584e9a63b 100644 --- a/sentence_transformers/models/Transformer.py +++ b/sentence_transformers/models/Transformer.py @@ -63,7 +63,7 @@ def _load_model(self, model_name_or_path, config, cache_dir, **model_args): self._load_mt5_model(model_name_or_path, config, cache_dir, **model_args) else: self.auto_model = AutoModel.from_pretrained( - model_name_or_path, config=config, cache_dir=cache_dir, **model_args + model_name_or_path, config=config, cache_dir=cache_dir, **model_args, torch_dtype="auto" ) def _load_t5_model(self, model_name_or_path, config, cache_dir, **model_args): diff --git a/tests/test_sentence_transformer.py b/tests/test_sentence_transformer.py index 7328c6afe..409cd5bd4 100644 --- a/tests/test_sentence_transformer.py +++ b/tests/test_sentence_transformer.py @@ -335,6 +335,21 @@ def test_save_load_prompts() -> None: assert fresh_model.default_prompt_name == "query" +def test_load_with_auto_dtype() -> None: + model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors") + + assert model.encode(["Hello there!"], convert_to_tensor=True).dtype == torch.float32 + + with tempfile.TemporaryDirectory() as tmp_folder: + fp16_model_dir = Path(tmp_folder) / "fp16_model" + model.half() + model.save(str(fp16_model_dir)) + del model + + fp16_model = SentenceTransformer(str(fp16_model_dir)) + assert fp16_model.encode(["Hello there!"], convert_to_tensor=True).dtype == torch.float16 + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test float16 support.") def test_encode_fp16() -> None: tiny_model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors") From 97cb87bfa69bed271e085678ac958df058869c8a Mon Sep 17 00:00:00 2001 From: satyamk7054 <43010011+satyamk7054@users.noreply.github.com> Date: Sat, 6 Apr 2024 09:11:13 +0000 Subject: [PATCH 2/6] Use torch_dtype='auto' as default if model_args doesn't have it --- sentence_transformers/models/Transformer.py | 4 +++- tests/test_sentence_transformer.py | 13 ++++++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/sentence_transformers/models/Transformer.py b/sentence_transformers/models/Transformer.py index 584e9a63b..407846ec6 100644 --- a/sentence_transformers/models/Transformer.py +++ b/sentence_transformers/models/Transformer.py @@ -62,8 +62,10 @@ def _load_model(self, model_name_or_path, config, cache_dir, **model_args): elif isinstance(config, MT5Config): self._load_mt5_model(model_name_or_path, config, cache_dir, **model_args) else: + if "torch_dtype" not in model_args: + model_args["torch_dtype"] = "auto" self.auto_model = AutoModel.from_pretrained( - model_name_or_path, config=config, cache_dir=cache_dir, **model_args, torch_dtype="auto" + model_name_or_path, config=config, cache_dir=cache_dir, **model_args ) def _load_t5_model(self, model_name_or_path, config, cache_dir, **model_args): diff --git a/tests/test_sentence_transformer.py b/tests/test_sentence_transformer.py index 409cd5bd4..e2a5a16b1 100644 --- a/tests/test_sentence_transformer.py +++ b/tests/test_sentence_transformer.py @@ -335,7 +335,7 @@ def test_save_load_prompts() -> None: assert fresh_model.default_prompt_name == "query" -def test_load_with_auto_dtype() -> None: +def test_load_defaults_to_auto_dtype() -> None: model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors") assert model.encode(["Hello there!"], convert_to_tensor=True).dtype == torch.float32 @@ -350,6 +350,17 @@ def test_load_with_auto_dtype() -> None: assert fp16_model.encode(["Hello there!"], convert_to_tensor=True).dtype == torch.float16 +def test_load_with_dtype_arg() -> None: + transformer = Transformer( + "sentence-transformers-testing/stsb-bert-tiny-safetensors", + model_args={"torch_dtype": torch.float16}, + ) + pooling = Pooling(transformer.get_word_embedding_dimension()) + pytorch_model = SentenceTransformer(modules=[transformer, pooling]) + + assert pytorch_model.encode(["Hello there!"], convert_to_tensor=True).dtype == torch.float16 + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test float16 support.") def test_encode_fp16() -> None: tiny_model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors") From 6b42bb5e0c03a416e2b94ae0cad11beebd8b87af Mon Sep 17 00:00:00 2001 From: satyamk7054 <43010011+satyamk7054@users.noreply.github.com> Date: Wed, 17 Apr 2024 00:30:28 +0000 Subject: [PATCH 3/6] Allow passing model_args for ST --- sentence_transformers/SentenceTransformer.py | 13 +++++++++++-- sentence_transformers/models/Transformer.py | 12 ++++-------- tests/test_sentence_transformer.py | 15 ++------------- 3 files changed, 17 insertions(+), 23 deletions(-) diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index 3163bb10a..c30c72630 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -4,7 +4,7 @@ import shutil from collections import OrderedDict import warnings -from typing import List, Dict, Literal, Tuple, Iterable, Type, Union, Callable, Optional, TYPE_CHECKING +from typing import List, Dict, Literal, Tuple, Iterable, Type, Union, Callable, Optional, TYPE_CHECKING, Any import numpy as np from numpy import ndarray import transformers @@ -82,6 +82,7 @@ def __init__( revision: Optional[str] = None, token: Optional[Union[bool, str]] = None, use_auth_token: Optional[Union[bool, str]] = None, + model_args: Optional[Dict[str, Any]] = None ): # Note: self._load_sbert_model can also update `self.prompts` and `self.default_prompt_name` self.prompts = prompts or {} @@ -194,6 +195,7 @@ def __init__( cache_folder=cache_folder, revision=revision, trust_remote_code=trust_remote_code, + model_args=model_args ) else: modules = self._load_auto_model( @@ -202,6 +204,7 @@ def __init__( cache_folder=cache_folder, revision=revision, trust_remote_code=trust_remote_code, + model_args=model_args ) if modules is not None and not isinstance(modules, OrderedDict): @@ -1122,6 +1125,7 @@ def _load_auto_model( cache_folder: Optional[str], revision: Optional[str] = None, trust_remote_code: bool = False, + model_args: Optional[Dict[str, Any]] = None ): """ Creates a simple Transformer + Mean Pooling model and returns the modules @@ -1134,7 +1138,8 @@ def _load_auto_model( transformer_model = Transformer( model_name_or_path, cache_dir=cache_folder, - model_args={"token": token, "trust_remote_code": trust_remote_code, "revision": revision}, + model_args={"token": token, "trust_remote_code": trust_remote_code, "revision": revision} | ( + model_args or {}), tokenizer_args={"token": token, "trust_remote_code": trust_remote_code, "revision": revision}, ) pooling_model = Pooling(transformer_model.get_word_embedding_dimension(), "mean") @@ -1147,6 +1152,7 @@ def _load_sbert_model( cache_folder: Optional[str], revision: Optional[str] = None, trust_remote_code: bool = False, + model_args: Optional[Dict[str, Any]] = None ): """ Loads a full sentence-transformers model @@ -1226,6 +1232,9 @@ def _load_sbert_model( kwargs["model_args"].update(hub_kwargs) else: kwargs["model_args"] = hub_kwargs + if model_args is not None: + kwargs["model_args"].update(model_args) + if "tokenizer_args" in kwargs: kwargs["tokenizer_args"].update(hub_kwargs) else: diff --git a/sentence_transformers/models/Transformer.py b/sentence_transformers/models/Transformer.py index 407846ec6..339d79489 100644 --- a/sentence_transformers/models/Transformer.py +++ b/sentence_transformers/models/Transformer.py @@ -32,8 +32,7 @@ def __init__( self.config_keys = ["max_seq_length", "do_lower_case"] self.do_lower_case = do_lower_case - config = AutoConfig.from_pretrained(model_name_or_path, **model_args, cache_dir=cache_dir) - self._load_model(model_name_or_path, config, cache_dir, **model_args) + self._load_model(model_name_or_path, cache_dir, **model_args) self.tokenizer = AutoTokenizer.from_pretrained( tokenizer_name_or_path if tokenizer_name_or_path is not None else model_name_or_path, @@ -55,18 +54,15 @@ def __init__( if tokenizer_name_or_path is not None: self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__ - def _load_model(self, model_name_or_path, config, cache_dir, **model_args): + def _load_model(self, model_name_or_path, cache_dir, **model_args): """Loads the transformer model""" + config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir) if isinstance(config, T5Config): self._load_t5_model(model_name_or_path, config, cache_dir, **model_args) elif isinstance(config, MT5Config): self._load_mt5_model(model_name_or_path, config, cache_dir, **model_args) else: - if "torch_dtype" not in model_args: - model_args["torch_dtype"] = "auto" - self.auto_model = AutoModel.from_pretrained( - model_name_or_path, config=config, cache_dir=cache_dir, **model_args - ) + self.auto_model = AutoModel.from_pretrained(model_name_or_path, cache_dir=cache_dir, **model_args) def _load_t5_model(self, model_name_or_path, config, cache_dir, **model_args): """Loads the encoder model from T5""" diff --git a/tests/test_sentence_transformer.py b/tests/test_sentence_transformer.py index e2a5a16b1..6f1deaac5 100644 --- a/tests/test_sentence_transformer.py +++ b/tests/test_sentence_transformer.py @@ -335,7 +335,7 @@ def test_save_load_prompts() -> None: assert fresh_model.default_prompt_name == "query" -def test_load_defaults_to_auto_dtype() -> None: +def test_load_with_model_args() -> None: model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors") assert model.encode(["Hello there!"], convert_to_tensor=True).dtype == torch.float32 @@ -346,21 +346,10 @@ def test_load_defaults_to_auto_dtype() -> None: model.save(str(fp16_model_dir)) del model - fp16_model = SentenceTransformer(str(fp16_model_dir)) + fp16_model = SentenceTransformer(str(fp16_model_dir), model_args={"torch_dtype": "auto"}) assert fp16_model.encode(["Hello there!"], convert_to_tensor=True).dtype == torch.float16 -def test_load_with_dtype_arg() -> None: - transformer = Transformer( - "sentence-transformers-testing/stsb-bert-tiny-safetensors", - model_args={"torch_dtype": torch.float16}, - ) - pooling = Pooling(transformer.get_word_embedding_dimension()) - pytorch_model = SentenceTransformer(modules=[transformer, pooling]) - - assert pytorch_model.encode(["Hello there!"], convert_to_tensor=True).dtype == torch.float16 - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test float16 support.") def test_encode_fp16() -> None: tiny_model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors") From 6106364fdea41d63b4d4827bd6ff76ae6e60a719 Mon Sep 17 00:00:00 2001 From: satyamk7054 <43010011+satyamk7054@users.noreply.github.com> Date: Wed, 17 Apr 2024 00:46:10 +0000 Subject: [PATCH 4/6] Make same change for T5 and MT5 --- sentence_transformers/models/Transformer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sentence_transformers/models/Transformer.py b/sentence_transformers/models/Transformer.py index 339d79489..787b8c6df 100644 --- a/sentence_transformers/models/Transformer.py +++ b/sentence_transformers/models/Transformer.py @@ -58,28 +58,28 @@ def _load_model(self, model_name_or_path, cache_dir, **model_args): """Loads the transformer model""" config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir) if isinstance(config, T5Config): - self._load_t5_model(model_name_or_path, config, cache_dir, **model_args) + self._load_t5_model(model_name_or_path, cache_dir, **model_args) elif isinstance(config, MT5Config): - self._load_mt5_model(model_name_or_path, config, cache_dir, **model_args) + self._load_mt5_model(model_name_or_path, cache_dir, **model_args) else: self.auto_model = AutoModel.from_pretrained(model_name_or_path, cache_dir=cache_dir, **model_args) - def _load_t5_model(self, model_name_or_path, config, cache_dir, **model_args): + def _load_t5_model(self, model_name_or_path, cache_dir, **model_args): """Loads the encoder model from T5""" from transformers import T5EncoderModel T5EncoderModel._keys_to_ignore_on_load_unexpected = ["decoder.*"] self.auto_model = T5EncoderModel.from_pretrained( - model_name_or_path, config=config, cache_dir=cache_dir, **model_args + model_name_or_path, cache_dir=cache_dir, **model_args ) - def _load_mt5_model(self, model_name_or_path, config, cache_dir, **model_args): + def _load_mt5_model(self, model_name_or_path, cache_dir, **model_args): """Loads the encoder model from T5""" from transformers import MT5EncoderModel MT5EncoderModel._keys_to_ignore_on_load_unexpected = ["decoder.*"] self.auto_model = MT5EncoderModel.from_pretrained( - model_name_or_path, config=config, cache_dir=cache_dir, **model_args + model_name_or_path, cache_dir=cache_dir, **model_args ) def __repr__(self): From 48c5ee6ce98301be89e5a82c2dd875aeafb22dc5 Mon Sep 17 00:00:00 2001 From: satyamk7054 <43010011+satyamk7054@users.noreply.github.com> Date: Wed, 17 Apr 2024 02:30:28 +0000 Subject: [PATCH 5/6] Update method documentation --- sentence_transformers/SentenceTransformer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index c30c72630..4fbc90822 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -68,6 +68,7 @@ class SentenceTransformer(nn.Sequential): This option should only be set to True for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. :param token: Hugging Face authentication token to download private models. + :param model_args: Arguments (key, value pairs) passed to the Huggingface Transformers model """ def __init__( From c2566ef8922a85079d3f389408788b67ac99067e Mon Sep 17 00:00:00 2001 From: satyamk7054 <43010011+satyamk7054@users.noreply.github.com> Date: Wed, 17 Apr 2024 02:55:20 +0000 Subject: [PATCH 6/6] Disable test if CUDA is not available --- tests/test_sentence_transformer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_sentence_transformer.py b/tests/test_sentence_transformer.py index 6f1deaac5..59e2ea164 100644 --- a/tests/test_sentence_transformer.py +++ b/tests/test_sentence_transformer.py @@ -335,6 +335,7 @@ def test_save_load_prompts() -> None: assert fresh_model.default_prompt_name == "query" +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test float16 support.") def test_load_with_model_args() -> None: model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors")