Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP tests and checkpointing fixes #26180

Merged
merged 26 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def is_fsdp_enabled():


def is_fsdp_enabled_and_dist_rank_0():
return is_fsdp_enabled() and torch.distributed.get_rank() == 0
return is_fsdp_enabled() and int(os.environ.get("LOCAL_RANK", -1)) == 0
pacman100 marked this conversation as resolved.
Show resolved Hide resolved


if is_sagemaker_mp_enabled():
Expand Down
10 changes: 10 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
is_essentia_available,
is_faiss_available,
is_flax_available,
is_fsdp_available,
is_ftfy_available,
is_ipex_available,
is_jieba_available,
Expand Down Expand Up @@ -315,6 +316,15 @@ def require_accelerate(test_case):
return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)


def require_fsdp(test_case, min_version: str = "1.12.0"):
"""
Decorator marking a test that requires fsdp. These tests are skipped when fsdp isn't installed.
"""
return unittest.skipUnless(is_fsdp_available(min_version), f"test requires torch version >= {min_version}")(
test_case
)


def require_safetensors(test_case):
"""
Decorator marking a test that requires safetensors. These tests are skipped when safetensors isn't installed.
Expand Down
66 changes: 43 additions & 23 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1690,9 +1690,6 @@ def _inner_training_loop(

model = self._wrap_model(self.model_wrapped)

if (is_sagemaker_mp_enabled() or self.is_fsdp_enabled) and resume_from_checkpoint is not None:
self._load_from_checkpoint(resume_from_checkpoint, model)

# as the model is wrapped, don't use `accelerator.prepare`
# this is for unhandled cases such as
# Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
Expand All @@ -1718,7 +1715,7 @@ def _inner_training_loop(
)

if self.is_fsdp_enabled:
self.model = model
self.model = self.model_wrapped = model

# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
Expand All @@ -1728,16 +1725,20 @@ def _inner_training_loop(
if self.is_deepspeed_enabled:
self.deepspeed = self.model_wrapped

# deepspeed ckpt loading
if resume_from_checkpoint is not None and self.is_deepspeed_enabled:
deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint)
# ckpt loading
if resume_from_checkpoint is not None:
if self.is_deepspeed_enabled:
deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint)
elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled:
self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)

# Check if saved optimizer or scheduler states exist
self._load_optimizer_and_scheduler(resume_from_checkpoint)

# important: at this point:
# self.model is the Transformers Model
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model),
# FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc.

# Train!
logger.info("***** Running training *****")
Expand Down Expand Up @@ -2078,17 +2079,28 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)
is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and any(
WEIGHTS_NAME.split(".")[0] in folder_name
for folder_name in os.listdir(resume_from_checkpoint)
if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
)
pacman100 marked this conversation as resolved.
Show resolved Hide resolved

if not any(
os.path.isfile(f)
for f in [
weights_file,
safe_weights_file,
weights_index_file,
safe_weights_index_file,
adapter_weights_file,
adapter_safe_weights_file,
]
if is_fsdp_ckpt and not self.is_fsdp_enabled:
raise ValueError(f"Checkpoint found at {resume_from_checkpoint} is only supported when using PyTorch FSDP")

if not (
any(
os.path.isfile(f)
for f in [
weights_file,
safe_weights_file,
weights_index_file,
safe_weights_index_file,
adapter_weights_file,
adapter_safe_weights_file,
]
)
or is_fsdp_ckpt
):
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")

Expand All @@ -2104,7 +2116,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
"yield to errors or unwanted behaviors."
)

if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file):
if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file) or is_fsdp_ckpt:
pacman100 marked this conversation as resolved.
Show resolved Hide resolved
# If the model is on the GPU, it still works!
if is_sagemaker_mp_enabled():
if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
Expand Down Expand Up @@ -2174,6 +2186,10 @@ def _load_best_model(self):
model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.is_deepspeed_enabled:
deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint)
elif self.is_fsdp_enabled:
load_result = load_fsdp_model(
self.accelerator.state.fsdp_plugin, self.accelerator, model, self.state.best_model_checkpoint
)
elif (
os.path.exists(best_model_path)
or os.path.exists(best_safe_model_path)
Expand Down Expand Up @@ -2201,10 +2217,6 @@ def _load_best_model(self):

state_dict["_smp_is_partial"] = False
load_result = model.load_state_dict(state_dict, strict=True)
elif self.is_fsdp_enabled:
load_result = load_fsdp_model(
self.accelerator.state.fsdp_plugin, self.accelerator, model, self.state.best_model_checkpoint
)
else:
if is_peft_available() and isinstance(model, PeftModel):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
Expand Down Expand Up @@ -2493,6 +2505,14 @@ def _load_optimizer_and_scheduler(self, checkpoint):
else (
os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
or os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME_BIN))
or (
os.path.isdir(checkpoint)
and any(
OPTIMIZER_NAME_BIN.split(".")[0] in folder_name
for folder_name in os.listdir(checkpoint)
if os.path.isdir(os.path.join(checkpoint, folder_name))
)
)
)
)
if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
is_essentia_available,
is_faiss_available,
is_flax_available,
is_fsdp_available,
is_ftfy_available,
is_in_notebook,
is_ipex_available,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,10 @@ def is_accelerate_available(min_version: str = None):
return _accelerate_available


def is_fsdp_available(min_version: str = "1.12.0"):
return version.parse(_torch_version) >= version.parse(min_version)


def is_optimum_available():
return _optimum_available

Expand Down
Loading