diff --git a/src/transformers/adapters/configuration.py b/src/transformers/adapters/configuration.py index 8420f0b587..9d61de25ed 100644 --- a/src/transformers/adapters/configuration.py +++ b/src/transformers/adapters/configuration.py @@ -462,6 +462,8 @@ def validate(configs): def __getitem__(self, key): if isinstance(key, int): return self.configs[key] + elif hasattr(self, key): + return getattr(self, key) else: i, k = key.split(".") return self.configs[int(i)][k] @@ -600,6 +602,8 @@ def match( config = self.get(adapter_name) if config is None: return None + elif not isinstance(config, AdapterConfigBase): + config = AdapterConfigBase.load(config) if isinstance(config, config_type): leave_out = config.get("leave_out", []) @@ -646,7 +650,7 @@ def add(self, adapter_name: str, config: Optional[Union[str, dict]] = None): # if it's a dict, compute it's hash and add a new entry to the config map elif isinstance(config, Mapping): config_name = get_adapter_config_hash(config) - self.config_map[config_name] = config + self.config_map[config_name] = AdapterConfigBase.load(config) else: raise ValueError("Invalid adapter config: {}".format(config)) self.adapters[adapter_name] = config_name diff --git a/src/transformers/adapters/model_mixin.py b/src/transformers/adapters/model_mixin.py index b19db36b54..815b2526e4 100644 --- a/src/transformers/adapters/model_mixin.py +++ b/src/transformers/adapters/model_mixin.py @@ -253,7 +253,7 @@ def _init_adapter_modules(self, add_prefix_tuning_pool=True): # Initialize adapters from config for adapter_name in self.config.adapters: - self.apply_to_adapter_layers(lambda i, layer: layer.add_adapter(adapter_name, i)) + self._add_adapter_weights(adapter_name) # Initialize fusion from config for fusion_name in self.config.adapters.fusions: self.apply_to_adapter_layers(lambda i, layer: layer.add_fusion_layer(fusion_name)) @@ -387,26 +387,27 @@ def add_adapter(self, adapter_name: str, config=None, overwrite_ok: bool = False self.delete_adapter(adapter_name) self.config.adapters.add(adapter_name, config=config) try: - self.apply_to_adapter_layers(lambda i, layer: layer.add_adapter(adapter_name, i)) - # PHM Layer - if self.config.adapters.match(adapter_name, AdapterConfig, location_key="phm_layer"): - self._add_shared_parameters(adapter_name, config) - # Prefix Tuning - for module in self.modules(): - if isinstance(module, PrefixTuningPool): - module.confirm_prefix(adapter_name) - if isinstance(self, InvertibleAdaptersMixin): - self.add_invertible_adapter(adapter_name) + self._add_adapter_weights(adapter_name) except ValueError as ex: self.delete_adapter(adapter_name) raise ex if set_active: self.set_active_adapters(adapter_name) - def _add_shared_parameters(self, adapter_name, adapter_config: AdapterConfig): - self.shared_parameters[adapter_name] = ( - list(self.get_adapter(adapter_name)[0].values())[0].adapter_down[0].init_shared_parameters() - ) + def _add_adapter_weights(self, adapter_name: str): + """Helper method that performs the actual parameter additions when adding a new adapter.""" + self.apply_to_adapter_layers(lambda i, layer: layer.add_adapter(adapter_name, i)) + # PHM Layer + if self.config.adapters.match(adapter_name, AdapterConfig, location_key="phm_layer"): + self.shared_parameters[adapter_name] = ( + list(self.get_adapter(adapter_name)[0].values())[0].adapter_down[0].init_shared_parameters() + ) + # Prefix Tuning + for module in self.modules(): + if isinstance(module, PrefixTuningPool): + module.confirm_prefix(adapter_name) + if isinstance(self, InvertibleAdaptersMixin): + self.add_invertible_adapter(adapter_name) def add_fusion(self, adapter_names: Union[Fuse, list], adapter_fusion_config=None, override_kwargs=None): warnings.warn( @@ -814,7 +815,7 @@ def adapter_summary(self, as_dict=False) -> Union[str, dict]: # fill in data for adapters for name, config_name in self.config.adapters.adapters.items(): config = self.config.adapters.config_map[config_name] - row = {"name": name, "architecture": config.architecture or "bottleneck"} + row = {"name": name, "architecture": config.get("architecture", None) or "bottleneck"} weights = self.get_adapter(name) row["active"] = self.active_adapters is not None and name in self.active_adapters.flatten() # count parameters diff --git a/tests_adapters/methods/base.py b/tests_adapters/methods/base.py index 9998b0d58f..e22555f0ed 100644 --- a/tests_adapters/methods/base.py +++ b/tests_adapters/methods/base.py @@ -144,7 +144,12 @@ def run_load_test(self, adapter_config): self.assertTrue(len(weights) > 0) # also tests that set_active works - model2.load_adapter(temp_dir, set_active=True) + loading_info = {} + model2.load_adapter(temp_dir, set_active=True, loading_info=loading_info) + + # check if all weights were loaded + self.assertEqual(0, len(loading_info["missing_keys"])) + self.assertEqual(0, len(loading_info["unexpected_keys"])) # check if adapter was correctly loaded self.assertTrue(name in model2.config.adapters) @@ -158,6 +163,34 @@ def run_load_test(self, adapter_config): self.assertEqual(len(output1), len(output2)) self.assertTrue(torch.allclose(output1[0], output2[0], atol=1e-4)) + def run_full_model_load_test(self, adapter_config): + model1 = self.get_model() + model1.eval() + + name = "dummy" + model1.add_adapter(name, config=adapter_config) + with tempfile.TemporaryDirectory() as temp_dir: + model1.save_pretrained(temp_dir) + + model2, loading_info = self.model_class.from_pretrained(temp_dir, output_loading_info=True) + + # check if all weights were loaded + self.assertEqual(0, len(loading_info["missing_keys"])) + self.assertEqual(0, len(loading_info["unexpected_keys"])) + + # check if adapter was correctly loaded + self.assertTrue(name in model2.config.adapters) + + # check equal output + input_data = self.get_input_samples(config=model1.config) + model1.to(torch_device) + model2.to(torch_device) + with AdapterSetup(name): + output1 = model1(**input_data) + output2 = model2(**input_data) + self.assertEqual(len(output1), len(output2)) + self.assertTrue(torch.equal(output1[0], output2[0])) + def trainings_run(self, model): # setup dataset train_dataset = self.dataset() diff --git a/tests_adapters/methods/test_adapter_common.py b/tests_adapters/methods/test_adapter_common.py index 501b7a8e83..8d3830bf5d 100644 --- a/tests_adapters/methods/test_adapter_common.py +++ b/tests_adapters/methods/test_adapter_common.py @@ -145,30 +145,8 @@ def test_load_adapter(self): def test_load_mam_adapter(self): self.run_load_test(MAMConfig()) - def test_load_full_model(self): - model1 = self.get_model() - model1.eval() - - name = "dummy" - model1.add_adapter(name) - model1.set_active_adapters([name]) - with tempfile.TemporaryDirectory() as temp_dir: - model1.save_pretrained(temp_dir) - - model2 = self.model_class.from_pretrained(temp_dir) - model2.set_active_adapters([name]) - - # check if adapter was correctly loaded - self.assertTrue(name in model2.config.adapters) - - # check equal output - input_data = self.get_input_samples(config=model1.config) - model1.to(torch_device) - model2.to(torch_device) - output1 = model1(**input_data) - output2 = model2(**input_data) - self.assertEqual(len(output1), len(output2)) - self.assertTrue(torch.equal(output1[0], output2[0])) + def test_load_full_model_adapter(self): + self.run_full_model_load_test(PfeifferConfig()) def test_model_config_serialization(self): """PretrainedConfigurations should not raise an Exception when serializing the config dict diff --git a/tests_adapters/methods/test_lora.py b/tests_adapters/methods/test_lora.py index c7055fff88..e42d213de7 100644 --- a/tests_adapters/methods/test_lora.py +++ b/tests_adapters/methods/test_lora.py @@ -28,6 +28,9 @@ def test_forward_lora(self): def test_load_lora(self): self.run_load_test(LoRAConfig()) + def test_load_full_model_lora(self): + self.run_full_model_load_test(LoRAConfig(init_weights="bert")) + def test_train_lora(self): self.run_train_test(LoRAConfig(init_weights="bert"), ["loras.{name}."]) diff --git a/tests_adapters/methods/test_prefix_tuning.py b/tests_adapters/methods/test_prefix_tuning.py index ccce1e7334..eefd68a3f3 100644 --- a/tests_adapters/methods/test_prefix_tuning.py +++ b/tests_adapters/methods/test_prefix_tuning.py @@ -27,6 +27,9 @@ def test_forward_prefix_tuning(self): def test_load_prefix_tuning(self): self.run_load_test(PrefixTuningConfig()) + def test_load_full_model_prefix_tuning(self): + self.run_full_model_load_test(PrefixTuningConfig()) + def test_train_prefix_tuning(self): self.run_train_test(PrefixTuningConfig(), ["prefix_tunings.{name}."])