-
Notifications
You must be signed in to change notification settings - Fork 32.2k
Closed
Description
Environment info
transformersversion: 4.6.0.dev0- Platform: Linux-4.15.0-140-generic-x86_64-with-debian-buster-sid
- Python version: 3.7.9
- PyTorch version (GPU?): 1.8.1 (True)
- Tensorflow version (GPU?): not installed (NA)
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: Yes
Who can help
@patil-suraj @elgeish @patrickvonplaten
I see the readme is written by @patil-suraj and @elgeish , so any help would be appreciated.
Information
Model I am using (Bert, XLNet ...): wav2vec2
The problem arises when using:
- the official example scripts: (give details below)
Although the fine-tuning week is over, the example is pretty useful.
I am working on a voice recognition problem and want to train using distributed learning.
I refer to huggingface's official example here:
https://github.com/huggingface/transformers/blob/master/examples/research_projects/wav2vec2/FINE_TUNE_XLSR_WAV2VEC2.md
To reproduce
Steps to reproduce the behavior:
- On a clean environment, install requirements and git clone transformers repository.
- Run multi GPU training code as written in the readme.
- Bug reproduces.
The code is
git clone https://github.com/huggingface/transformers.git
cd transformers/examples/research_projects/wav2vec2/
mkdir outputs
python -m torch.distributed.launch \
--nproc_per_node=4 run_common_voice.py \
--model_name_or_path="facebook/wav2vec2-large-xlsr-53" \
--dataset_config_name="tr" \
--output_dir=./outputs \
--overwrite_output_dir \
--num_train_epochs="5" \
--per_device_train_batch_size="16" \
--learning_rate="3e-4" \
--warmup_steps="500" \
--evaluation_strategy="steps" \
--save_steps="400" \
--eval_steps="400" \
--logging_steps="400" \
--save_total_limit="3" \
--freeze_feature_extractor \
--feat_proj_dropout="0.0" \
--layerdrop="0.1" \
--gradient_checkpointing \
--fp16 \
--group_by_length \
--do_train --do_evalError
The following error occurs.
0%| | 0/275 [00:00<?, ?it/s]/home/aidealab/.conda/envs/hf/lib/python3.7/site-packages/torch/nn/modules/module.py:760: UserWarning: Using non-full backward hooks on a Module that does not return a single Tensor or a tuple of Tensors is deprecated and will be removed in future versions. This hook will be missing some of the grad_output. Please use register_full_backward_hook to get the documented behavior.
warnings.warn("Using non-full backward hooks on a Module that does not return a "
/home/aidealab/.conda/envs/hf/lib/python3.7/site-packages/torch/nn/modules/module.py:795: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.
warnings.warn("Using a non-full backward hook when the forward contains multiple autograd Nodes "
Traceback (most recent call last):
File "run_common_voice.py", line 512, in <module>
main()
File "run_common_voice.py", line 484, in main
train_result = trainer.train(resume_from_checkpoint=checkpoint)
File "/home/aidealab/.conda/envs/hf/lib/python3.7/site-packages/transformers/trainer.py", line 1118, in train
tr_loss += self.training_step(model, inputs)
File "run_common_voice.py", line 230, in training_step
loss = self.compute_loss(model, inputs)
File "/home/aidealab/.conda/envs/hf/lib/python3.7/site-packages/transformers/trainer.py", line 1548, in compute_loss
outputs = model(**inputs)
File "/home/aidealab/.conda/envs/hf/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/aidealab/.conda/envs/hf/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 692, in forward
if self.reducer._rebuild_buckets():
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by (1) passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`; (2) making sure all `forward` function outputs participate in calculating loss. If you already have done the above two steps, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).
Traceback (most recent call last):
File "run_common_voice.py", line 512, in <module>
main()
File "run_common_voice.py", line 484, in main
train_result = trainer.train(resume_from_checkpoint=checkpoint)
File "/home/aidealab/.conda/envs/hf/lib/python3.7/site-packages/transformers/trainer.py", line 1118, in train
tr_loss += self.training_step(model, inputs)
File "run_common_voice.py", line 230, in training_step
loss = self.compute_loss(model, inputs)
File "/home/aidealab/.conda/envs/hf/lib/python3.7/site-packages/transformers/trainer.py", line 1548, in compute_loss
outputs = model(**inputs)
File "/home/aidealab/.conda/envs/hf/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/aidealab/.conda/envs/hf/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 692, in forward
if self.reducer._rebuild_buckets():
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by (1) passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`; (2) making sure all `forward` function outputs participate in calculating loss. If you already have done the above two steps, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).
Killing subprocess 25001
Killing subprocess 25002
Killing subprocess 25003
Killing subprocess 25004
Traceback (most recent call last):
File "/home/aidealab/.conda/envs/hf/lib/python3.7/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/home/aidealab/.conda/envs/hf/lib/python3.7/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/home/aidealab/.conda/envs/hf/lib/python3.7/site-packages/torch/distributed/launch.py", line 340, in <module>
main()
File "/home/aidealab/.conda/envs/hf/lib/python3.7/site-packages/torch/distributed/launch.py", line 326, in main
sigkill_handler(signal.SIGTERM, None) # not coming back
File "/home/aidealab/.conda/envs/hf/lib/python3.7/site-packages/torch/distributed/launch.py", line 301, in sigkill_handler
raise subprocess.CalledProcessError(returncode=last_return_code, cmd=cmd)
subprocess.CalledProcessError: Command '['/home/aidealab/.conda/envs/hf/bin/python', '-u', 'run_common_voice.py', '--local_rank=3', '--model_name_or_path=facebook/wav2vec2-large-xlsr-53', '--dataset_config_name=tr', '--output_dir=/home/aidealab/workspace/transformers/examples/research_projects/wav2vec2/outputs', '--overwrite_output_dir', '--num_train_epochs=5', '--per_device_train_batch_size=16', '--learning_rate=3e-4', '--warmup_steps=500', '--evaluation_strategy=steps', '--save_steps=400', '--eval_steps=400', '--logging_steps=400', '--save_total_limit=3', '--freeze_feature_extractor', '--feat_proj_dropout=0.0', '--layerdrop=0.1', '--gradient_checkpointing', '--fp16', '--group_by_length', '--do_train', '--do_eval']' returned non-zero exit status 1.
Expected behavior
It is expected the script runs without error.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels