Skip to content

Commit

Permalink
[fix] Add dtype cast for modules other than Transformer (#2889)
Browse files Browse the repository at this point in the history
* add dtype cast to other modules

* check if kwargs are None

* format

* Cast the entire model to the dtype of the first parameter

This should be a tad safer. The original fix failed for torch_dtype = "auto", "float16", "bfloat16", and would not be receptive to models automatically loaded in fp16 via a custom Module

---------

Co-authored-by: Tom Aarsen <Cubiegamedev@gmail.com>
  • Loading branch information
ir2718 and tomaarsen authored Sep 10, 2024
1 parent 6257cb0 commit 597d5ed
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
11 changes: 11 additions & 0 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,16 @@ def __init__(

super().__init__(modules)

# Ensure all tensors in the model are of the same dtype as the first tensor
# This is necessary if the first module has been given a lower precision via
# model_kwargs["torch_dtype"]. The rest of the model should be loaded in the same dtype
# See #2887 for more details
try:
dtype = next(self.parameters()).dtype
self.to(dtype)
except StopIteration:
pass

self.to(device)
self.is_hpu_graph_enabled = False

Expand Down Expand Up @@ -1651,6 +1661,7 @@ def _load_sbert_model(
local_files_only=local_files_only,
)
module = module_class.load(module_path)

modules[module_config["name"]] = module
module_kwargs[module_config["name"]] = module_config.get("kwargs", [])

Expand Down
12 changes: 12 additions & 0 deletions tests/test_sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,18 @@ def test_to() -> None:
assert model.device.type == "cpu", "Ensure that setting `_target_device` doesn't crash."


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test fp16 and bf16 inference.")
@pytest.mark.parametrize("torch_dtype", ["auto", "float16", "bfloat16", torch.float16, torch.bfloat16])
def test_torch_dtype(torch_dtype) -> None:
model = SentenceTransformer(
"sentence-transformers-testing/all-nli-bert-tiny-dense",
device="cuda",
model_kwargs={"torch_dtype": torch_dtype},
)
embedding = model.encode("Test sentence")
assert embedding.shape[-1] == model.get_sentence_embedding_dimension()


def test_push_to_hub(monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture) -> None:
def mock_create_repo(self, repo_id, **kwargs):
return RepoUrl(f"https://huggingface.co/{repo_id}")
Expand Down

0 comments on commit 597d5ed

Please sign in to comment.