From 4ee94bd1f1e6b9d8e0ea0bf41236e2311d4bb4be Mon Sep 17 00:00:00 2001 From: Infernaught <89timw@gmail.com> Date: Thu, 31 Aug 2023 17:04:29 -0400 Subject: [PATCH 01/17] Add functionality for pretrained lora weights --- ludwig/constants.py | 1 + ludwig/models/llm.py | 34 ++++++++++++++++++++++++----- ludwig/schema/llms/peft.py | 12 ++++++++++ ludwig/trainers/trainer.py | 1 + tests/integration_tests/test_llm.py | 31 ++++++++++++++++++++++++++ 5 files changed, 74 insertions(+), 5 deletions(-) diff --git a/ludwig/constants.py b/ludwig/constants.py index 87a7594b805..be1fcaecf91 100644 --- a/ludwig/constants.py +++ b/ludwig/constants.py @@ -281,6 +281,7 @@ GENERATION = "generation" PROMPT = "prompt" ADAPTER = "adapter" +PRETRAINED_WEIGHTS = "pretrained_weights" # CrossEntropyLoss for LLMs IGNORE_INDEX_TOKEN_ID = -100 diff --git a/ludwig/models/llm.py b/ludwig/models/llm.py index f26228413fb..6316c5527e7 100644 --- a/ludwig/models/llm.py +++ b/ludwig/models/llm.py @@ -221,12 +221,36 @@ def initialize_adapter(self): "`finetune` or remove the adapter config." ) - from peft import get_peft_model, TaskType + from peft import get_peft_model + + pretrained = False + if self.config_obj.adapter.pretrained_weights: + print(f"PRETRAINED_WEIGHTS: {self.config_obj.adapter.pretrained_weights}") + # If pretrained adapter weights are provided, we want to load them into the model + from peft import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PeftConfig + + pretrained = True + peft_config = PeftConfig.from_pretrained(self.config_obj.adapter.pretrained_weights) + peft_dict = peft_config.to_dict() + for param_name, param_value in self.config_obj.adapter.to_config().to_dict().items(): + if param_name is None: + continue + + if param_name not in peft_dict: + setattr(peft_config, param_name, param_value) + + self.model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type].from_pretrained( + self.model, self.config_obj.adapter.pretrained_weights + ) + else: + # If no pretrained adapter is provided, we want to load untrained weights into the model + from peft import TaskType - peft_config = self.config_obj.adapter.to_config( - task_type=TaskType.CAUSAL_LM, tokenizer_name_or_path=self.model_name - ) - self.model = get_peft_model(self.model, peft_config) + peft_config = self.config_obj.adapter.to_config( + task_type=TaskType.CAUSAL_LM, tokenizer_name_or_path=self.model_name + ) + + self.model = get_peft_model(self.model, peft_config, pretrained=pretrained) logger.info("==================================================") logger.info("Trainable Parameter Summary For Fine-Tuning") diff --git a/ludwig/schema/llms/peft.py b/ludwig/schema/llms/peft.py index 6e127ee5fbf..34e11b21ab0 100644 --- a/ludwig/schema/llms/peft.py +++ b/ludwig/schema/llms/peft.py @@ -69,6 +69,18 @@ class LoraConfig(BaseAdapterConfig): description="Bias type for Lora.", ) + pretrained_weights: Optional[str] = schema_utils.String( + default="none", + description="Path to pretrained weights for Lora.", + ) + + target_modules: Optional[list] = schema_utils.List( + str, + default=None, + allow_none=True, + description="List of modules to apply Lora to. If None, apply to all modules.", + ) + def to_config(self, task_type: str = None, **kwargs) -> "PeftConfig": from peft import LoraConfig as _LoraConfig diff --git a/ludwig/trainers/trainer.py b/ludwig/trainers/trainer.py index 7e6bb3d1294..f370c7b5d52 100644 --- a/ludwig/trainers/trainer.py +++ b/ludwig/trainers/trainer.py @@ -266,6 +266,7 @@ def closure(): targets, model_outputs, self.regularization_type, self.regularization_lambda ) loss = loss / self.gradient_accumulation_steps + loss.requires_grad = True # Begin the backward pass variables = self.dist_model.parameters() diff --git a/tests/integration_tests/test_llm.py b/tests/integration_tests/test_llm.py index 9644a9e95a2..392861e5f44 100644 --- a/tests/integration_tests/test_llm.py +++ b/tests/integration_tests/test_llm.py @@ -18,6 +18,7 @@ MODEL_TYPE, OUTPUT_FEATURES, PREPROCESSING, + PRETRAINED_WEIGHTS, PROMPT, TRAINER, TYPE, @@ -481,6 +482,36 @@ def test_llama_rope_scaling(): assert model.model.config.rope_scaling["factor"] == 2.0 +def test_load_pretrained_adapter_weights(): + from peft import PeftModel + from transformers import PreTrainedModel + + config = { + MODEL_TYPE: MODEL_LLM, + BASE_MODEL: TEST_MODEL_NAME, + INPUT_FEATURES: [text_feature(name="input", encoder={"type": "passthrough"})], + OUTPUT_FEATURES: [text_feature(name="output")], + TRAINER: { + TYPE: "finetune", + BATCH_SIZE: 8, + EPOCHS: 2, + }, + ADAPTER: {TYPE: "lora", PRETRAINED_WEIGHTS: "Infernaught/test_adapter_weights"}, + BACKEND: {TYPE: "local"}, + } + + print(ModelConfig) + config_obj = ModelConfig.from_dict(config) + model = LLM(config_obj) + + assert model.config_obj.adapter.pretrained_weights + assert model.config_obj.adapter.pretrained_weights == "Infernaught/test_adapter_weights" + + model.prepare_for_training() + assert not isinstance(model.model, PreTrainedModel) + assert isinstance(model.model, PeftModel) + + def _compare_models(model_1: torch.nn.Module, model_2: torch.nn.Module) -> bool: # Source: https://discuss.pytorch.org/t/check-if-models-have-same-weights/4351/6 for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()): From 1c007243544db6b998ec4291c6af77e964b21f6a Mon Sep 17 00:00:00 2001 From: Infernaught <89timw@gmail.com> Date: Thu, 31 Aug 2023 18:18:59 -0400 Subject: [PATCH 02/17] Address PR comments --- ludwig/models/llm.py | 9 ++++----- ludwig/schema/llms/peft.py | 3 +-- ludwig/trainers/trainer.py | 5 ++++- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/ludwig/models/llm.py b/ludwig/models/llm.py index 24c4c2a2411..4f784e22d3c 100644 --- a/ludwig/models/llm.py +++ b/ludwig/models/llm.py @@ -224,19 +224,18 @@ def initialize_adapter(self): from peft import get_peft_model - pretrained = False if self.config_obj.adapter.pretrained_weights: - print(f"PRETRAINED_WEIGHTS: {self.config_obj.adapter.pretrained_weights}") + logger.info(f"Using pretrained weights: {self.config_obj.adapter.pretrained_weights}") # If pretrained adapter weights are provided, we want to load them into the model from peft import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PeftConfig - pretrained = True peft_config = PeftConfig.from_pretrained(self.config_obj.adapter.pretrained_weights) peft_dict = peft_config.to_dict() + + # Need to update the peft config with some of the values from config_obj because not all of them are set for param_name, param_value in self.config_obj.adapter.to_config().to_dict().items(): if param_name is None: continue - if param_name not in peft_dict: setattr(peft_config, param_name, param_value) @@ -251,7 +250,7 @@ def initialize_adapter(self): task_type=TaskType.CAUSAL_LM, tokenizer_name_or_path=self.model_name ) - self.model = get_peft_model(self.model, peft_config, pretrained=pretrained) + self.model = get_peft_model(self.model, peft_config) logger.info("==================================================") logger.info("Trainable Parameter Summary For Fine-Tuning") diff --git a/ludwig/schema/llms/peft.py b/ludwig/schema/llms/peft.py index 34e11b21ab0..bf03b8c92e3 100644 --- a/ludwig/schema/llms/peft.py +++ b/ludwig/schema/llms/peft.py @@ -70,8 +70,7 @@ class LoraConfig(BaseAdapterConfig): ) pretrained_weights: Optional[str] = schema_utils.String( - default="none", - description="Path to pretrained weights for Lora.", + default=None, description="Path to pretrained weights for Lora.", allow_none=True ) target_modules: Optional[list] = schema_utils.List( diff --git a/ludwig/trainers/trainer.py b/ludwig/trainers/trainer.py index 0013627e88d..2964d214151 100644 --- a/ludwig/trainers/trainer.py +++ b/ludwig/trainers/trainer.py @@ -272,7 +272,10 @@ def closure(): targets, model_outputs, self.regularization_type, self.regularization_lambda ) loss = loss / self.gradient_accumulation_steps - loss.requires_grad = True + try: + loss.requires_grad = True + except RuntimeError: + pass # Begin the backward pass variables = self.dist_model.parameters() From 650e5e6b360fff0fbbb795cf22b1d36179633f1b Mon Sep 17 00:00:00 2001 From: Infernaught <89timw@gmail.com> Date: Wed, 6 Sep 2023 19:47:43 -0400 Subject: [PATCH 03/17] Address PR comments --- ludwig/constants.py | 2 +- ludwig/models/llm.py | 12 +++++++----- ludwig/schema/llms/peft.py | 2 +- tests/integration_tests/test_llm.py | 10 +++++----- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/ludwig/constants.py b/ludwig/constants.py index f27a3bf7353..d2cc455df24 100644 --- a/ludwig/constants.py +++ b/ludwig/constants.py @@ -282,7 +282,7 @@ GENERATION = "generation" PROMPT = "prompt" ADAPTER = "adapter" -PRETRAINED_WEIGHTS = "pretrained_weights" +PRETRAINED_ADAPTER_WEIGHTS = "pretrained_adapter_weights" # CrossEntropyLoss for LLMs IGNORE_INDEX_TOKEN_ID = -100 diff --git a/ludwig/models/llm.py b/ludwig/models/llm.py index 4f784e22d3c..8c4f0685713 100644 --- a/ludwig/models/llm.py +++ b/ludwig/models/llm.py @@ -224,23 +224,25 @@ def initialize_adapter(self): from peft import get_peft_model - if self.config_obj.adapter.pretrained_weights: - logger.info(f"Using pretrained weights: {self.config_obj.adapter.pretrained_weights}") + if self.config_obj.adapter.pretrained_adapter_weights: + logger.info(f"Using pretrained adapter weights: {self.config_obj.adapter.pretrained_adapter_weights}") # If pretrained adapter weights are provided, we want to load them into the model from peft import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PeftConfig - peft_config = PeftConfig.from_pretrained(self.config_obj.adapter.pretrained_weights) + peft_config = PeftConfig.from_pretrained(self.config_obj.adapter.pretrained_adapter_weights) peft_dict = peft_config.to_dict() # Need to update the peft config with some of the values from config_obj because not all of them are set for param_name, param_value in self.config_obj.adapter.to_config().to_dict().items(): - if param_name is None: + # Not all parameters are supported by all models, so we only add the parameter to the load kwargs + # if it is supported by the model. + if param_value is None: continue if param_name not in peft_dict: setattr(peft_config, param_name, param_value) self.model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type].from_pretrained( - self.model, self.config_obj.adapter.pretrained_weights + self.model, self.config_obj.adapter.pretrained_adapter_weights ) else: # If no pretrained adapter is provided, we want to load untrained weights into the model diff --git a/ludwig/schema/llms/peft.py b/ludwig/schema/llms/peft.py index bf03b8c92e3..282e5e3fa54 100644 --- a/ludwig/schema/llms/peft.py +++ b/ludwig/schema/llms/peft.py @@ -69,7 +69,7 @@ class LoraConfig(BaseAdapterConfig): description="Bias type for Lora.", ) - pretrained_weights: Optional[str] = schema_utils.String( + pretrained_adapter_weights: Optional[str] = schema_utils.String( default=None, description="Path to pretrained weights for Lora.", allow_none=True ) diff --git a/tests/integration_tests/test_llm.py b/tests/integration_tests/test_llm.py index 33830e1f660..c247097fb3e 100644 --- a/tests/integration_tests/test_llm.py +++ b/tests/integration_tests/test_llm.py @@ -18,7 +18,7 @@ MODEL_TYPE, OUTPUT_FEATURES, PREPROCESSING, - PRETRAINED_WEIGHTS, + PRETRAINED_ADAPTER_WEIGHTS, PROMPT, TRAINER, TYPE, @@ -493,7 +493,7 @@ def test_default_max_sequence_length(): BATCH_SIZE: 8, EPOCHS: 2, }, - ADAPTER: {TYPE: "lora", PRETRAINED_WEIGHTS: "Infernaught/test_adapter_weights"}, + ADAPTER: {TYPE: "lora", PRETRAINED_ADAPTER_WEIGHTS: "Infernaught/test_adapter_weights"}, BACKEND: {TYPE: "local"}, } config_obj = ModelConfig.from_dict(config) @@ -515,14 +515,14 @@ def test_load_pretrained_adapter_weights(): BATCH_SIZE: 8, EPOCHS: 2, }, - ADAPTER: {TYPE: "lora", PRETRAINED_WEIGHTS: "Infernaught/test_adapter_weights"}, + ADAPTER: {TYPE: "lora", PRETRAINED_ADAPTER_WEIGHTS: "Infernaught/test_adapter_weights"}, BACKEND: {TYPE: "local"}, } config_obj = ModelConfig.from_dict(config) model = LLM(config_obj) - assert model.config_obj.adapter.pretrained_weights - assert model.config_obj.adapter.pretrained_weights == "Infernaught/test_adapter_weights" + assert model.config_obj.adapter.pretrained_adapter_weights + assert model.config_obj.adapter.pretrained_adapter_weights == "Infernaught/test_adapter_weights" model.prepare_for_training() assert not isinstance(model.model, PreTrainedModel) From c18e9e8d14acab006cb30b040b04b24ba79d5ed0 Mon Sep 17 00:00:00 2001 From: Infernaught <89timw@gmail.com> Date: Thu, 7 Sep 2023 18:12:41 -0400 Subject: [PATCH 04/17] Add comments explaining PEFT config update --- ludwig/models/llm.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/ludwig/models/llm.py b/ludwig/models/llm.py index 8c4f0685713..3894b694ddf 100644 --- a/ludwig/models/llm.py +++ b/ludwig/models/llm.py @@ -237,8 +237,16 @@ def initialize_adapter(self): # Not all parameters are supported by all models, so we only add the parameter to the load kwargs # if it is supported by the model. if param_value is None: + # param_name and param_value come from the config object and contain default + # values for the adapter. Examples of parameters with missing values might be: + # 'auto_mapping', 'base_model_name_or_path', and 'task_type'. + # Note that some of these values might already be set in peft_config, which comes from HF + # directly (specifically, adapter_config.json in the model repo), and we don't want to override + # those values with None. continue if param_name not in peft_dict: + # If any parameters are not set in adapter_config.json in HF, we want to populate them with the + # appropriate default values. setattr(peft_config, param_name, param_value) self.model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type].from_pretrained( From d92190ae6b6e653e5621a3be1c2e33aed7b9b924 Mon Sep 17 00:00:00 2001 From: Infernaught <89timw@gmail.com> Date: Fri, 8 Sep 2023 10:25:35 -0400 Subject: [PATCH 05/17] Allow for adalora and adaption prompt loading --- ludwig/schema/llms/peft.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/ludwig/schema/llms/peft.py b/ludwig/schema/llms/peft.py index 282e5e3fa54..b6d5af6eed3 100644 --- a/ludwig/schema/llms/peft.py +++ b/ludwig/schema/llms/peft.py @@ -30,6 +30,17 @@ def wrap(config: BaseAdapterConfig): class BaseAdapterConfig(schema_utils.BaseMarshmallowConfig, ABC): type: str + pretrained_adapter_weights: Optional[str] = schema_utils.String( + default=None, description="Path to pretrained weights.", allow_none=True + ) + + target_modules: Optional[list] = schema_utils.List( + str, + default=None, + allow_none=True, + description="List of modules to apply adapter to. If None, apply to all modules.", + ) + @abstractmethod def to_config(self, **kwargs) -> "PeftConfig": pass @@ -69,17 +80,6 @@ class LoraConfig(BaseAdapterConfig): description="Bias type for Lora.", ) - pretrained_adapter_weights: Optional[str] = schema_utils.String( - default=None, description="Path to pretrained weights for Lora.", allow_none=True - ) - - target_modules: Optional[list] = schema_utils.List( - str, - default=None, - allow_none=True, - description="List of modules to apply Lora to. If None, apply to all modules.", - ) - def to_config(self, task_type: str = None, **kwargs) -> "PeftConfig": from peft import LoraConfig as _LoraConfig @@ -370,7 +370,7 @@ def description(cls) -> str: @register_adapter("adaption_prompt") @ludwig_dataclass class AdaptionPromptConfig(BaseAdapterConfig): - """Adapted from https://github.com/huggingface/peft/blob/main/src/peft/tuners/adaption_prompt.py.""" + """Adapted from https://github.com/huggingface/peft/blob/main/src/peft/tuners/adaption_prompt/config.py.""" def __post_init__(self): if not self.adapter_len: From 1fbd1b27d586d7711cdca4e11fc52f70ef691949 Mon Sep 17 00:00:00 2001 From: Infernaught <89timw@gmail.com> Date: Fri, 8 Sep 2023 12:18:40 -0400 Subject: [PATCH 06/17] Explain target_modules --- ludwig/schema/llms/peft.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ludwig/schema/llms/peft.py b/ludwig/schema/llms/peft.py index b6d5af6eed3..29c8a94b7f6 100644 --- a/ludwig/schema/llms/peft.py +++ b/ludwig/schema/llms/peft.py @@ -34,6 +34,8 @@ class BaseAdapterConfig(schema_utils.BaseMarshmallowConfig, ABC): default=None, description="Path to pretrained weights.", allow_none=True ) + # This is here for now to address "AttributeError: 'AdaloraConfig' object has no attribute 'target_modules'". Will + # continue investigating target_modules: Optional[list] = schema_utils.List( str, default=None, From a8b674fe5891e2a4929b89f44c986fb63c9de5d2 Mon Sep 17 00:00:00 2001 From: Infernaught <89timw@gmail.com> Date: Fri, 8 Sep 2023 12:26:22 -0400 Subject: [PATCH 07/17] Add tests for adalora and adaption prompt --- tests/integration_tests/test_llm.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/tests/integration_tests/test_llm.py b/tests/integration_tests/test_llm.py index c247097fb3e..1908a3e3895 100644 --- a/tests/integration_tests/test_llm.py +++ b/tests/integration_tests/test_llm.py @@ -501,13 +501,29 @@ def test_default_max_sequence_length(): assert config_obj.output_features[0].preprocessing.max_sequence_length is None -def test_load_pretrained_adapter_weights(): +@pytest.mark.parametrize("adapter", ["lora", "adalora", "adaption_prompt"]) +def test_load_pretrained_adapter_weights(adapter): from peft import PeftModel from transformers import PreTrainedModel + print(f"ADAPTER: {adapter}") + weights = "" + model = "" + if adapter == "lora": + weights = "Infernaught/test_adapter_weights" + base_model = TEST_MODEL_NAME + elif adapter == "adalora": + weights = "Infernaught/test_adalora_weights" + base_model = "HuggingFaceH4/tiny-random-LlamaForCausalLM" + elif adapter == "adaption_prompt": + weights = "Infernaught/test_ap_weights" + base_model = "HuggingFaceH4/tiny-random-LlamaForCausalLM" + else: + raise () + config = { MODEL_TYPE: MODEL_LLM, - BASE_MODEL: TEST_MODEL_NAME, + BASE_MODEL: base_model, INPUT_FEATURES: [text_feature(name="input", encoder={"type": "passthrough"})], OUTPUT_FEATURES: [text_feature(name="output")], TRAINER: { @@ -515,14 +531,14 @@ def test_load_pretrained_adapter_weights(): BATCH_SIZE: 8, EPOCHS: 2, }, - ADAPTER: {TYPE: "lora", PRETRAINED_ADAPTER_WEIGHTS: "Infernaught/test_adapter_weights"}, + ADAPTER: {TYPE: adapter, PRETRAINED_ADAPTER_WEIGHTS: weights}, BACKEND: {TYPE: "local"}, } config_obj = ModelConfig.from_dict(config) model = LLM(config_obj) assert model.config_obj.adapter.pretrained_adapter_weights - assert model.config_obj.adapter.pretrained_adapter_weights == "Infernaught/test_adapter_weights" + assert model.config_obj.adapter.pretrained_adapter_weights == weights model.prepare_for_training() assert not isinstance(model.model, PreTrainedModel) From d3d66efa2bf76e9a4e1ea6ee44bf33ab15d5309d Mon Sep 17 00:00:00 2001 From: Infernaught <89timw@gmail.com> Date: Fri, 8 Sep 2023 12:37:26 -0400 Subject: [PATCH 08/17] Change try-except to conditional for transparency --- ludwig/trainers/trainer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ludwig/trainers/trainer.py b/ludwig/trainers/trainer.py index 2964d214151..9f423f24839 100644 --- a/ludwig/trainers/trainer.py +++ b/ludwig/trainers/trainer.py @@ -272,10 +272,11 @@ def closure(): targets, model_outputs, self.regularization_type, self.regularization_lambda ) loss = loss / self.gradient_accumulation_steps - try: + if not loss.requires_grad: + # From torch autograd docs -- "All Tensors that have requires_grad which + # is False will be leaf Tensors by convention." + # We want to avoid "RuntimeError: you can only change requires_grad flags of leaf variables."" loss.requires_grad = True - except RuntimeError: - pass # Begin the backward pass variables = self.dist_model.parameters() From 4578552022aeb7127b7fbf97ed437e849a9f9d44 Mon Sep 17 00:00:00 2001 From: Infernaught <89timw@gmail.com> Date: Fri, 8 Sep 2023 16:19:50 -0400 Subject: [PATCH 09/17] Allow pretrained weights for inference only --- ludwig/config_validation/checks.py | 13 ++++++------- ludwig/models/llm.py | 2 +- ludwig/schema/llms/peft.py | 9 --------- ludwig/trainers/trainer.py | 5 ----- tests/integration_tests/test_llm.py | 2 +- 5 files changed, 8 insertions(+), 23 deletions(-) diff --git a/ludwig/config_validation/checks.py b/ludwig/config_validation/checks.py index 9387048519a..fcfdc2b08ae 100644 --- a/ludwig/config_validation/checks.py +++ b/ludwig/config_validation/checks.py @@ -493,6 +493,9 @@ def check_llm_finetuning_trainer_config(config: "ModelConfig"): # noqa: F821 if config.model_type != MODEL_LLM: return + if config.trainer.type == "none" and config.adapter.pretrained_adapter_weights is not None: + return + if config.adapter is not None and config.trainer.type != "finetune": raise ConfigValidationError("LLM finetuning requires trainer type to be finetune.") @@ -528,9 +531,8 @@ def check_llm_finetuning_backend_config(config: "ModelConfig"): # noqa: F821 def check_llm_finetuning_adalora_config(config: "ModelConfig"): """Checks that the adalora adapter is configured correctly. - It requires a set of target_modules to be specified in the config for the model. If it isn't specified by the user, - we also check against PEFT's predefined target module list for ADALORA to see if this key is present there. If - neither is true, AdaloraModel will run into issues downstream. + We check against PEFT's predefined target module list for ADALORA to see if this target_modules is present there. If + not, AdaloraModel will run into issues downstream. """ if config.model_type != MODEL_LLM: return @@ -544,10 +546,7 @@ def check_llm_finetuning_adalora_config(config: "ModelConfig"): from peft.utils import TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING model_config = _get_llm_model_config(config.base_model) - if ( - not config.adapter.target_modules - and model_config.model_type not in TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING - ): + if model_config.model_type not in TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING: raise ConfigValidationError( f"Adalora adapter is not supported for {model_config.model_type} model. " f"Supported model types are: {list(TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING.keys())}. " diff --git a/ludwig/models/llm.py b/ludwig/models/llm.py index 3894b694ddf..408d8c89cf4 100644 --- a/ludwig/models/llm.py +++ b/ludwig/models/llm.py @@ -216,7 +216,7 @@ def output_feature_decoder(self) -> OutputFeature: def initialize_adapter(self): """If an adapter config is provided, we want to wrap the model with a PEFT model for fine-tuning.""" if self.config_obj.adapter: - if self.config_obj.trainer.type != "finetune": + if self.config_obj.trainer.type != "finetune" and not self.config_obj.adapter.pretrained_adapter_weights: raise ValueError( "Adapter config was provided, but trainer type is not set to `finetune`. Either set the trainer to " "`finetune` or remove the adapter config." diff --git a/ludwig/schema/llms/peft.py b/ludwig/schema/llms/peft.py index 29c8a94b7f6..3ce30aeb07c 100644 --- a/ludwig/schema/llms/peft.py +++ b/ludwig/schema/llms/peft.py @@ -34,15 +34,6 @@ class BaseAdapterConfig(schema_utils.BaseMarshmallowConfig, ABC): default=None, description="Path to pretrained weights.", allow_none=True ) - # This is here for now to address "AttributeError: 'AdaloraConfig' object has no attribute 'target_modules'". Will - # continue investigating - target_modules: Optional[list] = schema_utils.List( - str, - default=None, - allow_none=True, - description="List of modules to apply adapter to. If None, apply to all modules.", - ) - @abstractmethod def to_config(self, **kwargs) -> "PeftConfig": pass diff --git a/ludwig/trainers/trainer.py b/ludwig/trainers/trainer.py index 9f423f24839..40f9abbca90 100644 --- a/ludwig/trainers/trainer.py +++ b/ludwig/trainers/trainer.py @@ -272,11 +272,6 @@ def closure(): targets, model_outputs, self.regularization_type, self.regularization_lambda ) loss = loss / self.gradient_accumulation_steps - if not loss.requires_grad: - # From torch autograd docs -- "All Tensors that have requires_grad which - # is False will be leaf Tensors by convention." - # We want to avoid "RuntimeError: you can only change requires_grad flags of leaf variables."" - loss.requires_grad = True # Begin the backward pass variables = self.dist_model.parameters() diff --git a/tests/integration_tests/test_llm.py b/tests/integration_tests/test_llm.py index 1908a3e3895..0f57df04794 100644 --- a/tests/integration_tests/test_llm.py +++ b/tests/integration_tests/test_llm.py @@ -527,7 +527,7 @@ def test_load_pretrained_adapter_weights(adapter): INPUT_FEATURES: [text_feature(name="input", encoder={"type": "passthrough"})], OUTPUT_FEATURES: [text_feature(name="output")], TRAINER: { - TYPE: "finetune", + TYPE: "none", BATCH_SIZE: 8, EPOCHS: 2, }, From 4c63915d2670148bfb96d7efaa9319ccb8109769 Mon Sep 17 00:00:00 2001 From: Infernaught <89timw@gmail.com> Date: Sat, 9 Sep 2023 16:03:18 -0400 Subject: [PATCH 10/17] Load pre-trained weights when trainer type is none --- ludwig/api.py | 3 +++ ludwig/config_validation/checks.py | 8 +++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/ludwig/api.py b/ludwig/api.py index 942034cfc1c..b47d9069b5d 100644 --- a/ludwig/api.py +++ b/ludwig/api.py @@ -615,6 +615,9 @@ def on_epoch_end(self, trainer, progress_tracker, save_path): # auto tune batch size self._tune_batch_size(trainer, training_set, random_seed=random_seed) + if self.config_obj.adapter.pretrained_adapter_weights and trainer.config.type == "none": + trainer.model.initialize_adapter() # Load pre-trained adapter weights for inference only + # train model if self.backend.is_coordinator(): print_boxed("TRAINING") diff --git a/ludwig/config_validation/checks.py b/ludwig/config_validation/checks.py index fcfdc2b08ae..432b5647804 100644 --- a/ludwig/config_validation/checks.py +++ b/ludwig/config_validation/checks.py @@ -477,7 +477,7 @@ def check_llm_finetuning_output_feature_config(config: "ModelConfig"): # noqa: if config.model_type != MODEL_LLM: return - if config.trainer.type != "finetune": + if config.trainer.type != "finetune" and config.adapter.pretrained_adapter_weights is not None: return if config.output_features[0].type != TEXT: @@ -511,7 +511,7 @@ def check_llm_finetuning_backend_config(config: "ModelConfig"): # noqa: F821 return # LLM finetuning is only supported by the finetune trainer type - if config.trainer.type != "finetune": + if config.trainer.type != "finetune" and config.adapter.pretrained_adapter_weights is not None: return # Using local backend, so skip the checks below @@ -600,7 +600,9 @@ def check_llm_quantization_backend_incompatibility(config: "ModelConfig") -> Non @register_config_check def check_qlora_requirements(config: "ModelConfig") -> None: # noqa: F821 """Checks that all the necessary settings are in place for QLoRA.""" - if config.model_type != MODEL_LLM or config.trainer.type == "none": + if config.model_type != MODEL_LLM or ( + config.trainer.type == "none" and config.adapter.pretrained_adapter_weights is not None + ): return if config.quantization and (not config.adapter or config.adapter.type != "lora"): From ac439cdbe3409851e41ecafaa098d965de47cf34 Mon Sep 17 00:00:00 2001 From: Infernaught <89timw@gmail.com> Date: Mon, 11 Sep 2023 12:14:38 -0400 Subject: [PATCH 11/17] Fix check for LLM text output features --- ludwig/config_validation/checks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ludwig/config_validation/checks.py b/ludwig/config_validation/checks.py index 432b5647804..613288f32b5 100644 --- a/ludwig/config_validation/checks.py +++ b/ludwig/config_validation/checks.py @@ -477,7 +477,7 @@ def check_llm_finetuning_output_feature_config(config: "ModelConfig"): # noqa: if config.model_type != MODEL_LLM: return - if config.trainer.type != "finetune" and config.adapter.pretrained_adapter_weights is not None: + if config.trainer.type != "finetune": return if config.output_features[0].type != TEXT: From 9b94d918ed96a47b9acbbf6a28b228bda4dffdec Mon Sep 17 00:00:00 2001 From: Infernaught <89timw@gmail.com> Date: Mon, 11 Sep 2023 12:15:38 -0400 Subject: [PATCH 12/17] Remove print statement --- tests/integration_tests/test_llm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/integration_tests/test_llm.py b/tests/integration_tests/test_llm.py index 0f57df04794..7d11c5e469d 100644 --- a/tests/integration_tests/test_llm.py +++ b/tests/integration_tests/test_llm.py @@ -506,7 +506,6 @@ def test_load_pretrained_adapter_weights(adapter): from peft import PeftModel from transformers import PreTrainedModel - print(f"ADAPTER: {adapter}") weights = "" model = "" if adapter == "lora": From 9e5dfec46b98fe7a7a532db6a1b21366af753b5b Mon Sep 17 00:00:00 2001 From: Infernaught <89timw@gmail.com> Date: Mon, 11 Sep 2023 14:50:29 -0400 Subject: [PATCH 13/17] Fix finetune check and add comment for clarity --- ludwig/config_validation/checks.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ludwig/config_validation/checks.py b/ludwig/config_validation/checks.py index 613288f32b5..cb4c8e2e902 100644 --- a/ludwig/config_validation/checks.py +++ b/ludwig/config_validation/checks.py @@ -493,7 +493,12 @@ def check_llm_finetuning_trainer_config(config: "ModelConfig"): # noqa: F821 if config.model_type != MODEL_LLM: return - if config.trainer.type == "none" and config.adapter.pretrained_adapter_weights is not None: + if ( + config.trainer.type == "none" + and config.adapter is not None + and config.adapter.pretrained_adapter_weights is not None + ): + # If performing zero-shot, we must specify pretrained adapter weights return if config.adapter is not None and config.trainer.type != "finetune": From f6695bb46bf85f8756e0d5486eced417b4e76e42 Mon Sep 17 00:00:00 2001 From: Infernaught <89timw@gmail.com> Date: Mon, 11 Sep 2023 17:52:52 -0400 Subject: [PATCH 14/17] Fix backend check --- ludwig/config_validation/checks.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ludwig/config_validation/checks.py b/ludwig/config_validation/checks.py index cb4c8e2e902..a64b2cf0b11 100644 --- a/ludwig/config_validation/checks.py +++ b/ludwig/config_validation/checks.py @@ -516,7 +516,11 @@ def check_llm_finetuning_backend_config(config: "ModelConfig"): # noqa: F821 return # LLM finetuning is only supported by the finetune trainer type - if config.trainer.type != "finetune" and config.adapter.pretrained_adapter_weights is not None: + if ( + config.trainer.type != "finetune" + and config.adapter is not None + and config.adapter.pretrained_adapter_weights is not None + ): return # Using local backend, so skip the checks below From 43553534ca579cbc3b12d61d5f19522661b95125 Mon Sep 17 00:00:00 2001 From: Infernaught <89timw@gmail.com> Date: Mon, 11 Sep 2023 18:02:36 -0400 Subject: [PATCH 15/17] Fix conditional to avoid checking adapter in ECDs --- ludwig/api.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ludwig/api.py b/ludwig/api.py index b47d9069b5d..6abe4068831 100644 --- a/ludwig/api.py +++ b/ludwig/api.py @@ -615,7 +615,12 @@ def on_epoch_end(self, trainer, progress_tracker, save_path): # auto tune batch size self._tune_batch_size(trainer, training_set, random_seed=random_seed) - if self.config_obj.adapter.pretrained_adapter_weights and trainer.config.type == "none": + if ( + self.config_obj.model_type == "LLM" + and trainer.config.type == "none" + and self.config_obj.adapter is not None + and self.config_obj.adapter.pretrained_adapter_weights is not None + ): trainer.model.initialize_adapter() # Load pre-trained adapter weights for inference only # train model From 1ab20ccf1cdc9ea28ad7bd81cc19e62f9ce41fc1 Mon Sep 17 00:00:00 2001 From: Infernaught <89timw@gmail.com> Date: Mon, 11 Sep 2023 18:28:26 -0400 Subject: [PATCH 16/17] Fix qlora check --- ludwig/config_validation/checks.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ludwig/config_validation/checks.py b/ludwig/config_validation/checks.py index a64b2cf0b11..41f9a9eedfb 100644 --- a/ludwig/config_validation/checks.py +++ b/ludwig/config_validation/checks.py @@ -610,7 +610,9 @@ def check_llm_quantization_backend_incompatibility(config: "ModelConfig") -> Non def check_qlora_requirements(config: "ModelConfig") -> None: # noqa: F821 """Checks that all the necessary settings are in place for QLoRA.""" if config.model_type != MODEL_LLM or ( - config.trainer.type == "none" and config.adapter.pretrained_adapter_weights is not None + config.trainer.type == "none" + and config.adapter is not None + and config.adapter.pretrained_adapter_weights is not None ): return From effcce0df2bf0e57f914021e3e8970ffef45fae4 Mon Sep 17 00:00:00 2001 From: Infernaught <89timw@gmail.com> Date: Mon, 11 Sep 2023 18:58:43 -0400 Subject: [PATCH 17/17] Revert change to qlora check --- ludwig/config_validation/checks.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/ludwig/config_validation/checks.py b/ludwig/config_validation/checks.py index 41f9a9eedfb..02bb399486b 100644 --- a/ludwig/config_validation/checks.py +++ b/ludwig/config_validation/checks.py @@ -609,11 +609,7 @@ def check_llm_quantization_backend_incompatibility(config: "ModelConfig") -> Non @register_config_check def check_qlora_requirements(config: "ModelConfig") -> None: # noqa: F821 """Checks that all the necessary settings are in place for QLoRA.""" - if config.model_type != MODEL_LLM or ( - config.trainer.type == "none" - and config.adapter is not None - and config.adapter.pretrained_adapter_weights is not None - ): + if config.model_type != MODEL_LLM or config.trainer.type == "none": return if config.quantization and (not config.adapter or config.adapter.type != "lora"):