diff --git a/tests/integration_tests/test_llm.py b/tests/integration_tests/test_llm.py index 09bdfe3915a..daed5d20931 100644 --- a/tests/integration_tests/test_llm.py +++ b/tests/integration_tests/test_llm.py @@ -1209,12 +1209,22 @@ def llm_encoder_config() -> dict[str, Any]: @pytest.mark.parametrize( "adapter,quantization", - [(None, None), ("lora", None), ("lora", {"bits": 4}), ("lora", {"bits": 8})], - ids=["FFT", "LoRA", "LoRA 4-bit", "LoRA 8-bit"], + [ + (None, None), + ("lora", None), + ("lora", {"bits": 4}), + ("lora", {"bits": 8}), + ("adalora", None), + ("adalora", {"bits": 4}), + ("adalora", {"bits": 8}), + ], + ids=["FFT", "LoRA", "LoRA 4-bit", "LoRA 8-bit", "AdaLoRA", "AdaLoRA 4-bit", "AdaLoRA 8-bit"], ) def test_llm_encoding(llm_encoder_config, adapter, quantization, tmpdir): if ( - _finetune_strategy_requires_cuda(finetune_strategy_name=adapter, quantization_args=quantization) + _finetune_strategy_requires_cuda( + finetune_strategy_name="lora" if adapter else None, quantization_args=quantization + ) and not (torch.cuda.is_available() and torch.cuda.device_count()) > 0 ): pytest.skip("Skip: quantization requires GPU and none are available.") diff --git a/tests/ludwig/encoders/test_llm_encoders.py b/tests/ludwig/encoders/test_llm_encoders.py index 8b5f6b1faee..d7c9121b388 100644 --- a/tests/ludwig/encoders/test_llm_encoders.py +++ b/tests/ludwig/encoders/test_llm_encoders.py @@ -7,11 +7,14 @@ from ludwig.encoders.text_encoders import LLMEncoder from ludwig.schema.encoders.text_encoders import LLMEncoderConfig -from ludwig.schema.llms.peft import BaseAdapterConfig, LoraConfig +from ludwig.schema.llms.peft import AdaloraConfig, BaseAdapterConfig, LoraConfig from ludwig.utils.llm_utils import get_context_len # Mapping of adapter types to test against and their respective config objects. -ADAPTER_CONFIG_MAP = {"lora": LoraConfig} +ADAPTER_CONFIG_MAP = { + "lora": LoraConfig, + "adalora": AdaloraConfig, +} @pytest.fixture() @@ -58,13 +61,30 @@ def create_encoder_config_with_adapter( new_config.adapter = ADAPTER_CONFIG_MAP[adapter](**kwargs) return new_config + def adapter_param_name_prefix(self, adapter: str) -> str: + """Get the PEFT paramter name prefix for a given adapter type. + + Args: + adapter: A valid config value for `adapter.type` + + Returns: + The PEFT-applied prefix for the adapter's parameter names. + + Raises: + KeyError: raised when the provided adapter name is not valid for LLMEncoder. + """ + return LLMEncoder.ADAPTER_PARAM_NAME_PREFIX[adapter] + def test_init(self, encoder_config: LLMEncoderConfig, model_config): # Test initializing without an adapter encoder = LLMEncoder(encoder_config=encoder_config) assert encoder.model_name == encoder_config.base_model assert isinstance(encoder.model, PreTrainedModel) - assert all(map(lambda k: "lora_" not in k, encoder.state_dict().keys())) # Check adapter was not initialized + # Check adapter was not initialized + for k in ADAPTER_CONFIG_MAP.keys(): + prefix = self.adapter_param_name_prefix(k) + assert all(map(lambda k: prefix not in k, encoder.state_dict().keys())) assert encoder.input_shape == torch.Size([encoder_config.max_sequence_length]) assert encoder.output_shape == torch.Size([encoder_config.max_sequence_length, model_config.hidden_size]) @@ -77,7 +97,10 @@ def test_init(self, encoder_config: LLMEncoderConfig, model_config): assert encoder.model_name == encoder_config.base_model assert isinstance(encoder.model, PreTrainedModel) - assert all(map(lambda k: "lora_" not in k, encoder.state_dict().keys())) # Check adapter was not initialized + # Check adapter was not initialized + for k in ADAPTER_CONFIG_MAP.keys(): + prefix = self.adapter_param_name_prefix(k) + assert all(map(lambda k: prefix not in k, encoder.state_dict().keys())) assert encoder.input_shape == torch.Size([context_len]) assert encoder.output_shape == torch.Size([context_len, model_config.hidden_size]) @@ -87,10 +110,11 @@ def test_init_with_adapter(self, encoder_config: LLMEncoderConfig, adapter: str, encoder_config_with_adapter = self.create_encoder_config_with_adapter(encoder_config, adapter) encoder = LLMEncoder(encoder_config=encoder_config_with_adapter) + prefix = self.adapter_param_name_prefix(adapter) # The adapter should not be initialized until `prepare_for_training` is called assert not isinstance(encoder.model, PeftModel) - assert not any(map(lambda k: "lora_" in k, encoder.state_dict().keys())) + assert not any(map(lambda k: prefix in k, encoder.state_dict().keys())) assert encoder.model_name == encoder_config.base_model assert encoder.input_shape == torch.Size([encoder_config.max_sequence_length]) @@ -102,31 +126,36 @@ def test_prepare_for_training(self, encoder_config: LLMEncoderConfig, adapter: s encoder_config_with_adapter = self.create_encoder_config_with_adapter(encoder_config, adapter) encoder = LLMEncoder(encoder_config=encoder_config_with_adapter) + prefix = self.adapter_param_name_prefix(adapter) # The adapter should not be initialized until `prepare_for_training` is called assert not isinstance(encoder.model, PeftModel) - assert not any(map(lambda k: "lora_" in k, encoder.state_dict().keys())) + assert not any(map(lambda k: prefix in k, encoder.state_dict().keys())) # Initialize the adapter encoder.prepare_for_training() # At this point, the adapter should be initialized and the state dict should contain adapter parameters assert isinstance(encoder.model, PeftModel) - assert any(map(lambda k: "lora_" in k, encoder.state_dict().keys())) + assert any(map(lambda k: prefix in k, encoder.state_dict().keys())) def test_save_to_state_dict(self, encoder_config: LLMEncoderConfig, tmpdir): # With no adapter, the state dict should only contain the model parameters encoder = LLMEncoder(encoder_config=encoder_config) - assert all(map(lambda k: "lora_" not in k, encoder.state_dict().keys())) + # Check adapter was not initialized + for k in ADAPTER_CONFIG_MAP.keys(): + prefix = self.adapter_param_name_prefix(k) + assert all(map(lambda k: prefix not in k, encoder.state_dict().keys())) @pytest.mark.parametrize("adapter", list(ADAPTER_CONFIG_MAP.keys())) def test_save_to_state_dict_adapter(self, encoder_config: LLMEncoderConfig, adapter: str, tmpdir): # With an adapter, the state dict should only contain adapter parameters encoder_config_with_adapter = self.create_encoder_config_with_adapter(encoder_config, adapter) encoder = LLMEncoder(encoder_config=encoder_config_with_adapter) + prefix = self.adapter_param_name_prefix(adapter) # Initialize the adapters encoder.prepare_for_training() - assert all(map(lambda k: "lora_" in k, encoder.state_dict().keys())) + assert all(map(lambda k: prefix in k, encoder.state_dict().keys())) @pytest.mark.parametrize("wrap", [False, True], ids=["no_wrapper", "with_wrapper"]) def test_load_from_state_dict(self, encoder_config: LLMEncoderConfig, wrap: bool): @@ -164,6 +193,8 @@ def weights_init(m): if hasattr(m, "weight") and m.weight.ndim > 1: torch.nn.init.xavier_uniform_(m.weight.data) + prefix = self.adapter_param_name_prefix(adapter) + # Update the config with an adapter encoder_config_with_adapter = self.create_encoder_config_with_adapter(encoder_config, adapter) @@ -183,8 +214,8 @@ def weights_init(m): encoder1_sd = encoder1.state_dict() encoder2_sd = encoder2.state_dict() - adapter_keys = [k for k in encoder1_sd.keys() if "lora_" in k and "weight" in k] - model_keys = [k for k in encoder1_sd.keys() if "lora_" not in k] + adapter_keys = [k for k in encoder1_sd.keys() if prefix in k and "weight" in k] + model_keys = [k for k in encoder1_sd.keys() if prefix not in k] # The LoRA weights should no longer be equal assert all(map(lambda k: not torch.equal(encoder1_sd[k], encoder2_sd[k]), adapter_keys))