Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve AdapterFusion config flexibility #216

Merged
merged 4 commits into from
Aug 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 58 additions & 3 deletions src/transformers/adapters/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,12 @@ def __init__(self, **kwargs):
adapters_list = dict(
map(lambda t: (t[0], t[1][1] or t[1][0] if isinstance(t[1], tuple) else t[1]), adapters_list.items())
)
self.adapters: Mapping[str] = adapters_list
self.adapters: Mapping[str, str] = adapters_list
self.config_map = kwargs.pop("config_map", {})

self.fusions: Mapping[str, str] = kwargs.pop("fusions", {})
self.fusion_config_map = kwargs.pop("fusion_config_map", {})

# TODO-V2 Save this with config?
self.active_setup: Optional[AdapterCompositionBlock] = None
self.skip_layers = None
Expand All @@ -212,7 +216,7 @@ def __iter__(self):
def __len__(self):
return len(self.adapters)

def get(self, adapter_name: str):
def get(self, adapter_name: str) -> Optional[dict]:
"""
Gets the config dictionary for a given adapter.

Expand Down Expand Up @@ -248,7 +252,7 @@ def add(self, adapter_name: str, config: Optional[Union[str, dict]] = None):
config = DEFAULT_ADAPTER_CONFIG
if isinstance(config, str):
if config not in ADAPTER_CONFIG_MAP and config not in self.config_map:
raise ValueError(f"Invalid adapter config identifier '{config}''")
raise ValueError(f"Invalid adapter config identifier '{config}'.")
config_name = config
# if it's a dict, compute it's hash and add a new entry to the config map
elif isinstance(config, Mapping):
Expand All @@ -259,6 +263,55 @@ def add(self, adapter_name: str, config: Optional[Union[str, dict]] = None):
self.adapters[adapter_name] = config_name
logger.info(f"Adding adapter '{adapter_name}'.")

def get_fusion(self, fusion_name: Union[str, List[str]]) -> Optional[dict]:
"""
Gets the config dictionary for a given AdapterFusion.

Args:
fusion_name (Union[str, List[str]]): The name of the AdapterFusion or the adapters to fuse.

Returns:
Optional[dict]: The AdapterFusion configuration.
"""
if isinstance(fusion_name, list):
fusion_name = ",".join(fusion_name)
if fusion_name in self.fusions:
config_name = self.fusions[fusion_name]
if config_name in self.fusion_config_map:
config = self.fusion_config_map.get(config_name, None)
else:
config = ADAPTERFUSION_CONFIG_MAP.get(config_name, None)
else:
config = None
return config

def add_fusion(self, fusion_name: Union[str, List[str]], config: Optional[Union[str, dict]] = None):
"""
Adds a new AdapterFusion.

Args:
fusion_name (Union[str, List[str]]): The name of the AdapterFusion or the adapters to fuse.
config (Optional[Union[str, dict]], optional): AdapterFusion config. Defaults to None.
"""
if isinstance(fusion_name, list):
fusion_name = ",".join(fusion_name)
if fusion_name in self.fusions:
raise ValueError(f"An AdapterFusion with the name '{fusion_name}' has already been added.")
calpt marked this conversation as resolved.
Show resolved Hide resolved
if config is None:
config = DEFAULT_ADAPTERFUSION_CONFIG
if isinstance(config, str):
if config not in ADAPTERFUSION_CONFIG_MAP and config not in self.fusion_config_map:
raise ValueError(f"Invalid AdapterFusion config identifier '{config}'.")
config_name = config
# if it's a dict, compute it's hash and add a new entry to the config map
elif isinstance(config, Mapping):
config_name = get_adapter_config_hash(config)
self.fusion_config_map[config_name] = config
else:
raise ValueError("Invalid AdapterFusion config: {}".format(config))
self.fusions[fusion_name] = config_name
logger.info(f"Adding AdapterFusion '{fusion_name}'.")

def common_config_value(self, adapter_names: list, attribute: str):
"""
Checks whether all adapters in a list share the same config setting for a given attribute and returns the
Expand All @@ -285,6 +338,8 @@ def to_dict(self):
output_dict = {}
output_dict["adapters"] = copy.deepcopy(self.adapters)
output_dict["config_map"] = copy.deepcopy(self.config_map)
output_dict["fusions"] = copy.deepcopy(self.fusions)
output_dict["fusion_config_map"] = copy.deepcopy(self.fusion_config_map)
return output_dict


Expand Down
19 changes: 14 additions & 5 deletions src/transformers/adapters/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,12 @@ def add_fusion_layer(self, adapter_names: Union[List, str]):
"""See BertModel.add_fusion_layer"""
adapter_names = adapter_names if isinstance(adapter_names, list) else adapter_names.split(",")
if self.config.adapters.common_config_value(adapter_names, self.adapter_config_key):
fusion = BertFusion(self.config)
fusion_config = self.config.adapters.get_fusion(adapter_names)
fusion = BertFusion(
fusion_config,
self.config.hidden_size,
self.config.attention_probs_dropout_prob,
)
fusion.train(self.training) # make sure training mode is consistent
self.adapter_fusion_layer[",".join(adapter_names)] = fusion

Expand Down Expand Up @@ -114,6 +119,7 @@ def get_adapter_preparams(
adapter_config,
hidden_states,
input_tensor,
fusion_config=None,
):
"""
Retrieves the hidden_states, query (for Fusion), and residual connection according to the set configuratio
Expand All @@ -131,7 +137,7 @@ def get_adapter_preparams(
if adapter_config["residual_before_ln"]:
residual = hidden_states

if hasattr(self.config, "adapter_fusion") and self.config.adapter_fusion["query_before_ln"]:
if fusion_config is not None and fusion_config["query_before_ln"]:
query = hidden_states

if adapter_config["original_ln_before"]:
Expand All @@ -143,7 +149,7 @@ def get_adapter_preparams(
if not adapter_config["residual_before_ln"]:
residual = hidden_states

if hasattr(self.config, "adapter_fusion") and not self.config.adapter_fusion["query_before_ln"]:
if fusion_config is not None and not fusion_config["query_before_ln"]:
query = hidden_states

return hidden_states, query, residual
Expand Down Expand Up @@ -196,7 +202,10 @@ def adapter_fusion(self, adapter_setup: Fuse, hidden_states, input_tensor, lvl=0
"""
# config of _last_ fused adapter is significant
adapter_config = self.config.adapters.get(adapter_setup.last())
hidden_states, query, residual = self.get_adapter_preparams(adapter_config, hidden_states, input_tensor)
fusion_config = self.config.adapters.get_fusion(adapter_setup.name)
hidden_states, query, residual = self.get_adapter_preparams(
adapter_config, hidden_states, input_tensor, fusion_config=fusion_config
)

up_list = []

Expand Down Expand Up @@ -239,7 +248,7 @@ def adapter_split(self, adapter_setup: Split, hidden_states, input_tensor, lvl=0
"""
# config of _first_ of splitted adapters is significant
adapter_config = self.config.adapters.get(adapter_setup.first())
hidden_states, query, residual = self.get_adapter_preparams(adapter_config, hidden_states, input_tensor)
hidden_states, _, residual = self.get_adapter_preparams(adapter_config, hidden_states, input_tensor)

# split hidden representations and residuals at split index
split_hidden_states = [
Expand Down
27 changes: 13 additions & 14 deletions src/transformers/adapters/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def rename_func(self, old_name, new_name):
"adapter_fusion_layer.{}".format(old_name), "adapter_fusion_layer.{}".format(new_name)
)

def save(self, save_directory: str, name: str):
def save(self, save_directory: str, name: str, meta_dict=None):
"""
Saves a AdapterFusion module into the given directory.

Expand All @@ -470,20 +470,19 @@ def save(self, save_directory: str, name: str):
name (str, optional): The AdapterFusion name.
"""

if hasattr(self.model.config, "adapter_fusion_models"):
if name not in self.model.config.adapter_fusion_models:
if self.error_on_missing:
raise ValueError(f"Unknown AdapterFusion '{name}'.")
else:
logger.debug(f"No AdapterFusion with name '{name}' available.")
return
if name not in self.model.config.adapters.fusions:
if self.error_on_missing:
raise ValueError(f"Unknown AdapterFusion '{name}'.")
else:
logger.debug(f"No AdapterFusion with name '{name}' available.")
return

if not exists(save_directory):
mkdir(save_directory)
else:
assert isdir(save_directory), "Saving path should be a directory where the head can be saved."

adapter_fusion_config = self.model.config.adapter_fusion
adapter_fusion_config = self.model.config.adapters.get_fusion(name)

# Save the adapter fusion configuration
config_dict = build_full_config(
Expand All @@ -493,7 +492,7 @@ def save(self, save_directory: str, name: str):
model_name=self.model.model_name,
model_class=self.model.__class__.__name__,
)
self.weights_helper.save_weights_config(save_directory, config_dict)
self.weights_helper.save_weights_config(save_directory, config_dict, meta_dict=meta_dict)

# Save head weights
filter_func = self.filter_func(name)
Expand All @@ -519,13 +518,13 @@ def load(self, save_directory, load_as=None, loading_info=None, **kwargs):
return None, None

config = self.weights_helper.load_weights_config(save_directory)
if not hasattr(self.model.config, "adapter_fusion_models"):
self.model.config.adapter_fusion_models = []

adapter_fusion_name = load_as or config["name"]
if adapter_fusion_name in self.model.config.adapter_fusion_models:
if adapter_fusion_name in self.model.config.adapters.fusions:
logger.warning("Overwriting existing adapter fusion module '{}'".format(adapter_fusion_name))
self.model.add_adapter_fusion(adapter_fusion_name, config["config"], set_active=kwargs.pop("set_active", True))
self.model.add_adapter_fusion(
adapter_fusion_name, config["config"], overwrite_ok=True, set_active=kwargs.pop("set_active", True)
)

# Load AdapterFusion weights
filter_func = self.filter_func(adapter_fusion_name)
Expand Down
Loading