Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Add dtype cast for modules other than Transformer #2889

Merged
merged 5 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,16 @@ def __init__(

super().__init__(modules)

# Ensure all tensors in the model are of the same dtype as the first tensor
# This is necessary if the first module has been given a lower precision via
# model_kwargs["torch_dtype"]. The rest of the model should be loaded in the same dtype
# See #2887 for more details
try:
dtype = next(self.parameters()).dtype
self.to(dtype)
except StopIteration:
pass

self.to(device)
self.is_hpu_graph_enabled = False

Expand Down Expand Up @@ -1581,6 +1591,7 @@ def _load_sbert_model(
local_files_only=local_files_only,
)
module = module_class.load(module_path)

modules[module_config["name"]] = module

if revision is None:
Expand Down
12 changes: 12 additions & 0 deletions tests/test_sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,18 @@ def test_to() -> None:
assert model.device.type == "cpu", "Ensure that setting `_target_device` doesn't crash."


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test fp16 and bf16 inference.")
@pytest.mark.parametrize("torch_dtype", ["auto", "float16", "bfloat16", torch.float16, torch.bfloat16])
def test_torch_dtype(torch_dtype) -> None:
model = SentenceTransformer(
"sentence-transformers-testing/all-nli-bert-tiny-dense",
device="cuda",
model_kwargs={"torch_dtype": torch_dtype},
)
embedding = model.encode("Test sentence")
assert embedding.shape[-1] == model.get_sentence_embedding_dimension()


def test_push_to_hub(monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture) -> None:
def mock_create_repo(self, repo_id, **kwargs):
return RepoUrl(f"https://huggingface.co/{repo_id}")
Expand Down