diff --git a/trl/models/modeling_base.py b/trl/models/modeling_base.py index b675420d0e..7cd85678d1 100644 --- a/trl/models/modeling_base.py +++ b/trl/models/modeling_base.py @@ -20,6 +20,7 @@ import torch.nn as nn from accelerate import Accelerator from huggingface_hub import hf_hub_download +from huggingface_hub.utils import EntryNotFoundError, HFValidationError, LocalEntryNotFoundError from transformers import PreTrainedModel from ..import_utils import is_peft_available @@ -163,7 +164,7 @@ class and the arguments that are specific to trl models. The kwargs "adapter_config.json", token=token, ) - except: # noqa + except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError): remote_adapter_config = None else: remote_adapter_config = None @@ -181,7 +182,8 @@ class and the arguments that are specific to trl models. The kwargs if local_adapter_present: trained_adapter_config = PeftConfig.from_pretrained(pretrained_model_name_or_path) else: - trained_adapter_config = PeftConfig.from_pretrained(remote_adapter_config) + remote_adapter_dir = os.path.dirname(remote_adapter_config) + trained_adapter_config = PeftConfig.from_pretrained(remote_adapter_dir) # Load the pretrained base model pretrained_model = cls.transformers_parent_class.from_pretrained( @@ -253,7 +255,7 @@ class and the arguments that are specific to trl models. The kwargs token=token, ) # sharded - except: # noqa + except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError): if os.path.exists(sharded_index_filename): index_file_name = sharded_index_filename else: @@ -263,7 +265,8 @@ class and the arguments that are specific to trl models. The kwargs "pytorch_model.bin.index.json", token=token, ) - except ValueError: # not continue training, do not have v_head weight + except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError): + # not continue training, do not have v_head weight is_resuming_training = False logging.warning( f"A {type(pretrained_model)} model is loaded from '{pretrained_model_name_or_path}', " @@ -279,6 +282,7 @@ class and the arguments that are specific to trl models. The kwargs if any([module in k for module in cls.supported_modules]): files_to_download.add(v) is_shared = True + if is_resuming_training: if is_shared: # download each file and add it to the state_dict