Skip to content

Commit

Permalink
Extend import utils to cover "editable" torch versions (huggingface#2…
Browse files Browse the repository at this point in the history
…9000)

* Extend import utils to cover "editable" torch versions

* Re-add type hint

* Remove whitespaces

* Double quote strings

* Update comment

Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>

* Restore package_exists

* Revert "Restore package_exists"

This reverts commit 66fd2cd.

---------

Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
  • Loading branch information
bhack and ydshieh authored Mar 15, 2024
1 parent 56b64bf commit f62407f
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,32 @@

# TODO: This doesn't work for all packages (`bs4`, `faiss`, etc.) Talk to Sylvain to see how to do with it better.
def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
# Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version
# Check if the package spec exists and grab its version to avoid importing a local directory
package_exists = importlib.util.find_spec(pkg_name) is not None
package_version = "N/A"
if package_exists:
try:
# Primary method to get the package version
package_version = importlib.metadata.version(pkg_name)
package_exists = True
except importlib.metadata.PackageNotFoundError:
package_exists = False
logger.debug(f"Detected {pkg_name} version {package_version}")
# Fallback method: Only for "torch" and versions containing "dev"
if pkg_name == "torch":
try:
package = importlib.import_module(pkg_name)
temp_version = getattr(package, "__version__", "N/A")
# Check if the version contains "dev"
if "dev" in temp_version:
package_version = temp_version
package_exists = True
else:
package_exists = False
except ImportError:
# If the package can't be imported, it's not available
package_exists = False
else:
# For packages other than "torch", don't attempt the fallback and set as not available
package_exists = False
logger.debug(f"Detected {pkg_name} version: {package_version}")
if return_version:
return package_exists, package_version
else:
Expand Down

0 comments on commit f62407f

Please sign in to comment.