Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Examples] TPU-based training of a language model using TensorFlow #21657

Merged
merged 29 commits into from
Apr 14, 2023

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Feb 16, 2023

This PR adds an example of performing (masked) language model training using TensorFlow and TPUs. The example is meant to act as a reference for the community on this topic. The following are the main components of the PR:

  • Tokenizer training script (for completeness)
  • TFRecords preparation script (recommended practice when using TPUs)
  • Training script
  • Evaluation / inference

The purpose of this separation (as opposed to having everything in a single script) is to allow the community to have isolated reference points for performing TPU-based training of our models, which I think is beneficial.

The artifacts produced during this project can be found here: https://huggingface.co/tf-tpu.

Cc: @Rocketknight1 @gante @amyeroberts

@sayakpaul sayakpaul added TensorFlow Anything TensorFlow Examples Which is related to examples in general TPU labels Feb 16, 2023
)
parser.add_argument(
"-vs",
"--vocab_size",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should play around with this a bit to see if a multiple of 64 actually helps improve the efficiency. Reference: https://twitter.com/karpathy/status/1621578354024677377?s=20

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can just use a multiple of 64 anyway, it's not really a big change! The next multiple of 64 after 10000 is 10048.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to me to me to train the tokenizer and redo the TFRecords with that?

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Feb 16, 2023

The documentation is not available anymore as the PR was closed or merged.

Comment on lines 42 to 47
parser.add_argument(
"--shard_size",
type=int,
default=1000,
help="Number of entries to go in a single shard.",
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should likely follow some advice from this guide to decide this number when running things at the full scale.

@sayakpaul
Copy link
Member Author

@Rocketknight1 I incorporated the group_texts() utility that we discussed over Slack. Let me know if the changes look good to you. Most of it is copy-pasted from here.

Here's Colab Notebook where I verified these.

@sayakpaul
Copy link
Member Author

@Rocketknight1 I took a deeper look into the TFRecord preparation script. I don't understand why there's a discrepancy in the following.

While serializing the TFRecords, I am making each TFRecord shard has got a specific number of samples. When there are lesser samples for a TFRecord shard than the specified amount, that's fine.

But when I load the TFRecords back and create a tf.data.Dataset out of them, the number of entries in the dataset (before batching) is much lesser.

Here is a minimal Colab Notebook that demonstrates the issue: https://colab.research.google.com/gist/sayakpaul/b4b02f3f656c0041c93f6ba78c8e65fd/scratchpad.ipynb.

When you get a moment, could you take a look?

@sayakpaul
Copy link
Member Author

Thanks @Rocketknight1 for your help in debugging #21657 (comment) (discussed internally via Slack). I am currently regenerating the TFRecord shards. I will update here once that's done.

@sayakpaul
Copy link
Member Author

@Rocketknight1 corrected TFRecord shards have been pushed to gs://tf-tpu-training-resources.

Here are the record counts per split:

  • Train: 300917
  • Validation: 626
  • Test: 722

The TFRecords were generated with a block size of 512.

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
@sayakpaul
Copy link
Member Author

@Rocketknight1 the training code looks good to me, except for a few things:

  • Maybe we should scale the LR with the batch size?
  • Take mlm_probability as a CLI arg?
  • Modularize the dataset preparation code a bit?

But all these are non-blockers. Let's do 4 - 5 training runs varying the number of epochs and the learning rate.

@Rocketknight1
Copy link
Member

@sayakpaul MLM probability added as an arg and I modularized the loading!

@sayakpaul
Copy link
Member Author

sayakpaul commented Mar 25, 2023

@Rocketknight1 started a training run with:

python3 train_model.py \
  --tokenizer tf-tpu/unigram-tokenizer-wikitext \
  --per_replica_batch_size 64 \
  --tpu_name local --tpu_zone us-central1 --gcp_project huggingface-ml --bfloat16 \
  --train_dataset gs://tf-tpu-training-resources/train --eval_dataset gs://tf-tpu-training-resources/validation \
  --num_epochs 100 \
  --output_dir roberta-base-epochs-100 --hub_model_id tf-tpu/roberta-base-epochs-100

@sayakpaul
Copy link
Member Author

sayakpaul commented Mar 26, 2023

@Rocketknight1 here's the final model trained with the command from here:

https://huggingface.co/tf-tpu/roberta-base-epochs-100

When you try out examples in the widget of the model page ^, pass [MASK] instead of the default <mask>. The results are far from perfect (evident from the validation accuracy), though.

@sayakpaul
Copy link
Member Author

@Rocketknight1 could you review this PR?

@sayakpaul sayakpaul marked this pull request as ready for review April 12, 2023 05:59
@sayakpaul sayakpaul requested review from gante and sgugger April 12, 2023 06:01
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for working on this! I left a couple of comments.

examples/tensorflow/tpu/language-modeling/README.md Outdated Show resolved Hide resolved
Comment on lines 257 to 261
special_tokens_mask = (
~tf.cast(batch["attention_mask"], tf.bool)
| (batch["input_ids"] == tokenizer.cls_token_id)
| (batch["input_ids"] == tokenizer.sep_token_id)
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not have the tokenizer return the special_token_mask instead of computing it manually here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought I was being clever but not storing all that data in the TFRecords, but you're right that it's probably just extra complexity. Let me fix it!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, on second thoughts, fixing it would require regenerating and reuploading the whole dataset and then updating the training loop too. Think it's worth it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessarily, but it would be cleaner if you ever do a v2.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noted, will do!

examples/tensorflow/tpu/language-modeling/train_unigram.py Outdated Show resolved Hide resolved
@sayakpaul
Copy link
Member Author

@sgugger thanks!

I addressed your comments. For #21657 (comment), I will defer to @Rocketknight1.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔥

@sayakpaul
Copy link
Member Author

Merging since the failing tests are unrelated.

@sayakpaul sayakpaul merged commit 390e121 into main Apr 14, 2023
@sayakpaul sayakpaul deleted the examples/tf-tpu branch April 14, 2023 05:11
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
…uggingface#21657)

* add: tokenizer training script for TF TPU LM training.

* add: script for preparing the TFRecord shards.

* add: sequence of execution to readme.

* remove limit from the tfrecord shard name.

* Add initial train_model.py

* Add basic training arguments and model init

* Get up to the point of writing the data collator

* Pushing progress so far!

* Complete first draft of model training code

* feat: grouping of texts efficiently.

Co-authored-by: Matt <rocketknight1@gmail.com>

* Add proper masking collator and get training loop working

* fix: things.

* Read sample counts from filenames

* Read sample counts from filenames

* Draft README

* Improve TPU warning

* Use distribute instead of distribute.experimental

* Apply suggestions from code review

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>

* Modularize loading and add MLM probability as arg

* minor refactoring to better use the cli args.

* readme fillup.

* include tpu and inference sections in the readme.

* table of contents.

* parallelize maps.

* polish readme.

* change script name to run_mlm.py

* address PR feedback (round I).

---------

Co-authored-by: Matt <rocketknight1@gmail.com>
Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Examples Which is related to examples in general TensorFlow Anything TensorFlow TPU
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants