Skip to content

Commit

Permalink
Use default head dropout prob if not provided by model (#685)
Browse files Browse the repository at this point in the history
Fixes #666 and issue described in #523.
  • Loading branch information
calpt committed Apr 25, 2024
1 parent 43ccd9f commit 25797a0
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
16 changes: 12 additions & 4 deletions src/adapters/heads/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,23 @@ def __init__(self, name):
self.config = {}
self.name = name

def build(self, model):
model_config = model.config
pred_head = []
def _get_dropout_prob(self, model_config):
# try to infer dropout prob from various sources, default to 0.0
if "dropout_prob" in self.config and self.config["dropout_prob"] is not None:
dropout_prob = self.config["dropout_prob"]
elif hasattr(model_config, "classifier_dropout") and model_config.classifier_dropout is not None:
dropout_prob = model_config.classifier_dropout
else:
elif hasattr(model_config, "hidden_dropout_prob") and model_config.hidden_dropout_prob is not None:
dropout_prob = model_config.hidden_dropout_prob
else:
dropout_prob = 0.0

return dropout_prob

def build(self, model):
model_config = model.config
pred_head = []
dropout_prob = self._get_dropout_prob(model_config)
bias = self.config.get("bias", True)
for l_id in range(self.config["layers"]):
pred_head.append(nn.Dropout(dropout_prob))
Expand Down
2 changes: 1 addition & 1 deletion src/adapters/heads/dependency_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def build(self, model):
n_in=model.config.hidden_size, n_out=self.config["num_labels"], bias_x=True, bias_y=True
)

self.dropout = nn.Dropout(model.config.hidden_dropout_prob)
self.dropout = nn.Dropout(self._get_dropout_prob(model.config))

self.loss_fn = CrossEntropyLoss()

Expand Down
1 change: 0 additions & 1 deletion tests/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ class LlamaAdapterTestBase(AdapterTestBase):
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
pad_token_id=0,
)
tokenizer_name = "openlm-research/open_llama_13b"
Expand Down

0 comments on commit 25797a0

Please sign in to comment.