From 605c6e21710b09fe11c0f48e9e5f3d90599ccdb4 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 18 Aug 2022 03:09:09 -0400 Subject: [PATCH] Allow users to force TF availability (#18650) * Allow users to force TF availability * Correctly name the envvar! --- src/transformers/utils/import_utils.py | 73 ++++++++++++++------------ 1 file changed, 40 insertions(+), 33 deletions(-) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 37172d14fcc289..219552976a0a6c 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -42,6 +42,8 @@ USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() +FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper() + _torch_version = "N/A" if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: _torch_available = importlib.util.find_spec("torch") is not None @@ -57,40 +59,45 @@ _tf_version = "N/A" -if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: - _tf_available = importlib.util.find_spec("tensorflow") is not None - if _tf_available: - candidates = ( - "tensorflow", - "tensorflow-cpu", - "tensorflow-gpu", - "tf-nightly", - "tf-nightly-cpu", - "tf-nightly-gpu", - "intel-tensorflow", - "intel-tensorflow-avx512", - "tensorflow-rocm", - "tensorflow-macos", - "tensorflow-aarch64", - ) - _tf_version = None - # For the metadata, we have to look for both tensorflow and tensorflow-cpu - for pkg in candidates: - try: - _tf_version = importlib_metadata.version(pkg) - break - except importlib_metadata.PackageNotFoundError: - pass - _tf_available = _tf_version is not None - if _tf_available: - if version.parse(_tf_version) < version.parse("2"): - logger.info(f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum.") - _tf_available = False - else: - logger.info(f"TensorFlow version {_tf_version} available.") +if FORCE_TF_AVAILABLE in ENV_VARS_TRUE_VALUES: + _tf_available = True else: - logger.info("Disabling Tensorflow because USE_TORCH is set") - _tf_available = False + if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: + _tf_available = importlib.util.find_spec("tensorflow") is not None + if _tf_available: + candidates = ( + "tensorflow", + "tensorflow-cpu", + "tensorflow-gpu", + "tf-nightly", + "tf-nightly-cpu", + "tf-nightly-gpu", + "intel-tensorflow", + "intel-tensorflow-avx512", + "tensorflow-rocm", + "tensorflow-macos", + "tensorflow-aarch64", + ) + _tf_version = None + # For the metadata, we have to look for both tensorflow and tensorflow-cpu + for pkg in candidates: + try: + _tf_version = importlib_metadata.version(pkg) + break + except importlib_metadata.PackageNotFoundError: + pass + _tf_available = _tf_version is not None + if _tf_available: + if version.parse(_tf_version) < version.parse("2"): + logger.info( + f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum." + ) + _tf_available = False + else: + logger.info(f"TensorFlow version {_tf_version} available.") + else: + logger.info("Disabling Tensorflow because USE_TORCH is set") + _tf_available = False if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: