diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index a194b38b84c8..f365e469c31f 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -89,8 +89,20 @@ try: _tf_version = importlib_metadata.version("tensorflow-cpu") except importlib_metadata.PackageNotFoundError: - _tf_version = None - _tf_available = False + try: + _tf_version = importlib_metadata.version("tensorflow-gpu") + except importlib_metadata.PackageNotFoundError: + try: + _tf_version = importlib_metadata.version("tf-nightly") + except importlib_metadata.PackageNotFoundError: + try: + _tf_version = importlib_metadata.version("tf-nightly-cpu") + except importlib_metadata.PackageNotFoundError: + try: + _tf_version = importlib_metadata.version("tf-nightly-gpu") + except importlib_metadata.PackageNotFoundError: + _tf_version = None + _tf_available = False if _tf_available: if version.parse(_tf_version) < version.parse("2"): logger.info(f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum.")