Skip to content

Commit

Permalink
Support for ort
Browse files Browse the repository at this point in the history
  • Loading branch information
raviskolli committed Mar 31, 2021
1 parent 369ef06 commit 3a66e31
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
3 changes: 2 additions & 1 deletion src/transformers/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def rewrite_logs(d):
return new_d


def init_deepspeed(trainer, model, num_training_steps):
def init_deepspeed(trainer, num_training_steps):
"""
Init DeepSpeed, after converting any relevant Trainer's args into DeepSpeed configuration
Expand All @@ -286,6 +286,7 @@ def init_deepspeed(trainer, model, num_training_steps):

args = trainer.args
ds_config_file = args.deepspeed
model = trainer.model

with io.open(ds_config_file, "r", encoding="utf-8") as f:
config = json.load(f)
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,14 +903,16 @@ def train(

delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE
model = self.model
if self.args.ortmodule:
if self.args.ort:
from onnxruntime.training import ORTModule
logger.info("Converting to ORTModule ....")
model = ORTModule(self.model)
self.model_wrapped = model
if self.args.deepspeed:
if self.args.ort:
self.model = model
model, optimizer, lr_scheduler = init_deepspeed(self, model, num_training_steps=max_steps)
self.model = model.module._original_module if self.args.ortmodule else model.module
self.model = model.module._original_module if self.args.ort else model.module
self.model_wrapped = model # will get further wrapped in DDP
self.deepspeed = model # DeepSpeedEngine object
self.optimizer = optimizer
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,9 +482,9 @@ class TrainingArguments:
default=None,
metadata={"help": "Enable deepspeed and pass the path to deepspeed json config file (e.g. ds_config.json)"},
)
ortmodule: Optional[bool] = field(
ort: Optional[bool] = field(
default=False,
metadata={"help": "Enable OrtModule"},
metadata={"help": "Enable Ort"},
)
label_smoothing_factor: float = field(
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
Expand Down

0 comments on commit 3a66e31

Please sign in to comment.