Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add head_mask/decoder_head_mask for BART #9404

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 113 additions & 5 deletions src/transformers/models/bart/modeling_bart.py

Large diffs are not rendered by default.

110 changes: 105 additions & 5 deletions src/transformers/models/blenderbot/modeling_blenderbot.py

Large diffs are not rendered by default.

109 changes: 104 additions & 5 deletions src/transformers/models/blenderbot_small/modeling_blenderbot_small.py

Large diffs are not rendered by default.

110 changes: 105 additions & 5 deletions src/transformers/models/marian/modeling_marian.py

Large diffs are not rendered by default.

104 changes: 100 additions & 4 deletions src/transformers/models/mbart/modeling_mbart.py

Large diffs are not rendered by default.

110 changes: 105 additions & 5 deletions src/transformers/models/pegasus/modeling_pegasus.py

Large diffs are not rendered by default.

11 changes: 10 additions & 1 deletion tests/test_modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,26 @@ def prepare_bart_inputs_dict(
config,
input_ids,
decoder_input_ids=None,
head_mask=None,
attention_mask=None,
decoder_attention_mask=None,
decoder_head_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id)
if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads)
if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads)
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_attention_mask": attention_mask,
"decoder_head_mask": decoder_head_mask,
}


Expand Down Expand Up @@ -142,9 +150,10 @@ def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
model = BartModel(config=config).get_decoder().to(torch_device).eval()
input_ids = inputs_dict["input_ids"]
attention_mask = inputs_dict["attention_mask"]
head_mask = inputs_dict["head_mask"]

# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)

output, past_key_values = outputs.to_tuple()

Expand Down
11 changes: 10 additions & 1 deletion tests/test_modeling_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,25 @@ def prepare_blenderbot_inputs_dict(
input_ids,
decoder_input_ids,
attention_mask=None,
head_mask=None,
decoder_attention_mask=None,
decoder_head_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id)
if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads)
if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads)
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_attention_mask": attention_mask,
"decoder_head_mask": decoder_head_mask,
}


Expand Down Expand Up @@ -129,9 +137,10 @@ def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
model = BlenderbotModel(config=config).get_decoder().to(torch_device).eval()
input_ids = inputs_dict["input_ids"]
attention_mask = inputs_dict["attention_mask"]
head_mask = inputs_dict["head_mask"]

# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)

output, past_key_values = outputs.to_tuple()

Expand Down
11 changes: 10 additions & 1 deletion tests/test_modeling_blenderbot_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,25 @@ def prepare_blenderbot_small_inputs_dict(
input_ids,
decoder_input_ids,
attention_mask=None,
head_mask=None,
decoder_attention_mask=None,
decoder_head_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id)
if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads)
if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads)
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_attention_mask": attention_mask,
"decoder_head_mask": decoder_head_mask,
}


Expand Down Expand Up @@ -137,9 +145,10 @@ def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
model = BlenderbotSmallModel(config=config).get_decoder().to(torch_device).eval()
input_ids = inputs_dict["input_ids"]
attention_mask = inputs_dict["attention_mask"]
head_mask = inputs_dict["head_mask"]

# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)

output, past_key_values = outputs.to_tuple()

Expand Down
36 changes: 31 additions & 5 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,12 @@ def test_forward_signature(self):
"decoder_attention_mask",
"encoder_outputs",
]
self.assertListEqual(arg_names[:5], expected_arg_names)
if model.config.model_type in ["bart", "mbart", "marian", "blenderbot", "blenderbot-small", "pegasus"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The order of the signature arguments IMO should be as follows:

                expected_arg_names = [
                    "input_ids",
                    "attention_mask",
                    "decoder_input_ids",
                    "decoder_attention_mask",
                    "head_mask",
                    "decoder_head_mask",
                    "encoder_outputs",
                ]

The reason is that decoder_input_ids is more important than head_mask for torchscript.

We models like Bart we would still like to be able to use torchscript as follows:

traced_bart(input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)

instead of having to do

traced_bart(input_ids, attention_mask, head_mask, decoder_input_ids, decoder_attention_mask)

where as head_mask would have to be a tensor of all 1's since in 99% of the times it's not used for torchscript.

So it'd be great if we can slightly change the order in all ...Model and all ...ForConditionalGeneration models so that we have:

                expected_arg_names = [
                    "input_ids",
                    "attention_mask",
                    "decoder_input_ids",
                    "decoder_attention_mask",
                    "head_mask",
                    "decoder_head_mask",
                    "encoder_outputs",
                ]

head_mask is just used to little for torchscript so that we have to break the (first all encoder inputs, then all decoder inputs) logic here.

We can adapt the test as follows:

...
            arg_names = [*signature.parameters.keys()]
            if model.config.is_encoder_decoder:
                expected_arg_names = [
                    "input_ids",
                    "attention_mask",
                    "decoder_input_ids",
                    "decoder_attention_mask",
                ]

                expected_arg_names.extend(["head_mask", "decoder_head_mask", "encoder_outputs"] if "head_mask" in arg_names else ["encoder_outputs"])
                ...

expected_arg_names.insert(2, "head_mask")
expected_arg_names.insert(5, "decoder_head_mask")
self.assertListEqual(arg_names[:7], expected_arg_names)
else:
self.assertListEqual(arg_names[:5], expected_arg_names)
else:
expected_arg_names = ["input_ids"]
self.assertListEqual(arg_names[:1], expected_arg_names)
Expand Down Expand Up @@ -395,10 +400,31 @@ def _create_and_check_torchscript(self, config, inputs_dict):
attention_mask = inputs["attention_mask"]
decoder_input_ids = inputs["decoder_input_ids"]
decoder_attention_mask = inputs["decoder_attention_mask"]

traced_model = torch.jit.trace(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try to not change this test. head_mask is more or less never used in torch.jit.trace. If you're keen we could overwrite this test in all Bart-like models including the head_mask, but it's not mandatory at all IMO.

However, we don't really like these if model.config.model_type not in .... statements in the tests.

model, (input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)
)
if model.config.model_type not in [
"bart",
"mbart",
"marian",
"blenderbot",
"blenderbot-small",
"pegasus",
]:
traced_model = torch.jit.trace(
model, (input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)
)
else:
head_mask = inputs["head_mask"]
decoder_head_mask = inputs["decoder_head_mask"]
traced_model = torch.jit.trace(
model,
(
input_ids,
attention_mask,
head_mask,
decoder_input_ids,
decoder_attention_mask,
decoder_head_mask,
),
)
else:
input_ids = inputs["input_ids"]
traced_model = torch.jit.trace(model, input_ids)
Expand Down
11 changes: 10 additions & 1 deletion tests/test_modeling_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,25 @@ def prepare_marian_inputs_dict(
input_ids,
decoder_input_ids,
attention_mask=None,
head_mask=None,
decoder_attention_mask=None,
decoder_head_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id)
if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads)
if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads)
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_attention_mask": attention_mask,
"decoder_head_mask": decoder_head_mask,
}


Expand Down Expand Up @@ -146,9 +154,10 @@ def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
model = MarianModel(config=config).get_decoder().to(torch_device).eval()
input_ids = inputs_dict["input_ids"]
attention_mask = inputs_dict["attention_mask"]
head_mask = inputs_dict["head_mask"]

# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)

output, past_key_values = outputs.to_tuple()

Expand Down
11 changes: 10 additions & 1 deletion tests/test_modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,25 @@ def prepare_mbart_inputs_dict(
input_ids,
decoder_input_ids,
attention_mask=None,
head_mask=None,
decoder_attention_mask=None,
decoder_head_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id)
if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads)
if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads)
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"head_mask": head_mask,
"attention_mask": attention_mask,
"decoder_attention_mask": attention_mask,
"decoder_head_mask": decoder_head_mask,
}


Expand Down Expand Up @@ -138,9 +146,10 @@ def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
model = MBartModel(config=config).get_decoder().to(torch_device).eval()
input_ids = inputs_dict["input_ids"]
attention_mask = inputs_dict["attention_mask"]
head_mask = inputs_dict["head_mask"]

# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)

output, past_key_values = outputs.to_tuple()

Expand Down
11 changes: 10 additions & 1 deletion tests/test_modeling_pegasus.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,25 @@ def prepare_pegasus_inputs_dict(
input_ids,
decoder_input_ids,
attention_mask=None,
head_mask=None,
decoder_attention_mask=None,
decoder_head_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id)
if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads)
if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads)
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_attention_mask": attention_mask,
"decoder_head_mask": decoder_head_mask,
}


Expand Down Expand Up @@ -130,9 +138,10 @@ def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
model = PegasusModel(config=config).get_decoder().to(torch_device).eval()
input_ids = inputs_dict["input_ids"]
attention_mask = inputs_dict["attention_mask"]
head_mask = inputs_dict["head_mask"]

# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)

output, past_key_values = outputs.to_tuple()

Expand Down