diff --git a/src/adapters/heads/base.py b/src/adapters/heads/base.py index 3cfe9564b7..82dd8097a2 100644 --- a/src/adapters/heads/base.py +++ b/src/adapters/heads/base.py @@ -65,15 +65,23 @@ def __init__(self, name): self.config = {} self.name = name - def build(self, model): - model_config = model.config - pred_head = [] + def _get_dropout_prob(self, model_config): + # try to infer dropout prob from various sources, default to 0.0 if "dropout_prob" in self.config and self.config["dropout_prob"] is not None: dropout_prob = self.config["dropout_prob"] elif hasattr(model_config, "classifier_dropout") and model_config.classifier_dropout is not None: dropout_prob = model_config.classifier_dropout - else: + elif hasattr(model_config, "hidden_dropout_prob") and model_config.hidden_dropout_prob is not None: dropout_prob = model_config.hidden_dropout_prob + else: + dropout_prob = 0.0 + + return dropout_prob + + def build(self, model): + model_config = model.config + pred_head = [] + dropout_prob = self._get_dropout_prob(model_config) bias = self.config.get("bias", True) for l_id in range(self.config["layers"]): pred_head.append(nn.Dropout(dropout_prob)) diff --git a/src/adapters/heads/dependency_parsing.py b/src/adapters/heads/dependency_parsing.py index 2b91d5290e..d568f356b0 100644 --- a/src/adapters/heads/dependency_parsing.py +++ b/src/adapters/heads/dependency_parsing.py @@ -81,7 +81,7 @@ def build(self, model): n_in=model.config.hidden_size, n_out=self.config["num_labels"], bias_x=True, bias_y=True ) - self.dropout = nn.Dropout(model.config.hidden_dropout_prob) + self.dropout = nn.Dropout(self._get_dropout_prob(model.config)) self.loss_fn = CrossEntropyLoss() diff --git a/tests/test_llama.py b/tests/test_llama.py index e8cf0557a0..7b99697122 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -29,7 +29,6 @@ class LlamaAdapterTestBase(AdapterTestBase): num_attention_heads=4, intermediate_size=37, hidden_act="gelu", - hidden_dropout_prob=0.1, pad_token_id=0, ) tokenizer_name = "openlm-research/open_llama_13b"