Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Traced models serialization and torchscripting fix #17206

Merged
merged 16 commits into from
May 23, 2022
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed for copy consistency.

attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))

if attention_mask is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/distilbert/modeling_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def unshape(x: torch.Tensor) -> torch.Tensor:
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length)
scores = scores.masked_fill(mask, -float("inf")) # (bs, n_heads, q_length, k_length)
scores = scores.masked_fill(mask, torch.tensor(-float("inf"))) # (bs, n_heads, q_length, k_length)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed to be able to TorchScript the traced model.


weights = nn.functional.softmax(scores, dim=-1) # (bs, n_heads, q_length, k_length)
weights = self.dropout(weights) # (bs, n_heads, q_length, k_length)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed to be able to TorchScript the traced model, this should not break things because it is equivalent, from the docs:

self.bool() is equivalent to self.to(torch.bool). See to().

attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))

if attention_mask is not None:
Expand Down Expand Up @@ -1410,7 +1410,7 @@ def forward(
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)

pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same reason.
This should not break things because the tensor should be on the same device as logits anyway, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and better not rely on self.device anyway for model parallelism (I've made a few PRs to hunt most of those).


loss = None
if labels is not None:
Expand Down
12 changes: 6 additions & 6 deletions src/transformers/models/gpt_neo/modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def __init__(self, config, attention_type):
self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9))

self.attn_dropout = nn.Dropout(config.attention_dropout)
self.resid_dropout = nn.Dropout(config.resid_dropout)
self.attn_dropout = nn.Dropout(float(config.attention_dropout))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TorchScripting fails otherwise, this should not change anything.

self.resid_dropout = nn.Dropout(float(config.resid_dropout))

self.embed_dim = config.hidden_size
self.num_heads = config.num_heads
Expand Down Expand Up @@ -188,7 +188,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))

query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as for GPT-2

attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))

if attention_mask is not None:
Expand Down Expand Up @@ -290,7 +290,7 @@ def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 *
self.c_fc = nn.Linear(embed_dim, intermediate_size)
self.c_proj = nn.Linear(intermediate_size, embed_dim)
self.act = ACT2FN[config.activation_function]
self.dropout = nn.Dropout(config.resid_dropout)
self.dropout = nn.Dropout(float(config.resid_dropout))

def forward(self, hidden_states):
hidden_states = self.c_fc(hidden_states)
Expand Down Expand Up @@ -475,7 +475,7 @@ def __init__(self, config):
self.embed_dim = config.hidden_size
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.drop = nn.Dropout(config.embed_dropout)
self.drop = nn.Dropout(float(config.embed_dropout))
self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i) for i in range(config.num_layers)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

Expand Down Expand Up @@ -887,7 +887,7 @@ def forward(
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)

pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]

loss = None
if labels is not None:
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
def rotate_every_two(x):
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x = torch.stack((-x2, x1), axis=-1)
x = torch.stack((-x2, x1), dim=-1)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed for TorchScript. This should be ok since stack can take the dim argument since the very beginning.

return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')


Expand Down Expand Up @@ -163,7 +163,7 @@ def _attn(

# compute causal mask from causal mask buffer
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing as GPT-2.


# Keep the attention weights computation in fp32 to avoid overflow issues
query = query.to(torch.float32)
Expand Down Expand Up @@ -971,7 +971,7 @@ def forward(
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)

pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing as GPT-2.


loss = None
if labels is not None:
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/mobilebert/modeling_mobilebert.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,9 @@ def forward(
# dimensional output.
inputs_embeds = torch.cat(
[
nn.functional.pad(inputs_embeds[:, 1:], [0, 0, 0, 1, 0, 0], value=0),
nn.functional.pad(inputs_embeds[:, 1:], [0, 0, 0, 1, 0, 0], value=0.0),
inputs_embeds,
nn.functional.pad(inputs_embeds[:, :-1], [0, 0, 1, 0, 0, 0], value=0),
nn.functional.pad(inputs_embeds[:, :-1], [0, 0, 1, 0, 0, 0], value=0.0),
],
dim=2,
)
Expand Down
Loading