Skip to content

Commit

Permalink
Fix loading adapters together with full model (#378)
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt authored Jul 5, 2022
1 parent 8c12bae commit d56e9a5
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 42 deletions.
6 changes: 5 additions & 1 deletion src/transformers/adapters/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,8 @@ def validate(configs):
def __getitem__(self, key):
if isinstance(key, int):
return self.configs[key]
elif hasattr(self, key):
return getattr(self, key)
else:
i, k = key.split(".")
return self.configs[int(i)][k]
Expand Down Expand Up @@ -600,6 +602,8 @@ def match(
config = self.get(adapter_name)
if config is None:
return None
elif not isinstance(config, AdapterConfigBase):
config = AdapterConfigBase.load(config)

if isinstance(config, config_type):
leave_out = config.get("leave_out", [])
Expand Down Expand Up @@ -646,7 +650,7 @@ def add(self, adapter_name: str, config: Optional[Union[str, dict]] = None):
# if it's a dict, compute it's hash and add a new entry to the config map
elif isinstance(config, Mapping):
config_name = get_adapter_config_hash(config)
self.config_map[config_name] = config
self.config_map[config_name] = AdapterConfigBase.load(config)
else:
raise ValueError("Invalid adapter config: {}".format(config))
self.adapters[adapter_name] = config_name
Expand Down
33 changes: 17 additions & 16 deletions src/transformers/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def _init_adapter_modules(self, add_prefix_tuning_pool=True):

# Initialize adapters from config
for adapter_name in self.config.adapters:
self.apply_to_adapter_layers(lambda i, layer: layer.add_adapter(adapter_name, i))
self._add_adapter_weights(adapter_name)
# Initialize fusion from config
for fusion_name in self.config.adapters.fusions:
self.apply_to_adapter_layers(lambda i, layer: layer.add_fusion_layer(fusion_name))
Expand Down Expand Up @@ -387,26 +387,27 @@ def add_adapter(self, adapter_name: str, config=None, overwrite_ok: bool = False
self.delete_adapter(adapter_name)
self.config.adapters.add(adapter_name, config=config)
try:
self.apply_to_adapter_layers(lambda i, layer: layer.add_adapter(adapter_name, i))
# PHM Layer
if self.config.adapters.match(adapter_name, AdapterConfig, location_key="phm_layer"):
self._add_shared_parameters(adapter_name, config)
# Prefix Tuning
for module in self.modules():
if isinstance(module, PrefixTuningPool):
module.confirm_prefix(adapter_name)
if isinstance(self, InvertibleAdaptersMixin):
self.add_invertible_adapter(adapter_name)
self._add_adapter_weights(adapter_name)
except ValueError as ex:
self.delete_adapter(adapter_name)
raise ex
if set_active:
self.set_active_adapters(adapter_name)

def _add_shared_parameters(self, adapter_name, adapter_config: AdapterConfig):
self.shared_parameters[adapter_name] = (
list(self.get_adapter(adapter_name)[0].values())[0].adapter_down[0].init_shared_parameters()
)
def _add_adapter_weights(self, adapter_name: str):
"""Helper method that performs the actual parameter additions when adding a new adapter."""
self.apply_to_adapter_layers(lambda i, layer: layer.add_adapter(adapter_name, i))
# PHM Layer
if self.config.adapters.match(adapter_name, AdapterConfig, location_key="phm_layer"):
self.shared_parameters[adapter_name] = (
list(self.get_adapter(adapter_name)[0].values())[0].adapter_down[0].init_shared_parameters()
)
# Prefix Tuning
for module in self.modules():
if isinstance(module, PrefixTuningPool):
module.confirm_prefix(adapter_name)
if isinstance(self, InvertibleAdaptersMixin):
self.add_invertible_adapter(adapter_name)

def add_fusion(self, adapter_names: Union[Fuse, list], adapter_fusion_config=None, override_kwargs=None):
warnings.warn(
Expand Down Expand Up @@ -814,7 +815,7 @@ def adapter_summary(self, as_dict=False) -> Union[str, dict]:
# fill in data for adapters
for name, config_name in self.config.adapters.adapters.items():
config = self.config.adapters.config_map[config_name]
row = {"name": name, "architecture": config.architecture or "bottleneck"}
row = {"name": name, "architecture": config.get("architecture", None) or "bottleneck"}
weights = self.get_adapter(name)
row["active"] = self.active_adapters is not None and name in self.active_adapters.flatten()
# count parameters
Expand Down
35 changes: 34 additions & 1 deletion tests_adapters/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,12 @@ def run_load_test(self, adapter_config):
self.assertTrue(len(weights) > 0)

# also tests that set_active works
model2.load_adapter(temp_dir, set_active=True)
loading_info = {}
model2.load_adapter(temp_dir, set_active=True, loading_info=loading_info)

# check if all weights were loaded
self.assertEqual(0, len(loading_info["missing_keys"]))
self.assertEqual(0, len(loading_info["unexpected_keys"]))

# check if adapter was correctly loaded
self.assertTrue(name in model2.config.adapters)
Expand All @@ -158,6 +163,34 @@ def run_load_test(self, adapter_config):
self.assertEqual(len(output1), len(output2))
self.assertTrue(torch.allclose(output1[0], output2[0], atol=1e-4))

def run_full_model_load_test(self, adapter_config):
model1 = self.get_model()
model1.eval()

name = "dummy"
model1.add_adapter(name, config=adapter_config)
with tempfile.TemporaryDirectory() as temp_dir:
model1.save_pretrained(temp_dir)

model2, loading_info = self.model_class.from_pretrained(temp_dir, output_loading_info=True)

# check if all weights were loaded
self.assertEqual(0, len(loading_info["missing_keys"]))
self.assertEqual(0, len(loading_info["unexpected_keys"]))

# check if adapter was correctly loaded
self.assertTrue(name in model2.config.adapters)

# check equal output
input_data = self.get_input_samples(config=model1.config)
model1.to(torch_device)
model2.to(torch_device)
with AdapterSetup(name):
output1 = model1(**input_data)
output2 = model2(**input_data)
self.assertEqual(len(output1), len(output2))
self.assertTrue(torch.equal(output1[0], output2[0]))

def trainings_run(self, model):
# setup dataset
train_dataset = self.dataset()
Expand Down
26 changes: 2 additions & 24 deletions tests_adapters/methods/test_adapter_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,30 +145,8 @@ def test_load_adapter(self):
def test_load_mam_adapter(self):
self.run_load_test(MAMConfig())

def test_load_full_model(self):
model1 = self.get_model()
model1.eval()

name = "dummy"
model1.add_adapter(name)
model1.set_active_adapters([name])
with tempfile.TemporaryDirectory() as temp_dir:
model1.save_pretrained(temp_dir)

model2 = self.model_class.from_pretrained(temp_dir)
model2.set_active_adapters([name])

# check if adapter was correctly loaded
self.assertTrue(name in model2.config.adapters)

# check equal output
input_data = self.get_input_samples(config=model1.config)
model1.to(torch_device)
model2.to(torch_device)
output1 = model1(**input_data)
output2 = model2(**input_data)
self.assertEqual(len(output1), len(output2))
self.assertTrue(torch.equal(output1[0], output2[0]))
def test_load_full_model_adapter(self):
self.run_full_model_load_test(PfeifferConfig())

def test_model_config_serialization(self):
"""PretrainedConfigurations should not raise an Exception when serializing the config dict
Expand Down
3 changes: 3 additions & 0 deletions tests_adapters/methods/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def test_forward_lora(self):
def test_load_lora(self):
self.run_load_test(LoRAConfig())

def test_load_full_model_lora(self):
self.run_full_model_load_test(LoRAConfig(init_weights="bert"))

def test_train_lora(self):
self.run_train_test(LoRAConfig(init_weights="bert"), ["loras.{name}."])

Expand Down
3 changes: 3 additions & 0 deletions tests_adapters/methods/test_prefix_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def test_forward_prefix_tuning(self):
def test_load_prefix_tuning(self):
self.run_load_test(PrefixTuningConfig())

def test_load_full_model_prefix_tuning(self):
self.run_full_model_load_test(PrefixTuningConfig())

def test_train_prefix_tuning(self):
self.run_train_test(PrefixTuningConfig(), ["prefix_tunings.{name}."])

Expand Down

0 comments on commit d56e9a5

Please sign in to comment.