Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF Examples Rewrite #18451

Merged
merged 36 commits into from
Aug 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
d1cf71c
Finished QA example
Rocketknight1 Aug 3, 2022
58176d1
Dodge a merge conflict
Rocketknight1 Aug 3, 2022
78f5371
Update text classification and LM examples
Rocketknight1 Aug 3, 2022
70ec8f5
Update NER example
Rocketknight1 Aug 3, 2022
2bd4df9
New Keras metrics WIP, fix NER example
Rocketknight1 Aug 4, 2022
a6d6340
Update NER example
Rocketknight1 Aug 5, 2022
5945d38
Update MC, summarization and translation examples
Rocketknight1 Aug 5, 2022
4f92e88
Add XLA warnings when shapes are variable
Rocketknight1 Aug 5, 2022
69bef1a
Make sure batch_size is consistently scaled by num_replicas
Rocketknight1 Aug 8, 2022
b84bbbb
Add PushToHubCallback to all models
Rocketknight1 Aug 8, 2022
d300fb9
Add docs links for KerasMetricCallback
Rocketknight1 Aug 8, 2022
31e0dfa
Add docs links for prepare_tf_dataset and jit_compile
Rocketknight1 Aug 8, 2022
d98dd9b
Correct inferred model names
Rocketknight1 Aug 8, 2022
0560e76
Don't assume the dataset has 'lang'
Rocketknight1 Aug 8, 2022
1da5d9b
Don't assume the dataset has 'lang'
Rocketknight1 Aug 8, 2022
a1c4a2d
Write metrics in text classification
Rocketknight1 Aug 8, 2022
82047f8
Add 'framework' to TrainingArguments and TFTrainingArguments
Rocketknight1 Aug 9, 2022
d53ec42
Export metrics in all examples and add tests
Rocketknight1 Aug 9, 2022
91feab0
Fix training args for Flax
Rocketknight1 Aug 9, 2022
5ccafe6
Update command line args for translation test
Rocketknight1 Aug 9, 2022
5141255
make fixup
Rocketknight1 Aug 9, 2022
541e9f3
Fix accidentally running other tests in fp16
Rocketknight1 Aug 9, 2022
6f59c12
Remove do_train/do_eval from run_clm.py
Rocketknight1 Aug 9, 2022
8f72fa2
Remove do_train/do_eval from run_mlm.py
Rocketknight1 Aug 9, 2022
d1a95ce
Add tensorflow tests to circleci
Rocketknight1 Aug 10, 2022
77453fb
Fix circleci
Rocketknight1 Aug 10, 2022
46e1998
Update examples/tensorflow/language-modeling/run_mlm.py
Rocketknight1 Aug 10, 2022
0b0bb96
Update examples/tensorflow/test_tensorflow_examples.py
Rocketknight1 Aug 10, 2022
debcaed
Update examples/tensorflow/translation/run_translation.py
Rocketknight1 Aug 10, 2022
c8ca4bb
Update examples/tensorflow/token-classification/run_ner.py
Rocketknight1 Aug 10, 2022
fb238e8
Fix save path for tests
Rocketknight1 Aug 10, 2022
4922a22
Fix some model card kwargs
Rocketknight1 Aug 10, 2022
66a6b8b
Explain the magical -1000
Rocketknight1 Aug 10, 2022
25399bf
Actually enable tests this time
Rocketknight1 Aug 10, 2022
80654eb
Skip text classification PR until we fix shape inference
Rocketknight1 Aug 10, 2022
9e0b471
make fixup
Rocketknight1 Aug 10, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,71 @@ jobs:
- store_artifacts:
path: ~/transformers/reports

run_examples_tensorflow:
working_directory: ~/transformers
docker:
- image: cimg/python:3.7.12
environment:
OMP_NUM_THREADS: 1
TRANSFORMERS_IS_CI: yes
PYTEST_TIMEOUT: 120
resource_class: xlarge
parallelism: 1
steps:
- checkout
- restore_cache:
keys:
- v0.5-tensorflow_examples-{{ checksum "setup.py" }}
- v0.5-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install .[sklearn,tensorflow,sentencepiece,testing]
- run: pip install -r examples/tensorflow/_tests_requirements.txt
- save_cache:
key: v0.5-tensorflow_examples-{{ checksum "setup.py" }}
paths:
- '~/.cache/pip'
- run: python utils/tests_fetcher.py --filters examples tests | tee test_preparation.txt
- store_artifacts:
path: ~/transformers/test_preparation.txt
- run: |
if [ -f test_list.txt ]; then
python -m pytest -n 8 --max-worker-restart=0 --dist=loadfile -s --make-reports=examples_tensorflow ./examples/tensorflow/ | tee tests_output.txt
fi
- store_artifacts:
path: ~/transformers/tensorflow_examples_output.txt
- store_artifacts:
path: ~/transformers/reports

run_examples_tensorflow_all:
working_directory: ~/transformers
docker:
- image: cimg/python:3.7.12
environment:
OMP_NUM_THREADS: 1
TRANSFORMERS_IS_CI: yes
PYTEST_TIMEOUT: 120
resource_class: xlarge
parallelism: 1
steps:
- checkout
- restore_cache:
keys:
- v0.5-tensorflow_examples-{{ checksum "setup.py" }}
- v0.5-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install .[sklearn,tensorflow,sentencepiece,testing]
- run: pip install -r examples/tensorflow/_tests_requirements.txt
- save_cache:
key: v0.5-tensorflow_examples-{{ checksum "setup.py" }}
paths:
- '~/.cache/pip'
- run: |
TRANSFORMERS_IS_CI=1 python -m pytest -n 8 --max-worker-restart=0 --dist=loadfile -s --make-reports=examples_tensorflow ./examples/tensorflow/ | tee examples_output.txt
- store_artifacts:
path: ~/transformers/tensorflow_examples_output.txt
- store_artifacts:
path: ~/transformers/reports

run_examples_flax:
working_directory: ~/transformers
docker:
Expand Down Expand Up @@ -1000,6 +1065,7 @@ workflows:
- check_code_quality
- check_repository_consistency
- run_examples_torch
- run_examples_tensorflow
- run_examples_flax
- run_tests_custom_tokenizers
- run_tests_torch_and_tf
Expand All @@ -1022,6 +1088,7 @@ workflows:
- main
jobs:
- run_examples_torch_all
- run_examples_tensorflow_all
- run_examples_flax_all
- run_tests_torch_and_tf_all
- run_tests_torch_and_flax_all
Expand Down
25 changes: 25 additions & 0 deletions examples/tensorflow/_tests_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
tensorflow
tensorboard
scikit-learn
seqeval
psutil
sacrebleu >= 1.4.12
git+https://github.com/huggingface/accelerate@main#egg=accelerate
rouge-score
tensorflow_datasets
matplotlib
git-python==1.0.3
faiss-cpu
streamlit
elasticsearch
nltk
pandas
datasets >= 1.13.3
fire
pytest
conllu
sentencepiece != 0.1.92
protobuf
jiwer
librosa
evaluate >= 0.2.0
159 changes: 114 additions & 45 deletions examples/tensorflow/language-modeling/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
"""
# You can also adapt this script on your own clm task. Pointers for this are left as comments.

import json

# region Imports
import logging
import math
Expand All @@ -46,8 +48,8 @@
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
AutoConfig,
AutoTokenizer,
DefaultDataCollator,
HfArgumentParser,
PushToHubCallback,
TFAutoModelForCausalLM,
TFTrainingArguments,
create_optimizer,
Expand Down Expand Up @@ -205,21 +207,6 @@ def __post_init__(self):
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."


# endregion

# region Helper classes
class SavePretrainedCallback(tf.keras.callbacks.Callback):
# Hugging Face models have a save_pretrained() method that saves both the weights and the necessary
# metadata to allow them to be loaded as a pretrained model in future. This is a simple Keras callback
# that saves the model with this method after each epoch.
def __init__(self, output_dir, **kwargs):
super().__init__()
self.output_dir = output_dir

def on_epoch_end(self, epoch, logs=None):
self.model.save_pretrained(self.output_dir)


# endregion


Expand Down Expand Up @@ -299,19 +286,22 @@ def main():
raw_datasets = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
)
if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
)
raw_datasets["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
)
else:
Expand All @@ -321,16 +311,39 @@ def main():
data_files["train"] = data_args.train_file
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.train_file.split(".")[-1]
extension = (
data_args.train_file.split(".")[-1]
if data_args.train_file is not None
else data_args.validation_file.split(".")[-1]
)
if extension == "txt":
extension = "text"
dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
raw_datasets = load_dataset(
extension,
data_files=data_files,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
**dataset_args,
)
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset(
extension,
data_files=data_files,
split=f"train[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
**dataset_args,
)
raw_datasets["train"] = load_dataset(
extension,
data_files=data_files,
split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
**dataset_args,
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# endregion
Expand Down Expand Up @@ -446,7 +459,7 @@ def group_texts(examples):
eval_dataset = eval_dataset.select(range(max_eval_samples))

# Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3):
for index in random.sample(range(len(train_dataset)), min(3, len(train_dataset))):
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
# endregion

Expand All @@ -465,44 +478,88 @@ def group_texts(examples):

# region TF Dataset preparation
num_replicas = training_args.strategy.num_replicas_in_sync
data_collator = DefaultDataCollator(return_tensors="tf")
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF

tf_train_dataset = train_dataset.to_tf_dataset(
# labels are passed as input, as we will use the model's internal loss
columns=[col for col in train_dataset.features if col != "special_tokens_mask"],
# model.prepare_tf_dataset() wraps a Hugging Face dataset in a tf.data.Dataset which is ready to use in
# training. This is the recommended way to use a Hugging Face dataset when training with Keras. You can also
# use the lower-level dataset.to_tf_dataset() method, but you will have to specify things like column names
# yourself if you use this method, whereas they are automatically inferred from the model input names when
# using model.prepare_tf_dataset()
# For more info see the docs:
# https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.TFPreTrainedModel.prepare_tf_dataset
# https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Dataset.to_tf_dataset

tf_train_dataset = model.prepare_tf_dataset(
train_dataset,
shuffle=True,
batch_size=num_replicas * training_args.per_device_train_batch_size,
collate_fn=data_collator,
drop_remainder=True,
).with_options(options)

tf_eval_dataset = eval_dataset.to_tf_dataset(
# labels are passed as input, as we will use the model's internal loss
columns=[col for col in eval_dataset.features if col != "special_tokens_mask"],
tf_eval_dataset = model.prepare_tf_dataset(
eval_dataset,
shuffle=False,
batch_size=num_replicas * training_args.per_device_train_batch_size,
collate_fn=data_collator,
batch_size=num_replicas * training_args.per_device_eval_batch_size,
drop_remainder=True,
).with_options(options)
# endregion

# region Optimizer and loss
batches_per_epoch = len(train_dataset) // (num_replicas * training_args.per_device_train_batch_size)
num_train_steps = len(tf_train_dataset) * int(training_args.num_train_epochs)
if training_args.warmup_steps > 0:
num_warmup_steps = training_args.warmup_steps
elif training_args.warmup_ratio > 0:
num_warmup_steps = int(num_train_steps * training_args.warmup_ratio)
else:
num_warmup_steps = 0

# Bias and layernorm weights are automatically excluded from the decay
optimizer, lr_schedule = create_optimizer(
init_lr=training_args.learning_rate,
num_train_steps=int(training_args.num_train_epochs * batches_per_epoch),
num_warmup_steps=training_args.warmup_steps,
num_train_steps=num_train_steps,
num_warmup_steps=num_warmup_steps,
adam_beta1=training_args.adam_beta1,
adam_beta2=training_args.adam_beta2,
adam_epsilon=training_args.adam_epsilon,
weight_decay_rate=training_args.weight_decay,
adam_global_clipnorm=training_args.max_grad_norm,
)

# no user-specified loss = will use the model internal loss
model.compile(optimizer=optimizer)
model.compile(optimizer=optimizer, jit_compile=training_args.xla)
# endregion

# region Preparing push_to_hub and model card
push_to_hub_model_id = training_args.push_to_hub_model_id
model_name = model_args.model_name_or_path.split("/")[-1]
if not push_to_hub_model_id:
if data_args.dataset_name is not None:
push_to_hub_model_id = f"{model_name}-finetuned-{data_args.dataset_name}"
else:
push_to_hub_model_id = f"{model_name}-finetuned-clm"

model_card_kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"}
if data_args.dataset_name is not None:
model_card_kwargs["dataset_tags"] = data_args.dataset_name
if data_args.dataset_config_name is not None:
model_card_kwargs["dataset_args"] = data_args.dataset_config_name
model_card_kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
else:
model_card_kwargs["dataset"] = data_args.dataset_name

if training_args.push_to_hub:
callbacks = [
PushToHubCallback(
output_dir=training_args.output_dir,
model_id=push_to_hub_model_id,
organization=training_args.push_to_hub_organization,
token=training_args.push_to_hub_token,
tokenizer=tokenizer,
**model_card_kwargs,
)
]
else:
callbacks = []
# endregion

# region Training and validation
Expand All @@ -512,33 +569,45 @@ def group_texts(examples):
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
logger.info(f" Total train batch size = {training_args.per_device_train_batch_size * num_replicas}")

# For long training runs, you may wish to use the PushToHub() callback here to save intermediate checkpoints
# to the Hugging Face Hub rather than just pushing the finished model.
# See https://huggingface.co/docs/transformers/main_classes/keras_callbacks#transformers.PushToHubCallback

history = model.fit(
tf_train_dataset,
validation_data=tf_eval_dataset,
epochs=int(training_args.num_train_epochs),
steps_per_epoch=len(train_dataset) // (training_args.per_device_train_batch_size * num_replicas),
callbacks=[SavePretrainedCallback(output_dir=training_args.output_dir)],
callbacks=callbacks,
)
train_loss = history.history["loss"][-1]
try:
train_perplexity = math.exp(history.history["loss"][-1])
train_perplexity = math.exp(train_loss)
except OverflowError:
train_perplexity = math.inf
logger.info(f" Final train loss: {train_loss:.3f}")
logger.info(f" Final train perplexity: {train_perplexity:.3f}")
validation_loss = history.history["val_loss"][-1]
try:
validation_perplexity = math.exp(history.history["val_loss"][-1])
validation_perplexity = math.exp(validation_loss)
except OverflowError:
validation_perplexity = math.inf
logger.info(f" Final train loss: {history.history['loss'][-1]:.3f}")
logger.info(f" Final train perplexity: {train_perplexity:.3f}")
logger.info(f" Final validation loss: {history.history['val_loss'][-1]:.3f}")
logger.info(f" Final validation loss: {validation_loss:.3f}")
logger.info(f" Final validation perplexity: {validation_perplexity:.3f}")
# endregion

if training_args.output_dir is not None:
model.save_pretrained(training_args.output_dir)
output_eval_file = os.path.join(training_args.output_dir, "all_results.json")
results_dict = dict()
results_dict["train_loss"] = train_loss
results_dict["train_perplexity"] = train_perplexity
results_dict["eval_loss"] = validation_loss
results_dict["eval_perplexity"] = validation_perplexity
with open(output_eval_file, "w") as writer:
writer.write(json.dumps(results_dict))
# endregion

if training_args.push_to_hub:
# You'll probably want to include some of your own metadata here!
model.push_to_hub()
if training_args.output_dir is not None and not training_args.push_to_hub:
# If we're not pushing to hub, at least save a local copy when we're done
model.save_pretrained(training_args.output_dir)


if __name__ == "__main__":
Expand Down
Loading