Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torchao version check changes/BC import of TensorCoreTiledLayout #1812

Merged
merged 5 commits into from
Oct 12, 2024

Conversation

ebsmothers
Copy link
Contributor

@ebsmothers ebsmothers commented Oct 11, 2024

  • Move torchao version check utilities out of modules and into utils (that's where they should've been all along)
  • Update the nightly version check to match latest torchao nightlies format
    • We also no longer need all the extra importlib stuff now that all recent ao versions have __version__ defined
  • Define variable _NEW_TENSOR_CORE_TILED_LAYOUT_API based on the following conditions:
    • Not fbcode and (ao version >= 0.7.0 or (ao version is nightly and ao nightly date >= "2024-10-10"))
  • Gate the import of TensorCoreTiledLayoutType in training/quantization.py and alias to TensorCoreTiledLayout (the new API name) in either case

Test plan

Test quantization recipe on both stable and nightly torchao versions. Prereq: download Llama2 7B:

tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf

Test on torchao 0.5

$ conda create -n ao-testing-stable python=3.11 -y 
$ conda activate ao-testing-stable
$ pip install torch torchvision torchao
$ pip install -e ".[dev]"
$ pip list | grep torchao
torchao                   0.5.0
$ tune run quantize --config quantization quantizer=torchtune.training.quantization.Int4WeightOnlyQuantizer quantizer.groupsize=128
...
INFO:torchtune.utils._logging:Model is initialized with precision torch.bfloat16.
INFO:torchtune.utils._logging:Time for quantization: 0.24 sec
INFO:torchtune.utils._logging:Memory used: 13.95 GB
INFO:torchtune.utils._logging:Model checkpoint of size 3.79 GB saved to /tmp/Llama-2-7b-hf/pytorch_model-00001-of-00002-4w.pt

Test on torchao nightly

$ conda create -n ao-testing-nightly python=3.11 -y 
$ conda activate ao-testing-nightly
$ pip install --pre torch torchvision torchao --index-url https://download.pytorch.org/whl/nightly/cu124
$ pip install -e ".[dev]"
$ pip list | grep torchao
torchao                   0.7.0.dev20241011+cu124
$ tune run quantize --config quantization quantizer=torchtune.training.quantization.Int4WeightOnlyQuantizer quantizer.groupsize=128
...
INFO:torchtune.utils._logging:Model is initialized with precision torch.bfloat16.
INFO:torchtune.utils._logging:Time for quantization: 0.37 sec
INFO:torchtune.utils._logging:Memory used: 13.95 GB
INFO:torchtune.utils._logging:Model checkpoint of size 3.79 GB saved to /tmp/Llama-2-7b-hf/pytorch_model-00001-of-00002-4w.pt

Copy link

pytorch-bot bot commented Oct 11, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1812

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 024b812 with merge base c5b7386 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 11, 2024

torchao_version = _get_torchao_version()

_NEW_TENSOR_CORE_TILED_LAYOUT_API = not _is_fbcode() and (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mind adding a quick comment explaining this similar to flex attention above?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also nit: can you name this as something that implies a yes/no true/false answer (like _USE_NEW_TENSOR_CORE_TILED_LAYOUT_API)

Copy link
Contributor

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

few minor comments, but no concerns

return not hasattr(torch.version, "git_version")


def _nightly_version_ge(ao_version_str: str, date: str) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't generalize to pytorch nightly version? if's ao specific, let's include ao in the function name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it should generalize since PyTorch versions use the same format

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if so, should we change the variable name to indicate general use?


# We can only use flex attention / BlockMask if torch version >= 2.5.0 and GPU is Turing / SM75 and above
_SUPPORTS_FLEX_ATTENTION = (
torch_version_ge("2.5.0")
and torch.cuda.is_available()
and torch.cuda.get_device_capability() >= (7, 5)
)

torchao_version = _get_torchao_version()
Copy link
Contributor

@RdoubleA RdoubleA Oct 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

couldn't this return None if is_fbcode? a bit awkward to return either a version string or None, but then you check for fbcode again below

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good point. Just did a check and in fbcode torchao.__version__ == 'unknown', so it at least won't throw an error. Then I can just explicitly say ao_version = torchao.__version__ in _import_guard.py, gate behind _is_fbcode() in _USE_NEW_TENSOR_CORE_TILED_LAYOUT_API, and delete _get_torchao_version altogether


torchao_version = _get_torchao_version()

_NEW_TENSOR_CORE_TILED_LAYOUT_API = not _is_fbcode() and (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also nit: can you name this as something that implies a yes/no true/false answer (like _USE_NEW_TENSOR_CORE_TILED_LAYOUT_API)

@codecov-commenter
Copy link

codecov-commenter commented Oct 11, 2024

Codecov Report

Attention: Patch coverage is 73.33333% with 4 lines in your changes missing coverage. Please review.

Project coverage is 25.72%. Comparing base (54673b7) to head (024b812).
Report is 13 commits behind head on main.

Files with missing lines Patch % Lines
torchtune/training/quantization.py 60.00% 2 Missing ⚠️
torchtune/utils/_version.py 66.66% 2 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1812       +/-   ##
===========================================
- Coverage   67.05%   25.72%   -41.34%     
===========================================
  Files         305      304        -1     
  Lines       15937    16000       +63     
===========================================
- Hits        10687     4116     -6571     
- Misses       5250    11884     +6634     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@jerryzh168
Copy link
Contributor

ah, our diff train is not landed yet, should we land first? it will break internal torchtune I think

@RdoubleA
Copy link
Contributor

ah, our diff train is not landed yet, should we land first? it will break internal torchtune I think

Yeah that sounds good we can wait, I don't have full context on if this is blocking anything

@ebsmothers
Copy link
Contributor Author

@jerryzh168 can you clarify? I don’t fully understand why this PR would break internal. It checks fbcode and uses the old API, which is what I currently see in internal

@jerryzh168
Copy link
Contributor

@jerryzh168 can you clarify? I don’t fully understand why this PR would break internal. It checks fbcode and uses the old API, which is what I currently see in internal

oh what I meant is that torchao internal is not updated with the new name yet

If this PR can fix both internal and external, maybe this one should land first.

@ebsmothers ebsmothers merged commit 7744608 into pytorch:main Oct 12, 2024
17 checks passed
mori360 pushed a commit to mori360/torchtune that referenced this pull request Oct 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants