Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use barriers to reduce duplicate work/resources #9

Merged
merged 1 commit into from
Apr 1, 2020

Conversation

jysohn23
Copy link

@jysohn23 jysohn23 commented Apr 1, 2020

No description provided.

@jysohn23 jysohn23 requested a review from taylanbil April 1, 2020 00:50
Copy link

@taylanbil taylanbil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

much cleaner this way, thanks Daniel.

if xm.is_master_ordinal():
model_to_save.config.save_pretrained(save_directory)
# xm.save takes care of saving only from master
xm.save(model_to_save.state_dict(), output_model_file)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idea for the future; you could use the checkpoint tagger later, if they mark which chpt is the best etc.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I know, they don't have such tagger but let me see if they do. Good idea we could even upstream if they don't. Thanks.

@@ -341,6 +355,9 @@ def main(args):
label_list = processor.get_labels()
num_labels = len(label_list)

if not xm.is_master_ordinal():
xm.rendezvous('download_only_once') # Make sure only the first process in distributed training will download model & vocab

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming no tensor work is done in the methods b/w download_only_once rendezvous's

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, correct.

@jysohn23 jysohn23 merged commit 6d17e91 into pytorch-tpu:tpu Apr 1, 2020
@jysohn23 jysohn23 deleted the tpu branch April 1, 2020 17:22
jysohn23 added a commit to jysohn23/transformers that referenced this pull request Apr 10, 2020
* Initial commit to get BERT + run_glue.py on TPU

* Add README section for TPU and address comments.

* Cleanup TPU bits from run_glue.py (pytorch-tpu#3)

TPU runner is currently implemented in:
https://github.com/pytorch-tpu/transformers/blob/tpu/examples/run_glue_tpu.py.

We plan to upstream this directly into `huggingface/transformers`
(either `master` or `tpu`) branch once it's been more thoroughly tested.

* Cleanup TPU bits from run_glue.py

TPU runner is currently implemented in:
https://github.com/pytorch-tpu/transformers/blob/tpu/examples/run_glue_tpu.py.

We plan to upstream this directly into `huggingface/transformers`
(either `master` or `tpu`) branch once it's been more thoroughly tested.

* No need to call `xm.mark_step()` explicitly (pytorch-tpu#4)

Since for gradient accumulation we're accumulating on batches from
`ParallelLoader` instance which on next() marks the step itself.

* Resolve R/W conflicts from multiprocessing (pytorch-tpu#5)

* Add XLNet in list of models for `run_glue_tpu.py` (pytorch-tpu#6)

* Add RoBERTa to list of models in TPU GLUE (pytorch-tpu#7)

* Add RoBERTa and DistilBert to list of models in TPU GLUE (pytorch-tpu#8)

* Use barriers to reduce duplicate work/resources (pytorch-tpu#9)

* Shard eval dataset and aggregate eval metrics (pytorch-tpu#10)

* Shard eval dataset and aggregate eval metrics

Also, instead of calling `eval_loss.item()` every time do summation with
tensors on device.

* Change defaultdict to float

* Reduce the pred, label tensors instead of metrics

As brought up during review some metrics like f1 cannot be aggregated
via averaging. GLUE task metrics depends largely on the dataset, so
instead we sync the prediction and label tensors so that the metrics can
be computed accurately on those instead.

* Only use tb_writer from master (pytorch-tpu#11)

* Apply huggingface black code formatting

* Style

* Remove `--do_lower_case` as example uses cased

* Add option to specify tensorboard logdir

This is needed for our testing framework which checks regressions
against key metrics writtern by the summary writer.

* Using configuration for `xla_device`

* Prefix TPU specific comments.

* num_cores clarification and namespace eval metrics

* Cache features file under `args.cache_dir`

Instead of under `args.data_dir`. This is needed as our test infra uses
data_dir with a read-only filesystem.

* Rename `run_glue_tpu` to `run_tpu_glue`

Co-authored-by: LysandreJik <lysandre.debut@reseau.eseo.fr>
alanwaketan pushed a commit that referenced this pull request Apr 17, 2024
* Cohere Model Release (#1)

Cohere Model Release

* Remove unnecessary files and code (#2)

Some cleanup

* Delete cohere-model directory (#3)

* Make Fix (#5)

* Pr fixes (#6)

* fixes for pr

* pr fixes for the format

* pr fixes for the format

* src/transformers/models/auto/tokenization_auto.py

* Tokenizer test (#8)

* tokenizer test

* format fix

* Adding Docs and other minor changes (#7)

* Add modeling tests (#9)

* Smol Fix (#11)

* tokenization tests are fixed

* format fixes

* fix pr doc tests

* fix pr doc tests

* fix pr doc tests

* fix pr style check

* small changes in cohere.md

* FIX: Address final comments for transformers integration (#13)

* fix modeling final nits and add proper test file

* for now leave empty tests

* add integration test

* push new test

* fix modeling cohere (#14)

* Update chat templates to use the new API (#15)

---------

Co-authored-by: ahmetustun <ahmetustun89@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants