-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
⚠️ [CLAP] Fix dtype of logit scales in init #25682
Conversation
@@ -1956,8 +1955,8 @@ def __init__(self, config: ClapConfig): | |||
text_config = config.text_config | |||
audio_config = config.audio_config | |||
|
|||
self.logit_scale_a = nn.Parameter(torch.tensor(np.log(config.logit_scale_init_value))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The aforementioned behaviour is a result of the np.log
operation defaulting to float64
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the original code we might need to init in float64 then cast to float
if it makes a difference. No idea if the actual value save is in float64!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameters are initialised in float64 but are stored in float32 in the state dict
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned offline, never used in the original repo. Is a bit breaking but it is a bug fix. Let's just add one
The documentation is not available anymore as the PR was closed or merged. |
Note that in the original repo, the model is always cast to float16 for all training / inference. Thus, they likely never used the model in it's default dtype, and always relied on explicitly casting to float16 |
[CLAP] Fix dtype of logit scales
[CLAP] Fix dtype of logit scales
[CLAP] Fix dtype of logit scales
What does this PR do?
The dtype of the CLAP logit scale parameters was always float64 by default (even if the rest of the model was initialised in float32). This PR fixes the logit scales, such that they respect the default dtype of the model.