Skip to content

Commit

Permalink
Fix version for Python<3.8 (#1363)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuki authored Feb 27, 2024
1 parent 2a2676e commit cf68d87
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions trl/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,16 @@ def is_accelerate_greater_20_0() -> bool:
return accelerate_version >= "0.20.0"


def is_transformers_greater_than(version: str) -> bool:
_transformers_version = importlib.metadata.version("transformers")
return _transformers_version > version
def is_transformers_greater_than(current_version: str) -> bool:
if _is_python_greater_3_8:
from importlib.metadata import version

_transformers_version = version("transformers")
else:
import pkg_resources

_transformers_version = pkg_resources.get_distribution("transformers").version
return _transformers_version > current_version


def is_torch_greater_2_0() -> bool:
Expand Down

0 comments on commit cf68d87

Please sign in to comment.