Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ort to debertav2 model config #12

Merged
merged 6 commits into from
May 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/pytorch/text-classification/run_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ def main():
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
ort=True if training_args.ort else None,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
Expand Down
51 changes: 51 additions & 0 deletions src/transformers/models/deberta_v2/jit_tracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# flake8: noqa
# coding=utf-8
# Copyright 2020, Microsoft and the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Logging util @Author: penhe@microsoft.com
"""

""" Utils for torch jit tracing customer operators/functions"""
import os

import torch


def traceable(cls):
class _Function(object):
@staticmethod
def apply(*args):
if torch.onnx.is_in_onnx_export():
return cls.forward(_Function, *args)
else:
return cls.apply(*args)

@staticmethod
def save_for_backward(*args):
pass

return _Function


class TraceMode:
"""Trace context used when tracing modules contains customer operators/Functions"""

def __enter__(self):
os.environ["JIT_TRACE"] = "True"
return self

def __exit__(self, exp_value, exp_type, trace):
del os.environ["JIT_TRACE"]
65 changes: 53 additions & 12 deletions src/transformers/models/deberta_v2/modeling_deberta_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,21 @@

import numpy as np
import torch
from torch import _softmax_backward_data, nn
from torch.nn import CrossEntropyLoss, LayerNorm
from torch import (
_softmax_backward_data,
nn,
)
from torch.nn import (
CrossEntropyLoss,
LayerNorm,
)

from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from ...modeling_outputs import (
BaseModelOutput,
MaskedLMOutput,
Expand All @@ -34,6 +44,7 @@
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from .configuration_deberta_v2 import DebertaV2Config
from .jit_tracing import traceable


logger = logging.get_logger(__name__)
Expand All @@ -55,7 +66,10 @@ class ContextPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size)
self.dropout = StableDropout(config.pooler_dropout)
if config.ort:
self.dropout = TorchNNDropout(config.pooler_dropout)
else:
self.dropout = StableDropout(config.pooler_dropout)
self.config = config

def forward(self, hidden_states):
Expand All @@ -73,6 +87,7 @@ def output_dim(self):
return self.config.hidden_size


@traceable
# Copied from transformers.models.deberta.modeling_deberta.XSoftmax with deberta->deberta_v2
class XSoftmax(torch.autograd.Function):
"""
Expand Down Expand Up @@ -144,6 +159,7 @@ def get_mask(input, local_context):
return mask, dropout


@traceable
# Copied from transformers.models.deberta.modeling_deberta.XDropout
class XDropout(torch.autograd.Function):
"""Optimized dropout function to save computation and memory by using mask operation instead of multiplication."""
Expand All @@ -167,6 +183,11 @@ def backward(ctx, grad_output):
return grad_output, None


class TorchNNDropout(torch.nn.Dropout):
def __init__(self, drop_prob):
super().__init__(drop_prob)


# Copied from transformers.models.deberta.modeling_deberta.StableDropout
class StableDropout(torch.nn.Module):
"""
Expand Down Expand Up @@ -223,7 +244,10 @@ def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
self.dropout = StableDropout(config.hidden_dropout_prob)
if config.ort:
self.dropout = TorchNNDropout(config.hidden_dropout_prob)
else:
self.dropout = StableDropout(config.hidden_dropout_prob)

def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
Expand Down Expand Up @@ -291,7 +315,10 @@ def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
self.dropout = StableDropout(config.hidden_dropout_prob)
if config.ort:
self.dropout = TorchNNDropout(config.hidden_dropout_prob)
else:
self.dropout = StableDropout(config.hidden_dropout_prob)
self.config = config

def forward(self, hidden_states, input_tensor):
Expand Down Expand Up @@ -346,7 +373,10 @@ def __init__(self, config):
config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups
)
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
self.dropout = StableDropout(config.hidden_dropout_prob)
if config.ort:
self.dropout = TorchNNDropout(config.hidden_dropout_prob)
else:
self.dropout = StableDropout(config.hidden_dropout_prob)
self.config = config

def forward(self, hidden_states, residual_states, input_mask):
Expand Down Expand Up @@ -584,16 +614,21 @@ def __init__(self, config):
self.pos_ebd_size = self.max_relative_positions
if self.position_buckets > 0:
self.pos_ebd_size = self.position_buckets

self.pos_dropout = StableDropout(config.hidden_dropout_prob)
if config.ort:
self.pos_dropout = TorchNNDropout(config.hidden_dropout_prob)
else:
self.pos_dropout = StableDropout(config.hidden_dropout_prob)

if not self.share_att_key:
if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type:
self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size)

self.dropout = StableDropout(config.attention_probs_dropout_prob)
if config.ort:
self.dropout = TorchNNDropout(config.attention_probs_dropout_prob)
else:
self.dropout = StableDropout(config.attention_probs_dropout_prob)

def transpose_for_scores(self, x, attention_heads):
new_x_shape = x.size()[:-1] + (attention_heads, -1)
Expand Down Expand Up @@ -816,7 +851,10 @@ def __init__(self, config):
if self.embedding_size != config.hidden_size:
self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False)
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
self.dropout = StableDropout(config.hidden_dropout_prob)
if config.ort:
self.dropout = TorchNNDropout(config.hidden_dropout_prob)
else:
self.dropout = StableDropout(config.hidden_dropout_prob)
self.config = config

# position_ids (1, len position emb) is contiguous in memory and exported when serialized
Expand Down Expand Up @@ -1247,7 +1285,10 @@ def __init__(self, config):
self.classifier = torch.nn.Linear(output_dim, num_labels)
drop_out = getattr(config, "cls_dropout", None)
drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
self.dropout = StableDropout(drop_out)
if config.ort:
self.dropout = TorchNNDropout(drop_out)
else:
self.dropout = StableDropout(drop_out)

self.init_weights()

Expand Down