Skip to content

Commit 77f3af6

Browse files
committed
generalize torch compatibility check
1 parent 1591603 commit 77f3af6

File tree

4 files changed

+153
-17
lines changed

4 files changed

+153
-17
lines changed
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Testing compatibility
2+
# We know that torchao .so files built using PyTorch 2.8.0 are not ABI compatible with PyTorch 2.9+. (see #2919)
3+
# If the version of torch is not compatible with the version of torchao,
4+
# we expect to skip loading the .so files and a warning should be logged but no error
5+
6+
# Function to check torchao import with configurable expectations
7+
check_torchao_import() {
8+
local expect_warning="$1"
9+
local warning_text="$2"
10+
local torch_incompatible="${3:-}"
11+
12+
if [ -n "$torch_incompatible" ]; then
13+
output=$(TORCH_INCOMPATIBLE=1 python -c "import torchao" 2>&1)
14+
else
15+
output=$(python -c "import torchao" 2>&1)
16+
fi
17+
exit_code=$?
18+
19+
if [ $exit_code -ne 0 ]; then
20+
echo "ERROR: Failed to import torchao"
21+
echo "Output: $output"
22+
exit 1
23+
fi
24+
25+
warning_found=false
26+
if [ -n "$warning_text" ] && echo "$output" | grep -i "$warning_text" > /dev/null; then
27+
echo "Output: $output"
28+
warning_found=true
29+
fi
30+
31+
if [ "$expect_warning" != "$warning_found" ]; then
32+
echo echo "FAILURE: expect_warning is $expect_warning but warning_found is $warning_found with message $output"
33+
exit 1
34+
fi
35+
}
36+
37+
## torch == 2.8.0, torchao == 0.13.0
38+
# Uninstall torch
39+
pip uninstall torch
40+
# Uninstall torchao
41+
pip uninstall torchao
42+
# Install current compatible version of torch (2.8.0)
43+
pip install torch==2.8.0
44+
# Installs current compatible version of torchao (0.13.0)
45+
pip install torchao==0.13.0
46+
# hould import successfully without warning
47+
check_torchao_import "false" ""
48+
49+
## torch == 2.9.0.dev..., torchao == 0.14.0dev...
50+
# Uninstall torch
51+
pip uninstall torch
52+
# Uninstall torchao
53+
pip uninstall torchao
54+
# Install specific compatible version of torch (nightly 2.9.0dev)
55+
pip install torch==2.9.0.dev20250905+cu129 --index-url https://download.pytorch.org/whl/nightly/cu129
56+
# Installs specific nightly torchao (0.14.0dev...)
57+
pip install torchao==0.14.0.dev20250901+cu129 --index-url https://download.pytorch.org/whl/nightly/cu129
58+
# Should import successfully without warning
59+
check_torchao_import "false" ""
60+
61+
## torch == 2.9.0.dev..., torchao from source, env var = True
62+
# Uninstall torch
63+
pip uninstall torch
64+
# Uninstall torchao
65+
pip uninstall torchao
66+
# Install compatible version of torch (nightly 2.9.0dev)
67+
pip install torch==2.9.0.dev20250905+cu129 --index-url https://download.pytorch.org/whl/nightly/cu129
68+
# Build torchao from source
69+
python setup.py develop
70+
# Should import with warning because optional env var is set to true
71+
check_torchao_import "true" "Skipping import of cpp extensions due to incompatible torch version" "TORCHAO_SKIP_LOADING_SO_FILES=1"
72+
73+
## torch == 2.9.0.dev..., torchao from source, env var not set
74+
# Uninstall torch
75+
pip uninstall torch
76+
# Uninstall torchao
77+
pip uninstall torchao
78+
# Install compatible version of torch (nightly 2.9.0dev)
79+
pip install torch==2.9.0.dev20250905+cu129 --index-url https://download.pytorch.org/whl/nightly/cu129
80+
# Build torchao from source
81+
python setup.py develop
82+
# Should import successfully without warning because torchao was built from source and env var is not set
83+
check_torchao_import "false" ""
84+
85+
## torch == 2.8.0..., torchao == 0.14.0.dev... (incompatible)
86+
# Uninstall torch
87+
pip uninstall torch
88+
# Uninstall torchao
89+
pip uninstall torchao
90+
# Install non-ABI stable torch version with current version of torchao
91+
pip install torch==2.8.0
92+
# Installs specific nightly torchao (0.14.0dev...)
93+
pip install torchao==0.14.0.dev20250831+cu129 --index-url https://download.pytorch.org/whl/nightly/cu129
94+
# Should import with specific warning
95+
check_torchao_import "true" "Skipping import of cpp extensions due to incompatible torch version"

.github/scripts/validate_binaries.sh

100644100755
File mode changed.

.github/workflows/validate-binaries.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,5 +45,5 @@ jobs:
4545
repository: "pytorch/ao"
4646
with_cuda: "enable"
4747
with_rocm: "disable"
48-
smoke_test: "source ./.github/scripts/validate_binaries.sh"
48+
smoke_test: "source ./.github/scripts/validate_binaries.sh && source ./.github/scripts/test_torch_version_torchao_version_compatibility.sh"
4949
install_torch: true

torchao/__init__.py

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import logging
2+
import os
3+
import re
24

35
# torch/nested/_internal/nested_tensor.py:417: UserWarning: Failed to initialize NumPy: No module named 'numpy'
46
import warnings
@@ -20,28 +22,67 @@
2022
except PackageNotFoundError:
2123
__version__ = "unknown" # In case this logic breaks don't break the build
2224

25+
2326
logger = logging.getLogger(__name__)
2427

28+
29+
def _parse_version(version_string):
30+
"""
31+
Parse version string representing pre-release with -1
32+
33+
Examples: "2.5.0.dev20240708+cu121" -> [2, 5, -1], "2.5.0" -> [2, 5, 0]
34+
"""
35+
# Check for pre-release indicators
36+
is_prerelease = bool(re.search(r"(git|dev)", version_string))
37+
match = re.match(r"(\d+)\.(\d+)\.(\d+)", version_string)
38+
if match:
39+
major, minor, patch = map(int, match.groups())
40+
if is_prerelease:
41+
patch = -1
42+
return [major, minor, patch]
43+
else:
44+
raise ValueError(f"Invalid version string format: {version_string}")
45+
46+
2547
skip_loading_so_files = False
26-
# if torchao version has "+git", assume it's locally built and we don't know
27-
# anything about the PyTorch version used to build it
28-
# otherwise, assume it's prebuilt by torchao's build scripts and we can make
29-
# assumptions about the PyTorch version used to build it.
30-
if (not "+git" in __version__) and not ("unknown" in __version__):
31-
# torchao v0.13.0 is built with PyTorch 2.8.0. We know that torchao .so
32-
# files built using PyTorch 2.8.0 are not ABI compatible with PyTorch 2.9+.
33-
# The following code skips importing the .so files if PyTorch 2.9+ is
34-
# detected, to avoid crashing the Python process with "Aborted (core
48+
if bool(os.getenv("TORCHAO_SKIP_LOADING_SO_FILES", False)):
49+
# user override
50+
# users can set env var TORCH_INCOMPATIBLE=1 to skip loading .so files
51+
# this way, if they are using an incompatbile torch version, they can still use the API by setting the env var
52+
skip_loading_so_files = True
53+
# if torchao version has "+git", assume it's locally built and we don't know
54+
# anything about the PyTorch version used to build it unless user provides override flag
55+
# otherwise, assume it's prebuilt by torchao's build scripts and we can make
56+
# assumptions about the PyTorch version used to build it.
57+
elif not ("+git" in __version__) and not ("unknown" in __version__):
58+
# We know that torchao .so files built using PyTorch 2.8.0 are not ABI compatible with PyTorch 2.9+. (see #2919)
59+
# The following code skips importing the .so files if incompatible torch version is detected,
60+
# to avoid crashing the Python process with "Aborted (core
3561
# dumped)".
36-
# TODO(#2901, and before next torchao release): make this generic for
37-
# future torchao and torch versions
38-
if __version__.startswith("0.13.0") and str(torch.__version__) >= "2.9":
39-
logger.warning(
40-
f"Skipping import of cpp extensions due to incompatible torch version {torch.__version__} for torchao version {__version__}"
41-
)
62+
torch_version = _parse_version(torch.__version__)
63+
torchao_version = _parse_version(__version__)
64+
65+
v2_8_0 = _parse_version("2.8.0")
66+
v0_13_0 = _parse_version("0.13.0")
67+
v2_9_0_dev = _parse_version("2.9.0.dev")
68+
v0_14_0_dev = _parse_version("0.14.0.dev")
69+
70+
if torch_version == v2_8_0 and torchao_version == v0_13_0:
71+
# current torchao version and torch version, check here for clarity
72+
skip_loading_so_files = False
73+
elif torch_version == v2_9_0_dev and torchao_version == v0_14_0_dev:
74+
# .dev for nightlies since 2.9.0 and 0.14.0 has not been released
75+
skip_loading_so_files = False
76+
else:
4277
skip_loading_so_files = True
4378

44-
if not skip_loading_so_files:
79+
80+
if skip_loading_so_files:
81+
logger.warning(
82+
f"Skipping import of cpp extensions due to incompatible torch version {torch.__version__} for torchao version {__version__} \
83+
Please see GitHub issue #2919 for more info"
84+
)
85+
else:
4586
try:
4687
from pathlib import Path
4788

0 commit comments

Comments
 (0)