Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix loading adapters together with full model #378

Merged
merged 2 commits into from
Jul 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/transformers/adapters/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,8 @@ def validate(configs):
def __getitem__(self, key):
if isinstance(key, int):
return self.configs[key]
elif hasattr(self, key):
return getattr(self, key)
else:
i, k = key.split(".")
return self.configs[int(i)][k]
Expand Down Expand Up @@ -600,6 +602,8 @@ def match(
config = self.get(adapter_name)
if config is None:
return None
elif not isinstance(config, AdapterConfigBase):
config = AdapterConfigBase.load(config)

if isinstance(config, config_type):
leave_out = config.get("leave_out", [])
Expand Down Expand Up @@ -646,7 +650,7 @@ def add(self, adapter_name: str, config: Optional[Union[str, dict]] = None):
# if it's a dict, compute it's hash and add a new entry to the config map
elif isinstance(config, Mapping):
config_name = get_adapter_config_hash(config)
self.config_map[config_name] = config
self.config_map[config_name] = AdapterConfigBase.load(config)
else:
raise ValueError("Invalid adapter config: {}".format(config))
self.adapters[adapter_name] = config_name
Expand Down
33 changes: 17 additions & 16 deletions src/transformers/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def _init_adapter_modules(self, add_prefix_tuning_pool=True):

# Initialize adapters from config
for adapter_name in self.config.adapters:
self.apply_to_adapter_layers(lambda i, layer: layer.add_adapter(adapter_name, i))
self._add_adapter_weights(adapter_name)
# Initialize fusion from config
for fusion_name in self.config.adapters.fusions:
self.apply_to_adapter_layers(lambda i, layer: layer.add_fusion_layer(fusion_name))
Expand Down Expand Up @@ -387,26 +387,27 @@ def add_adapter(self, adapter_name: str, config=None, overwrite_ok: bool = False
self.delete_adapter(adapter_name)
self.config.adapters.add(adapter_name, config=config)
try:
self.apply_to_adapter_layers(lambda i, layer: layer.add_adapter(adapter_name, i))
# PHM Layer
if self.config.adapters.match(adapter_name, AdapterConfig, location_key="phm_layer"):
self._add_shared_parameters(adapter_name, config)
# Prefix Tuning
for module in self.modules():
if isinstance(module, PrefixTuningPool):
module.confirm_prefix(adapter_name)
if isinstance(self, InvertibleAdaptersMixin):
self.add_invertible_adapter(adapter_name)
self._add_adapter_weights(adapter_name)
except ValueError as ex:
self.delete_adapter(adapter_name)
raise ex
if set_active:
self.set_active_adapters(adapter_name)

def _add_shared_parameters(self, adapter_name, adapter_config: AdapterConfig):
self.shared_parameters[adapter_name] = (
list(self.get_adapter(adapter_name)[0].values())[0].adapter_down[0].init_shared_parameters()
)
def _add_adapter_weights(self, adapter_name: str):
"""Helper method that performs the actual parameter additions when adding a new adapter."""
self.apply_to_adapter_layers(lambda i, layer: layer.add_adapter(adapter_name, i))
# PHM Layer
if self.config.adapters.match(adapter_name, AdapterConfig, location_key="phm_layer"):
self.shared_parameters[adapter_name] = (
list(self.get_adapter(adapter_name)[0].values())[0].adapter_down[0].init_shared_parameters()
)
# Prefix Tuning
for module in self.modules():
if isinstance(module, PrefixTuningPool):
module.confirm_prefix(adapter_name)
if isinstance(self, InvertibleAdaptersMixin):
self.add_invertible_adapter(adapter_name)

def add_fusion(self, adapter_names: Union[Fuse, list], adapter_fusion_config=None, override_kwargs=None):
warnings.warn(
Expand Down Expand Up @@ -812,7 +813,7 @@ def adapter_summary(self, as_dict=False) -> Union[str, dict]:
# fill in data for adapters
for name, config_name in self.config.adapters.adapters.items():
config = self.config.adapters.config_map[config_name]
row = {"name": name, "architecture": config.architecture or "bottleneck"}
row = {"name": name, "architecture": config.get("architecture", None) or "bottleneck"}
weights = self.get_adapter(name)
row["active"] = self.active_adapters is not None and name in self.active_adapters.flatten()
# count parameters
Expand Down
35 changes: 34 additions & 1 deletion tests_adapters/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,12 @@ def run_load_test(self, adapter_config):
self.assertTrue(len(weights) > 0)

# also tests that set_active works
model2.load_adapter(temp_dir, set_active=True)
loading_info = {}
model2.load_adapter(temp_dir, set_active=True, loading_info=loading_info)

# check if all weights were loaded
self.assertEqual(0, len(loading_info["missing_keys"]))
self.assertEqual(0, len(loading_info["unexpected_keys"]))

# check if adapter was correctly loaded
self.assertTrue(name in model2.config.adapters)
Expand All @@ -158,6 +163,34 @@ def run_load_test(self, adapter_config):
self.assertEqual(len(output1), len(output2))
self.assertTrue(torch.allclose(output1[0], output2[0], atol=1e-4))

def run_full_model_load_test(self, adapter_config):
model1 = self.get_model()
model1.eval()

name = "dummy"
model1.add_adapter(name, config=adapter_config)
with tempfile.TemporaryDirectory() as temp_dir:
model1.save_pretrained(temp_dir)

model2, loading_info = self.model_class.from_pretrained(temp_dir, output_loading_info=True)

# check if all weights were loaded
self.assertEqual(0, len(loading_info["missing_keys"]))
self.assertEqual(0, len(loading_info["unexpected_keys"]))

# check if adapter was correctly loaded
self.assertTrue(name in model2.config.adapters)

# check equal output
input_data = self.get_input_samples(config=model1.config)
model1.to(torch_device)
model2.to(torch_device)
with AdapterSetup(name):
output1 = model1(**input_data)
output2 = model2(**input_data)
self.assertEqual(len(output1), len(output2))
self.assertTrue(torch.equal(output1[0], output2[0]))

def trainings_run(self, model):
# setup dataset
train_dataset = self.dataset()
Expand Down
26 changes: 2 additions & 24 deletions tests_adapters/methods/test_adapter_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,30 +145,8 @@ def test_load_adapter(self):
def test_load_mam_adapter(self):
self.run_load_test(MAMConfig())

def test_load_full_model(self):
model1 = self.get_model()
model1.eval()

name = "dummy"
model1.add_adapter(name)
model1.set_active_adapters([name])
with tempfile.TemporaryDirectory() as temp_dir:
model1.save_pretrained(temp_dir)

model2 = self.model_class.from_pretrained(temp_dir)
model2.set_active_adapters([name])

# check if adapter was correctly loaded
self.assertTrue(name in model2.config.adapters)

# check equal output
input_data = self.get_input_samples(config=model1.config)
model1.to(torch_device)
model2.to(torch_device)
output1 = model1(**input_data)
output2 = model2(**input_data)
self.assertEqual(len(output1), len(output2))
self.assertTrue(torch.equal(output1[0], output2[0]))
def test_load_full_model_adapter(self):
self.run_full_model_load_test(PfeifferConfig())

def test_model_config_serialization(self):
"""PretrainedConfigurations should not raise an Exception when serializing the config dict
Expand Down
3 changes: 3 additions & 0 deletions tests_adapters/methods/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def test_forward_lora(self):
def test_load_lora(self):
self.run_load_test(LoRAConfig())

def test_load_full_model_lora(self):
self.run_full_model_load_test(LoRAConfig(init_weights="bert"))

def test_train_lora(self):
self.run_train_test(LoRAConfig(init_weights="bert"), ["loras.{name}."])

Expand Down
3 changes: 3 additions & 0 deletions tests_adapters/methods/test_prefix_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def test_forward_prefix_tuning(self):
def test_load_prefix_tuning(self):
self.run_load_test(PrefixTuningConfig())

def test_load_full_model_prefix_tuning(self):
self.run_full_model_load_test(PrefixTuningConfig())

def test_train_prefix_tuning(self):
self.run_train_test(PrefixTuningConfig(), ["prefix_tunings.{name}."])

Expand Down