diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index ce4e1f24067187..90b25cfb2097bc 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2226,16 +2226,35 @@ def _load_best_model(self): state_dict["_smp_is_partial"] = False load_result = model.load_state_dict(state_dict, strict=True) else: - # We load the model state dict on the CPU to avoid an OOM error. - if self.args.save_safetensors and os.path.isfile(best_safe_model_path): - state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu") + if hasattr(model, "base_model") and getattr(model.base_model, "is_8bit_serializable", False): + # If train base_8_bit_models using PEFT & LoRA, assume that adapter have been saved properly. + if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"): + if os.path.exists(os.path.join(self.state.best_model_checkpoint, "adapter_model.bin")): + model.load_adapter(self.state.best_model_checkpoint, model.active_adapter) + # Load_adapter has no return value present, modify it when appropriate. + from torch.nn.modules.module import _IncompatibleKeys + + load_result = _IncompatibleKeys([], []) + else: + logger.warning( + "The intermediate checkpoints of PEFT may not be saved correctly, " + "using `TrainerCallback` to save adapter_model.bin in corresponding folders, " + "here are some examples https://github.com/huggingface/peft/issues/96" + ) + else: + # We can't do pure 8bit training using transformers. + logger.warning("Could not loading a quantized checkpoint.") else: - state_dict = torch.load(best_model_path, map_location="cpu") + # We load the model state dict on the CPU to avoid an OOM error. + if self.args.save_safetensors and os.path.isfile(best_safe_model_path): + state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu") + else: + state_dict = torch.load(best_model_path, map_location="cpu") - # If the model is on the GPU, it still works! - # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 - # which takes *args instead of **kwargs - load_result = model.load_state_dict(state_dict, False) + # If the model is on the GPU, it still works! + # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 + # which takes *args instead of **kwargs + load_result = model.load_state_dict(state_dict, False) if not is_sagemaker_mp_enabled(): self._issue_warnings_after_load(load_result) elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):