Skip to content

Commit

Permalink
Clarify docstrings, help messages, assert messages in merge_peft_adap…
Browse files Browse the repository at this point in the history
…ter.py (huggingface#838)

An assertion was also corrected to the intended test condition
  • Loading branch information
larekrow authored and Andrew Lapp committed May 10, 2024
1 parent 0b45b3c commit aa4fc81
Showing 1 changed file with 8 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,24 @@
@dataclass
class ScriptArguments:
"""
The name of the Casual LM model we wish to fine with PPO
The input names representing the Adapter and Base model fine-tuned with PEFT, and the output name representing the
merged model.
"""

adapter_model_name: Optional[str] = field(default=None, metadata={"help": "the model name"})
base_model_name: Optional[str] = field(default=None, metadata={"help": "the model name"})
output_name: Optional[str] = field(default=None, metadata={"help": "the model name"})
adapter_model_name: Optional[str] = field(default=None, metadata={"help": "the adapter name"})
base_model_name: Optional[str] = field(default=None, metadata={"help": "the base model name"})
output_name: Optional[str] = field(default=None, metadata={"help": "the merged model name"})


parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
assert script_args.adapter_model_name is not None, "please provide the name of the Adapter you would like to merge"
assert script_args.base_model_name is not None, "please provide the name of the Base model"
assert script_args.base_model_name is not None, "please provide the output name of the merged model"
assert script_args.output_name is not None, "please provide the output name of the merged model"

peft_config = PeftConfig.from_pretrained(script_args.adapter_model_name)
if peft_config.task_type == "SEQ_CLS":
# peft is for reward model so load sequence classification
# The sequence classification task is used for the reward model in PPO
model = AutoModelForSequenceClassification.from_pretrained(
script_args.base_model_name, num_labels=1, torch_dtype=torch.bfloat16
)
Expand All @@ -36,7 +37,7 @@ class ScriptArguments:

tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name)

# Load the Lora model
# Load the PEFT model
model = PeftModel.from_pretrained(model, script_args.adapter_model_name)
model.eval()

Expand Down

0 comments on commit aa4fc81

Please sign in to comment.