Skip to content

Commit ea007fc

Browse files
committed
generalize torch compatibility check
1 parent 1591603 commit ea007fc

File tree

4 files changed

+144
-17
lines changed

4 files changed

+144
-17
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Testing compatibility
2+
# We know that torchao .so files built using PyTorch 2.8.0 are not ABI compatible with PyTorch 2.9+. (see #2919)
3+
# If the version of torch is not compatible with the version of torchao,
4+
# we expect to skip loading the .so files and a warning should be logged but no error
5+
6+
# Function to check torchao import with configurable expectations
7+
check_torchao_import() {
8+
local expect_warning="$1"
9+
local warning_text="$2"
10+
local torch_incompatible="${3:-}"
11+
12+
if [ -n "$torch_incompatible" ]; then
13+
output=$(TORCH_INCOMPATIBLE=1 python -c "import torchao" 2>&1)
14+
else
15+
output=$(python -c "import torchao" 2>&1)
16+
fi
17+
exit_code=$?
18+
19+
if [ $exit_code -ne 0 ]; then
20+
echo "ERROR: Failed to import torchao"
21+
echo "Output: $output"
22+
exit 1
23+
fi
24+
25+
warning_found=false
26+
if [ -n "$warning_text" ] && echo "$output" | grep -i "$warning_text" > /dev/null; then
27+
echo "Output: $output"
28+
warning_found=true
29+
fi
30+
31+
if [ "$expect_warning" != "$warning_found" ]; then
32+
echo echo "FAILURE: expect_warning is $expect_warning but warning_found is $warning_found with message $output"
33+
exit 1
34+
fi
35+
}
36+
37+
# Uninstall torch
38+
pip uninstall torch
39+
# Uninstall torchao
40+
pip uninstall torchao
41+
# Install current compatible version of torch (2.8.0)
42+
pip install torch==2.8.0
43+
# Installs current compatible version of torchao (0.13.0)
44+
pip install torchao==0.13.0
45+
# hould import successfully without warning
46+
check_torchao_import "false" ""
47+
48+
# Uninstall torch
49+
pip uninstall torch
50+
# Uninstall torchao
51+
pip uninstall torchao
52+
# Install specific compatible version of torch (nightly 2.9.0dev)
53+
pip install torch==2.9.0.dev20250905+cu129 --index-url https://download.pytorch.org/whl/nightly/cu129
54+
# Installs specific nightly torchao (0.14.0dev...)
55+
pip install torchao==0.14.0.dev20250901+cu129 --index-url https://download.pytorch.org/whl/nightly/cu129
56+
# Should import successfully without warning
57+
check_torchao_import "false" ""
58+
59+
# Uninstall torch
60+
pip uninstall torch
61+
# Uninstall torchao
62+
pip uninstall torchao
63+
# Install compatible version of torch (nightly 2.9.0dev)
64+
pip install torch==2.9.0.dev20250905+cu129 --index-url https://download.pytorch.org/whl/nightly/cu129
65+
# Build torchao from source
66+
python setup.py develop
67+
# Should import with warning because optional env var is set to true
68+
check_torchao_import "true" "Skipping import of cpp extensions due to incompatible torch version" "TORCHAO_SKIP_LOADING_SO_FILES=1"
69+
70+
# Uninstall torch
71+
pip uninstall torch
72+
# Uninstall torchao
73+
pip uninstall torchao
74+
# Install compatible version of torch (nightly 2.9.0dev)
75+
pip install torch==2.9.0.dev20250905+cu129 --index-url https://download.pytorch.org/whl/nightly/cu129
76+
# Build torchao from source
77+
python setup.py develop
78+
# Should import successfully without warning because torchao was built from source and env var is not set
79+
check_torchao_import "false" ""
80+
81+
# Uninstall torch
82+
pip uninstall torch
83+
# Uninstall torchao
84+
pip uninstall torchao
85+
# Install non-ABI stable torch version with current version of torchao
86+
pip install torch==2.8.0
87+
# Installs specific nightly torchao (0.14.0dev...)
88+
pip install torchao==0.14.0.dev20250831+cu129 --index-url https://download.pytorch.org/whl/nightly/cu129
89+
# Should import with specific warning
90+
check_torchao_import "true" "Skipping import of cpp extensions due to incompatible torch version"

.github/scripts/validate_binaries.sh

100644100755
File mode changed.

.github/workflows/validate-binaries.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,5 +45,5 @@ jobs:
4545
repository: "pytorch/ao"
4646
with_cuda: "enable"
4747
with_rocm: "disable"
48-
smoke_test: "source ./.github/scripts/validate_binaries.sh"
48+
smoke_test: "source ./.github/scripts/validate_binaries.sh && source ./.github/scripts/test_compatibility.sh"
4949
install_torch: true

torchao/__init__.py

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import logging
2+
import os
3+
import re
24

35
# torch/nested/_internal/nested_tensor.py:417: UserWarning: Failed to initialize NumPy: No module named 'numpy'
46
import warnings
@@ -20,28 +22,63 @@
2022
except PackageNotFoundError:
2123
__version__ = "unknown" # In case this logic breaks don't break the build
2224

25+
2326
logger = logging.getLogger(__name__)
2427

28+
29+
def _parse_version(version_string):
30+
"""
31+
Parse version string representing pre-release with -1
32+
33+
Examples: "2.5.0.dev20240708+cu121" -> [2, 5, -1], "2.5.0" -> [2, 5, 0]
34+
"""
35+
# Check for pre-release indicators
36+
is_prerelease = bool(re.search(r"(git|dev)", version_string))
37+
match = re.match(r"(\d+)\.(\d+)\.(\d+)", version_string)
38+
if match:
39+
major, minor, patch = map(int, match.groups())
40+
if is_prerelease:
41+
patch = -1
42+
return [major, minor, patch]
43+
else:
44+
raise ValueError(f"Invalid version string format: {version_string}")
45+
46+
2547
skip_loading_so_files = False
26-
# if torchao version has "+git", assume it's locally built and we don't know
27-
# anything about the PyTorch version used to build it
28-
# otherwise, assume it's prebuilt by torchao's build scripts and we can make
29-
# assumptions about the PyTorch version used to build it.
30-
if (not "+git" in __version__) and not ("unknown" in __version__):
31-
# torchao v0.13.0 is built with PyTorch 2.8.0. We know that torchao .so
32-
# files built using PyTorch 2.8.0 are not ABI compatible with PyTorch 2.9+.
33-
# The following code skips importing the .so files if PyTorch 2.9+ is
34-
# detected, to avoid crashing the Python process with "Aborted (core
48+
if bool(os.getenv("TORCHAO_SKIP_LOADING_SO_FILES", False)):
49+
# user override
50+
# users can set env var TORCH_INCOMPATIBLE=1 to skip loading .so files
51+
# this way, if they are using an incompatbile torch version, they can still use the API by setting the env var
52+
skip_loading_so_files = True
53+
# if torchao version has "+git", assume it's locally built and we don't know
54+
# anything about the PyTorch version used to build it unless user provides override flag
55+
# otherwise, assume it's prebuilt by torchao's build scripts and we can make
56+
# assumptions about the PyTorch version used to build it.
57+
elif not ("+git" in __version__) and not ("unknown" in __version__):
58+
# We know that torchao .so files built using PyTorch 2.8.0 are not ABI compatible with PyTorch 2.9+. (see #2919)
59+
# The following code skips importing the .so files if incompatible torch version is detected,
60+
# to avoid crashing the Python process with "Aborted (core
3561
# dumped)".
36-
# TODO(#2901, and before next torchao release): make this generic for
37-
# future torchao and torch versions
38-
if __version__.startswith("0.13.0") and str(torch.__version__) >= "2.9":
39-
logger.warning(
40-
f"Skipping import of cpp extensions due to incompatible torch version {torch.__version__} for torchao version {__version__}"
41-
)
62+
if _parse_version(torch.__version__) == _parse_version("2.8.0") and _parse_version(
63+
__version__
64+
) == _parse_version("0.13.0"):
65+
# current torchao version and torch version, check here for clarity
66+
skip_loading_so_files = False
67+
if _parse_version(torch.__version__) == _parse_version(
68+
"2.9.0.dev"
69+
) and _parse_version(__version__) == _parse_version("0.14.0.dev"):
70+
# .dev for nightlies since 2.9.0 and 0.14.0 has not been released
71+
skip_loading_so_files = False
72+
else:
4273
skip_loading_so_files = True
4374

44-
if not skip_loading_so_files:
75+
76+
if skip_loading_so_files:
77+
logger.warning(
78+
f"Skipping import of cpp extensions due to incompatible torch version {torch.__version__} for torchao version {__version__} \
79+
Please see GitHub issue #2919 for more info"
80+
)
81+
else:
4582
try:
4683
from pathlib import Path
4784

0 commit comments

Comments
 (0)