Skip to content

Commit afd6096

Browse files
authored
generalize torch compatibility check (#3042)
* generalize torch compatibility check * remove test from ci
1 parent 01849b2 commit afd6096

File tree

3 files changed

+136
-13
lines changed

3 files changed

+136
-13
lines changed

.github/scripts/validate_binaries.sh

100644100755
File mode changed.
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Testing compatibility
2+
# We know that torchao .so files built using PyTorch 2.8.0 are not ABI compatible with PyTorch 2.9+. (see #2919)
3+
# If the version of torch is not compatible with the version of torchao,
4+
# we expect to skip loading the .so files and a warning should be logged but no error
5+
6+
PREV_TORCH_VERSION = 2.8.0
7+
PREV_TORCHAO_VERSION = 0.13.0
8+
9+
# Function to check torchao import with configurable expectations
10+
check_torchao_import() {
11+
local expect_warning="$1"
12+
local warning_text="$2"
13+
local torch_incompatible="${3:-}"
14+
15+
if [ -n "$torch_incompatible" ]; then
16+
output=$(TORCH_INCOMPATIBLE=1 python -c "import torchao" 2>&1)
17+
else
18+
output=$(python -c "import torchao" 2>&1)
19+
fi
20+
exit_code=$?
21+
22+
if [ $exit_code -ne 0 ]; then
23+
echo "ERROR: Failed to import torchao"
24+
echo "Output: $output"
25+
exit 1
26+
fi
27+
28+
warning_found=false
29+
if [ -n "$warning_text" ] && echo "$output" | grep -i "$warning_text" > /dev/null; then
30+
echo "Output: $output"
31+
warning_found=true
32+
fi
33+
34+
if [ "$expect_warning" != "$warning_found" ]; then
35+
echo echo "FAILURE: expect_warning is $expect_warning but warning_found is $warning_found with message $output"
36+
exit 1
37+
fi
38+
}
39+
40+
## prev torch version, prev torchao version
41+
# Uninstall torch
42+
pip uninstall torch
43+
# Uninstall torchao
44+
pip uninstall torchao
45+
# Install prev compatible version of torch
46+
pip install PREV_TORCH_VERSION
47+
# Installs prev compatible version of torchao
48+
pip install PREV_TORCHAO_VERSION
49+
# hould import successfully without warning
50+
check_torchao_import "false" ""
51+
52+
## current torch, current torchao
53+
# Uninstall torch
54+
pip uninstall torch
55+
# Uninstall torchao
56+
pip uninstall torchao
57+
# Install specific compatible version of torch (nightly 2.9.0dev)
58+
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu129
59+
# Build torchao from source
60+
python setup.py develop
61+
# Should import successfully without warning
62+
check_torchao_import "false" ""
63+
## prev torch, torchao from source (do not rebuild), env var = True
64+
# Uninstall torch
65+
pip uninstall torch
66+
# Install incompatible version of torch
67+
pip install torch==PREV_TORCH_VERSION
68+
# Should import with warning because optional env var is set to true
69+
check_torchao_import "true" "Skipping import of cpp extensions due to incompatible torch version" "TORCHAO_SKIP_LOADING_SO_FILES=1"
70+
71+
72+
# current torch, prev torchao
73+
# Uninstall torch
74+
pip uninstall torch
75+
# Uninstall torchao
76+
pip uninstall torchao
77+
# Install non-ABI stable torch version
78+
pip install torch==2.9.0
79+
# Installs incompatible torchao
80+
pip install torchao==PREV_TORCHAO_VERSION
81+
# Should import with specific warning
82+
check_torchao_import "true" "Skipping import of cpp extensions due to incompatible torch version"

torchao/__init__.py

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import logging
2+
import os
3+
import re
24

35
# torch/nested/_internal/nested_tensor.py:417: UserWarning: Failed to initialize NumPy: No module named 'numpy'
46
import warnings
@@ -20,28 +22,67 @@
2022
except PackageNotFoundError:
2123
__version__ = "unknown" # In case this logic breaks don't break the build
2224

25+
2326
logger = logging.getLogger(__name__)
2427

28+
29+
def _parse_version(version_string):
30+
"""
31+
Parse version string representing pre-release with -1
32+
33+
Examples: "2.5.0.dev20240708+cu121" -> [2, 5, -1], "2.5.0" -> [2, 5, 0]
34+
"""
35+
# Check for pre-release indicators
36+
is_prerelease = bool(re.search(r"(git|dev)", version_string))
37+
match = re.match(r"(\d+)\.(\d+)\.(\d+)", version_string)
38+
if match:
39+
major, minor, patch = map(int, match.groups())
40+
if is_prerelease:
41+
patch = -1
42+
return [major, minor, patch]
43+
else:
44+
raise ValueError(f"Invalid version string format: {version_string}")
45+
46+
2547
skip_loading_so_files = False
48+
if bool(os.getenv("TORCHAO_SKIP_LOADING_SO_FILES", False)):
49+
# user override
50+
# users can set env var TORCH_INCOMPATIBLE=1 to skip loading .so files
51+
# this way, if they are using an incompatbile torch version, they can still use the API by setting the env var
52+
skip_loading_so_files = True
2653
# if torchao version has "+git", assume it's locally built and we don't know
27-
# anything about the PyTorch version used to build it
54+
# anything about the PyTorch version used to build it unless user provides override flag
2855
# otherwise, assume it's prebuilt by torchao's build scripts and we can make
2956
# assumptions about the PyTorch version used to build it.
30-
if (not "+git" in __version__) and not ("unknown" in __version__):
31-
# torchao v0.13.0 is built with PyTorch 2.8.0. We know that torchao .so
32-
# files built using PyTorch 2.8.0 are not ABI compatible with PyTorch 2.9+.
33-
# The following code skips importing the .so files if PyTorch 2.9+ is
34-
# detected, to avoid crashing the Python process with "Aborted (core
57+
elif not ("+git" in __version__) and not ("unknown" in __version__):
58+
# We know that torchao .so files built using PyTorch 2.8.0 are not ABI compatible with PyTorch 2.9+. (see #2919)
59+
# The following code skips importing the .so files if incompatible torch version is detected,
60+
# to avoid crashing the Python process with "Aborted (core
3561
# dumped)".
36-
# TODO(#2901, and before next torchao release): make this generic for
37-
# future torchao and torch versions
38-
if __version__.startswith("0.13.0") and str(torch.__version__) >= "2.9":
39-
logger.warning(
40-
f"Skipping import of cpp extensions due to incompatible torch version {torch.__version__} for torchao version {__version__}"
41-
)
62+
torch_version = _parse_version(torch.__version__)
63+
torchao_version = _parse_version(__version__)
64+
65+
v2_8_0 = _parse_version("2.8.0")
66+
v0_13_0 = _parse_version("0.13.0")
67+
v2_9_0_dev = _parse_version("2.9.0.dev")
68+
v0_14_0_dev = _parse_version("0.14.0.dev")
69+
70+
if torch_version == v2_8_0 and torchao_version == v0_13_0:
71+
# current torchao version and torch version, check here for clarity
72+
skip_loading_so_files = False
73+
elif torch_version == v2_9_0_dev and torchao_version == v0_14_0_dev:
74+
# .dev for nightlies since 2.9.0 and 0.14.0 has not been released
75+
skip_loading_so_files = False
76+
else:
4277
skip_loading_so_files = True
4378

44-
if not skip_loading_so_files:
79+
80+
if skip_loading_so_files:
81+
logger.warning(
82+
f"Skipping import of cpp extensions due to incompatible torch version {torch.__version__} for torchao version {__version__} \
83+
Please see GitHub issue #2919 for more info"
84+
)
85+
else:
4586
try:
4687
from pathlib import Path
4788

0 commit comments

Comments
 (0)