diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index 530f4fa65ac854..634cea5ff0836f 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -270,7 +270,7 @@ def rewrite_logs(d): return new_d -def init_deepspeed(trainer, model, num_training_steps): +def init_deepspeed(trainer, num_training_steps): """ Init DeepSpeed, after converting any relevant Trainer's args into DeepSpeed configuration @@ -286,6 +286,7 @@ def init_deepspeed(trainer, model, num_training_steps): args = trainer.args ds_config_file = args.deepspeed + model = trainer.model with io.open(ds_config_file, "r", encoding="utf-8") as f: config = json.load(f) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index d7d6b51768832c..92d84966f5bdd7 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -903,14 +903,16 @@ def train( delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE model = self.model - if self.args.ortmodule: + if self.args.ort: from onnxruntime.training import ORTModule logger.info("Converting to ORTModule ....") model = ORTModule(self.model) self.model_wrapped = model if self.args.deepspeed: - model, optimizer, lr_scheduler = init_deepspeed(self, model, num_training_steps=max_steps) - self.model = model.module._original_module if self.args.ortmodule else model.module + if self.args.ort: + self.model = model + model, optimizer, lr_scheduler = init_deepspeed(self, num_training_steps=max_steps) + self.model = model.module._original_module if self.args.ort else model.module self.model_wrapped = model # will get further wrapped in DDP self.deepspeed = model # DeepSpeedEngine object self.optimizer = optimizer diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 594878ed10128b..ce78a3f011c63f 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -482,9 +482,9 @@ class TrainingArguments: default=None, metadata={"help": "Enable deepspeed and pass the path to deepspeed json config file (e.g. ds_config.json)"}, ) - ortmodule: Optional[bool] = field( + ort: Optional[bool] = field( default=False, - metadata={"help": "Enable OrtModule"}, + metadata={"help": "Enable Ort"}, ) label_smoothing_factor: float = field( default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}