Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing model_args to ST #2612

Closed
wants to merge 8 commits into from

Conversation

satyamk7054
Copy link
Contributor

@satyamk7054 satyamk7054 commented Apr 24, 2024

Summary

Allow passing model_args to ST

Details

This fixes #2579.
New models like e5-mistral-7b-instruct use FP16 as the dtype.
However, when loaded using sentence_transformers, they are loaded with FP32. This is because
HF transformers uses FP32 as the default unless torch_dtype='auto' is passed, as mentioned here.

Passing "auto" to model_args does not work because of the below error. Moreover, the SentenceTransformers class does not currently expose a model_args param.

cls = <class 'transformers.models.bert.modeling_bert.BertModel'>, dtype = 'auto'

    @classmethod
    def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
        """
        Change the default dtype and return the previous one. This is needed when wanting to instantiate the model
        under specific dtype.
    
        Args:
            dtype (`torch.dtype`):
                a floating dtype to set to.
    
        Returns:
            `torch.dtype`: the original `dtype` that can be used to restore `torch.set_default_dtype(dtype)` if it was
            modified. If it wasn't, returns `None`.
    
        Note `set_default_dtype` currently only works with floating-point types and asserts if for example,
        `torch.int64` is passed. So if a non-float `dtype` is passed this functions will throw an exception.
        """
>       if not dtype.is_floating_point:
E       AttributeError: 'str' object has no attribute 'is_floating_point'

../venv/lib/python3.10/site-packages/transformers/modeling_utils.py:1412: AttributeError

Testing Done

Added a new unit tests that validates the dtype of the loaded model using the embedding tensor
created with it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

sentence-transformers does not pick torch_dtype from model config
1 participant