Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Dataparallel Metric Calculation #992

Merged
merged 39 commits into from
Feb 20, 2020

Conversation

pruksmhc
Copy link
Contributor

@pruksmhc pruksmhc commented Jan 16, 2020

Addressing #947
The problem is there is a race condition with dataparallelism to do with scorers and variable updating where metrics can exceed 1.0. To fix this, this PR moves task.update_metrics() in trainer.py as opposed to the various forward functions in jiant/models.py.

@pep8speaks
Copy link

pep8speaks commented Jan 16, 2020

Hello @pruksmhc! Thanks for updating this PR. We checked the lines you've touched for PEP 8 issues, and found:

Line 768:101: E501 line too long (104 > 100 characters)

Line 2787:99: W291 trailing whitespace
Line 3030:71: W291 trailing whitespace

You can repair most issues by installing black and running: black -l 100 ./*. If you contribute often, have a look at the 'Contributing' section of the README for instructions on doing this automatically.

Comment last updated at 2020-02-19 18:50:10 UTC

@sleepinyourhat
Copy link
Contributor

@pyeres Mind taking the lead on this?

@pruksmhc pruksmhc requested a review from pyeres as a code owner January 17, 2020 19:12
@pruksmhc pruksmhc changed the title Fix Dataparallel Metric Calculation [WIP] Fix Dataparallel Metric Calculation Jan 17, 2020
@pyeres
Copy link
Contributor

pyeres commented Jan 22, 2020

Hey @pruksmhc, fyi, running ReCoRD like this:

python main.py --config jiant/config/superglue_bert.conf --overrides "run_name = rte-batch-8-GPU-2, pretrain_tasks = \"record\", target_tasks = \"record\", do_pretrain = 1, do_target_task_training = 0, do_full_eval = 1, batch_size = 8, val_interval = 1000, val_data_limit = -1"

It looks like there's an exception related to DataParallel:

01/21 07:06:21 PM: Beginning training with stopping criteria based on metric: record_avg
01/21 07:06:30 PM: Fatal error in main():
Traceback (most recent call last):
  File "main.py", line 16, in <module>
    main(sys.argv[1:])
  File "/home/pcy214/projects/jiant/jiant/__main__.py", line 553, in main
    phase="pretrain",
  File "/home/pcy214/projects/jiant/jiant/trainer.py", line 579, in train
    output_dict = self._forward(batch, task=task)
  File "/home/pcy214/projects/jiant/jiant/trainer.py", line 1034, in _forward
    model_out = self._model.forward(task, batch)
  File "/home/pcy214/.conda/envs/jiant/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 153, in forward
    return self.gather(outputs, self.output_device)
  File "/home/pcy214/.conda/envs/jiant/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 165, in gather
    return gather(outputs, output_device, dim=self.dim)
  File "/home/pcy214/.conda/envs/jiant/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 67, in gather
    return gather_map(outputs)
  File "/home/pcy214/.conda/envs/jiant/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 61, in gather_map
    for k in out))
  File "/home/pcy214/.conda/envs/jiant/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 61, in <genexpr>
    for k in out))
  File "/home/pcy214/.conda/envs/jiant/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 62, in gather_map
    return type(out)(map(gather_map, zip(*outputs)))
  File "/home/pcy214/.conda/envs/jiant/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 62, in gather_map
    return type(out)(map(gather_map, zip(*outputs)))
  File "/home/pcy214/.conda/envs/jiant/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 62, in gather_map
    return type(out)(map(gather_map, zip(*outputs)))
TypeError: zip argument #1 must support iteration

@pruksmhc
Copy link
Contributor Author

Hm, let me debug this, thanks for catching!

@pruksmhc
Copy link
Contributor Author

Alright, I tested it on ReCoRD and STS-B as well as RTE now. @pyeres ready for another look

@pyeres
Copy link
Contributor

pyeres commented Jan 30, 2020

If I understand these changes correctly, in model.py data from the model's output dict (out) and data from the model's input dict (batch) is stored under special keys with prefix update_metrics_* in the model output dict. Then in trainer.py and evaluate.py, a new dictionary update_metrics_kwargs is created to collect the special keys in the model's out dict and pass that data as keyword arguments to task's update_metrics method. This also involves introducing a task class import into trainer.py to handle updating metrics for ReCoRDTask which has some special requirements.

Instead of adding special keys to the out dict in models.py and handling the special case of the ReCoRD task in our trainer.py, could you instead just pass the out and batch dicts directly as arguments to task's update_metrics method and handle extracting the required data for the task's update_metrics method in the tasks themselves? If you can do that, you can have common update_metrics interface for all tasks (update_metrics(out, batch)), and this PR could be one-line changes in trainer.py and evaluate.py, there would also be no need to write special keys into the output dict in models.py.

@pyeres
Copy link
Contributor

pyeres commented Jan 30, 2020

Separate from the implementation details above, because multi-GPU support touches many code paths (and address a bug), with your final PR can you please show a performance comparison of multi-GPU vs single-GPU on some task(s). For PR #990 @HaokunLiu ran a performance comparison across a a set of common tasks — running these should make for a good comparison (what Haokun already reported in #990 vs. including your changes and using a single-gpu vs. including your changes and using multi-gpu). If you have trouble getting cluster time, just pass me a script and I can run some of these experiments.

@pruksmhc
Copy link
Contributor Author

Makes sense!

@pruksmhc
Copy link
Contributor Author

pruksmhc commented Feb 4, 2020

Hm, for some reason, I'm getting accuracy of 0.56 using 1 GPU in jiant master, but 0.527 with 1 GPU in this branch. Debugging

@pruksmhc
Copy link
Contributor Author

Ready! @pyeres

jiant/tasks/tasks.py Outdated Show resolved Hide resolved
Copy link
Contributor

@pyeres pyeres left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see the open comment in the conversation thread for a small change requests, and please let me know when you're ready for a final look and approval.

@pyeres
Copy link
Contributor

pyeres commented Feb 19, 2020

Heads up: for CCG it looks like there's an issue with the calculation of training and validation accuracy (accuracy is stuck at zero while training and validation loss is decreasing).

If it would be helpful I'm happy to debug this further, please let me know.

@pyeres
Copy link
Contributor

pyeres commented Feb 19, 2020

These changes have been successfully benchmarked on tasks from GLUE, SuperGLUE and the benchmarks provided with PR #990 in both single-GPU and multi-GPU mode:

Task albert-xxlarge-v1 (#990) albert-xxlarge-v2 (#990) roberta-large (#990) roberta-large DataParallel 2-GPUs roberta-large DataParallel 1GPU
commitbank 0.943 0.932 1 0.946 0.982
copa 0.93 0.95 0.94 0.95 0.94
wsc 0.721 0.76 0.692 0.74 0.721
rte 0.895 0.888 0.866 0.863 0.866
multirc 0.658 0.666 0.628 0.633 0.628
wic 0.752 0.74 0.724 0.724 0.74
boolq 0.886 0.89 0.879 0.873 0.879
commonsenseqa 0.762 0.771 0.735 0.737 0.735
cosmosqa 0.832 0.847 0.814 0.812 0.814
record 0.843 0.858 0.863 0.859 0.863
cola 0.671 0.708 0.664 0.673 0.648
qamr 0.717 0.723 0.726 0.725 0.726
scitail 0.977 0.978 0.976 0.977 0.976
social-iqa 0.783 0.778 0.766 0.774 0.766
ccg 0.954 0.956 0.96 0.96 0.959
hellaswag 0.885 0.896 0.851 0.851 0.851
qasrl 0.643 0.641 0.664 0.664 0.66
sst 0.95 0.955 0.961 0.963 0.961
qqp 0.893 0.898 0.903 0.898 0.903
mnli 0.899 0.898 0.896 0.897 0.896
sts-b 0.92/0.92 0.92/0.92 0.924/0.922
mrpc 0.92/0.90 0.913/0.90 0.927/0.90
wnli* 0.775 0.746 0.775
glue-diagnostic 0.45 0.455 0.451
winogender-diagnostic 0.961/0.699 0.933/0.713 0.961/0.699
broadcoverage-diagnostic* 0.442 0.419 0.442
qnli 0.944 0.944 0.943

*performance on wnli and broadcoverage-diagnostic are below the public leaderboard. However, performance of the PR branch in single- and multi-GPU modes closely matches the performance achieved using the same experimental config on master at release 1.2.1.

Copy link
Contributor

@pyeres pyeres left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The benchmark tests for this branch look good. Thanks for making this work!

@pruksmhc
Copy link
Contributor Author

Thanks for running the extensive benchmark tests!

@pruksmhc pruksmhc merged commit b82ac3e into master Feb 20, 2020
@pruksmhc pruksmhc mentioned this pull request Apr 7, 2020
phu-pmh pushed a commit that referenced this pull request Apr 17, 2020
* moving update_metrics to outside scope of dataparallel

* fixing micro_avg calculation

* undo debugging

* Fixing tests, moving update_metrics out of other tasks

* remove extraneous change

* fix multiple choice dataparallel forward

* adding update_metrics abstraction

* delete update_metrics_ notation

* spelling check

* black formatting

* fixing tests

* Adding batch_size constraints to multi-GPU setting

* adding documentation

* adding batch size test

* black correct version

* Fixing batch size assertion

* generalize batch size assertion for more than 2 GPU setting

* reducing label loops in code

* fixing span forward

* Fixing span prediction forward for multi-GPU

* fix commonsenseQA forward

* adding function documentation

* resolving nits, fixing seq_gen forward

* remove nit

* fixing batch_size assert and SpanPrediction task

* Remove debugging

* Fix batch size mismatch multi-GPU test

* Fix order of assert checking for batch size mismatch

* reverse batch size checking

* fix SpanPrediction update_metrics for single-GPU

* fix update_diagnostic_metric and ccg acc

Co-authored-by: Check your git settings! <chris@chris-laptop>
@jeswan jeswan added the jiant-v1-legacy Relevant to versions <= v1.3.2 label Sep 17, 2020
@jeswan jeswan deleted the fix_dataparallel_metric_calculation branch September 22, 2020 03:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
jiant-v1-legacy Relevant to versions <= v1.3.2
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants