Skip to content

Commit

Permalink
Merge pull request huggingface#18 from ROCmSoftwarePlatform/pnunna_or…
Browse files Browse the repository at this point in the history
…t_HF

Enable ORT for HuggingFace workloads
  • Loading branch information
amathews-amd authored Oct 14, 2022
2 parents 24b288f + f36ff0c commit 1003b4b
Show file tree
Hide file tree
Showing 11 changed files with 127 additions and 27 deletions.
1 change: 1 addition & 0 deletions examples/pytorch/language-modeling/run_mlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def main():
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
"ort": True if training_args.ort else None,
}
if model_args.config_name:
config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
Expand Down
1 change: 1 addition & 0 deletions examples/pytorch/question-answering/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ def main():
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
ort = training_args.ort,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
Expand Down
1 change: 1 addition & 0 deletions examples/pytorch/summarization/run_summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ def main():
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
ort=True if training_args.ort else None,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
Expand Down
1 change: 1 addition & 0 deletions examples/pytorch/text-classification/run_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ def main():
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
ort=True if training_args.ort else None,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
Expand Down
1 change: 1 addition & 0 deletions examples/pytorch/translation/run_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def main():
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
ort=True if training_args.ort else None,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,10 @@ class PretrainedConfig(PushToHubMixin):
Whether the model should use legacy TensorFlow losses. Legacy losses have variable output shapes and may
not be XLA-compatible. This option is here for backward compatibility and will be removed in Transformers
v5.
Onnxruntime specific parameters
- **ort** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether or not the model should use ORT.
"""
model_type: str = ""
is_composition: bool = False
Expand All @@ -249,6 +253,7 @@ def __init__(self, **kwargs):
self.torch_dtype = kwargs.pop("torch_dtype", None) # Only used by PyTorch models
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
self.tf_legacy_loss = kwargs.pop("tf_legacy_loss", False) # Only used by TensorFlow models
self.ort = kwargs.pop("ort", False)
self.pruned_heads = kwargs.pop("pruned_heads", {})
self.tie_word_embeddings = kwargs.pop(
"tie_word_embeddings", True
Expand Down
45 changes: 35 additions & 10 deletions src/transformers/models/deberta_v2/modeling_deberta_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ class ContextPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size)
self.dropout = StableDropout(config.pooler_dropout)
if config.ort:
self.dropout = TorchNNDropout(config.pooler_dropout)
else:
self.dropout = StableDropout(config.pooler_dropout)
self.config = config

def forward(self, hidden_states):
Expand Down Expand Up @@ -168,7 +171,6 @@ def get_mask(input, local_context):

return mask, dropout


# Copied from transformers.models.deberta.modeling_deberta.XDropout
class XDropout(torch.autograd.Function):
"""Optimized dropout function to save computation and memory by using mask operation instead of multiplication."""
Expand Down Expand Up @@ -208,6 +210,9 @@ def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: Union[float, D
# return torch.onnx.symbolic_opset9.dropout(g, input, dropout_p, train)
return symbolic_opset12.dropout(g, input, dropout_p, train)

class TorchNNDropout(torch.nn.Dropout):
def __init__(self, drop_prob):
super().__init__(drop_prob)

# Copied from transformers.models.deberta.modeling_deberta.StableDropout
class StableDropout(nn.Module):
Expand Down Expand Up @@ -265,7 +270,10 @@ def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
self.dropout = StableDropout(config.hidden_dropout_prob)
if config.ort:
self.dropout = TorchNNDropout(config.hidden_dropout_prob)
else:
self.dropout = StableDropout(config.hidden_dropout_prob)

def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
Expand Down Expand Up @@ -333,7 +341,10 @@ def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
self.dropout = StableDropout(config.hidden_dropout_prob)
if config.ort:
self.dropout = TorchNNDropout(config.hidden_dropout_prob)
else:
self.dropout = StableDropout(config.hidden_dropout_prob)
self.config = config

def forward(self, hidden_states, input_tensor):
Expand Down Expand Up @@ -388,7 +399,10 @@ def __init__(self, config):
config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups
)
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
self.dropout = StableDropout(config.hidden_dropout_prob)
if config.ort:
self.dropout = TorchNNDropout(config.hidden_dropout_prob)
else:
self.dropout = StableDropout(config.hidden_dropout_prob)
self.config = config

def forward(self, hidden_states, residual_states, input_mask):
Expand Down Expand Up @@ -645,16 +659,21 @@ def __init__(self, config):
self.pos_ebd_size = self.max_relative_positions
if self.position_buckets > 0:
self.pos_ebd_size = self.position_buckets

self.pos_dropout = StableDropout(config.hidden_dropout_prob)
if config.ort:
self.pos_dropout = TorchNNDropout(config.hidden_dropout_prob)
else:
self.pos_dropout = StableDropout(config.hidden_dropout_prob)

if not self.share_att_key:
if "c2p" in self.pos_att_type:
self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
if "p2c" in self.pos_att_type:
self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size)

self.dropout = StableDropout(config.attention_probs_dropout_prob)
if config.ort:
self.dropout = TorchNNDropout(config.attention_probs_dropout_prob)
else:
self.dropout = StableDropout(config.attention_probs_dropout_prob)

def transpose_for_scores(self, x, attention_heads):
new_x_shape = x.size()[:-1] + (attention_heads, -1)
Expand Down Expand Up @@ -845,7 +864,10 @@ def __init__(self, config):
if self.embedding_size != config.hidden_size:
self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False)
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
self.dropout = StableDropout(config.hidden_dropout_prob)
if config.ort:
self.dropout = TorchNNDropout(config.hidden_dropout_prob)
else:
self.dropout = StableDropout(config.hidden_dropout_prob)
self.config = config

# position_ids (1, len position emb) is contiguous in memory and exported when serialized
Expand Down Expand Up @@ -1258,7 +1280,10 @@ def __init__(self, config):
self.classifier = nn.Linear(output_dim, num_labels)
drop_out = getattr(config, "cls_dropout", None)
drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
self.dropout = StableDropout(drop_out)
if config.ort:
self.dropout = TorchNNDropout(drop_out)
else:
self.dropout = StableDropout(drop_out)

# Initialize weights and apply final processing
self.post_init()
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/roberta/modeling_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -1472,6 +1472,7 @@ class RobertaForQuestionAnswering(RobertaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.ort = config.ort

self.roberta = RobertaModel(config, add_pooling_layer=False)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
Expand Down Expand Up @@ -1541,7 +1542,7 @@ def forward(
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
ignored_index = start_logits.size(1) if not self.ort else 344
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)

Expand Down
76 changes: 61 additions & 15 deletions src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,37 @@ def forward(self, hidden_states):

ALL_LAYERNORM_LAYERS.append(T5LayerNorm)

class T5ClampedDropout(nn.Module):
def __init__(self, config):
super().__init__()
self.ort = config.ort
self.dropout = nn.Dropout(config.dropout_rate)
self.dropout_rate = config.dropout_rate

def forward(self, hidden_states):
# clamp inf values to enable fp16 training
if self.ort:
# Remove data-based control flow for static graph
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(torch.isinf(hidden_states).any(), torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max)
clamp_value = (1.0-self.dropout_rate)*clamp_value
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
else:
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

hidden_states = self.dropout(hidden_states)
return hidden_states


class T5DenseActDense(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
self.dropout = T5ClampedDropout(config)
self.act = ACT2FN[config.dense_act_fn]

def forward(self, hidden_states):
Expand All @@ -300,7 +324,7 @@ def __init__(self, config: T5Config):
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
self.dropout = T5ClampedDropout(config)
self.act = ACT2FN[config.dense_act_fn]

def forward(self, hidden_states):
Expand All @@ -321,7 +345,7 @@ def __init__(self, config: T5Config):
self.DenseReluDense = T5DenseActDense(config)

self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
self.dropout = T5ClampedDropout(config)

def forward(self, hidden_states):
forwarded_states = self.layer_norm(hidden_states)
Expand Down Expand Up @@ -556,7 +580,7 @@ def __init__(self, config, has_relative_attention_bias=False):
super().__init__()
self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
self.dropout = T5ClampedDropout(config)

def forward(
self,
Expand Down Expand Up @@ -588,7 +612,7 @@ def __init__(self, config):
super().__init__()
self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
self.dropout = T5ClampedDropout(config)

def forward(
self,
Expand Down Expand Up @@ -623,6 +647,7 @@ class T5Block(nn.Module):
def __init__(self, config, has_relative_attention_bias=False):
super().__init__()
self.is_decoder = config.is_decoder
self.ort = config.ort
self.layer = nn.ModuleList()
self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
if self.is_decoder:
Expand Down Expand Up @@ -676,9 +701,16 @@ def forward(
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights

# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
if self.ort:
# Remove data-based control flow for static graph
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(torch.isinf(hidden_states).any(), torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
else:
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

do_cross_attention = self.is_decoder and encoder_hidden_states is not None
if do_cross_attention:
Expand All @@ -703,9 +735,16 @@ def forward(
hidden_states = cross_attention_outputs[0]

# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
if self.ort:
# Remove data-based control flow for static graph
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(torch.isinf(hidden_states).any(), torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
else:
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

# Combine self attn and cross attn key value states
if present_key_value_state is not None:
Expand All @@ -718,9 +757,16 @@ def forward(
hidden_states = self.layer[-1](hidden_states)

# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
if self.ort:
# Remove data-based control flow for static graph
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(torch.isinf(hidden_states).any(), torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
else:
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

outputs = (hidden_states,)

Expand Down Expand Up @@ -841,7 +887,7 @@ def __init__(self, config, embed_tokens=None):
[T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
)
self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
self.dropout = T5ClampedDropout(config)

# Initialize weights and apply final processing
self.post_init()
Expand Down
14 changes: 13 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1316,7 +1316,12 @@ def _wrap_model(self, model, training=True, dataloader=None):

# train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
if unwrap_model(model) is not model:
return model
if self.args.ort:
from torch_ort import ORTModule
if type(model) is not ORTModule:
return model
else:
return model

# Mixed precision training with apex (torch < 1.6)
if self.use_apex and training:
Expand Down Expand Up @@ -1570,7 +1575,14 @@ def _inner_training_loop(
or is_sagemaker_mp_enabled()
or self.fsdp is not None
)
if args.ort:
from torch_ort import ORTModule
logger.info("Converting to ORTModule ....")
model = ORTModule(self.model)
self.model_wrapped = model
if args.deepspeed:
if args.ort:
self.model = model
deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint
)
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,8 @@ class TrainingArguments:
Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may
evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
`ds_config.json`) or an already loaded json file as a `dict`"
ortmodule (:obj:`bool`, `optional`):
Use `ORTModule <https://github.com/microsoft/onnxruntime>`__.
label_smoothing_factor (`float`, *optional*, defaults to 0.0):
The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded
labels are changed from 0s and 1s to `label_smoothing_factor/num_labels` and `1 - label_smoothing_factor +
Expand Down Expand Up @@ -823,6 +825,10 @@ class TrainingArguments:
)
},
)
ort: Optional[bool] = field(
default=False,
metadata={"help": "Enable Ort"},
)
label_smoothing_factor: float = field(
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
)
Expand Down

0 comments on commit 1003b4b

Please sign in to comment.