Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TF Longformer #9348

Merged
merged 7 commits into from
Jan 5, 2021
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 26 additions & 28 deletions src/transformers/models/longformer/modeling_tf_longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se
True` else after `sep_token_id`.
"""

assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions"
assert shape_list(sep_token_indices)[1] == 2, "`input_ids` should have two dimensions"
question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1]
question_end_index = tf.cast(question_end_index[:, None], tf.dtypes.int32) # size: batch_size x 1
# bool attention mask with True in locations of global attention
Expand Down Expand Up @@ -1028,7 +1028,7 @@ def _mask_invalid_locations(input_tensor, window_overlap):
)

# pad to full matrix
padding = tf.constant(
padding = tf.convert_to_tensor(
[[0, shape_list(input_tensor)[1] - window_overlap], [0, shape_list(input_tensor)[3] - window_overlap - 1]]
)

Expand Down Expand Up @@ -1523,8 +1523,7 @@ def call(
training=False,
):
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
all_global_attentions = () if (output_attentions and is_global_attn) else None
all_attentions = all_global_attentions = () if output_attentions else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!


for i, layer_module in enumerate(self.layer):
if output_hidden_states:
Expand All @@ -1547,9 +1546,8 @@ def call(
# bzs x seq_len x num_attn_heads x (num_global_attn + attention_window_len + 1) => bzs x num_attn_heads x seq_len x (num_global_attn + attention_window_len + 1)
all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),)

if is_global_attn:
# bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn
all_global_attentions = all_global_attentions + (tf.transpose(layer_outputs[2], (0, 1, 3, 2)))
# bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn
all_global_attentions = all_global_attentions + (tf.transpose(layer_outputs[2], (0, 1, 3, 2)))

# Add last layer
if output_hidden_states:
Expand Down Expand Up @@ -1766,24 +1764,26 @@ def _pad_to_window_size(
)
)

paddings = tf.constant([[0, 0], [0, padding_len]])
paddings = tf.convert_to_tensor([[0, 0], [0, padding_len]])

if input_ids is not None:
input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id)
if input_ids is not None:
input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When padding_len==0, then this won't change the input_ids correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct!


if position_ids is not None:
# pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings
position_ids = tf.pad(position_ids, paddings, constant_values=pad_token_id)

if position_ids is not None:
# pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings
position_ids = tf.pad(position_ids, paddings, constant_values=pad_token_id)
if inputs_embeds is not None:

if inputs_embeds is not None:
def pad_embeddings():
input_ids_padding = tf.fill((batch_size, padding_len), self.pad_token_id)
inputs_embeds_padding = self.embeddings(input_ids_padding)
inputs_embeds = tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2)
return tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2)

attention_mask = tf.pad(
attention_mask, paddings, constant_values=False
) # no attention on the padding tokens
token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0
inputs_embeds = tf.cond(padding_len > 0, pad_embeddings, lambda: inputs_embeds)

attention_mask = tf.pad(attention_mask, paddings, constant_values=False) # no attention on the padding tokens
token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0

return (
padding_len,
Expand Down Expand Up @@ -2171,16 +2171,14 @@ def call(

# set global attention on question tokens
if inputs["global_attention_mask"] is None and inputs["input_ids"] is not None:
if inputs["input_ids"] is None:
logger.warning(
"It is not possible to automatically generate the `global_attention_mask`. Please make sure that it is correctly set."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we leave this warning?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can but here the problem is that we test if inputs["input_ids"] is None inside a if that already test if inputs["input_ids"] is not None, this seems strange

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed with @jplu here ;-)

)
elif (
tf.where(inputs["input_ids"] == self.config.sep_token_id).shape[0] != 3 * inputs["input_ids"].shape[0]
if (
shape_list(tf.where(inputs["input_ids"] == self.config.sep_token_id))[0]
!= 3 * shape_list(inputs["input_ids"])[0]
):
logger.warning(
f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this. This is most likely an error."
)
inputs["global_attention_mask"] = tf.zeros(shape_list(inputs["input_ids"]))
else:
logger.info("Initializing global attention on question tokens...")
# put global attention on all tokens until `config.sep_token_id` is reached
Expand Down Expand Up @@ -2317,8 +2315,8 @@ def call(
inputs["global_attention_mask"] = tf.zeros_like(inputs["input_ids"])
inputs["global_attention_mask"] = tf.tensor_scatter_nd_update(
inputs["global_attention_mask"],
[[i, 0] for i in range(inputs["input_ids"].shape[0])],
[1 for _ in range(inputs["input_ids"].shape[0])],
[[i, 0] for i in range(shape_list(inputs["input_ids"])[0])],
[1 for _ in range(shape_list(inputs["input_ids"])[0])],
)

outputs = self.longformer(
Expand Down Expand Up @@ -2443,7 +2441,7 @@ def call(
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
)
flat_global_attention_mask = (
tf.reshape(inputs["global_attention_mask"], (-1, inputs["global_attention_mask"].shape[-1]))
tf.reshape(inputs["global_attention_mask"], (-1, shape_list(inputs["global_attention_mask"])[-1]))
if inputs["global_attention_mask"] is not None
else None
)
Expand Down