From 53c407db24ca7ee5ee53eab413d3185915bdee5a Mon Sep 17 00:00:00 2001 From: Konstantin Date: Thu, 13 Oct 2022 21:41:34 +0000 Subject: [PATCH] Add recipe_name to default file names --- src/transformers/models/bert/modeling_bert.py | 25 +++++++++++++++++-- src/transformers/utils/__init__.py | 1 + 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 11664f66cba8e7..fe79c728ee6cfe 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -240,6 +240,22 @@ def forward( return embeddings +class QATMatMul(nn.Module): + def __init__(self): + super().__init__() + + # behaves like normal torch.matmul unless a SparseML QuantizationModifier + # is initialized + self.wrap_qat = True + self.qat_wrapper_kwargs = { + "num_inputs": 2, + "input_qconfigs": ["asymmetric", "symmetric"], + } + + def forward(self, a: torch.Tensor, b: torch.Tensor): + return torch.matmul(a, b) + + class BertSelfAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() @@ -257,6 +273,11 @@ def __init__(self, config, position_embedding_type=None): self.key = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size) + # non-parameterized matmuls will behave as normal torch.matmul ops unless + # Quantization-Aware-Training is invoked + self.attention_scores_matmul = QATMatMul() + self.context_layer_matmul = QATMatMul() + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.position_embedding_type = position_embedding_type or getattr( config, "position_embedding_type", "absolute" @@ -320,7 +341,7 @@ def forward( past_key_value = (key_layer, value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = self.attention_scores_matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": seq_length = hidden_states.size()[1] @@ -354,7 +375,7 @@ def forward( if head_mask is not None: attention_probs = attention_probs * head_mask - context_layer = torch.matmul(attention_probs, value_layer) + context_layer = self.context_layer_matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 2269f225485820..caa323cc5e40ed 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -165,6 +165,7 @@ CONFIG_NAME = "config.json" FEATURE_EXTRACTOR_NAME = "preprocessor_config.json" MODEL_CARD_NAME = "modelcard.json" +RECIPE_NAME = "recipe.yaml" SENTENCEPIECE_UNDERLINE = "▁" SPIECE_UNDERLINE = SENTENCEPIECE_UNDERLINE # Kept for backward compatibility