Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: load_best_model_at_end error when load_in_8bit is True #23443

Merged
merged 1 commit into from
May 23, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2226,16 +2226,35 @@ def _load_best_model(self):
state_dict["_smp_is_partial"] = False
load_result = model.load_state_dict(state_dict, strict=True)
else:
# We load the model state dict on the CPU to avoid an OOM error.
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
if hasattr(model, "base_model") and getattr(model.base_model, "is_8bit_serializable", False):
dkqkxx marked this conversation as resolved.
Show resolved Hide resolved
# If train base_8_bit_models using PEFT & LoRA, assume that adapter have been saved properly.
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
dkqkxx marked this conversation as resolved.
Show resolved Hide resolved
if os.path.exists(os.path.join(self.state.best_model_checkpoint, "adapter_model.bin")):
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
dkqkxx marked this conversation as resolved.
Show resolved Hide resolved
# Load_adapter has no return value present, modify it when appropriate.
from torch.nn.modules.module import _IncompatibleKeys

load_result = _IncompatibleKeys([], [])
dkqkxx marked this conversation as resolved.
Show resolved Hide resolved
else:
logger.warning(
"The intermediate checkpoints of PEFT may not be saved correctly, "
"using `TrainerCallback` to save adapter_model.bin in corresponding folders, "
"here are some examples https://github.com/huggingface/peft/issues/96"
)
else:
# We can't do pure 8bit training using transformers.
logger.warning("Could not loading a quantized checkpoint.")
else:
state_dict = torch.load(best_model_path, map_location="cpu")
# We load the model state dict on the CPU to avoid an OOM error.
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
else:
state_dict = torch.load(best_model_path, map_location="cpu")

# If the model is on the GPU, it still works!
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
# which takes *args instead of **kwargs
load_result = model.load_state_dict(state_dict, False)
# If the model is on the GPU, it still works!
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
# which takes *args instead of **kwargs
load_result = model.load_state_dict(state_dict, False)
if not is_sagemaker_mp_enabled():
dkqkxx marked this conversation as resolved.
Show resolved Hide resolved
self._issue_warnings_after_load(load_result)
elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
Expand Down