-
Notifications
You must be signed in to change notification settings - Fork 530
/
Copy pathfind_tensorflow.py
238 lines (214 loc) · 7.92 KB
/
find_tensorflow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
# SPDX-License-Identifier: LGPL-3.0-or-later
import os
import site
from functools import (
lru_cache,
)
from importlib.machinery import (
FileFinder,
)
from importlib.util import (
find_spec,
)
from pathlib import (
Path,
)
from sysconfig import (
get_path,
)
from typing import (
List,
Optional,
Tuple,
Union,
)
from packaging.specifiers import (
SpecifierSet,
)
@lru_cache
def find_tensorflow() -> Tuple[Optional[str], List[str]]:
"""Find TensorFlow library.
Tries to find TensorFlow in the order of:
1. Environment variable `TENSORFLOW_ROOT` if set
2. The current Python environment.
3. user site packages directory if enabled
4. system site packages directory (purelib)
5. add as a requirement (detect TENSORFLOW_VERSION or the latest) and let pip install it
Returns
-------
str
TensorFlow library path if found.
list of str
TensorFlow requirement if not found. Empty if found.
"""
if os.environ.get("DP_ENABLE_TENSORFLOW", "1") == "0":
return None, []
requires = []
tf_spec = None
if (tf_spec is None or not tf_spec) and os.environ.get(
"TENSORFLOW_ROOT"
) is not None:
site_packages = Path(os.environ.get("TENSORFLOW_ROOT")).parent.absolute()
tf_spec = FileFinder(str(site_packages)).find_spec("tensorflow")
# get tensorflow spec
# note: isolated build will not work for backend
if tf_spec is None or not tf_spec:
tf_spec = find_spec("tensorflow")
if not tf_spec and site.ENABLE_USER_SITE:
# first search TF from user site-packages before global site-packages
site_packages = site.getusersitepackages()
if site_packages:
tf_spec = FileFinder(site_packages).find_spec("tensorflow")
if not tf_spec:
# purelib gets site-packages path
site_packages = get_path("purelib")
if site_packages:
tf_spec = FileFinder(site_packages).find_spec("tensorflow")
# get install dir from spec
try:
tf_install_dir = tf_spec.submodule_search_locations[0] # type: ignore
# AttributeError if ft_spec is None
# TypeError if submodule_search_locations are None
# IndexError if submodule_search_locations is an empty list
except (AttributeError, TypeError, IndexError):
tf_version = ""
if os.environ.get("CIBUILDWHEEL", "0") == "1":
cuda_version = os.environ.get("CUDA_VERSION", "12.2")
if cuda_version == "" or cuda_version in SpecifierSet(">=12,<13"):
# CUDA 12.2
requires.extend(
[
"tensorflow-cpu>=2.15.0rc0; platform_machine=='x86_64' and platform_system == 'Linux'",
]
)
elif cuda_version in SpecifierSet(">=11,<12"):
# CUDA 11.8
requires.extend(
[
"tensorflow-cpu>=2.5.0rc0,<2.15; platform_machine=='x86_64' and platform_system == 'Linux'",
]
)
tf_version = "2.14.1"
else:
raise RuntimeError("Unsupported CUDA version")
requires.extend(get_tf_requirement(tf_version)["cpu"])
# setuptools will re-find tensorflow after installing setup_requires
tf_install_dir = None
return tf_install_dir, requires
@lru_cache
def get_tf_requirement(tf_version: str = "") -> dict:
"""Get TensorFlow requirement (CPU) when TF is not installed.
If tf_version is not given and the environment variable `TENSORFLOW_VERSION` is set, use it as the requirement.
Parameters
----------
tf_version : str, optional
TF version
Returns
-------
dict
TensorFlow requirement, including cpu and gpu.
"""
if tf_version is None:
return {
"cpu": [],
"gpu": [],
"mpi": [],
}
if tf_version == "":
tf_version = os.environ.get("TENSORFLOW_VERSION", "")
extra_requires = []
extra_select = {}
if not (tf_version == "" or tf_version in SpecifierSet(">=2.12", prereleases=True)):
extra_requires.append("protobuf<3.20")
# keras 3 is not compatible with tf.compat.v1
# 2024/04/24: deepmd.tf doesn't import tf.keras any more
if tf_version == "" or tf_version in SpecifierSet(">=1.15", prereleases=True):
extra_select["mpi"] = [
"horovod",
"mpi4py",
]
else:
extra_select["mpi"] = []
if tf_version == "":
return {
"cpu": [
"tensorflow-cpu; platform_machine!='aarch64' and (platform_machine!='arm64' or platform_system != 'Darwin')",
"tensorflow; platform_machine=='aarch64' or (platform_machine=='arm64' and platform_system == 'Darwin')",
# https://github.com/tensorflow/tensorflow/issues/61830
"tensorflow-cpu!=2.15.*; platform_system=='Windows'",
# TODO: build(wheel): unpin h5py on aarch64
# Revert after https://github.com/h5py/h5py/issues/2408 is fixed;
# or set UV_PREFER_BINARY when https://github.com/astral-sh/uv/issues/1794 is resolved.
# 3.6.0 is the first version to have aarch64 wheels.
"h5py>=3.6.0,<3.11.0; platform_system=='Linux' and platform_machine=='aarch64'",
*extra_requires,
],
"gpu": [
"tensorflow",
"tensorflow-metal; platform_machine=='arm64' and platform_system == 'Darwin'",
# See above.
"h5py>=3.6.0,<3.11.0; platform_system=='Linux' and platform_machine=='aarch64'",
*extra_requires,
],
**extra_select,
}
elif tf_version in SpecifierSet(
"<1.15", prereleases=True
) or tf_version in SpecifierSet(">=2.0,<2.1", prereleases=True):
return {
"cpu": [
f"tensorflow=={tf_version}",
*extra_requires,
],
"gpu": [
f"tensorflow-gpu=={tf_version}; platform_machine!='aarch64'",
f"tensorflow=={tf_version}; platform_machine=='aarch64'",
*extra_requires,
],
**extra_select,
}
else:
return {
"cpu": [
f"tensorflow-cpu=={tf_version}; platform_machine!='aarch64' and (platform_machine!='arm64' or platform_system != 'Darwin')",
f"tensorflow=={tf_version}; platform_machine=='aarch64' or (platform_machine=='arm64' and platform_system == 'Darwin')",
*extra_requires,
],
"gpu": [
f"tensorflow=={tf_version}",
"tensorflow-metal; platform_machine=='arm64' and platform_system == 'Darwin'",
*extra_requires,
],
**extra_select,
}
@lru_cache
def get_tf_version(tf_path: Union[str, Path]) -> str:
"""Get TF version from a TF Python library path.
Parameters
----------
tf_path : str or Path
TF Python library path
Returns
-------
str
version
"""
if tf_path is None or tf_path == "":
return ""
version_file = (
Path(tf_path) / "include" / "tensorflow" / "core" / "public" / "version.h"
)
major = minor = patch = None
with open(version_file) as f:
for line in f:
if line.startswith("#define TF_MAJOR_VERSION"):
major = line.split()[-1]
elif line.startswith("#define TF_MINOR_VERSION"):
minor = line.split()[-1]
elif line.startswith("#define TF_PATCH_VERSION"):
patch = line.split()[-1]
elif line.startswith("#define TF_VERSION_SUFFIX"):
suffix = line.split()[-1].strip('"')
if None in (major, minor, patch):
raise RuntimeError("Failed to read TF version")
return ".".join((major, minor, patch)) + suffix