Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file modified .github/scripts/validate_binaries.sh
100644 → 100755
Empty file.
82 changes: 82 additions & 0 deletions scripts/test_torch_version_torchao_version_compatibility.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Testing compatibility
# We know that torchao .so files built using PyTorch 2.8.0 are not ABI compatible with PyTorch 2.9+. (see #2919)
# If the version of torch is not compatible with the version of torchao,
# we expect to skip loading the .so files and a warning should be logged but no error

PREV_TORCH_VERSION = 2.8.0
PREV_TORCHAO_VERSION = 0.13.0

# Function to check torchao import with configurable expectations
check_torchao_import() {
local expect_warning="$1"
local warning_text="$2"
local torch_incompatible="${3:-}"

if [ -n "$torch_incompatible" ]; then
output=$(TORCH_INCOMPATIBLE=1 python -c "import torchao" 2>&1)
else
output=$(python -c "import torchao" 2>&1)
fi
exit_code=$?

if [ $exit_code -ne 0 ]; then
echo "ERROR: Failed to import torchao"
echo "Output: $output"
exit 1
fi

warning_found=false
if [ -n "$warning_text" ] && echo "$output" | grep -i "$warning_text" > /dev/null; then
echo "Output: $output"
warning_found=true
fi

if [ "$expect_warning" != "$warning_found" ]; then
echo echo "FAILURE: expect_warning is $expect_warning but warning_found is $warning_found with message $output"
exit 1
fi
}

## prev torch version, prev torchao version
# Uninstall torch
pip uninstall torch
# Uninstall torchao
pip uninstall torchao
# Install prev compatible version of torch
pip install PREV_TORCH_VERSION
# Installs prev compatible version of torchao
pip install PREV_TORCHAO_VERSION
# hould import successfully without warning
check_torchao_import "false" ""

## current torch, current torchao
# Uninstall torch
pip uninstall torch
# Uninstall torchao
pip uninstall torchao
# Install specific compatible version of torch (nightly 2.9.0dev)
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu129
# Build torchao from source
python setup.py develop
# Should import successfully without warning
check_torchao_import "false" ""
## prev torch, torchao from source (do not rebuild), env var = True
# Uninstall torch
pip uninstall torch
# Install incompatible version of torch
pip install torch==PREV_TORCH_VERSION
# Should import with warning because optional env var is set to true
check_torchao_import "true" "Skipping import of cpp extensions due to incompatible torch version" "TORCHAO_SKIP_LOADING_SO_FILES=1"


# current torch, prev torchao
# Uninstall torch
pip uninstall torch
# Uninstall torchao
pip uninstall torchao
# Install non-ABI stable torch version
pip install torch==2.9.0
# Installs incompatible torchao
pip install torchao==PREV_TORCHAO_VERSION
# Should import with specific warning
check_torchao_import "true" "Skipping import of cpp extensions due to incompatible torch version"
67 changes: 54 additions & 13 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import logging
import os
import re

# torch/nested/_internal/nested_tensor.py:417: UserWarning: Failed to initialize NumPy: No module named 'numpy'
import warnings
Expand All @@ -20,28 +22,67 @@
except PackageNotFoundError:
__version__ = "unknown" # In case this logic breaks don't break the build


logger = logging.getLogger(__name__)


def _parse_version(version_string):
"""
Parse version string representing pre-release with -1

Examples: "2.5.0.dev20240708+cu121" -> [2, 5, -1], "2.5.0" -> [2, 5, 0]
"""
# Check for pre-release indicators
is_prerelease = bool(re.search(r"(git|dev)", version_string))
match = re.match(r"(\d+)\.(\d+)\.(\d+)", version_string)
if match:
major, minor, patch = map(int, match.groups())
if is_prerelease:
patch = -1
return [major, minor, patch]
else:
raise ValueError(f"Invalid version string format: {version_string}")


skip_loading_so_files = False
if bool(os.getenv("TORCHAO_SKIP_LOADING_SO_FILES", False)):
# user override
# users can set env var TORCH_INCOMPATIBLE=1 to skip loading .so files
# this way, if they are using an incompatbile torch version, they can still use the API by setting the env var
skip_loading_so_files = True
# if torchao version has "+git", assume it's locally built and we don't know
# anything about the PyTorch version used to build it
# anything about the PyTorch version used to build it unless user provides override flag
# otherwise, assume it's prebuilt by torchao's build scripts and we can make
# assumptions about the PyTorch version used to build it.
if (not "+git" in __version__) and not ("unknown" in __version__):
# torchao v0.13.0 is built with PyTorch 2.8.0. We know that torchao .so
# files built using PyTorch 2.8.0 are not ABI compatible with PyTorch 2.9+.
# The following code skips importing the .so files if PyTorch 2.9+ is
# detected, to avoid crashing the Python process with "Aborted (core
elif not ("+git" in __version__) and not ("unknown" in __version__):
# We know that torchao .so files built using PyTorch 2.8.0 are not ABI compatible with PyTorch 2.9+. (see #2919)
# The following code skips importing the .so files if incompatible torch version is detected,
# to avoid crashing the Python process with "Aborted (core
# dumped)".
# TODO(#2901, and before next torchao release): make this generic for
# future torchao and torch versions
if __version__.startswith("0.13.0") and str(torch.__version__) >= "2.9":
logger.warning(
f"Skipping import of cpp extensions due to incompatible torch version {torch.__version__} for torchao version {__version__}"
)
torch_version = _parse_version(torch.__version__)
torchao_version = _parse_version(__version__)

v2_8_0 = _parse_version("2.8.0")
v0_13_0 = _parse_version("0.13.0")
v2_9_0_dev = _parse_version("2.9.0.dev")
v0_14_0_dev = _parse_version("0.14.0.dev")

if torch_version == v2_8_0 and torchao_version == v0_13_0:
# current torchao version and torch version, check here for clarity
skip_loading_so_files = False
elif torch_version == v2_9_0_dev and torchao_version == v0_14_0_dev:
# .dev for nightlies since 2.9.0 and 0.14.0 has not been released
skip_loading_so_files = False
else:
skip_loading_so_files = True

if not skip_loading_so_files:

if skip_loading_so_files:
logger.warning(
f"Skipping import of cpp extensions due to incompatible torch version {torch.__version__} for torchao version {__version__} \
Please see GitHub issue #2919 for more info"
)
else:
try:
from pathlib import Path

Expand Down
Loading