-
Notifications
You must be signed in to change notification settings - Fork 432
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Torchao version check changes/BC import of TensorCoreTiledLayout #1812
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,11 +5,23 @@ | |
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
from torchtune.utils._version import torch_version_ge | ||
from torchtune.utils._version import ( | ||
_get_torchao_version, | ||
_is_fbcode, | ||
_nightly_version_ge, | ||
torch_version_ge, | ||
) | ||
|
||
# We can only use flex attention / BlockMask if torch version >= 2.5.0 and GPU is Turing / SM75 and above | ||
_SUPPORTS_FLEX_ATTENTION = ( | ||
torch_version_ge("2.5.0") | ||
and torch.cuda.is_available() | ||
and torch.cuda.get_device_capability() >= (7, 5) | ||
) | ||
|
||
torchao_version = _get_torchao_version() | ||
|
||
_NEW_TENSOR_CORE_TILED_LAYOUT_API = not _is_fbcode() and ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mind adding a quick comment explaining this similar to flex attention above? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also nit: can you name this as something that implies a yes/no true/false answer (like |
||
("dev" not in torchao_version and torchao_version >= "0.6.0") | ||
or ("dev" in torchao_version and _nightly_version_ge(torchao_version, "2024-10-10")) | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,8 +3,14 @@ | |
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from datetime import datetime | ||
from typing import Optional | ||
|
||
import torch | ||
|
||
import torchao | ||
|
||
|
||
def torch_version_ge(version: str) -> bool: | ||
""" | ||
|
@@ -23,3 +29,36 @@ def torch_version_ge(version: str) -> bool: | |
True | ||
""" | ||
return version in torch.__version__ or torch.__version__ >= version | ||
|
||
|
||
def _is_fbcode(): | ||
return not hasattr(torch.version, "git_version") | ||
|
||
|
||
def _nightly_version_ge(ao_version_str: str, date: str) -> bool: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this doesn't generalize to pytorch nightly version? if's ao specific, let's include ao in the function name There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually it should generalize since PyTorch versions use the same format There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if so, should we change the variable name to indicate general use? |
||
""" | ||
Compare a torchao nightly version to a date of the form | ||
%Y-%m-%d. | ||
|
||
Returns True if the nightly version is greater than or equal to | ||
the date, False otherwise | ||
""" | ||
ao_datetime = datetime.strptime( | ||
ao_version_str.split("+")[0].split("dev")[1], "%Y%m%d" | ||
) | ||
return ao_datetime >= datetime.strptime(date, "%Y-%m-%d") | ||
|
||
|
||
def _get_torchao_version() -> Optional[str]: | ||
""" | ||
Get torchao version. | ||
|
||
Checks: | ||
1) is_fbcode, then | ||
2) torchao.__version__ (only defined for torchao >= 0.3.0), then | ||
|
||
""" | ||
if _is_fbcode(): | ||
return None | ||
else: | ||
return torchao.__version__ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
couldn't this return None if is_fbcode? a bit awkward to return either a version string or None, but then you check for fbcode again below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah good point. Just did a check and in fbcode
torchao.__version__ == 'unknown'
, so it at least won't throw an error. Then I can just explicitly sayao_version = torchao.__version__
in_import_guard.py
, gate behind_is_fbcode()
in_USE_NEW_TENSOR_CORE_TILED_LAYOUT_API
, and delete_get_torchao_version
altogether