Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sentence-transformers does not pick torch_dtype from model config #2579

Closed
satyamk7054 opened this issue Apr 6, 2024 · 9 comments · Fixed by #2578
Closed

sentence-transformers does not pick torch_dtype from model config #2579

satyamk7054 opened this issue Apr 6, 2024 · 9 comments · Fixed by #2578

Comments

@satyamk7054
Copy link
Contributor

Hi,

SentenceTransformers does not use the torch_dtype specified in the model config. New models like e5-mistral-7b-instruct are larger in size and use FP16 as the dtype. However, when loaded using SentenceTransformers, the model is loaded with FP32 dtype.

This is because HF transformers uses FP32 as the default unless torch_dtype='auto' is passed, as mentioned here.

I've tried to fix the issue in #2578.

Thank you

@tomaarsen
Copy link
Collaborator

Hello!

I'm afraid that I cannot accept your proposal in #2578. I understand that some models are recommended to be used in fp16 or bf16, but setting the torch_dtype to "auto" for everyone would result in breaking changes: people's embedding models would give different results when updating to the next version.

I think the proper solution is to more clearly expose the torch_dtype parameter such that users can use whatever dtype they choose when loading a model. #2426 introduces a proposal that tries to address this.

  • Tom Aarsen

@satyamk7054
Copy link
Contributor Author

Hello!

I'm afraid that I cannot accept your proposal in #2578. I understand that some models are recommended to be used in fp16 or bf16, but setting the torch_dtype to "auto" for everyone would result in breaking changes: people's embedding models would give different results when updating to the next version.

I think the proper solution is to more clearly expose the torch_dtype parameter such that users can use whatever dtype they choose when loading a model. #2426 introduces a proposal that tries to address this.

  • Tom Aarsen

@tomaarsen Thank you for your response!

That's a very valid point on compatibility. I think #2426 does not handle this particular issue. This is because HF transformers does not pick dtype from config unless you pass "auto" as the dtype when loading the model, as mentioned here.

Exposing kwargs in SentenceTransformers makes sense. However, if you pass dtype="auto" to that, at the moment it fails in Transformers because we pass those kwargs to both AutoConfig and AutoModel:

  • AutoConfig does not accept "auto" and seems to fail with AttributeError: 'str' object has no attribute 'is_floating_point'
    HuggingFace implementation even has an explicit check for handling this

One way to fix this would be to pass only the kwargs to from_pretrained instead of both the self-built config object as well as the kwargs (the AutoModel class will build the config object). Given we need to have checks like isinstance(config, T5Config), we'd need to build the config once using AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir) (note: no model_args passed), but we do not pass it to AutoModel.from_pretrained

If this makes sense, I can make the required changes.

@tomaarsen
Copy link
Collaborator

Exposing kwargs in SentenceTransformers makes sense. However, if you pass dtype="auto" to that, at the moment it fails in Transformers because we pass those kwargs to both AutoConfig and AutoModel:

Indeed, we have to be a bit smarter here I think. Perhaps the nicest solution is to avoid explicitly creating the config altogether, but just using AutoModel instead. I'm not sure if we reduce the configuration options for ST then, though.

@satyamk7054
Copy link
Contributor Author

satyamk7054 commented Apr 16, 2024

Exposing kwargs in SentenceTransformers makes sense. However, if you pass dtype="auto" to that, at the moment it fails in Transformers because we pass those kwargs to both AutoConfig and AutoModel:

Indeed, we have to be a bit smarter here I think. Perhaps the nicest solution is to avoid explicitly creating the config altogether, but just using AutoModel instead. I'm not sure if we reduce the configuration options for ST then, though.

Makes sense. How about I do what I suggested for this issue / PR (not pass args to AutoConfig) and as a follow-up we can improve the logic to just rely on AutoModel?

edit: I've updated my PR to allow passing model_args to ST and I've kept the existing params to maintain backwards compatibility

@satyamk7054
Copy link
Contributor Author

@tomaarsen I understand you want to use #2426 as one way to solve this. Could you please let me know when that PR will be merged?

If it'll take a long time, could we please enable dtype in the interim using my PR (open to any suggestions)? There's a significant difference in inference time using float32 v/s float16 (166.58ms v/s 35.8ms for a ~64 token string)

@tomaarsen
Copy link
Collaborator

In the last few days I had another look at #2426 to see whether I can include it in the next version, but I think that my consensus for now is that it is too hacky.
Instead, I think the best solution is to move forward with your PR: #2578. I do have some suggestions on API differences that I think might be best. I'll write out those comments in the PR.

I'll also request reviews from people who work on transformers as the API changes are fairly drastic.

  • Tom Aarsen

@fengshansi
Copy link

When will the next version be released?

@tomaarsen
Copy link
Collaborator

The next release, v3.0, will likely be this month. #2578 will be included.

  • Tom Aarsen

@rangehow
Copy link

rangehow commented Jan 6, 2025

model = SentenceTransformer(model_name, trust_remote_code=True, device=device, model_kwargs={"torch_dtype": "auto",}, tokenizer_kwargs={"model_max_length": 8192, "truncation": True})

When loading jina-embeddings-v3

'str' object has no attribute 'is_floating_point'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants