Skip to content

Commit

Permalink
enable adalora in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffkinnison committed Jan 17, 2024
1 parent b0795e7 commit dfa8257
Showing 1 changed file with 42 additions and 11 deletions.
53 changes: 42 additions & 11 deletions tests/ludwig/encoders/test_llm_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@

from ludwig.encoders.text_encoders import LLMEncoder
from ludwig.schema.encoders.text_encoders import LLMEncoderConfig
from ludwig.schema.llms.peft import BaseAdapterConfig, LoraConfig
from ludwig.schema.llms.peft import AdaloraConfig, BaseAdapterConfig, LoraConfig
from ludwig.utils.llm_utils import get_context_len

# Mapping of adapter types to test against and their respective config objects.
ADAPTER_CONFIG_MAP = {"lora": LoraConfig}
ADAPTER_CONFIG_MAP = {
"lora": LoraConfig,
"adalora": AdaloraConfig,
}


@pytest.fixture()
Expand Down Expand Up @@ -58,13 +61,30 @@ def create_encoder_config_with_adapter(
new_config.adapter = ADAPTER_CONFIG_MAP[adapter](**kwargs)
return new_config

def adapter_param_name_prefix(self, adapter: str) -> str:
"""Get the PEFT paramter name prefix for a given adapter type.
Args:
adapter: A valid config value for `adapter.type`
Returns:
The PEFT-applied prefix for the adapter's parameter names.
Raises:
KeyError: raised when the provided adapter name is not valid for LLMEncoder.
"""
return LLMEncoder.ADAPTER_PARAM_NAME_PREFIX[adapter]

def test_init(self, encoder_config: LLMEncoderConfig, model_config):
# Test initializing without an adapter
encoder = LLMEncoder(encoder_config=encoder_config)

assert encoder.model_name == encoder_config.base_model
assert isinstance(encoder.model, PreTrainedModel)
assert all(map(lambda k: "lora_" not in k, encoder.state_dict().keys())) # Check adapter was not initialized
# Check adapter was not initialized
for k in ADAPTER_CONFIG_MAP.keys():
prefix = self.adapter_param_name_prefix(k)
assert all(map(lambda k: prefix not in k, encoder.state_dict().keys()))
assert encoder.input_shape == torch.Size([encoder_config.max_sequence_length])
assert encoder.output_shape == torch.Size([encoder_config.max_sequence_length, model_config.hidden_size])

Expand All @@ -77,7 +97,10 @@ def test_init(self, encoder_config: LLMEncoderConfig, model_config):

assert encoder.model_name == encoder_config.base_model
assert isinstance(encoder.model, PreTrainedModel)
assert all(map(lambda k: "lora_" not in k, encoder.state_dict().keys())) # Check adapter was not initialized
# Check adapter was not initialized
for k in ADAPTER_CONFIG_MAP.keys():
prefix = self.adapter_param_name_prefix(k)
assert all(map(lambda k: prefix not in k, encoder.state_dict().keys()))
assert encoder.input_shape == torch.Size([context_len])
assert encoder.output_shape == torch.Size([context_len, model_config.hidden_size])

Expand All @@ -87,10 +110,11 @@ def test_init_with_adapter(self, encoder_config: LLMEncoderConfig, adapter: str,

encoder_config_with_adapter = self.create_encoder_config_with_adapter(encoder_config, adapter)
encoder = LLMEncoder(encoder_config=encoder_config_with_adapter)
prefix = self.adapter_param_name_prefix(adapter)

# The adapter should not be initialized until `prepare_for_training` is called
assert not isinstance(encoder.model, PeftModel)
assert not any(map(lambda k: "lora_" in k, encoder.state_dict().keys()))
assert not any(map(lambda k: prefix in k, encoder.state_dict().keys()))

assert encoder.model_name == encoder_config.base_model
assert encoder.input_shape == torch.Size([encoder_config.max_sequence_length])
Expand All @@ -102,31 +126,36 @@ def test_prepare_for_training(self, encoder_config: LLMEncoderConfig, adapter: s

encoder_config_with_adapter = self.create_encoder_config_with_adapter(encoder_config, adapter)
encoder = LLMEncoder(encoder_config=encoder_config_with_adapter)
prefix = self.adapter_param_name_prefix(adapter)

# The adapter should not be initialized until `prepare_for_training` is called
assert not isinstance(encoder.model, PeftModel)
assert not any(map(lambda k: "lora_" in k, encoder.state_dict().keys()))
assert not any(map(lambda k: prefix in k, encoder.state_dict().keys()))

# Initialize the adapter
encoder.prepare_for_training()

# At this point, the adapter should be initialized and the state dict should contain adapter parameters
assert isinstance(encoder.model, PeftModel)
assert any(map(lambda k: "lora_" in k, encoder.state_dict().keys()))
assert any(map(lambda k: prefix in k, encoder.state_dict().keys()))

def test_save_to_state_dict(self, encoder_config: LLMEncoderConfig, tmpdir):
# With no adapter, the state dict should only contain the model parameters
encoder = LLMEncoder(encoder_config=encoder_config)
assert all(map(lambda k: "lora_" not in k, encoder.state_dict().keys()))
# Check adapter was not initialized
for k in ADAPTER_CONFIG_MAP.keys():
prefix = self.adapter_param_name_prefix(k)
assert all(map(lambda k: prefix not in k, encoder.state_dict().keys()))

@pytest.mark.parametrize("adapter", list(ADAPTER_CONFIG_MAP.keys()))
def test_save_to_state_dict_adapter(self, encoder_config: LLMEncoderConfig, adapter: str, tmpdir):
# With an adapter, the state dict should only contain adapter parameters
encoder_config_with_adapter = self.create_encoder_config_with_adapter(encoder_config, adapter)
encoder = LLMEncoder(encoder_config=encoder_config_with_adapter)
prefix = self.adapter_param_name_prefix(adapter)
# Initialize the adapters
encoder.prepare_for_training()
assert all(map(lambda k: "lora_" in k, encoder.state_dict().keys()))
assert all(map(lambda k: prefix in k, encoder.state_dict().keys()))

@pytest.mark.parametrize("wrap", [False, True], ids=["no_wrapper", "with_wrapper"])
def test_load_from_state_dict(self, encoder_config: LLMEncoderConfig, wrap: bool):
Expand Down Expand Up @@ -164,6 +193,8 @@ def weights_init(m):
if hasattr(m, "weight") and m.weight.ndim > 1:
torch.nn.init.xavier_uniform_(m.weight.data)

prefix = self.adapter_param_name_prefix(adapter)

# Update the config with an adapter
encoder_config_with_adapter = self.create_encoder_config_with_adapter(encoder_config, adapter)

Expand All @@ -183,8 +214,8 @@ def weights_init(m):

encoder1_sd = encoder1.state_dict()
encoder2_sd = encoder2.state_dict()
adapter_keys = [k for k in encoder1_sd.keys() if "lora_" in k and "weight" in k]
model_keys = [k for k in encoder1_sd.keys() if "lora_" not in k]
adapter_keys = [k for k in encoder1_sd.keys() if prefix in k and "weight" in k]
model_keys = [k for k in encoder1_sd.keys() if prefix not in k]

# The LoRA weights should no longer be equal
assert all(map(lambda k: not torch.equal(encoder1_sd[k], encoder2_sd[k]), adapter_keys))
Expand Down

0 comments on commit dfa8257

Please sign in to comment.