Skip to content

Commit

Permalink
Add recipe_name to default file names
Browse files Browse the repository at this point in the history
  • Loading branch information
KSGulin committed Oct 13, 2022
1 parent bd469c4 commit 53c407d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
25 changes: 23 additions & 2 deletions src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,22 @@ def forward(
return embeddings


class QATMatMul(nn.Module):
def __init__(self):
super().__init__()

# behaves like normal torch.matmul unless a SparseML QuantizationModifier
# is initialized
self.wrap_qat = True
self.qat_wrapper_kwargs = {
"num_inputs": 2,
"input_qconfigs": ["asymmetric", "symmetric"],
}

def forward(self, a: torch.Tensor, b: torch.Tensor):
return torch.matmul(a, b)


class BertSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
Expand All @@ -257,6 +273,11 @@ def __init__(self, config, position_embedding_type=None):
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)

# non-parameterized matmuls will behave as normal torch.matmul ops unless
# Quantization-Aware-Training is invoked
self.attention_scores_matmul = QATMatMul()
self.context_layer_matmul = QATMatMul()

self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
Expand Down Expand Up @@ -320,7 +341,7 @@ def forward(
past_key_value = (key_layer, value_layer)

# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = self.attention_scores_matmul(query_layer, key_layer.transpose(-1, -2))

if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
seq_length = hidden_states.size()[1]
Expand Down Expand Up @@ -354,7 +375,7 @@ def forward(
if head_mask is not None:
attention_probs = attention_probs * head_mask

context_layer = torch.matmul(attention_probs, value_layer)
context_layer = self.context_layer_matmul(attention_probs, value_layer)

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@
CONFIG_NAME = "config.json"
FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
MODEL_CARD_NAME = "modelcard.json"
RECIPE_NAME = "recipe.yaml"

SENTENCEPIECE_UNDERLINE = "▁"
SPIECE_UNDERLINE = SENTENCEPIECE_UNDERLINE # Kept for backward compatibility
Expand Down

0 comments on commit 53c407d

Please sign in to comment.