Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Update call sites missing opt (#3715)
Browse files Browse the repository at this point in the history
* update mha init in polyencoder

* add polyencoder mha test

* hit both multihead init sites in test
  • Loading branch information
spencerp authored Jun 15, 2021
1 parent 02d4ae7 commit fe36d88
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
4 changes: 2 additions & 2 deletions parlai/agents/transformer/polyencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def __init__(self, opt, dict_, null_idx):
# The attention for the codes.
if self.codes_attention_type == 'multihead':
self.code_attention = MultiHeadAttention(
self.codes_attention_num_heads, embed_dim, opt['dropout']
opt, self.codes_attention_num_heads, embed_dim, opt['dropout']
)
elif self.codes_attention_type == 'sqrt':
self.code_attention = PolyBasicAttention(
Expand All @@ -360,7 +360,7 @@ def __init__(self, opt, dict_, null_idx):
# The final attention (the one that takes the candidate as key)
if self.attention_type == 'multihead':
self.attention = MultiHeadAttention(
self.attention_num_heads, opt['embedding_size'], opt['dropout']
opt, self.attention_num_heads, opt['embedding_size'], opt['dropout']
)
else:
self.attention = PolyBasicAttention(
Expand Down
22 changes: 22 additions & 0 deletions tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,28 @@ class TestPolyencoder(TestTransformerBase):
def test_resize_embeddings(self):
self._test_resize_embeddings('transformer/polyencoder')

def test_multi_head_attention(self):
with testing_utils.tempdir() as tmpdir:
model_file = os.path.join(tmpdir, 'model_file')
_, _ = testing_utils.train_model(
Opt(
model='transformer/polyencoder',
task='integration_tests:short_fixed',
n_layers=1,
n_encoder_layers=2,
n_decoder_layers=4,
num_epochs=1,
dict_tokenizer='bytelevelbpe',
bpe_vocab=DEFAULT_BYTELEVEL_BPE_VOCAB,
bpe_merge=DEFAULT_BYTELEVEL_BPE_MERGE,
bpe_add_prefix_space=False,
model_file=model_file,
save_after_valid=True,
poly_attention_type='multihead',
codes_attention_type='multihead',
)
)


@testing_utils.skipUnlessVision
class TestImagePolyencoder(unittest.TestCase):
Expand Down

0 comments on commit fe36d88

Please sign in to comment.