Skip to content
9 changes: 7 additions & 2 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,14 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
package_exists = False
elif pkg_name == "triton":
try:
package_version = importlib.metadata.version("pytorch-triton")
# import triton works for both linux and windows
package = importlib.import_module(pkg_name)
package_version = getattr(package, "__version__", "N/A")
except Exception:
package_exists = False
try:
package_version = importlib.metadata.version("pytorch-triton") # pytorch-triton
except Exception:
package_exists = False
else:
# For packages other than "torch", don't attempt the fallback and set as not available
package_exists = False
Expand Down