From 18f685b6b30a98d1513e1c9a05b0a235b34c435f Mon Sep 17 00:00:00 2001 From: harshithapv Date: Fri, 14 May 2021 17:48:32 +0000 Subject: [PATCH 1/6] add ort config for debertav2 model --- .../pytorch/text-classification/run_glue.py | 1 + .../models/deberta_v2/jit_tracing.py | 57 +++++++++++++++++++ .../models/deberta_v2/modeling_deberta_v2.py | 51 ++++++++++++++--- 3 files changed, 100 insertions(+), 9 deletions(-) create mode 100644 src/transformers/models/deberta_v2/jit_tracing.py diff --git a/examples/pytorch/text-classification/run_glue.py b/examples/pytorch/text-classification/run_glue.py index cd8c6a94aefc..dc81604456cb 100755 --- a/examples/pytorch/text-classification/run_glue.py +++ b/examples/pytorch/text-classification/run_glue.py @@ -301,6 +301,7 @@ def main(): cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, + ort=True if training_args.ort else None, ) tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, diff --git a/src/transformers/models/deberta_v2/jit_tracing.py b/src/transformers/models/deberta_v2/jit_tracing.py new file mode 100644 index 000000000000..633310c1e85d --- /dev/null +++ b/src/transformers/models/deberta_v2/jit_tracing.py @@ -0,0 +1,57 @@ +""" +Logging util +@Author: penhe@microsoft.com +""" + +""" Utils for torch jit tracing customer operators/functions +""" +import os +import torch +''' +def traceable(cls): + """ Decorator over customer functions + There is an issue for tracing customer python torch Function, using this decorator to work around it. + e.g. + @traceable + class MyOp(torch.autograd.Function): + xxx + """ + class _Function(object): + @staticmethod + def apply(*args): + jit_trace = (os.getenv('JIT_TRACE', 'False').lower() == 'true') + if jit_trace: + return cls.forward(_Function, *args) + else: + return cls.apply(*args) + @staticmethod + def save_for_backward(*args): + pass + return _Function +''' + +def traceable(cls): + class _Function(object): + @staticmethod + def apply(*args): + if torch.onnx.is_in_onnx_export(): + return cls.forward(_Function, *args) + else: + return cls.apply(*args) + + @staticmethod + def save_for_backward(*args): + pass + return _Function + + +class TraceMode(): + """ Trace context used when tracing modules contains customer operators/Functions + """ + def __enter__(self): + os.environ['JIT_TRACE'] = 'True' + return self + + def __exit__(self, exp_value, exp_type, trace): + del os.environ['JIT_TRACE'] + diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index 03563b02b913..793c08cfb63b 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -35,6 +35,7 @@ from ...utils import logging from .configuration_deberta_v2 import DebertaV2Config +from .jit_tracing import traceable logger = logging.get_logger(__name__) @@ -55,7 +56,10 @@ class ContextPooler(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size) - self.dropout = StableDropout(config.pooler_dropout) + if config.ort: + self.dropout = TorchNNDropout(config.pooler_dropout) + else: + self.dropout = StableDropout(config.pooler_dropout) self.config = config def forward(self, hidden_states): @@ -74,6 +78,7 @@ def output_dim(self): # Copied from transformers.models.deberta.modeling_deberta.XSoftmax with deberta->deberta_v2 +@traceable class XSoftmax(torch.autograd.Function): """ Masked Softmax which is optimized for saving memory @@ -145,6 +150,7 @@ def get_mask(input, local_context): # Copied from transformers.models.deberta.modeling_deberta.XDropout +@traceable class XDropout(torch.autograd.Function): """Optimized dropout function to save computation and memory by using mask operation instead of multiplication.""" @@ -167,6 +173,11 @@ def backward(ctx, grad_output): return grad_output, None +class TorchNNDropout(torch.nn.Dropout): + def __init__(self, drop_prob): + print("using torch.nn.Dropout") + super().__init__(drop_prob) + # Copied from transformers.models.deberta.modeling_deberta.StableDropout class StableDropout(torch.nn.Module): """ @@ -177,6 +188,7 @@ class StableDropout(torch.nn.Module): """ def __init__(self, drop_prob): + print("using StableDropout") super().__init__() self.drop_prob = drop_prob self.count = 0 @@ -223,7 +235,11 @@ def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) - self.dropout = StableDropout(config.hidden_dropout_prob) + if config.ort: + self.dropout = TorchNNDropout(config.hidden_dropout_prob) + else: + self.dropout = StableDropout(config.hidden_dropout_prob) + def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) @@ -291,7 +307,10 @@ def __init__(self, config): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) - self.dropout = StableDropout(config.hidden_dropout_prob) + if config.ort: + self.dropout = TorchNNDropout(config.hidden_dropout_prob) + else: + self.dropout = StableDropout(config.hidden_dropout_prob) self.config = config def forward(self, hidden_states, input_tensor): @@ -346,7 +365,10 @@ def __init__(self, config): config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups ) self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) - self.dropout = StableDropout(config.hidden_dropout_prob) + if config.ort: + self.dropout = TorchNNDropout(config.hidden_dropout_prob) + else: + self.dropout = StableDropout(config.hidden_dropout_prob) self.config = config def forward(self, hidden_states, residual_states, input_mask): @@ -584,8 +606,10 @@ def __init__(self, config): self.pos_ebd_size = self.max_relative_positions if self.position_buckets > 0: self.pos_ebd_size = self.position_buckets - - self.pos_dropout = StableDropout(config.hidden_dropout_prob) + if config.ort: + self.pos_dropout = TorchNNDropout(config.hidden_dropout_prob) + else: + self.pos_dropout = StableDropout(config.hidden_dropout_prob) if not self.share_att_key: if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: @@ -593,7 +617,10 @@ def __init__(self, config): if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size) - self.dropout = StableDropout(config.attention_probs_dropout_prob) + if config.ort: + self.dropout = TorchNNDropout(config.attention_probs_dropout_prob) + else: + self.dropout = StableDropout(config.attention_probs_dropout_prob) def transpose_for_scores(self, x, attention_heads): new_x_shape = x.size()[:-1] + (attention_heads, -1) @@ -816,7 +843,10 @@ def __init__(self, config): if self.embedding_size != config.hidden_size: self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False) self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) - self.dropout = StableDropout(config.hidden_dropout_prob) + if config.ort: + self.dropout = TorchNNDropout(config.hidden_dropout_prob) + else: + self.dropout = StableDropout(config.hidden_dropout_prob) self.config = config # position_ids (1, len position emb) is contiguous in memory and exported when serialized @@ -1247,7 +1277,10 @@ def __init__(self, config): self.classifier = torch.nn.Linear(output_dim, num_labels) drop_out = getattr(config, "cls_dropout", None) drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out - self.dropout = StableDropout(drop_out) + if config.ort: + self.dropout = TorchNNDropout(drop_out) + else: + self.dropout = StableDropout(drop_out) self.init_weights() From 2a1aaf13442701b5763258371a045c6289d39d55 Mon Sep 17 00:00:00 2001 From: harshithapv Date: Fri, 14 May 2021 18:15:41 +0000 Subject: [PATCH 2/6] remove prints --- src/transformers/models/deberta_v2/modeling_deberta_v2.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index 793c08cfb63b..57243db2c136 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -175,7 +175,6 @@ def backward(ctx, grad_output): class TorchNNDropout(torch.nn.Dropout): def __init__(self, drop_prob): - print("using torch.nn.Dropout") super().__init__(drop_prob) # Copied from transformers.models.deberta.modeling_deberta.StableDropout @@ -188,7 +187,6 @@ class StableDropout(torch.nn.Module): """ def __init__(self, drop_prob): - print("using StableDropout") super().__init__() self.drop_prob = drop_prob self.count = 0 From 81969f9b7b16f3f34571f5f7f6dbc0c4d885f20d Mon Sep 17 00:00:00 2001 From: harshithapv Date: Fri, 14 May 2021 18:43:16 +0000 Subject: [PATCH 3/6] remove old commented code --- .../models/deberta_v2/jit_tracing.py | 22 ------------------- 1 file changed, 22 deletions(-) diff --git a/src/transformers/models/deberta_v2/jit_tracing.py b/src/transformers/models/deberta_v2/jit_tracing.py index 633310c1e85d..a9c35f1741d5 100644 --- a/src/transformers/models/deberta_v2/jit_tracing.py +++ b/src/transformers/models/deberta_v2/jit_tracing.py @@ -7,28 +7,6 @@ """ import os import torch -''' -def traceable(cls): - """ Decorator over customer functions - There is an issue for tracing customer python torch Function, using this decorator to work around it. - e.g. - @traceable - class MyOp(torch.autograd.Function): - xxx - """ - class _Function(object): - @staticmethod - def apply(*args): - jit_trace = (os.getenv('JIT_TRACE', 'False').lower() == 'true') - if jit_trace: - return cls.forward(_Function, *args) - else: - return cls.apply(*args) - @staticmethod - def save_for_backward(*args): - pass - return _Function -''' def traceable(cls): class _Function(object): From 5202810f67dd1783807b2a1136178d0eb4f499b6 Mon Sep 17 00:00:00 2001 From: harshithapv Date: Fri, 14 May 2021 22:40:08 +0000 Subject: [PATCH 4/6] fix run style error --- .../models/deberta_v2/jit_tracing.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/deberta_v2/jit_tracing.py b/src/transformers/models/deberta_v2/jit_tracing.py index a9c35f1741d5..79081dfea6b8 100644 --- a/src/transformers/models/deberta_v2/jit_tracing.py +++ b/src/transformers/models/deberta_v2/jit_tracing.py @@ -1,3 +1,18 @@ +# coding=utf-8 +# Copyright 2020, Microsoft and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ Logging util @Author: penhe@microsoft.com @@ -31,5 +46,4 @@ def __enter__(self): return self def __exit__(self, exp_value, exp_type, trace): - del os.environ['JIT_TRACE'] - + del os.environ['JIT_TRACE'] \ No newline at end of file From e5eee95a31f9b4136ac6819781b1d01ccb698a1e Mon Sep 17 00:00:00 2001 From: harshithapv Date: Fri, 14 May 2021 23:09:54 +0000 Subject: [PATCH 5/6] add flake ignore comment --- src/transformers/models/deberta_v2/jit_tracing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/deberta_v2/jit_tracing.py b/src/transformers/models/deberta_v2/jit_tracing.py index 79081dfea6b8..a6b57ca8787c 100644 --- a/src/transformers/models/deberta_v2/jit_tracing.py +++ b/src/transformers/models/deberta_v2/jit_tracing.py @@ -1,3 +1,4 @@ +# flake8: noqa # coding=utf-8 # Copyright 2020, Microsoft and the HuggingFace Inc. team. # From 1723a1a1262a83e113f7b244bce402f28f297818 Mon Sep 17 00:00:00 2001 From: harshithapv Date: Tue, 18 May 2021 05:09:30 +0000 Subject: [PATCH 6/6] trial to fix blackify format error --- .../models/deberta_v2/jit_tracing.py | 55 ++++++++++--------- .../models/deberta_v2/modeling_deberta_v2.py | 28 +++++++--- 2 files changed, 47 insertions(+), 36 deletions(-) diff --git a/src/transformers/models/deberta_v2/jit_tracing.py b/src/transformers/models/deberta_v2/jit_tracing.py index a6b57ca8787c..c2fd9a0323e1 100644 --- a/src/transformers/models/deberta_v2/jit_tracing.py +++ b/src/transformers/models/deberta_v2/jit_tracing.py @@ -15,36 +15,37 @@ # limitations under the License. """ -Logging util -@Author: penhe@microsoft.com +Logging util @Author: penhe@microsoft.com """ -""" Utils for torch jit tracing customer operators/functions -""" +""" Utils for torch jit tracing customer operators/functions""" import os + import torch + def traceable(cls): - class _Function(object): - @staticmethod - def apply(*args): - if torch.onnx.is_in_onnx_export(): - return cls.forward(_Function, *args) - else: - return cls.apply(*args) - - @staticmethod - def save_for_backward(*args): - pass - return _Function - - -class TraceMode(): - """ Trace context used when tracing modules contains customer operators/Functions - """ - def __enter__(self): - os.environ['JIT_TRACE'] = 'True' - return self - - def __exit__(self, exp_value, exp_type, trace): - del os.environ['JIT_TRACE'] \ No newline at end of file + class _Function(object): + @staticmethod + def apply(*args): + if torch.onnx.is_in_onnx_export(): + return cls.forward(_Function, *args) + else: + return cls.apply(*args) + + @staticmethod + def save_for_backward(*args): + pass + + return _Function + + +class TraceMode: + """Trace context used when tracing modules contains customer operators/Functions""" + + def __enter__(self): + os.environ["JIT_TRACE"] = "True" + return self + + def __exit__(self, exp_value, exp_type, trace): + del os.environ["JIT_TRACE"] diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index 57243db2c136..260ea9f91989 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -19,11 +19,21 @@ import numpy as np import torch -from torch import _softmax_backward_data, nn -from torch.nn import CrossEntropyLoss, LayerNorm +from torch import ( + _softmax_backward_data, + nn, +) +from torch.nn import ( + CrossEntropyLoss, + LayerNorm, +) from ...activations import ACT2FN -from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward +from ...file_utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, +) from ...modeling_outputs import ( BaseModelOutput, MaskedLMOutput, @@ -34,9 +44,9 @@ from ...modeling_utils import PreTrainedModel from ...utils import logging from .configuration_deberta_v2 import DebertaV2Config - from .jit_tracing import traceable + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "DebertaV2Config" @@ -77,8 +87,8 @@ def output_dim(self): return self.config.hidden_size -# Copied from transformers.models.deberta.modeling_deberta.XSoftmax with deberta->deberta_v2 @traceable +# Copied from transformers.models.deberta.modeling_deberta.XSoftmax with deberta->deberta_v2 class XSoftmax(torch.autograd.Function): """ Masked Softmax which is optimized for saving memory @@ -149,8 +159,8 @@ def get_mask(input, local_context): return mask, dropout -# Copied from transformers.models.deberta.modeling_deberta.XDropout @traceable +# Copied from transformers.models.deberta.modeling_deberta.XDropout class XDropout(torch.autograd.Function): """Optimized dropout function to save computation and memory by using mask operation instead of multiplication.""" @@ -174,8 +184,9 @@ def backward(ctx, grad_output): class TorchNNDropout(torch.nn.Dropout): - def __init__(self, drop_prob): - super().__init__(drop_prob) + def __init__(self, drop_prob): + super().__init__(drop_prob) + # Copied from transformers.models.deberta.modeling_deberta.StableDropout class StableDropout(torch.nn.Module): @@ -237,7 +248,6 @@ def __init__(self, config): self.dropout = TorchNNDropout(config.hidden_dropout_prob) else: self.dropout = StableDropout(config.hidden_dropout_prob) - def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states)