diff --git a/.github/workflows/self-scheduled.yml b/.github/workflows/self-scheduled.yml index c49a967d2aba72..978d9e02a69d38 100644 --- a/.github/workflows/self-scheduled.yml +++ b/.github/workflows/self-scheduled.yml @@ -33,8 +33,7 @@ jobs: run: | apt -y update && apt install -y libsndfile1-dev pip install --upgrade pip - pip install .[sklearn,testing,onnxruntime,sentencepiece,speech] - pip install deepspeed + pip install .[sklearn,testing,onnxruntime,sentencepiece,speech,deepspeed] - name: Are GPUs recognized by our DL frameworks run: | @@ -156,9 +155,7 @@ jobs: run: | apt -y update && apt install -y libsndfile1-dev pip install --upgrade pip - pip install .[sklearn,testing,onnxruntime,sentencepiece,speech] - pip install fairscale - pip install deepspeed + pip install .[sklearn,testing,onnxruntime,sentencepiece,speech,deepspeed,fairscale] - name: Are GPUs recognized by our DL frameworks run: | diff --git a/docs/source/main_classes/trainer.rst b/docs/source/main_classes/trainer.rst index bc9f248827ad6a..10a7a9d54aa3bf 100644 --- a/docs/source/main_classes/trainer.rst +++ b/docs/source/main_classes/trainer.rst @@ -274,6 +274,14 @@ Install the library via pypi: pip install fairscale +or via ``transformers``' ``extras``: + +.. code-block:: bash + + pip install transformers[fairscale] + +(will become available starting from ``transformers==4.6.0``) + or find more details on `the FairScale's GitHub page `__. If you're still struggling with the build, first make sure to read :ref:`zero-install-notes`. @@ -419,6 +427,14 @@ Install the library via pypi: pip install deepspeed +or via ``transformers``' ``extras``: + +.. code-block:: bash + + pip install transformers[deepspeed] + +(will become available starting from ``transformers==4.6.0``) + or find more details on `the DeepSpeed's GitHub page `__ and `advanced install `__. diff --git a/setup.py b/setup.py index c3583a30700980..1142fd19838d86 100644 --- a/setup.py +++ b/setup.py @@ -90,7 +90,9 @@ "cookiecutter==1.7.2", "dataclasses", "datasets", + "deepspeed>0.3.13", "docutils==0.16.0", + "fairscale>0.3", "faiss-cpu", "fastapi", "filelock", @@ -233,6 +235,8 @@ def run(self): extras["modelcreation"] = deps_list("cookiecutter") extras["sagemaker"] = deps_list("sagemaker") +extras["deepspeed"] = deps_list("deepspeed") +extras["fairscale"] = deps_list("fairscale") extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette") extras["speech"] = deps_list("soundfile", "torchaudio") diff --git a/src/transformers/dependency_versions_check.py b/src/transformers/dependency_versions_check.py index 7e36aaef3091ba..e6e676481d79c9 100644 --- a/src/transformers/dependency_versions_check.py +++ b/src/transformers/dependency_versions_check.py @@ -14,7 +14,7 @@ import sys from .dependency_versions_table import deps -from .utils.versions import require_version_core +from .utils.versions import require_version, require_version_core # define which module versions we always want to check at run time @@ -41,3 +41,7 @@ require_version_core(deps[pkg]) else: raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py") + + +def dep_version_check(pkg, hint=None): + require_version(deps[pkg], hint) diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 43f4c028feca57..bd070d7bdf254f 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -7,7 +7,9 @@ "cookiecutter": "cookiecutter==1.7.2", "dataclasses": "dataclasses", "datasets": "datasets", + "deepspeed": "deepspeed>0.3.13", "docutils": "docutils==0.16.0", + "fairscale": "fairscale>0.3", "faiss-cpu": "faiss-cpu", "fastapi": "fastapi", "filelock": "filelock", diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index 65824c25ca7468..7e4ab0f5c7a100 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -24,8 +24,8 @@ from copy import deepcopy from pathlib import Path +from .dependency_versions_check import dep_version_check from .utils import logging -from .utils.versions import require_version logger = logging.get_logger(__name__) @@ -324,7 +324,7 @@ def deepspeed_parse_config(ds_config): If it's already a dict, return a copy of it, so that we can freely modify it. """ - require_version("deepspeed>0.3.13") + dep_version_check("deepspeed") if isinstance(ds_config, dict): # Don't modify user's data should they want to reuse it (e.g. in tests), because once we diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index dc311643310bf6..41800b7fd3a32c 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -54,6 +54,7 @@ from torch.utils.data.sampler import RandomSampler, SequentialSampler from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator +from .dependency_versions_check import dep_version_check from .file_utils import ( WEIGHTS_NAME, is_apex_available, @@ -139,17 +140,14 @@ import torch_xla.distributed.parallel_loader as pl if is_fairscale_available(): + dep_version_check("fairscale") import fairscale + from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP + from fairscale.nn.wrap import auto_wrap from fairscale.optim import OSS from fairscale.optim.grad_scaler import ShardedGradScaler - if version.parse(fairscale.__version__) >= version.parse("0.3"): - from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP - from fairscale.nn.wrap import auto_wrap - else: - FullyShardedDDP = None - if is_sagemaker_dp_enabled(): import smdistributed.dataparallel.torch.distributed as dist from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP diff --git a/src/transformers/utils/versions.py b/src/transformers/utils/versions.py index b573a361b96ff7..73151487bc71f2 100644 --- a/src/transformers/utils/versions.py +++ b/src/transformers/utils/versions.py @@ -60,6 +60,12 @@ def require_version(requirement: str, hint: Optional[str] = None) -> None: Args: requirement (:obj:`str`): pip style definition, e.g., "tokenizers==0.9.4", "tqdm>=4.27", "numpy" hint (:obj:`str`, `optional`): what suggestion to print in case of requirements not being met + + Example:: + + require_version("pandas>1.1.2") + require_version("numpy>1.18.5", "this is important to have for whatever reason") + """ hint = f"\n{hint}" if hint is not None else ""