Skip to content

Commit

Permalink
resolved comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Adam Louly committed Feb 23, 2023
1 parent 040d77e commit c45bd53
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions optimum/onnxruntime/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union
from tqdm.auto import tqdm


# Integrations must be imported before ML frameworks:
Expand Down Expand Up @@ -51,7 +50,6 @@
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
from transformers.deepspeed import deepspeed_init, is_deepspeed_zero3_enabled
from transformers.dependency_versions_check import dep_version_check
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.file_utils import (
is_apex_available,
is_sagemaker_dp_enabled,
Expand Down Expand Up @@ -146,6 +144,8 @@ def __init__(self, model, args) -> None:
super().__init__()
self._original_model = model
self.args = args

# Creating an instance of huggingFace Trainer so we can use compute_loss() logic and avoid duplicated code.
self.hf_trainer = Trainer(model)
# Label smoothing
if self.args.label_smoothing_factor != 0:
Expand Down
2 changes: 1 addition & 1 deletion optimum/onnxruntime/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class ORTTrainingArguments(TrainingArguments):

loss_in_train: Optional[bool] = field(
default=False,
metadata={"help": "Use ModuleWithLoss Wrapper to compute loss inside the training loop."},
metadata={"help": "Use ModuleWithLoss Wrapper to compute loss inside the training loop, when label smoother is NOT none having this will help save memory for ORTMOdule Runs."},
)

# This method will not need to be overriden after the deprecation of `--adafactor` in version 5 of 🤗 Transformers.
Expand Down

0 comments on commit c45bd53

Please sign in to comment.