Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change BartLearnedPositionalEmbedding's forward method signature to support Opacus training #18486

Merged
merged 8 commits into from
Aug 11, 2022
26 changes: 15 additions & 11 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,14 @@ def __init__(self, num_embeddings: int, embedding_dim: int):
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)

def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids_shape[:2]
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
"""`input_ids' shape is expected to be [bsz x seqlen]."""

bsz, seq_len = input_ids.shape[:2]
positions = torch.arange(
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
)
).expand(bsz, -1)

return super().forward(positions + self.offset)


Expand Down Expand Up @@ -788,17 +790,17 @@ def forward(
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
input = input_ids
input_ids = input_ids.view(-1, input_ids.shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
input = inputs_embeds[:, :, -1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

embed_pos = self.embed_positions(input_shape)
embed_pos = self.embed_positions(input)

hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
Expand Down Expand Up @@ -1013,18 +1015,20 @@ def forward(
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input = input_ids
input_shape = input.shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like input_shape is not use in this forward anymore, and the other variables are not strictly necessary either, so you could should remove those to elif and just do an
elif input_ids is None and inputs_embeds is None for the last ValueError.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_shape is used for the _prepare_decoder_attention_mask and _expand_mask methods within this forward.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, was looking at the wrong forward 🤦‍♂️

input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
input = inputs_embeds[:, :, -1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
inputs_embeds = self.embed_tokens(input) * self.embed_scale
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why replace here as inputs will be input_ids?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for clarity that input is set after the if, elif checks. Happy to keep it input_ids if you feel strongly about it.


attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
Expand All @@ -1036,7 +1040,7 @@ def forward(
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])

# embed positions
positions = self.embed_positions(input_shape, past_key_values_length)
positions = self.embed_positions(input, past_key_values_length)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it could always be inputs_embeds[:, :, -1] here since input_embeds are now defined.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps I'm missing something, but if inputs_embeds is not provided, the input here will be input_ids. If inputs_embeds is provided (but input_ids is not), this input will be inputs_embeds[:, :, -1]. So the dependency on what is provided means we cannot always pass inputs_embeds[:, :, -1].

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but the input_embeds are now defined and their shape will match. Anyhow it's not important, it was to go along with the suggestion above to remove input entirely.


hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states)
Expand Down
23 changes: 14 additions & 9 deletions src/transformers/models/mbart/modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,14 @@ def __init__(self, num_embeddings: int, embedding_dim: int):
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)

def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids_shape[:2]
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
"""`input_ids' shape is expected to be [bsz x seqlen]."""

bsz, seq_len = input_ids.shape[:2]
positions = torch.arange(
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
)
).expand(bsz, -1)

return super().forward(positions + self.offset)


Expand Down Expand Up @@ -783,17 +785,18 @@ def forward(
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input = input_ids
input_shape = input.shape
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
input = inputs_embeds[:, :, -1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

embed_pos = self.embed_positions(input_shape)
embed_pos = self.embed_positions(input)

hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
Expand Down Expand Up @@ -1011,10 +1014,12 @@ def forward(
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input = input_ids
input_shape = input.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
input = inputs_embeds[:, :, -1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

Expand All @@ -1034,7 +1039,7 @@ def forward(
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])

# embed positions
positions = self.embed_positions(input_shape, past_key_values_length)
positions = self.embed_positions(input, past_key_values_length)

hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states)
Expand Down
22 changes: 14 additions & 8 deletions src/transformers/models/mvp/modeling_mvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,14 @@ def __init__(self, num_embeddings: int, embedding_dim: int):
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)

def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids_shape[:2]
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
"""`input_ids' shape is expected to be [bsz x seqlen]."""

bsz, seq_len = input_ids.shape[:2]
positions = torch.arange(
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
)
).expand(bsz, -1)

return super().forward(positions + self.offset)


Expand Down Expand Up @@ -895,17 +897,19 @@ def forward(
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input = input_ids
input_shape = input.shape
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
input = inputs_embeds[:, :, -1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

embed_pos = self.embed_positions(input_shape)
embed_pos = self.embed_positions(input)

hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
Expand Down Expand Up @@ -1144,10 +1148,12 @@ def forward(
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input = input_ids
input_shape = input_ids.shape
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
input = inputs_embeds[:, :, -1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

Expand All @@ -1167,7 +1173,7 @@ def forward(
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])

# embed positions
positions = self.embed_positions(input_shape, past_key_values_length)
positions = self.embed_positions(input, past_key_values_length)

hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states)
Expand Down
26 changes: 15 additions & 11 deletions src/transformers/models/plbart/modeling_plbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,14 @@ def __init__(self, num_embeddings: int, embedding_dim: int):
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)

def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids_shape[:2]
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
"""`input_ids' shape is expected to be [bsz x seqlen]."""

bsz, seq_len = input_ids.shape[:2]
positions = torch.arange(
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
)
).expand(bsz, -1)

return super().forward(positions + self.offset)


Expand Down Expand Up @@ -759,17 +761,17 @@ def forward(
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
input = input_ids
input_ids = input_ids.view(-1, input_ids.shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
input = inputs_embeds[:, :, -1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

embed_pos = self.embed_positions(input_shape)
embed_pos = self.embed_positions(input)

hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
Expand Down Expand Up @@ -985,18 +987,20 @@ def forward(
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input = input_ids
input_shape = input.shape
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
input = inputs_embeds[:, :, -1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
inputs_embeds = self.embed_tokens(input) * self.embed_scale

attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
Expand All @@ -1008,7 +1012,7 @@ def forward(
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])

# embed positions
positions = self.embed_positions(input_shape, past_key_values_length)
positions = self.embed_positions(input, past_key_values_length)

hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states)
Expand Down
19 changes: 12 additions & 7 deletions src/transformers/models/trocr/modeling_trocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,14 @@ def __init__(self, num_embeddings: int, embedding_dim: int):
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)

def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids_shape[:2]
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
"""`input_ids' shape is expected to be [bsz x seqlen]."""

bsz, seq_len = input_ids.shape[:2]
positions = torch.arange(
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
)
).expand(bsz, -1)

return super().forward(positions + self.offset)


Expand Down Expand Up @@ -626,10 +628,11 @@ def forward(
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
input = input_ids
input_ids = input_ids.view(-1, input.shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
input = inputs_embeds[:, :, -1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

Expand All @@ -640,7 +643,7 @@ def forward(
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

if self.config.use_learned_position_embeddings:
embed_pos = self.embed_positions(input_shape, past_key_values_length=past_key_values_length)
embed_pos = self.embed_positions(input, past_key_values_length=past_key_values_length)
else:
embed_pos = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)

Expand All @@ -651,6 +654,8 @@ def forward(

hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)

input_shape = input.shape

attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
Expand Down