Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

device agnostic fsdp testing #27120

Merged
merged 3 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,22 @@ def require_torch_multi_gpu(test_case):
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)


def require_torch_multi_accelerator(test_case):
"""
Decorator marking a test that requires a multi-accelerator (in PyTorch). These tests are skipped on a machine
without multiple accelerators.

To run *only* the multi_accelerator tests, assuming all test names contain multi_accelerator: $ pytest -sv ./tests
-k "multi_accelerator"
"""
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)

return unittest.skipUnless(backend_device_count(torch_device) > 1, "test requires multiple accelerators")(
test_case
)


def require_torch_non_multi_gpu(test_case):
"""
Decorator marking a test that requires 0 or 1 GPU setup (in PyTorch).
Expand Down
21 changes: 11 additions & 10 deletions tests/fsdp/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,19 @@
from transformers import is_torch_available
from transformers.testing_utils import (
TestCasePlus,
backend_device_count,
execute_subprocess_async,
get_gpu_count,
mockenv_context,
require_accelerate,
require_fsdp,
require_torch_gpu,
require_torch_multi_gpu,
require_torch_accelerator,
require_torch_multi_accelerator,
slow,
torch_device,
)
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import FSDPOption, set_seed
from transformers.utils import is_accelerate_available, is_torch_bf16_gpu_available
from transformers.utils import is_accelerate_available, is_torch_bf16_available_on_device


if is_torch_available():
Expand All @@ -46,7 +47,7 @@
# default torch.distributed port
DEFAULT_MASTER_PORT = "10999"
dtypes = ["fp16"]
if is_torch_bf16_gpu_available():
if is_torch_bf16_available_on_device(torch_device):
dtypes += ["bf16"]
sharding_strategies = ["full_shard", "shard_grad_op"]
state_dict_types = ["FULL_STATE_DICT", "SHARDED_STATE_DICT"]
Expand Down Expand Up @@ -100,7 +101,7 @@ def get_launcher(distributed=False, use_accelerate=False):
# - it won't be able to handle that
# 2. for now testing with just 2 gpus max (since some quality tests may give different
# results with mode gpus because we use very little data)
num_gpus = min(2, get_gpu_count()) if distributed else 1
num_gpus = min(2, backend_device_count(torch_device)) if distributed else 1
master_port = get_master_port(real_launcher=True)
if use_accelerate:
return f"""accelerate launch
Expand All @@ -121,7 +122,7 @@ def _parameterized_custom_name_func(func, param_num, param):


@require_accelerate
@require_torch_gpu
@require_torch_accelerator
@require_fsdp_version
class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
def setUp(self):
Expand Down Expand Up @@ -170,7 +171,7 @@ def test_fsdp_config(self, sharding_strategy, dtype):
self.assertEqual(os.environ.get("ACCELERATE_USE_FSDP", "false"), "true")

@parameterized.expand(params, name_func=_parameterized_custom_name_func)
@require_torch_multi_gpu
@require_torch_multi_accelerator
@slow
def test_basic_run(self, sharding_strategy, dtype):
launcher = get_launcher(distributed=True, use_accelerate=False)
Expand All @@ -182,7 +183,7 @@ def test_basic_run(self, sharding_strategy, dtype):
execute_subprocess_async(cmd, env=self.get_env())

@parameterized.expand(dtypes)
@require_torch_multi_gpu
@require_torch_multi_accelerator
@slow
@unittest.skipIf(not is_torch_greater_or_equal_than_2_1, reason="This test on pytorch 2.0 takes 4 hours.")
def test_basic_run_with_cpu_offload(self, dtype):
Expand All @@ -195,7 +196,7 @@ def test_basic_run_with_cpu_offload(self, dtype):
execute_subprocess_async(cmd, env=self.get_env())

@parameterized.expand(state_dict_types, name_func=_parameterized_custom_name_func)
@require_torch_multi_gpu
@require_torch_multi_accelerator
@slow
def test_training_and_can_resume_normally(self, state_dict_type):
output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False)
Expand Down
Loading