-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pix2StructImageProcessor
requires torch>=1.11.0
#24270
Conversation
mel_shrink = torch.nn.functional.interpolate( | ||
mel, size=[chunk_frames, 64], mode="bilinear", align_corners=False, antialias=False | ||
) | ||
mel_shrink = torch.nn.functional.interpolate(mel, size=[chunk_frames, 64], mode="bilinear", align_corners=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the default value for antialias
is False
, no need to specify. Remove this so we don't need to require torch >= 1.11
The documentation is not available anymore as the PR was closed or merged. |
@@ -192,7 +192,7 @@ def _random_mel_fusion(self, mel, total_frames, chunk_frames): | |||
|
|||
mel = torch.tensor(mel[None, None, :]) | |||
mel_shrink = torch.nn.functional.interpolate( | |||
mel, size=[chunk_frames, 64], mode="bilinear", align_corners=False, antialias=False | |||
mel, size=[chunk_frames, 64], mode="bilinear", align_corners=False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know that it's False
by default - but have we confirmed that it had this default behaviour before being added? (I'm assuming yes, otherwise a lot of things would have broken)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
antialias=False
is there since the commit (main
) where clap is added, so nothing earlier to compare against. As the default value is False
, the change in this PR would just keep what we have on main
, so I think we are good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing!
|
||
logger = logging.get_logger(__name__) | ||
DEFAULT_FONT_PATH = "ybelkada/fonts" | ||
|
||
|
||
if is_torch_available() and not is_torch_greater_or_equal_than_1_11: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we'll want a way to specify versions using or with something similar to requires_xxx
or requires_backends(...)
.
Generally, users don't read warnings, and so have an error closer to when the object is used - instantiation or method calling - could be useful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, but partially. In some modeling file (tapas
, vilt
), we have something
if not is_torch_greater_or_equal_than_xxx:
logger.warning(
f"You are using torch=={torch.__version__}, but torch>=xxx is required to use "
"TapasModel. Please upgrade torch."
)
However, I understand it's not really feasible to do what you describe in a modeling file, but it could be done in a processor file.
Happy to apply the suggestion to this single file image_processing_pix2struct.py
.
if is_torch_available() and not is_torch_greater_or_equal_than_1_11: | ||
raise ImportError( | ||
f"You are using torch=={torch.__version__}, but torch>=1.11.0 is required to use " | ||
"Pix2StructImageProcessor. Please upgrade torch." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@amyeroberts Could you take a final look? 🙏
ImportError
is taken from sth we have before
class PytorchGELUTanh(nn.Module):
def __init__(self):
super().__init__()
if version.parse(torch.__version__) < version.parse("1.12.0"):
raise ImportError(
f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use "
"PytorchGELUTanh. Please upgrade torch."
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! LGTM 👍
* fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
What does this PR do?
So let's be nice to past CI ❤️ !
It's the argument
antialias
ininterpolate
only supported intorch>=1.11.0
: