Skip to content

Commit

Permalink
when from_pretrained is called in finetune of lora with flag "is_trai…
Browse files Browse the repository at this point in the history
…nable" True, should not call model.eval()
  • Loading branch information
sywangyi committed Jun 16, 2023
1 parent 38e9c65 commit 92e1752
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def from_pretrained(cls, model, model_id, adapter_name="default", is_trainable=F
model = cls(model, config, adapter_name)
else:
model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type](model, config, adapter_name)
model.load_adapter(model_id, adapter_name, **kwargs)
model.load_adapter(model_id, adapter_name, is_trainable=is_trainable, **kwargs)
return model

def _setup_prompt_encoder(self, adapter_name):
Expand Down Expand Up @@ -508,7 +508,8 @@ def load_adapter(self, model_id, adapter_name, is_trainable=False, **kwargs):
add_hook_to_module(self.get_base_model(), hook)

# Set model in evaluation mode to deactivate Dropout modules by default
self.eval()
if not is_trainable:
self.eval()
return load_result

def set_adapter(self, adapter_name):
Expand Down

0 comments on commit 92e1752

Please sign in to comment.