|
1 | 1 | import logging |
| 2 | +import os |
| 3 | +import re |
2 | 4 |
|
3 | 5 | # torch/nested/_internal/nested_tensor.py:417: UserWarning: Failed to initialize NumPy: No module named 'numpy' |
4 | 6 | import warnings |
|
20 | 22 | except PackageNotFoundError: |
21 | 23 | __version__ = "unknown" # In case this logic breaks don't break the build |
22 | 24 |
|
| 25 | + |
23 | 26 | logger = logging.getLogger(__name__) |
24 | 27 |
|
| 28 | + |
| 29 | +def _parse_version(version_string): |
| 30 | + """ |
| 31 | + Parse version string representing pre-release with -1 |
| 32 | +
|
| 33 | + Examples: "2.5.0.dev20240708+cu121" -> [2, 5, -1], "2.5.0" -> [2, 5, 0] |
| 34 | + """ |
| 35 | + # Check for pre-release indicators |
| 36 | + is_prerelease = bool(re.search(r"(git|dev)", version_string)) |
| 37 | + match = re.match(r"(\d+)\.(\d+)\.(\d+)", version_string) |
| 38 | + if match: |
| 39 | + major, minor, patch = map(int, match.groups()) |
| 40 | + if is_prerelease: |
| 41 | + patch = -1 |
| 42 | + return [major, minor, patch] |
| 43 | + else: |
| 44 | + raise ValueError(f"Invalid version string format: {version_string}") |
| 45 | + |
| 46 | + |
25 | 47 | skip_loading_so_files = False |
26 | | -# if torchao version has "+git", assume it's locally built and we don't know |
27 | | -# anything about the PyTorch version used to build it |
28 | | -# otherwise, assume it's prebuilt by torchao's build scripts and we can make |
29 | | -# assumptions about the PyTorch version used to build it. |
30 | | -if (not "+git" in __version__) and not ("unknown" in __version__): |
31 | | - # torchao v0.13.0 is built with PyTorch 2.8.0. We know that torchao .so |
32 | | - # files built using PyTorch 2.8.0 are not ABI compatible with PyTorch 2.9+. |
33 | | - # The following code skips importing the .so files if PyTorch 2.9+ is |
34 | | - # detected, to avoid crashing the Python process with "Aborted (core |
| 48 | +if bool(os.getenv("TORCHAO_SKIP_LOADING_SO_FILES", False)): |
| 49 | + # user override |
| 50 | + # users can set env var TORCH_INCOMPATIBLE=1 to skip loading .so files |
| 51 | + # this way, if they are using an incompatbile torch version, they can still use the API by setting the env var |
| 52 | + skip_loading_so_files = True |
| 53 | + # if torchao version has "+git", assume it's locally built and we don't know |
| 54 | + # anything about the PyTorch version used to build it unless user provides override flag |
| 55 | + # otherwise, assume it's prebuilt by torchao's build scripts and we can make |
| 56 | + # assumptions about the PyTorch version used to build it. |
| 57 | +elif not ("+git" in __version__) and not ("unknown" in __version__): |
| 58 | + # We know that torchao .so files built using PyTorch 2.8.0 are not ABI compatible with PyTorch 2.9+. (see #2919) |
| 59 | + # The following code skips importing the .so files if incompatible torch version is detected, |
| 60 | + # to avoid crashing the Python process with "Aborted (core |
35 | 61 | # dumped)". |
36 | | - # TODO(#2901, and before next torchao release): make this generic for |
37 | | - # future torchao and torch versions |
38 | | - if __version__.startswith("0.13.0") and str(torch.__version__) >= "2.9": |
39 | | - logger.warning( |
40 | | - f"Skipping import of cpp extensions due to incompatible torch version {torch.__version__} for torchao version {__version__}" |
41 | | - ) |
| 62 | + if _parse_version(torch.__version__) == _parse_version("2.8.0") and _parse_version( |
| 63 | + __version__ |
| 64 | + ) == _parse_version("0.13.0"): |
| 65 | + # current torchao version and torch version, check here for clarity |
| 66 | + skip_loading_so_files = False |
| 67 | + if _parse_version(torch.__version__) == _parse_version( |
| 68 | + "2.9.0.dev" |
| 69 | + ) and _parse_version(__version__) == _parse_version("0.14.0.dev"): |
| 70 | + # .dev for nightlies since 2.9.0 and 0.14.0 has not been released |
| 71 | + skip_loading_so_files = False |
| 72 | + else: |
42 | 73 | skip_loading_so_files = True |
43 | 74 |
|
44 | | -if not skip_loading_so_files: |
| 75 | + |
| 76 | +if skip_loading_so_files: |
| 77 | + logger.warning( |
| 78 | + f"Skipping import of cpp extensions due to incompatible torch version {torch.__version__} for torchao version {__version__} \ |
| 79 | + Please see GitHub issue #2919 for more info" |
| 80 | + ) |
| 81 | +else: |
45 | 82 | try: |
46 | 83 | from pathlib import Path |
47 | 84 |
|
|
0 commit comments