Skip to content

Commit

Permalink
Allow users to force TF availability (huggingface#18650)
Browse files Browse the repository at this point in the history
* Allow users to force TF availability

* Correctly name the envvar!
  • Loading branch information
Rocketknight1 authored and oneraghavan committed Sep 26, 2022
1 parent 0a08e1e commit 605c6e2
Showing 1 changed file with 40 additions and 33 deletions.
73 changes: 40 additions & 33 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()

FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper()

_torch_version = "N/A"
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
_torch_available = importlib.util.find_spec("torch") is not None
Expand All @@ -57,40 +59,45 @@


_tf_version = "N/A"
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
_tf_available = importlib.util.find_spec("tensorflow") is not None
if _tf_available:
candidates = (
"tensorflow",
"tensorflow-cpu",
"tensorflow-gpu",
"tf-nightly",
"tf-nightly-cpu",
"tf-nightly-gpu",
"intel-tensorflow",
"intel-tensorflow-avx512",
"tensorflow-rocm",
"tensorflow-macos",
"tensorflow-aarch64",
)
_tf_version = None
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
for pkg in candidates:
try:
_tf_version = importlib_metadata.version(pkg)
break
except importlib_metadata.PackageNotFoundError:
pass
_tf_available = _tf_version is not None
if _tf_available:
if version.parse(_tf_version) < version.parse("2"):
logger.info(f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum.")
_tf_available = False
else:
logger.info(f"TensorFlow version {_tf_version} available.")
if FORCE_TF_AVAILABLE in ENV_VARS_TRUE_VALUES:
_tf_available = True
else:
logger.info("Disabling Tensorflow because USE_TORCH is set")
_tf_available = False
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
_tf_available = importlib.util.find_spec("tensorflow") is not None
if _tf_available:
candidates = (
"tensorflow",
"tensorflow-cpu",
"tensorflow-gpu",
"tf-nightly",
"tf-nightly-cpu",
"tf-nightly-gpu",
"intel-tensorflow",
"intel-tensorflow-avx512",
"tensorflow-rocm",
"tensorflow-macos",
"tensorflow-aarch64",
)
_tf_version = None
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
for pkg in candidates:
try:
_tf_version = importlib_metadata.version(pkg)
break
except importlib_metadata.PackageNotFoundError:
pass
_tf_available = _tf_version is not None
if _tf_available:
if version.parse(_tf_version) < version.parse("2"):
logger.info(
f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum."
)
_tf_available = False
else:
logger.info(f"TensorFlow version {_tf_version} available.")
else:
logger.info("Disabling Tensorflow because USE_TORCH is set")
_tf_available = False


if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
Expand Down

0 comments on commit 605c6e2

Please sign in to comment.