Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Populate torch_dtype from model to pipeline #28940

Merged

Conversation

B-Step62
Copy link
Contributor

@B-Step62 B-Step62 commented Feb 9, 2024

What does this PR do?

When constructing a pipeline from a model, it doesn't inherit the torch_dtype attribute from the model's dtype. This causes asymmetry of pipeline and model, as the model always inherit the torch_dtype when the pipeline is created with torch_dtype param. Sometimes it's a bit confusing that the pipeline's torch_dtype is None (which could mislead the dtype is default one), while the underlying model has different dtype.
Therefore, this PR updates the pipeline construction logic to set torch_dtype attribute on pipeline based on model's dtype.

Fixes #28817

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker @Rocketknight1

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype of the model is subject to changes. A property might be a lot simpler 😉

@B-Step62 B-Step62 force-pushed the populate-model-dtype-to-pipeline branch 2 times, most recently from d250e1c to 00a61c1 Compare March 1, 2024 14:33
@B-Step62 B-Step62 changed the base branch from main to test_composition_remote_tool March 1, 2024 15:33
@B-Step62 B-Step62 changed the base branch from test_composition_remote_tool to main March 1, 2024 15:33
@B-Step62
Copy link
Contributor Author

B-Step62 commented Mar 1, 2024

@ArthurZucker Sorry for my extremely late response...🙇

Using property to handle model dtype update makes sense. Revised the logic as such so would appreciate if you could take another look. Thanks!

@B-Step62 B-Step62 force-pushed the populate-model-dtype-to-pipeline branch from 02a4926 to c0e9e4a Compare March 1, 2024 15:43
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright LGTM, but I think we should just always return the dtype of the model

Comment on lines 221 to 222
# If dtype is NOT specified in the pipeline constructor, the property should NOT return type
# as we don't know if the pipeline supports torch_dtype
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really agree here, we should always return the dtype of the model just for consistence. The pipeline should error out normally and we should not assume that None torch_dtype == not supported

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with the consistency concern, but also wonder that the model's dtype doesn't always translate to torch_dtype, like a model loaded with tensorflow/jax backend. If that sounds ok, I will proceed with simply propagating the model dtype:)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have a FlaxLlamaModel then accessing the dtype should just be consistent imo

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker Make sense, I found that TFPretrainedModel doesn't have dtype property so my concern is not the case. My apologies for misunderstanding!
I've updated the code so would appreciate if you could take another look, thanks!

tests/pipelines/test_pipelines_common.py Outdated Show resolved Hide resolved
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
@B-Step62 B-Step62 force-pushed the populate-model-dtype-to-pipeline branch from 7cba0d5 to e9ae6c9 Compare March 11, 2024 15:22
@ArthurZucker
Copy link
Collaborator

Sorry was off for a week !

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker ArthurZucker merged commit 8e9a220 into huggingface:main Mar 25, 2024
21 checks passed
@B-Step62
Copy link
Contributor Author

B-Step62 commented Mar 26, 2024

@ArthurZucker Sorry for bothering you repeated times. I found an issue in this change.

I found that TFPretrainedModel doesn't have dtype property so my concern is not the case.

I've just tested a few models and turned out this is no accurate. TFPretrainedModel inherits dtype property from keras.Layer. Its value is not pytorch dtype so returning model dtype as torch_dtype property is inaccurate imo.

I think there are a few workaround,

  1. Check the type of model.dtype or not, and return nothing if it isn't torch.dtype.
  2. Check the model is instance of TFPretrainedModel and return nothing if so.
  3. Rename the property from torch_dtype to dtype.

I think the first one is most convenient for users, but please let me know your thoughts!

@ArthurZucker
Copy link
Collaborator

That would be a new feature! Given that before there were no dtype. It's not super necessary but yeah we can add a new one

@B-Step62
Copy link
Contributor Author

B-Step62 commented Mar 26, 2024

Yeah if the first option sounds good I'm happy to file a follow-up PR now before this change is released. I think it's more proper behavior for torch_dtype property (and what its type annotation says:)).

hovnatan pushed a commit to hovnatan/transformers that referenced this pull request Mar 27, 2024
* Populate torch_dtype from model to pipeline

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

* use property

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

* lint

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

* Remove default handling

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

---------

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
itazap pushed a commit that referenced this pull request May 14, 2024
* Populate torch_dtype from model to pipeline

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

* use property

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

* lint

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

* Remove default handling

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

---------

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Populate torch_dtype from a model to a pipeline
3 participants