diff --git a/optimum/onnxruntime/trainer.py b/optimum/onnxruntime/trainer.py index d8f4aedd85..29b3183862 100644 --- a/optimum/onnxruntime/trainer.py +++ b/optimum/onnxruntime/trainer.py @@ -20,6 +20,7 @@ import shutil import sys import time +import types import warnings from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union @@ -134,6 +135,32 @@ SCALER_NAME = "scaler.pt" +class ModuleWithLoss(nn.Module): + def __init__(self, model, args, label_smoother): + super().__init__() + self._original_model = model + self.args = args + # Label smoothing + self.label_smoother = label_smoother + + def forward(self, inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs): + # The compute_model_plus_loss_internal is assigned once the class is instantiated. + # It should have same signature as Trainer.compute_loss(). + # We do this to avoid potential un-synced states if we duplicated compute loss codes . + return self.compute_model_plus_loss_internal(self._original_model, inputs, return_outputs) + + @property + def module(self): + """The original `torch.nn.Module` that this module wraps. + This property provides access to methods and properties on the original module.""" + + return self._original_model.module + + @property + def config(self): + return self._original_model.config + + class ORTFeaturesManager: _TASKS_TO_ORTMODELS = { "default": ORTModelForFeatureExtraction, @@ -279,12 +306,55 @@ def __init__( preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) + # We leverage both training_model and inference_model in conjunction with model. + # _training_model will be wrapped so it will use ORT and will use the overriden functions in ModuleWithLoss. + # _training_model will be storing the default version of the model and will unwrap it in case of eval/test. + + # Only Wrap the model if we pass --use_module_with_loss flag. + if args.use_module_with_loss: + self._training_model = self.create_model_with_loss() + + self.model = model + self.feature = feature self.onnx_model_path = onnx_model_path self.exported_with_loss = False if self.args.local_rank: torch.cuda.set_device(self.args.local_rank) + # this method will create a ModuleWithLoss Instance to use if you are passing --use_module_with_loss flag. + # It will help reducing the peak memory usage by computing loss inside training. + def create_model_with_loss(self): + model_with_loss = ModuleWithLoss(self.model, self.args, self.label_smoother) + model_with_loss.compute_model_plus_loss_internal = types.MethodType(Trainer.compute_loss, model_with_loss) + + return model_with_loss + + # we assume that training_model and inference_model have the same forward signature column. + # self._signature_columns attribute only stores the first-time parsed signature + def _set_signature_columns_if_needed(self): + if self._signature_columns is None: + # Inspect model forward signature to keep only the arguments it accepts. + import inspect + + if isinstance(self.model, ModuleWithLoss): + signature = inspect.signature(self.model._original_model.forward) + else: + signature = inspect.signature(self.model.forward) + + self._signature_columns = list(signature.parameters.keys()) + # Labels may be named label or label_ids, the default data collator handles that. + self._signature_columns += list(set(["label", "label_ids"] + self.label_names)) + + def compute_loss(self, model_with_loss, inputs, return_outputs=False): + # Run model forward + loss compute. + if isinstance(self.model, ModuleWithLoss): + # ORTModule Does not support the BatchEncoding Type so we have to convert to a dict. + dict_inputs = dict(inputs.items()) + return model_with_loss(dict_inputs, return_outputs) + else: + return super().compute_loss(model_with_loss, inputs, return_outputs) + def train( self, resume_from_checkpoint: Optional[Union[str, bool]] = None, @@ -313,6 +383,8 @@ def train( "You need to install `onnxruntime-training` to use `ORTTrainer` for training. Check out " "https://huggingface.co/docs/optimum/onnxruntime/usage_guides/trainer#install-onnx-runtime." ) + if self.args.use_module_with_loss: + self.model = self._training_model if resume_from_checkpoint is False: resume_from_checkpoint = None @@ -801,6 +873,8 @@ def evaluate( dictionary also contains the epoch number which comes from the training state. """ # memory metrics - must set up as early as possible + # TODO: We need to enable evaluation using ORT backend. + self.model = unwrap_model(self.model) self._memory_tracker.start() eval_dataloader = self.get_eval_dataloader(eval_dataset) @@ -892,6 +966,9 @@ def predict( - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained labels). """ + # TODO: We need to enable evaluation using ORT backend. + self.model = unwrap_model(self.model) + # memory metrics - must set up as early as possible self._memory_tracker.start() @@ -909,10 +986,7 @@ def predict( try: output = eval_loop( - test_dataloader, - description="Prediction", - ignore_keys=ignore_keys, - metric_key_prefix=metric_key_prefix, + test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix ) except Exception as error: logger.error(error) @@ -1522,13 +1596,7 @@ def _export( opset = max(opset, 12) # Operators like `nll_loss`are added for opset>=12 output_path = model_path / ONNX_WEIGHTS_NAME - _ = export( - model=model, - config=onnx_config, - opset=opset, - output=output_path, - device=device, - ) + _ = export(model=model, config=onnx_config, opset=opset, output=output_path, device=device) model.config.save_pretrained(model_path) @@ -1697,11 +1765,7 @@ def create_optimizer(self): optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) if self.sharded_ddp == ShardedDDPOption.SIMPLE: - self.optimizer = OSS( - params=optimizer_grouped_parameters, - optim=optimizer_cls, - **optimizer_kwargs, - ) + self.optimizer = OSS(params=optimizer_grouped_parameters, optim=optimizer_cls, **optimizer_kwargs) else: self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) if optimizer_cls.__name__ == "Adam8bit": @@ -1731,10 +1795,7 @@ def get_ort_optimizer_cls_and_kwargs(args: ORTTrainingArguments) -> Tuple[Any, A The training arguments for the training session. """ optimizer_kwargs = {"lr": args.learning_rate} - adam_kwargs = { - "betas": (args.adam_beta1, args.adam_beta2), - "eps": args.adam_epsilon, - } + adam_kwargs = {"betas": (args.adam_beta1, args.adam_beta2), "eps": args.adam_epsilon} if args.optim == ORTOptimizerNames.ADAMW_ORT_FUSED: try: from onnxruntime.training.optim import FusedAdam diff --git a/optimum/onnxruntime/training_args.py b/optimum/onnxruntime/training_args.py index 155974d4d3..bf16b37b72 100644 --- a/optimum/onnxruntime/training_args.py +++ b/optimum/onnxruntime/training_args.py @@ -60,9 +60,13 @@ class ORTTrainingArguments(TrainingArguments): The optimizer to use, including optimizers in Transformers: adamw_hf, adamw_torch, adamw_apex_fused, or adafactor. And optimizers implemented by ONNX Runtime: adamw_ort_fused. """ - optim: Optional[str] = field( - default="adamw_hf", - metadata={"help": "The optimizer to use."}, + optim: Optional[str] = field(default="adamw_hf", metadata={"help": "The optimizer to use."}) + + use_module_with_loss: Optional[bool] = field( + default=False, + metadata={ + "help": "Use ModuleWithLoss Wrapper to compute loss inside the training loop, having this will help save memory for ORTModule Runs." + }, ) # This method will not need to be overriden after the deprecation of `--adafactor` in version 5 of 🤗 Transformers. @@ -336,3 +340,10 @@ def __post_init__(self): f"{self.hub_model_id}).", FutureWarning, ) + if self.use_module_with_loss is True: + logger.info( + "Using ModuleWithLoss Wrapper." + "loss will be computed during training loop and it will save memory peak " + ) + else: + logger.info("Not Using ModuleWithLoss Wrapper.")