-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Populate torch_dtype from model to pipeline #28940
Populate torch_dtype from model to pipeline #28940
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dtype of the model is subject to changes. A property might be a lot simpler 😉
d250e1c
to
00a61c1
Compare
@ArthurZucker Sorry for my extremely late response...🙇 Using property to handle model dtype update makes sense. Revised the logic as such so would appreciate if you could take another look. Thanks! |
02a4926
to
c0e9e4a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright LGTM, but I think we should just always return the dtype of the model
# If dtype is NOT specified in the pipeline constructor, the property should NOT return type | ||
# as we don't know if the pipeline supports torch_dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really agree here, we should always return the dtype of the model just for consistence. The pipeline should error out normally and we should not assume that None torch_dtype == not supported
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with the consistency concern, but also wonder that the model's dtype doesn't always translate to torch_dtype, like a model loaded with tensorflow/jax backend. If that sounds ok, I will proceed with simply propagating the model dtype:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you have a FlaxLlamaModel
then accessing the dtype should just be consistent imo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ArthurZucker Make sense, I found that TFPretrainedModel doesn't have dtype
property so my concern is not the case. My apologies for misunderstanding!
I've updated the code so would appreciate if you could take another look, thanks!
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
7cba0d5
to
e9ae6c9
Compare
Sorry was off for a week ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@ArthurZucker Sorry for bothering you repeated times. I found an issue in this change.
I've just tested a few models and turned out this is no accurate. TFPretrainedModel inherits I think there are a few workaround,
I think the first one is most convenient for users, but please let me know your thoughts! |
That would be a new feature! Given that before there were no |
Yeah if the first option sounds good I'm happy to file a follow-up PR now before this change is released. I think it's more proper behavior for |
* Populate torch_dtype from model to pipeline Signed-off-by: B-Step62 <yuki.watanabe@databricks.com> * use property Signed-off-by: B-Step62 <yuki.watanabe@databricks.com> * lint Signed-off-by: B-Step62 <yuki.watanabe@databricks.com> * Remove default handling Signed-off-by: B-Step62 <yuki.watanabe@databricks.com> --------- Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
* Populate torch_dtype from model to pipeline Signed-off-by: B-Step62 <yuki.watanabe@databricks.com> * use property Signed-off-by: B-Step62 <yuki.watanabe@databricks.com> * lint Signed-off-by: B-Step62 <yuki.watanabe@databricks.com> * Remove default handling Signed-off-by: B-Step62 <yuki.watanabe@databricks.com> --------- Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
What does this PR do?
When constructing a pipeline from a model, it doesn't inherit the
torch_dtype
attribute from the model's dtype. This causes asymmetry of pipeline and model, as the model always inherit thetorch_dtype
when the pipeline is created withtorch_dtype
param. Sometimes it's a bit confusing that the pipeline'storch_dtype
isNone
(which could mislead the dtype is default one), while the underlying model has different dtype.Therefore, this PR updates the pipeline construction logic to set
torch_dtype
attribute on pipeline based on model's dtype.Fixes #28817
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker @Rocketknight1