Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
class TestTorchVersion(unittest.TestCase):
def test_torch_version_at_least(self):
test_cases = [
("2.5.0a0+git9f17037", "2.5.0", True),
("2.5.0a0+git9f17037", "2.5.0", False),
("2.5.0a0+git9f17037", "2.4.0", True),
("2.5.0.dev20240708+cu121", "2.5.0", True),
("2.5.0.dev20240708+cu121", "2.5.0", False),
("2.5.0.dev20240708+cu121", "2.4.0", True),
("2.5.0", "2.4.0", True),
("2.5.0", "2.5.0", True),
Expand Down
9 changes: 8 additions & 1 deletion torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,14 @@ def is_fbcode():


def torch_version_at_least(min_version):
return is_fbcode() or compare_versions(torch.__version__, min_version) >= 0
from packaging.version import parse as parse_version

if is_fbcode():
return True

# Parser for local identifiers
current_version = re.sub(r"\+.*$", "", torch.__version__)
return parse_version(current_version) >= parse_version(min_version)
Copy link
Contributor

@jerryzh168 jerryzh168 Aug 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some impression that we don't want to do this, but @msaroufim would have more context here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems fine to merge altho we probably want to delete the compare_versions function

Copy link
Contributor Author

@namgyu-youn namgyu-youn Aug 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compare_versions & parse_version are unavailable to used here because parse_version only extracts \d+\.\d+\.\d+ (e.g., 2.5.0→[2, 5, 0]). Therefore, we can inject more parsers (e.g., a0, dev) into parse_version, but I am not certain because check_cpu_version & check_xpu_version are chained with them.

Copy link
Contributor

@jerryzh168 jerryzh168 Aug 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw we also have

def is_package_at_least(package_name: str, min_version: str):
, should we just reuse that? not sure if this works for pre-releases as well, could you check?

I vaguely remember at some point that we don't want to use parse version, but don't remember why though

Copy link
Contributor Author

@namgyu-youn namgyu-youn Aug 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw we also have

def is_package_at_least(package_name: str, min_version: str):

, should we just reuse that? not sure if this works for pre-releases as well, could you check?
I vaguely remember at some point that we don't want to use parse version, but don't remember why though

Installation

# 2.8.0 stable
pip install torch --index-url https://download.pytorch.org/whl/cpu
# 2.9.0 pre-release (nighty)
pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu

Test code

test_cases = ["2.8.0a0+git9f17037", "2.8.0", "2.9.0a0+git9f17037", "2.9.0", "2.10.0a0+git9f17037", "2.10.0"]

print(torch.__version__)
print(torch.version)

for test in test_cases:
	print(torch_version_at_least(test), is_package_at_least(package_name="torch", min_version=test))

Result

# stable
2.8.0+cpu
<module 'torch.version' from 'ao/.venv/lib/python3.12/site-packages/torch/version.py'>
True False
True True
False False
False False
False True
False True

# pre-release (nighty)
2.9.0.dev20250821+cpu
<module 'torch.version' from 'ao/.venv/lib/python3.12/site-packages/torch/version.py'>
True True
True True
True False
False True
False True
False True

It seems that is_package_at_least() doesn't work for both stable/pre-release becauseversion() returns the module object; __version__() returns the actual version string. In my naive guess, torch_version_at_least and is_package_at_least can be consolidated because they share "return true if installed <package/torch> is higher than required" (with PyTorch pre-release detection). How about consolidating them?



def _deprecated_torch_version_at_least(version_str: str) -> str:
Expand Down
Loading