Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pix2StructImageProcessor requires torch>=1.11.0 #24270

Merged
merged 3 commits into from
Jun 14, 2023
Merged

Pix2StructImageProcessor requires torch>=1.11.0 #24270

merged 3 commits into from
Jun 14, 2023

Conversation

ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented Jun 14, 2023

What does this PR do?

So let's be nice to past CI ❤️ !

It's the argument antialias in interpolate only supported in torch>=1.11.0:

torch.nn.functional.interpolate(..., antialias=True)

@ydshieh ydshieh requested a review from amyeroberts June 14, 2023 09:01
mel_shrink = torch.nn.functional.interpolate(
mel, size=[chunk_frames, 64], mode="bilinear", align_corners=False, antialias=False
)
mel_shrink = torch.nn.functional.interpolate(mel, size=[chunk_frames, 64], mode="bilinear", align_corners=False)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the default value for antialias is False, no need to specify. Remove this so we don't need to require torch >= 1.11

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 14, 2023

The documentation is not available anymore as the PR was closed or merged.

@@ -192,7 +192,7 @@ def _random_mel_fusion(self, mel, total_frames, chunk_frames):

mel = torch.tensor(mel[None, None, :])
mel_shrink = torch.nn.functional.interpolate(
mel, size=[chunk_frames, 64], mode="bilinear", align_corners=False, antialias=False
mel, size=[chunk_frames, 64], mode="bilinear", align_corners=False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know that it's False by default - but have we confirmed that it had this default behaviour before being added? (I'm assuming yes, otherwise a lot of things would have broken)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

antialias=False is there since the commit (main) where clap is added, so nothing earlier to compare against. As the default value is False, the change in this PR would just keep what we have on main, so I think we are good.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!


logger = logging.get_logger(__name__)
DEFAULT_FONT_PATH = "ybelkada/fonts"


if is_torch_available() and not is_torch_greater_or_equal_than_1_11:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we'll want a way to specify versions using or with something similar to requires_xxx or requires_backends(...).

Generally, users don't read warnings, and so have an error closer to when the object is used - instantiation or method calling - could be useful.

Copy link
Collaborator Author

@ydshieh ydshieh Jun 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, but partially. In some modeling file (tapas, vilt), we have something

if not is_torch_greater_or_equal_than_xxx:
    logger.warning(
        f"You are using torch=={torch.__version__}, but torch>=xxx is required to use "
        "TapasModel. Please upgrade torch."
    )

However, I understand it's not really feasible to do what you describe in a modeling file, but it could be done in a processor file.

Happy to apply the suggestion to this single file image_processing_pix2struct.py.

Comment on lines +56 to +60
if is_torch_available() and not is_torch_greater_or_equal_than_1_11:
raise ImportError(
f"You are using torch=={torch.__version__}, but torch>=1.11.0 is required to use "
"Pix2StructImageProcessor. Please upgrade torch."
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts Could you take a final look? 🙏

ImportError is taken from sth we have before

class PytorchGELUTanh(nn.Module):
    def __init__(self):
        super().__init__()
        if version.parse(torch.__version__) < version.parse("1.12.0"):
            raise ImportError(
                f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use "
                "PytorchGELUTanh. Please upgrade torch."
            )

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! LGTM 👍

@ydshieh ydshieh merged commit a04ebc8 into main Jun 14, 2023
@ydshieh ydshieh deleted the fix_antialias branch June 14, 2023 15:05
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
* fix

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants