-
Notifications
You must be signed in to change notification settings - Fork 431
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Torchao version check changes/BC import of TensorCoreTiledLayout #1812
Torchao version check changes/BC import of TensorCoreTiledLayout #1812
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1812
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 024b812 with merge base c5b7386 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torchtune/utils/_import_guard.py
Outdated
|
||
torchao_version = _get_torchao_version() | ||
|
||
_NEW_TENSOR_CORE_TILED_LAYOUT_API = not _is_fbcode() and ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mind adding a quick comment explaining this similar to flex attention above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also nit: can you name this as something that implies a yes/no true/false answer (like _USE_NEW_TENSOR_CORE_TILED_LAYOUT_API
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
few minor comments, but no concerns
return not hasattr(torch.version, "git_version") | ||
|
||
|
||
def _nightly_version_ge(ao_version_str: str, date: str) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this doesn't generalize to pytorch nightly version? if's ao specific, let's include ao in the function name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually it should generalize since PyTorch versions use the same format
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if so, should we change the variable name to indicate general use?
torchtune/utils/_import_guard.py
Outdated
|
||
# We can only use flex attention / BlockMask if torch version >= 2.5.0 and GPU is Turing / SM75 and above | ||
_SUPPORTS_FLEX_ATTENTION = ( | ||
torch_version_ge("2.5.0") | ||
and torch.cuda.is_available() | ||
and torch.cuda.get_device_capability() >= (7, 5) | ||
) | ||
|
||
torchao_version = _get_torchao_version() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
couldn't this return None if is_fbcode? a bit awkward to return either a version string or None, but then you check for fbcode again below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah good point. Just did a check and in fbcode torchao.__version__ == 'unknown'
, so it at least won't throw an error. Then I can just explicitly say ao_version = torchao.__version__
in _import_guard.py
, gate behind _is_fbcode()
in _USE_NEW_TENSOR_CORE_TILED_LAYOUT_API
, and delete _get_torchao_version
altogether
torchtune/utils/_import_guard.py
Outdated
|
||
torchao_version = _get_torchao_version() | ||
|
||
_NEW_TENSOR_CORE_TILED_LAYOUT_API = not _is_fbcode() and ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also nit: can you name this as something that implies a yes/no true/false answer (like _USE_NEW_TENSOR_CORE_TILED_LAYOUT_API
)
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1812 +/- ##
===========================================
- Coverage 67.05% 25.72% -41.34%
===========================================
Files 305 304 -1
Lines 15937 16000 +63
===========================================
- Hits 10687 4116 -6571
- Misses 5250 11884 +6634 ☔ View full report in Codecov by Sentry. |
ah, our diff train is not landed yet, should we land first? it will break internal torchtune I think |
Yeah that sounds good we can wait, I don't have full context on if this is blocking anything |
@jerryzh168 can you clarify? I don’t fully understand why this PR would break internal. It checks fbcode and uses the old API, which is what I currently see in internal |
oh what I meant is that torchao internal is not updated with the new name yet If this PR can fix both internal and external, maybe this one should land first. |
__version__
defined_NEW_TENSOR_CORE_TILED_LAYOUT_API
based on the following conditions:TensorCoreTiledLayoutType
in training/quantization.py and alias toTensorCoreTiledLayout
(the new API name) in either caseTest plan
Test quantization recipe on both stable and nightly torchao versions. Prereq: download Llama2 7B:
Test on torchao 0.5
Test on torchao nightly