Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow user to specify huggingface link or local path to pretrained lora weights #3572

Merged
merged 18 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ludwig/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,9 @@ def on_epoch_end(self, trainer, progress_tracker, save_path):
# auto tune batch size
self._tune_batch_size(trainer, training_set, random_seed=random_seed)

if self.config_obj.adapter.pretrained_adapter_weights and trainer.config.type == "none":
trainer.model.initialize_adapter() # Load pre-trained adapter weights for inference only

# train model
if self.backend.is_coordinator():
print_boxed("TRAINING")
Expand Down
21 changes: 11 additions & 10 deletions ludwig/config_validation/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def check_llm_finetuning_output_feature_config(config: "ModelConfig"): # noqa:
if config.model_type != MODEL_LLM:
return

if config.trainer.type != "finetune":
if config.trainer.type != "finetune" and config.adapter.pretrained_adapter_weights is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be an OR? Why does specifying pretrained_adapter_weights no longer require that the first output feature be TEXT?

Or is it that we want to make it so that using the none trainer type doesn't require an output feature?

if config.trainer.type == "none":
    return

CC: @arnavgarg1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was an oversight on my part. I was trying to go through the code to see where my change might break something down the line, and I might have gotten a little overzealous.

return

if config.output_features[0].type != TEXT:
Expand All @@ -493,6 +493,9 @@ def check_llm_finetuning_trainer_config(config: "ModelConfig"): # noqa: F821
if config.model_type != MODEL_LLM:
return

if config.trainer.type == "none" and config.adapter.pretrained_adapter_weights is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be more simply

if config.trainer.type == "none":
    # The NoneTrainer for ZS is valid.
    return

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But in this case, we would load in untrained LoRA weights if pretrained adapter weights weren't specified in the config, right? Would that be a problem?

return

if config.adapter is not None and config.trainer.type != "finetune":
raise ConfigValidationError("LLM finetuning requires trainer type to be finetune.")

Expand All @@ -508,7 +511,7 @@ def check_llm_finetuning_backend_config(config: "ModelConfig"): # noqa: F821
return

# LLM finetuning is only supported by the finetune trainer type
if config.trainer.type != "finetune":
if config.trainer.type != "finetune" and config.adapter.pretrained_adapter_weights is not None:
return

# Using local backend, so skip the checks below
Expand All @@ -528,9 +531,8 @@ def check_llm_finetuning_backend_config(config: "ModelConfig"): # noqa: F821
def check_llm_finetuning_adalora_config(config: "ModelConfig"):
"""Checks that the adalora adapter is configured correctly.

It requires a set of target_modules to be specified in the config for the model. If it isn't specified by the user,
we also check against PEFT's predefined target module list for ADALORA to see if this key is present there. If
neither is true, AdaloraModel will run into issues downstream.
We check against PEFT's predefined target module list for ADALORA to see if this target_modules is present there. If
not, AdaloraModel will run into issues downstream.
"""
if config.model_type != MODEL_LLM:
return
Expand All @@ -544,10 +546,7 @@ def check_llm_finetuning_adalora_config(config: "ModelConfig"):
from peft.utils import TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING

model_config = _get_llm_model_config(config.base_model)
if (
not config.adapter.target_modules
and model_config.model_type not in TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING
):
if model_config.model_type not in TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING:
raise ConfigValidationError(
f"Adalora adapter is not supported for {model_config.model_type} model. "
f"Supported model types are: {list(TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING.keys())}. "
Expand Down Expand Up @@ -601,7 +600,9 @@ def check_llm_quantization_backend_incompatibility(config: "ModelConfig") -> Non
@register_config_check
def check_qlora_requirements(config: "ModelConfig") -> None: # noqa: F821
"""Checks that all the necessary settings are in place for QLoRA."""
if config.model_type != MODEL_LLM or config.trainer.type == "none":
if config.model_type != MODEL_LLM or (
config.trainer.type == "none" and config.adapter.pretrained_adapter_weights is not None
):
return

if config.quantization and (not config.adapter or config.adapter.type != "lora"):
Expand Down
1 change: 1 addition & 0 deletions ludwig/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@
GENERATION = "generation"
PROMPT = "prompt"
ADAPTER = "adapter"
PRETRAINED_ADAPTER_WEIGHTS = "pretrained_adapter_weights"

# CrossEntropyLoss for LLMs
IGNORE_INDEX_TOKEN_ID = -100
Expand Down
45 changes: 39 additions & 6 deletions ludwig/models/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,18 +216,51 @@ def output_feature_decoder(self) -> OutputFeature:
def initialize_adapter(self):
"""If an adapter config is provided, we want to wrap the model with a PEFT model for fine-tuning."""
if self.config_obj.adapter:
if self.config_obj.trainer.type != "finetune":
if self.config_obj.trainer.type != "finetune" and not self.config_obj.adapter.pretrained_adapter_weights:
raise ValueError(
"Adapter config was provided, but trainer type is not set to `finetune`. Either set the trainer to "
"`finetune` or remove the adapter config."
)

from peft import get_peft_model, TaskType
from peft import get_peft_model

if self.config_obj.adapter.pretrained_adapter_weights:
logger.info(f"Using pretrained adapter weights: {self.config_obj.adapter.pretrained_adapter_weights}")
# If pretrained adapter weights are provided, we want to load them into the model
from peft import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PeftConfig

peft_config = PeftConfig.from_pretrained(self.config_obj.adapter.pretrained_adapter_weights)
peft_dict = peft_config.to_dict()

# Need to update the peft config with some of the values from config_obj because not all of them are set
for param_name, param_value in self.config_obj.adapter.to_config().to_dict().items():
jeffkinnison marked this conversation as resolved.
Show resolved Hide resolved
# Not all parameters are supported by all models, so we only add the parameter to the load kwargs
# if it is supported by the model.
if param_value is None:
Infernaught marked this conversation as resolved.
Show resolved Hide resolved
# param_name and param_value come from the config object and contain default
# values for the adapter. Examples of parameters with missing values might be:
# 'auto_mapping', 'base_model_name_or_path', and 'task_type'.
# Note that some of these values might already be set in peft_config, which comes from HF
# directly (specifically, adapter_config.json in the model repo), and we don't want to override
# those values with None.
continue
if param_name not in peft_dict:
# If any parameters are not set in adapter_config.json in HF, we want to populate them with the
# appropriate default values.
setattr(peft_config, param_name, param_value)

self.model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type].from_pretrained(
self.model, self.config_obj.adapter.pretrained_adapter_weights
)
else:
# If no pretrained adapter is provided, we want to load untrained weights into the model
from peft import TaskType

peft_config = self.config_obj.adapter.to_config(
task_type=TaskType.CAUSAL_LM, tokenizer_name_or_path=self.model_name
)
self.model = get_peft_model(self.model, peft_config)
peft_config = self.config_obj.adapter.to_config(
task_type=TaskType.CAUSAL_LM, tokenizer_name_or_path=self.model_name
)

self.model = get_peft_model(self.model, peft_config)

logger.info("==================================================")
logger.info("Trainable Parameter Summary For Fine-Tuning")
Expand Down
6 changes: 5 additions & 1 deletion ludwig/schema/llms/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def wrap(config: BaseAdapterConfig):
class BaseAdapterConfig(schema_utils.BaseMarshmallowConfig, ABC):
type: str

pretrained_adapter_weights: Optional[str] = schema_utils.String(
default=None, description="Path to pretrained weights.", allow_none=True
)

@abstractmethod
def to_config(self, **kwargs) -> "PeftConfig":
pass
Expand Down Expand Up @@ -359,7 +363,7 @@ def description(cls) -> str:
@register_adapter("adaption_prompt")
@ludwig_dataclass
class AdaptionPromptConfig(BaseAdapterConfig):
"""Adapted from https://github.com/huggingface/peft/blob/main/src/peft/tuners/adaption_prompt.py."""
"""Adapted from https://github.com/huggingface/peft/blob/main/src/peft/tuners/adaption_prompt/config.py."""

def __post_init__(self):
if not self.adapter_len:
Expand Down
51 changes: 51 additions & 0 deletions tests/integration_tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
MODEL_TYPE,
OUTPUT_FEATURES,
PREPROCESSING,
PRETRAINED_ADAPTER_WEIGHTS,
PROMPT,
TRAINER,
TYPE,
Expand Down Expand Up @@ -492,12 +493,62 @@ def test_default_max_sequence_length():
BATCH_SIZE: 8,
EPOCHS: 2,
},
ADAPTER: {TYPE: "lora", PRETRAINED_ADAPTER_WEIGHTS: "Infernaught/test_adapter_weights"},
BACKEND: {TYPE: "local"},
}
config_obj = ModelConfig.from_dict(config)
assert config_obj.input_features[0].preprocessing.max_sequence_length is None
assert config_obj.output_features[0].preprocessing.max_sequence_length is None


@pytest.mark.parametrize("adapter", ["lora", "adalora", "adaption_prompt"])
def test_load_pretrained_adapter_weights(adapter):
from peft import PeftModel
from transformers import PreTrainedModel

print(f"ADAPTER: {adapter}")
Infernaught marked this conversation as resolved.
Show resolved Hide resolved
weights = ""
model = ""
if adapter == "lora":
weights = "Infernaught/test_adapter_weights"
base_model = TEST_MODEL_NAME
elif adapter == "adalora":
weights = "Infernaught/test_adalora_weights"
base_model = "HuggingFaceH4/tiny-random-LlamaForCausalLM"
elif adapter == "adaption_prompt":
weights = "Infernaught/test_ap_weights"
base_model = "HuggingFaceH4/tiny-random-LlamaForCausalLM"
else:
raise ()

config = {
MODEL_TYPE: MODEL_LLM,
BASE_MODEL: base_model,
INPUT_FEATURES: [text_feature(name="input", encoder={"type": "passthrough"})],
OUTPUT_FEATURES: [text_feature(name="output")],
TRAINER: {
TYPE: "none",
BATCH_SIZE: 8,
EPOCHS: 2,
},
ADAPTER: {TYPE: adapter, PRETRAINED_ADAPTER_WEIGHTS: weights},
BACKEND: {TYPE: "local"},
}
config_obj = ModelConfig.from_dict(config)
model = LLM(config_obj)

assert model.config_obj.adapter.pretrained_adapter_weights
assert model.config_obj.adapter.pretrained_adapter_weights == weights

model.prepare_for_training()
assert not isinstance(model.model, PreTrainedModel)
assert isinstance(model.model, PeftModel)

config_obj = ModelConfig.from_dict(config)
assert config_obj.input_features[0].preprocessing.max_sequence_length is None
assert config_obj.output_features[0].preprocessing.max_sequence_length is None


def _compare_models(model_1: torch.nn.Module, model_2: torch.nn.Module) -> bool:
# Source: https://discuss.pytorch.org/t/check-if-models-have-same-weights/4351/6
for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
Expand Down