Skip to content

Commit 2f28062

Browse files
committed
generalize torch compatibility check
1 parent 1591603 commit 2f28062

File tree

3 files changed

+110
-14
lines changed

3 files changed

+110
-14
lines changed

.github/scripts/validate_binaries.sh

100644100755
Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,67 @@
1-
pip install ${PYTORCH_PIP_PREFIX} torchao --index-url ${PYTORCH_PIP_DOWNLOAD_URL}
2-
# Intial smoke test, tries importing torchao
3-
python ./test/smoke_tests/smoke_tests.py
4-
# Now we install dev-requirments and try to run the tests
5-
pip install -r dev-requirements.txt
6-
pytest test --verbose -s
1+
# pip install ${PYTORCH_PIP_PREFIX} torchao --index-url ${PYTORCH_PIP_DOWNLOAD_URL}
2+
# # Intial smoke test, tries importing torchao
3+
# python ./test/smoke_tests/smoke_tests.py
4+
# # Now we install dev-requirments and try to run the tests
5+
# pip install -r dev-requirements.txt
6+
# pytest test --verbose -s
7+
8+
# Testing compatibility
9+
# We know that torchao .so files built using PyTorch 2.8.0 are not ABI compatible with PyTorch 2.9+. (see #2919)
10+
# If the version of torch is not compatible with the version of torchao,
11+
# we expect to skip loading the .so files and a warning should be logged but no error
12+
13+
# Function to check torchao import with configurable expectations
14+
check_torchao_import() {
15+
local expect_warning="$1"
16+
local warning_text="$2"
17+
output=$(python -c "import torchao" 2>&1)
18+
exit_code=$?
19+
20+
if [ $exit_code -ne 0 ]; then
21+
echo "ERROR: Failed to import torchao"
22+
echo "Output: $output"
23+
exit 1
24+
fi
25+
26+
warning_found=false
27+
if [ -n "$warning_text" ] && echo "$output" | grep -i "$warning_text" > /dev/null; then
28+
echo "Output: $output"
29+
warning_found=true
30+
fi
31+
32+
if [ "$expect_warning" != "$warning_found" ]; then
33+
echo echo "FAILURE: expect_warning is $expect_warning but warning_found is $warning_found with message $output"
34+
exit 1
35+
fi
36+
}
37+
38+
# Uninstall torch
39+
pip uninstall torch
40+
# Install compatible version of torch
41+
pip install torch==2.8.0
42+
# Build torchao
43+
pip install torchao==0.13.0
44+
# Test 1: Should import successfully without warning
45+
check_torchao_import "false" ""
46+
47+
# Uninstall torch
48+
pip uninstall torch
49+
# Uninstall torchao
50+
pip uninstall torchao
51+
# Install compatible version of torch (nightly 2.9.0dev)
52+
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu129
53+
# Build torchao (nightly 0.14.0dev...)
54+
pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu129
55+
# Test 1: Should import successfully without warning
56+
check_torchao_import "false" ""
57+
58+
# Uninstall torch
59+
pip uninstall torch
60+
# Uninstall torchao
61+
pip uninstall torchao
62+
# Install non-ABI stable torch version with older version of torchao
63+
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu129
64+
# Build torchao
65+
pip install torchao==0.13.0
66+
# Test 2: Should import with specific warning
67+
check_torchao_import "true" "Skipping import of cpp extensions due to incompatible torch version"

torchao/__init__.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import re
23

34
# torch/nested/_internal/nested_tensor.py:417: UserWarning: Failed to initialize NumPy: No module named 'numpy'
45
import warnings
@@ -20,24 +21,50 @@
2021
except PackageNotFoundError:
2122
__version__ = "unknown" # In case this logic breaks don't break the build
2223

24+
2325
logger = logging.getLogger(__name__)
2426

27+
28+
def parse_version(version_string):
29+
"""
30+
Parses the major and minor of a torch version
31+
32+
Examples:
33+
- 2.9.0dev... (pre-release) becomes [2, 9]
34+
- 2.8.0 becomes [2, 8]
35+
- 2.8.1 becomes [2, 8]
36+
"""
37+
# Check for pre-release indicators
38+
match = re.match(r"(\d+)\.(\d+)", version_string)
39+
if match:
40+
major, minor = map(int, match.groups())
41+
return [major, minor]
42+
else:
43+
raise ValueError(f"Invalid version string format: {version_string}")
44+
45+
2546
skip_loading_so_files = False
2647
# if torchao version has "+git", assume it's locally built and we don't know
2748
# anything about the PyTorch version used to build it
2849
# otherwise, assume it's prebuilt by torchao's build scripts and we can make
2950
# assumptions about the PyTorch version used to build it.
3051
if (not "+git" in __version__) and not ("unknown" in __version__):
31-
# torchao v0.13.0 is built with PyTorch 2.8.0. We know that torchao .so
32-
# files built using PyTorch 2.8.0 are not ABI compatible with PyTorch 2.9+.
33-
# The following code skips importing the .so files if PyTorch 2.9+ is
34-
# detected, to avoid crashing the Python process with "Aborted (core
52+
# We know that torchao .so files built using PyTorch 2.8.0 are not ABI compatible with PyTorch 2.9+. (see #2919)
53+
# The following code skips importing the .so files if incompatible torch version is detected,
54+
# to avoid crashing the Python process with "Aborted (core
3555
# dumped)".
36-
# TODO(#2901, and before next torchao release): make this generic for
37-
# future torchao and torch versions
38-
if __version__.startswith("0.13.0") and str(torch.__version__) >= "2.9":
56+
if parse_version(torch.__version__) == parse_version("2.8") and parse_version(
57+
__version__
58+
) == parse_version("0.13"):
59+
pass
60+
elif parse_version(torch.__version__) == parse_version("2.9") and parse_version(
61+
__version__
62+
) == parse_version("0.14"):
63+
pass
64+
else:
3965
logger.warning(
40-
f"Skipping import of cpp extensions due to incompatible torch version {torch.__version__} for torchao version {__version__}"
66+
f"Skipping import of cpp extensions due to incompatible torch version {torch.__version__} for torchao version {__version__} \
67+
Please see GitHub issue #2919 for more info"
4168
)
4269
skip_loading_so_files = True
4370

torchao/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,14 @@ def torch_version_at_least(min_version):
378378
return parse_version(torch.__version__) >= parse_version(min_version)
379379

380380

381+
def _torch_version_greater_than(min_version):
382+
if is_fbcode():
383+
return True
384+
385+
# Parser for local identifiers
386+
return parse_version(torch.__version__) > parse_version(min_version)
387+
388+
381389
def _deprecated_torch_version_at_least(version_str: str) -> str:
382390
"""
383391
Wrapper for existing TORCH_VERSION_AT_LEAST* variables that will log

0 commit comments

Comments
 (0)