From 3a66e316c1066ff48735c3b9e73e03928a8e32ec Mon Sep 17 00:00:00 2001 From: Ravi shankar Kolli Date: Mon, 22 Mar 2021 13:48:30 -0700 Subject: [PATCH 1/2] Support for ort --- src/transformers/integrations.py | 3 ++- src/transformers/trainer.py | 6 ++++-- src/transformers/training_args.py | 4 ++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index 530f4fa65ac8..634cea5ff083 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -270,7 +270,7 @@ def rewrite_logs(d): return new_d -def init_deepspeed(trainer, model, num_training_steps): +def init_deepspeed(trainer, num_training_steps): """ Init DeepSpeed, after converting any relevant Trainer's args into DeepSpeed configuration @@ -286,6 +286,7 @@ def init_deepspeed(trainer, model, num_training_steps): args = trainer.args ds_config_file = args.deepspeed + model = trainer.model with io.open(ds_config_file, "r", encoding="utf-8") as f: config = json.load(f) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index d7d6b5176883..f2bc1b1be690 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -903,14 +903,16 @@ def train( delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE model = self.model - if self.args.ortmodule: + if self.args.ort: from onnxruntime.training import ORTModule logger.info("Converting to ORTModule ....") model = ORTModule(self.model) self.model_wrapped = model if self.args.deepspeed: + if self.args.ort: + self.model = model model, optimizer, lr_scheduler = init_deepspeed(self, model, num_training_steps=max_steps) - self.model = model.module._original_module if self.args.ortmodule else model.module + self.model = model.module._original_module if self.args.ort else model.module self.model_wrapped = model # will get further wrapped in DDP self.deepspeed = model # DeepSpeedEngine object self.optimizer = optimizer diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 594878ed1012..ce78a3f011c6 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -482,9 +482,9 @@ class TrainingArguments: default=None, metadata={"help": "Enable deepspeed and pass the path to deepspeed json config file (e.g. ds_config.json)"}, ) - ortmodule: Optional[bool] = field( + ort: Optional[bool] = field( default=False, - metadata={"help": "Enable OrtModule"}, + metadata={"help": "Enable Ort"}, ) label_smoothing_factor: float = field( default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."} From 1506ee58318a3408383f0ce4037757dfe8c5c461 Mon Sep 17 00:00:00 2001 From: Ravi shankar Kolli Date: Thu, 1 Apr 2021 21:43:53 -0700 Subject: [PATCH 2/2] Update init_deepspeed api --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f2bc1b1be690..92d84966f5bd 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -911,7 +911,7 @@ def train( if self.args.deepspeed: if self.args.ort: self.model = model - model, optimizer, lr_scheduler = init_deepspeed(self, model, num_training_steps=max_steps) + model, optimizer, lr_scheduler = init_deepspeed(self, num_training_steps=max_steps) self.model = model.module._original_module if self.args.ort else model.module self.model_wrapped = model # will get further wrapped in DDP self.deepspeed = model # DeepSpeedEngine object