Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ORT fused adam optimizer #295

Merged
merged 2 commits into from
Jul 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions examples/onnxruntime/training/language-modeling/run_mlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@
AutoTokenizer,
DataCollatorForLanguageModeling,
HfArgumentParser,
Trainer,
TrainingArguments,
is_torch_tpu_available,
set_seed,
)
Expand All @@ -50,6 +48,7 @@
from transformers.utils.versions import require_version

from optimum.onnxruntime import ORTTrainer
from optimum.onnxruntime.training_args import ORTTrainingArguments


# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
Expand Down Expand Up @@ -217,7 +216,7 @@ def main():
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.

parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, ORTTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
Expand Down
19 changes: 14 additions & 5 deletions examples/onnxruntime/training/question-answering/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,20 @@
EvalPrediction,
HfArgumentParser,
PreTrainedTokenizerFast,
TrainingArguments,
default_data_collator,
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.versions import require_version

from optimum.onnxruntime.training_args import ORTTrainingArguments
from trainer_qa import QuestionAnsweringORTTrainer
from utils_qa import postprocess_qa_predictions


# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.12.0")
check_min_version("4.20.0")

require_version(
"datasets>=1.8.0", "To fix: pip install -r examples/onnxruntime/training/question-answering/requirements.txt"
Expand Down Expand Up @@ -206,7 +206,7 @@ def main():
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.

parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, ORTTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
Expand Down Expand Up @@ -265,7 +265,10 @@ def main():
if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
raw_datasets = load_dataset(
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir
data_args.dataset_name,
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
)
else:
data_files = {}
Expand All @@ -278,7 +281,13 @@ def main():
if data_args.test_file is not None:
data_files["test"] = data_args.test_file
extension = data_args.test_file.split(".")[-1]
raw_datasets = load_dataset(extension, data_files=data_files, field="data", cache_dir=model_args.cache_dir)
raw_datasets = load_dataset(
extension,
data_files=data_files,
field="data",
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.

Expand Down
32 changes: 25 additions & 7 deletions examples/onnxruntime/training/text-classification/run_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
EvalPrediction,
HfArgumentParser,
PretrainedConfig,
TrainingArguments,
default_data_collator,
set_seed,
)
Expand All @@ -46,10 +45,11 @@
from transformers.utils.versions import require_version

from optimum.onnxruntime import ORTTrainer
from optimum.onnxruntime.training_args import ORTTrainingArguments


# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.15.0")
check_min_version("4.20.0")

require_version(
"datasets>=1.8.0", "To fix: pip install -r examples/onnxruntime/training/text-classification/requirements.txt"
Expand Down Expand Up @@ -195,7 +195,7 @@ def main():
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.

parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, ORTTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
Expand Down Expand Up @@ -256,11 +256,19 @@ def main():
# download the dataset.
if data_args.task_name is not None:
# Downloading and loading a dataset from the hub.
raw_datasets = load_dataset("glue", data_args.task_name, cache_dir=model_args.cache_dir)
raw_datasets = load_dataset(
"glue",
data_args.task_name,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
)
elif data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
raw_datasets = load_dataset(
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir
data_args.dataset_name,
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
)
else:
# Loading a dataset from your local files.
Expand All @@ -285,10 +293,20 @@ def main():

if data_args.train_file.endswith(".csv"):
# Loading a dataset from local csv files
raw_datasets = load_dataset("csv", data_files=data_files, cache_dir=model_args.cache_dir)
raw_datasets = load_dataset(
"csv",
data_files=data_files,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
)
else:
# Loading a dataset from local json files
raw_datasets = load_dataset("json", data_files=data_files, cache_dir=model_args.cache_dir)
raw_datasets = load_dataset(
"json",
data_files=data_files,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
)
# See more about loading any type of standard or custom dataset at
# https://huggingface.co/docs/datasets/loading_datasets.html.

Expand Down
11 changes: 7 additions & 4 deletions examples/onnxruntime/training/token-classification/run_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,18 @@
HfArgumentParser,
PretrainedConfig,
PreTrainedTokenizerFast,
TrainingArguments,
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.versions import require_version

from optimum.onnxruntime import ORTTrainer
from optimum.onnxruntime.training_args import ORTTrainingArguments


# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.15.0")
check_min_version("4.20.0")

require_version(
"datasets>=1.18.0", "To fix: pip install -r examples/onnxruntime/training/token-classification/requirements.txt"
Expand Down Expand Up @@ -193,7 +193,7 @@ def main():
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.

parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, ORTTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
Expand Down Expand Up @@ -252,7 +252,10 @@ def main():
if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
raw_datasets = load_dataset(
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir
data_args.dataset_name,
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
)
else:
data_files = {}
Expand Down
18 changes: 13 additions & 5 deletions examples/onnxruntime/training/translation/run_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
MBart50TokenizerFast,
MBartTokenizer,
MBartTokenizerFast,
Seq2SeqTrainingArguments,
default_data_collator,
set_seed,
)
Expand All @@ -48,10 +47,11 @@
from transformers.utils.versions import require_version

from optimum.onnxruntime import ORTSeq2SeqTrainer
from optimum.onnxruntime.training_args_seq2seq import ORTSeq2SeqTrainingArguments


# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.15.0")
check_min_version("4.20.0")

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")

Expand Down Expand Up @@ -235,7 +235,7 @@ def main():
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.

parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, ORTSeq2SeqTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
Expand Down Expand Up @@ -306,7 +306,10 @@ def main():
if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
raw_datasets = load_dataset(
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir
data_args.dataset_name,
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
)
else:
data_files = {}
Expand All @@ -319,7 +322,12 @@ def main():
if data_args.test_file is not None:
data_files["test"] = data_args.test_file
extension = data_args.test_file.split(".")[-1]
raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
raw_datasets = load_dataset(
extension,
data_files=data_files,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.

Expand Down
4 changes: 4 additions & 0 deletions optimum/onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
"quantization": ["ORTQuantizer"],
"trainer": ["ORTTrainer"],
"trainer_seq2seq": ["ORTSeq2SeqTrainer"],
"training_args": ["ORTTrainingArguments"],
"training_args_seq2seq": ["ORTSeq2SeqTrainingArguments"],
"utils": [
"ONNX_DECODER_NAME",
"ONNX_DECODER_WITH_PAST_NAME",
Expand Down Expand Up @@ -59,6 +61,8 @@
from .quantization import ORTQuantizer
from .trainer import ORTTrainer
from .trainer_seq2seq import ORTSeq2SeqTrainer
from .training_args import ORTTrainingArguments
from .training_args_seq2seq import ORTSeq2SeqTrainingArguments
from .utils import (
ONNX_DECODER_NAME,
ONNX_DECODER_WITH_PAST_NAME,
Expand Down
Loading