|
1 | 1 | import logging |
| 2 | +import os |
| 3 | +import re |
2 | 4 |
|
3 | 5 | # torch/nested/_internal/nested_tensor.py:417: UserWarning: Failed to initialize NumPy: No module named 'numpy' |
4 | 6 | import warnings |
|
20 | 22 | except PackageNotFoundError: |
21 | 23 | __version__ = "unknown" # In case this logic breaks don't break the build |
22 | 24 |
|
| 25 | + |
23 | 26 | logger = logging.getLogger(__name__) |
24 | 27 |
|
| 28 | + |
| 29 | +def _parse_version(version_string): |
| 30 | + """ |
| 31 | + Parse version string representing pre-release with -1 |
| 32 | +
|
| 33 | + Examples: "2.5.0.dev20240708+cu121" -> [2, 5, -1], "2.5.0" -> [2, 5, 0] |
| 34 | + """ |
| 35 | + # Check for pre-release indicators |
| 36 | + is_prerelease = bool(re.search(r"(git|dev)", version_string)) |
| 37 | + match = re.match(r"(\d+)\.(\d+)\.(\d+)", version_string) |
| 38 | + if match: |
| 39 | + major, minor, patch = map(int, match.groups()) |
| 40 | + if is_prerelease: |
| 41 | + patch = -1 |
| 42 | + return [major, minor, patch] |
| 43 | + else: |
| 44 | + raise ValueError(f"Invalid version string format: {version_string}") |
| 45 | + |
| 46 | + |
25 | 47 | skip_loading_so_files = False |
| 48 | +if bool(os.getenv("TORCHAO_SKIP_LOADING_SO_FILES", False)): |
| 49 | + # user override |
| 50 | + # users can set env var TORCH_INCOMPATIBLE=1 to skip loading .so files |
| 51 | + # this way, if they are using an incompatbile torch version, they can still use the API by setting the env var |
| 52 | + skip_loading_so_files = True |
26 | 53 | # if torchao version has "+git", assume it's locally built and we don't know |
27 | | -# anything about the PyTorch version used to build it |
| 54 | +# anything about the PyTorch version used to build it unless user provides override flag |
28 | 55 | # otherwise, assume it's prebuilt by torchao's build scripts and we can make |
29 | 56 | # assumptions about the PyTorch version used to build it. |
30 | | -if (not "+git" in __version__) and not ("unknown" in __version__): |
31 | | - # torchao v0.13.0 is built with PyTorch 2.8.0. We know that torchao .so |
32 | | - # files built using PyTorch 2.8.0 are not ABI compatible with PyTorch 2.9+. |
33 | | - # The following code skips importing the .so files if PyTorch 2.9+ is |
34 | | - # detected, to avoid crashing the Python process with "Aborted (core |
| 57 | +elif not ("+git" in __version__) and not ("unknown" in __version__): |
| 58 | + # We know that torchao .so files built using PyTorch 2.8.0 are not ABI compatible with PyTorch 2.9+. (see #2919) |
| 59 | + # The following code skips importing the .so files if incompatible torch version is detected, |
| 60 | + # to avoid crashing the Python process with "Aborted (core |
35 | 61 | # dumped)". |
36 | | - # TODO(#2901, and before next torchao release): make this generic for |
37 | | - # future torchao and torch versions |
38 | | - if __version__.startswith("0.13.0") and str(torch.__version__) >= "2.9": |
39 | | - logger.warning( |
40 | | - f"Skipping import of cpp extensions due to incompatible torch version {torch.__version__} for torchao version {__version__}" |
41 | | - ) |
| 62 | + torch_version = _parse_version(torch.__version__) |
| 63 | + torchao_version = _parse_version(__version__) |
| 64 | + |
| 65 | + v2_8_0 = _parse_version("2.8.0") |
| 66 | + v0_13_0 = _parse_version("0.13.0") |
| 67 | + v2_9_0_dev = _parse_version("2.9.0.dev") |
| 68 | + v0_14_0_dev = _parse_version("0.14.0.dev") |
| 69 | + |
| 70 | + if torch_version == v2_8_0 and torchao_version == v0_13_0: |
| 71 | + # current torchao version and torch version, check here for clarity |
| 72 | + skip_loading_so_files = False |
| 73 | + elif torch_version == v2_9_0_dev and torchao_version == v0_14_0_dev: |
| 74 | + # .dev for nightlies since 2.9.0 and 0.14.0 has not been released |
| 75 | + skip_loading_so_files = False |
| 76 | + else: |
42 | 77 | skip_loading_so_files = True |
43 | 78 |
|
44 | | -if not skip_loading_so_files: |
| 79 | + |
| 80 | +if skip_loading_so_files: |
| 81 | + logger.warning( |
| 82 | + f"Skipping import of cpp extensions due to incompatible torch version {torch.__version__} for torchao version {__version__} \ |
| 83 | + Please see GitHub issue #2919 for more info" |
| 84 | + ) |
| 85 | +else: |
45 | 86 | try: |
46 | 87 | from pathlib import Path |
47 | 88 |
|
|
0 commit comments