-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Traced models serialization and torchscripting fix #17206
Changes from all commits
5661ed4
b8e5d9a
9268f53
9309f27
37c5358
ddef9c3
af092a1
119c8c5
e107121
80c5f96
d6f4763
7d6ae49
cb9f2b9
175582a
72e6fe9
772dada
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -211,7 +211,7 @@ def unshape(x: torch.Tensor) -> torch.Tensor: | |
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head) | ||
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length) | ||
mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length) | ||
scores = scores.masked_fill(mask, -float("inf")) # (bs, n_heads, q_length, k_length) | ||
scores = scores.masked_fill(mask, torch.tensor(-float("inf"))) # (bs, n_heads, q_length, k_length) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Needed to be able to TorchScript the traced model. |
||
|
||
weights = nn.functional.softmax(scores, dim=-1) # (bs, n_heads, q_length, k_length) | ||
weights = self.dropout(weights) # (bs, n_heads, q_length, k_length) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -198,7 +198,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): | |
if not self.is_cross_attention: | ||
# if only "normal" attention layer implements causal mask | ||
query_length, key_length = query.size(-2), key.size(-2) | ||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() | ||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) | ||
|
||
if attention_mask is not None: | ||
|
@@ -1410,7 +1410,7 @@ def forward( | |
"unexpected if using padding tokens in conjunction with `inputs_embeds.`" | ||
) | ||
|
||
pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths] | ||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same reason. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, and better not rely on |
||
|
||
loss = None | ||
if labels is not None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -147,8 +147,8 @@ def __init__(self, config, attention_type): | |
self.register_buffer("bias", bias) | ||
self.register_buffer("masked_bias", torch.tensor(-1e9)) | ||
|
||
self.attn_dropout = nn.Dropout(config.attention_dropout) | ||
self.resid_dropout = nn.Dropout(config.resid_dropout) | ||
self.attn_dropout = nn.Dropout(float(config.attention_dropout)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TorchScripting fails otherwise, this should not change anything. |
||
self.resid_dropout = nn.Dropout(float(config.resid_dropout)) | ||
|
||
self.embed_dim = config.hidden_size | ||
self.num_heads = config.num_heads | ||
|
@@ -188,7 +188,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): | |
attn_weights = torch.matmul(query, key.transpose(-1, -2)) | ||
|
||
query_length, key_length = query.size(-2), key.size(-2) | ||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() | ||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as for GPT-2 |
||
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) | ||
|
||
if attention_mask is not None: | ||
|
@@ -290,7 +290,7 @@ def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * | |
self.c_fc = nn.Linear(embed_dim, intermediate_size) | ||
self.c_proj = nn.Linear(intermediate_size, embed_dim) | ||
self.act = ACT2FN[config.activation_function] | ||
self.dropout = nn.Dropout(config.resid_dropout) | ||
self.dropout = nn.Dropout(float(config.resid_dropout)) | ||
|
||
def forward(self, hidden_states): | ||
hidden_states = self.c_fc(hidden_states) | ||
|
@@ -475,7 +475,7 @@ def __init__(self, config): | |
self.embed_dim = config.hidden_size | ||
self.wte = nn.Embedding(config.vocab_size, self.embed_dim) | ||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) | ||
self.drop = nn.Dropout(config.embed_dropout) | ||
self.drop = nn.Dropout(float(config.embed_dropout)) | ||
self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i) for i in range(config.num_layers)]) | ||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) | ||
|
||
|
@@ -887,7 +887,7 @@ def forward( | |
"unexpected if using padding tokens in conjunction with `inputs_embeds.`" | ||
) | ||
|
||
pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths] | ||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] | ||
|
||
loss = None | ||
if labels is not None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -69,7 +69,7 @@ def fixed_pos_embedding(x, seq_dim=1, seq_len=None): | |
def rotate_every_two(x): | ||
x1 = x[:, :, :, ::2] | ||
x2 = x[:, :, :, 1::2] | ||
x = torch.stack((-x2, x1), axis=-1) | ||
x = torch.stack((-x2, x1), dim=-1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Needed for TorchScript. This should be ok since stack can take the |
||
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)') | ||
|
||
|
||
|
@@ -163,7 +163,7 @@ def _attn( | |
|
||
# compute causal mask from causal mask buffer | ||
query_length, key_length = query.size(-2), key.size(-2) | ||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() | ||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same thing as GPT-2. |
||
|
||
# Keep the attention weights computation in fp32 to avoid overflow issues | ||
query = query.to(torch.float32) | ||
|
@@ -971,7 +971,7 @@ def forward( | |
"unexpected if using padding tokens in conjunction with `inputs_embeds.`" | ||
) | ||
|
||
pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths] | ||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same thing as GPT-2. |
||
|
||
loss = None | ||
if labels is not None: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needed for copy consistency.