-
Notifications
You must be signed in to change notification settings - Fork 169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
retry version guard fix #679
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/679
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 1650f9d with merge base 6199f89 (): FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This does not seem to support source build of pytorch without
I'm wondering why a custom version parser is implemented, instead of using diff --git a/setup.py b/setup.py
index 47f7714..b2ac860 100644
--- a/setup.py
+++ b/setup.py
@@ -131,6 +131,7 @@ setup(
package_data={
"torchao.kernel.configs": ["*.pkl"],
},
+ install_requires=["packaging"],
ext_modules=get_extensions() if use_cpp != "0" else None,
extras_require={"dev": read_requirements("dev-requirements.txt")},
description="Package for applying ao techniques to GPU models",
diff --git a/torchao/utils.py b/torchao/utils.py
index 47227b1..c3e2785 100644
--- a/torchao/utils.py
+++ b/torchao/utils.py
@@ -7,6 +7,7 @@ import torch.nn.utils.parametrize as parametrize
import itertools
import time
import warnings
+from packaging.version import parse
__all__ = [
"benchmark_model",
@@ -279,23 +280,16 @@ def unwrap_tensor_subclass(model, filter_fn=None):
unwrap_tensor_subclass(child)
return model
-def parse_version(version_string):
- # Remove any suffixes like '+cu121' or '.dev'
- version = version_string.split('+')[0].split('.dev')[0]
- return [int(x) for x in version.split('.')]
-
def compare_versions(v1, v2):
- v1_parts = parse_version(v1)
- v2_parts = parse_version(v2)
-
- for i in range(max(len(v1_parts), len(v2_parts))):
- v1_part = v1_parts[i] if i < len(v1_parts) else 0
- v2_part = v2_parts[i] if i < len(v2_parts) else 0
- if v1_part > v2_part:
- return 1
- elif v1_part < v2_part:
- return -1
- return 0
+ v1_version = parse(v1)
+ v2_version = parse(v2)
+
+ if v1_version == v2_version:
+ return 0
+ elif v1_version > v2_version:
+ return 1
+ else:
+ return -1
def is_fbcode():
return not hasattr(torch.version, "git_version") |
So we've been bit before by having dependencies in general - you can see these 2 threads for more context. The short of it is forcing dependencies was causing forced version downgrades for our partners and most of our dependencies did not really seem necessary Unfortunately PyTorch does not depend on packaging https://gist.github.com/msaroufim/09c316d10cd314e189e1b2ad28823c3b but I wish it did Specifically regarding the version check logic thank you for catching that, this does indeed bother me I did though try applying your diff and the check is unfortunately also incorrect but if you have a forward fix that would solve this case and the test case I mention the PR description then we should be good. I'll poke around a bit more myself at this as well import torch
from torchao.utils import torch_version_at_least
torch.__version__ = "2.5.0a0+git9f17037" # nightly version
print(torch_version_at_least("2.5.0")) # Returns False diff --git a/setup.py b/setup.py
index 47f7714..b2ac860 100644
--- a/setup.py
+++ b/setup.py
@@ -131,6 +131,7 @@ setup(
package_data={
"torchao.kernel.configs": ["*.pkl"],
},
+ install_requires=["packaging"],
ext_modules=get_extensions() if use_cpp != "0" else None,
extras_require={"dev": read_requirements("dev-requirements.txt")},
description="Package for applying ao techniques to GPU models",
diff --git a/torchao/utils.py b/torchao/utils.py
index 47227b1..05a04b5 100644
--- a/torchao/utils.py
+++ b/torchao/utils.py
@@ -7,6 +7,8 @@ import torch.nn.utils.parametrize as parametrize
import itertools
import time
import warnings
+from packaging.version import parse
+
__all__ = [
"benchmark_model",
@@ -279,23 +281,17 @@ def unwrap_tensor_subclass(model, filter_fn=None):
unwrap_tensor_subclass(child)
return model
-def parse_version(version_string):
- # Remove any suffixes like '+cu121' or '.dev'
- version = version_string.split('+')[0].split('.dev')[0]
- return [int(x) for x in version.split('.')]
def compare_versions(v1, v2):
- v1_parts = parse_version(v1)
- v2_parts = parse_version(v2)
-
- for i in range(max(len(v1_parts), len(v2_parts))):
- v1_part = v1_parts[i] if i < len(v1_parts) else 0
- v2_part = v2_parts[i] if i < len(v2_parts) else 0
- if v1_part > v2_part:
- return 1
- elif v1_part < v2_part:
- return -1
- return 0
+ v1_version = parse(v1)
+ v2_version = parse(v2)
+
+ if v1_version == v2_version:
+ return 0
+ elif v1_version > v2_version:
+ return 1
+ else:
+ return -1
def is_fbcode():
return not hasattr(torch.version, "git_version")
|
NIT: torch.__version__ = "2.5.0.dev20240708+cu121"
...
print(torch_version_at_least("2.5.0")) # Return True Relase import torch
from packaging import version
torch.__version__ = "2.5.0.dev20240708+cu121"
version.parse(torch.__version__) >= version.parse("2.4.0.dev") # True
version.parse(torch.__version__) >= version.parse("2.4.0") # True
version.parse(torch.__version__) >= version.parse("2.5.0.dev") # True
version.parse(torch.__version__) >= version.parse("2.5.0") # False !!!! HERE !!!!
torch.__version__ = "2.5.0"
version.parse(torch.__version__) >= version.parse("2.4.0.dev") # True
version.parse(torch.__version__) >= version.parse("2.4.0") # True
version.parse(torch.__version__) >= version.parse("2.5.0.dev") # True
version.parse(torch.__version__) >= version.parse("2.5.0") # True
torch.__version__ = "2.4.0"
version.parse(torch.__version__) >= version.parse("2.4.0.dev") # True
version.parse(torch.__version__) >= version.parse("2.4.0") # True
version.parse(torch.__version__) >= version.parse("2.5.0.dev") # False
version.parse(torch.__version__) >= version.parse("2.5.0") # False |
@ptrblck I agree this is somewhat confusing but here's the reasoning #684 (comment) |
EDIT: There were many version conflict issues so I'm reopening this PR #485
The only interesting changes to review are in
torchao/utils.py
and the version guard check is a tad complex right now but it's because we're insisting not to have any dependencies for torchaoThis PR solves 2 issues
TORCH_VERSION_AFTER_2_3(2.3)
would return false for 2.3 (see below for more detail) and adding .dev was obscuring the actual intentTORCH_VERSION_AFTER_X_X
toTORCH_VERSION_AT_LEAST_X_X,
But after changes in this PR
In practice this change didn't matter too much but these tests are now failing because they were intended to work with versions greater than 2.3 not 2.3 and above so I skipped them