Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ERROR FT DONUT-docvqa: TypeError: prepare_inputs_for_inference() got an unexpected keyword argument 'past_key_values' #132

Open
emigomez opened this issue Feb 1, 2023 · 4 comments

Comments

@emigomez
Copy link

emigomez commented Feb 1, 2023

I want to make fine-tuning with the donut-docvqa model, and I have follow the next steps:

git clone https://github.com/clovaai/donut.git
cd donut/
conda create -n donut_official python=3.7
conda activate donut_official
pip install .

pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
pip install tensorboardX

For the training I launch the next command:

python train.py --config config/train_docvqa_tests.yaml --pretrained_model_name_or_path "naver-clova-ix/donut-base-finetuned-docvqa" --dataset_name_or_paths '["nielsr/docvqa_1200_examples_donut"]' --exp_version "donut-docvqa-ft-nielsrdocvqa"

config/train_docvqa_tests.yaml:

resume_from_checkpoint_path: null
result_path: "./result"
pretrained_model_name_or_path: "naver-clova-ix/donut-base"
dataset_name_or_paths: ["./dataset/docvqa"] # should be prepared from https://rrc.cvc.uab.es/?ch=17
sort_json_key: True
train_batch_sizes: [2]
val_batch_sizes: [4]
input_size: [1280, 960]
max_length: 128
align_long_axis: False
num_nodes: 1
seed: 2022
lr: 3e-5
warmup_steps: 10000
num_training_samples_per_epoch: 39463
max_epochs: 50
max_steps: -1
#num_workers: 1
num_workers: 8
val_check_interval: 1.0
check_val_every_n_epoch: 1
gradient_clip_val: 0.25
verbose: True

The execution gives me the next error:

$python train.py --config config/train_docvqa_tests.yaml --pretrained_model_name_or_path "naver-clova-ix/donut-base-finetuned-docvqa" --dataset_name_or_paths '["nielsr/docvqa_1200_examples_donut"]' --exp_version "donut-docvqa-ft-nielsrdocvqa" 
 
resume_from_checkpoint_path: None
result_path: ./result
pretrained_model_name_or_path: naver-clova-ix/donut-base-finetuned-docvqa
dataset_name_or_paths: 
  - nielsr/docvqa_1200_examples_donut
sort_json_key: True
train_batch_sizes: 
  - 2
val_batch_sizes: 
  - 4
input_size: 
  - 1280
  - 960
max_length: 128
align_long_axis: False
num_nodes: 1
seed: 2022
lr: 3e-05
warmup_steps: 10000
num_training_samples_per_epoch: 39463
max_epochs: 50
max_steps: -1
num_workers: 8
val_check_interval: 1.0
check_val_every_n_epoch: 1
gradient_clip_val: 0.25
verbose: True
exp_name: train_docvqa_tests
exp_version: donut-docvqa-ft-nielsrdocvqa
Config is saved at result/train_docvqa_tests/donut-docvqa-ft-nielsrdocvqa/config.yaml
/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/pytorch_lightning/utilities/seed.py:49: LightningDeprecationWarning: `pytorch_lightning.utilities.seed.seed_everything` has been deprecated in v1.8.0 and will be removed in v2.0.0. Please use `lightning_fabric.utilities.seed.seed_everything` instead.
  "`pytorch_lightning.utilities.seed.seed_everything` has been deprecated in v1.8.0 and will be"
Global seed set to 2022
/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/torch/functional.py:568: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  ../aten/src/ATen/native/TensorShape.cpp:2228.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Some weights of DonutModel were not initialized from the model checkpoint at naver-clova-ix/donut-base-finetuned-docvqa and are newly initialized because the shapes did not match:
- encoder.model.layers.0.blocks.1.attn_mask: found shape torch.Size([3072, 100, 100]) in the checkpoint and torch.Size([768, 100, 100]) in the model instantiated
- encoder.model.layers.1.blocks.1.attn_mask: found shape torch.Size([768, 100, 100]) in the checkpoint and torch.Size([192, 100, 100]) in the model instantiated
- encoder.model.layers.2.blocks.1.attn_mask: found shape torch.Size([192, 100, 100]) in the checkpoint and torch.Size([48, 100, 100]) in the model instantiated
- encoder.model.layers.2.blocks.3.attn_mask: found shape torch.Size([192, 100, 100]) in the checkpoint and torch.Size([48, 100, 100]) in the model instantiated
- encoder.model.layers.2.blocks.5.attn_mask: found shape torch.Size([192, 100, 100]) in the checkpoint and torch.Size([48, 100, 100]) in the model instantiated
- encoder.model.layers.2.blocks.7.attn_mask: found shape torch.Size([192, 100, 100]) in the checkpoint and torch.Size([48, 100, 100]) in the model instantiated
- encoder.model.layers.2.blocks.9.attn_mask: found shape torch.Size([192, 100, 100]) in the checkpoint and torch.Size([48, 100, 100]) in the model instantiated
- encoder.model.layers.2.blocks.11.attn_mask: found shape torch.Size([192, 100, 100]) in the checkpoint and torch.Size([48, 100, 100]) in the model instantiated
- encoder.model.layers.2.blocks.13.attn_mask: found shape torch.Size([192, 100, 100]) in the checkpoint and torch.Size([48, 100, 100]) in the model instantiated
- encoder.model.layers.3.blocks.1.attn_mask: found shape torch.Size([48, 100, 100]) in the checkpoint and torch.Size([12, 100, 100]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
SPLIT:  train
Using custom data configuration nielsr--docvqa_1200_examples_donut-05c02546813a49c7
Found cached dataset parquet (/home/ubuntu/.cache/huggingface/datasets/nielsr___parquet/nielsr--docvqa_1200_examples_donut-05c02546813a49c7/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
SPLIT:  test
Using custom data configuration nielsr--docvqa_1200_examples_donut-05c02546813a49c7
Found cached dataset parquet (/home/ubuntu/.cache/huggingface/datasets/nielsr___parquet/nielsr--docvqa_1200_examples_donut-05c02546813a49c7/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:468: LightningDeprecationWarning: Setting `Trainer(gpus=1)` is deprecated in v1.7 and will be removed in v2.0. Please use `Trainer(accelerator='gpu', devices=1)` instead.
  f"Setting `Trainer(gpus={gpus!r})` is deprecated in v1.7 and will be removed"
Using 16bit None Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..
[rank: 0] Global seed set to 2022
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:612: UserWarning: Checkpoint directory /home/ubuntu/aymane/donut/result/train_docvqa_tests/donut-docvqa-ft-nielsrdocvqa exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type       | Params
-------------------------------------
0 | model | DonutModel | 200 M 
-------------------------------------
200 M     Trainable params
0         Non-trainable params
200 M     Total params
400.841   Total estimated model params size (MB)
/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:229: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  category=PossibleUserWarning,
Epoch 0:   0%|                                                                                                                                                                     | 0/550 [00:00<?, ?it/s]/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:136: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
  "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
Epoch 0:  91%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉            | 500/550 [02:40<00:16,  3.12it/s, loss=2.6, v_num=cvqaTraceback (most recent call last):                                                                                                                                                   | 0/50 [00:00<?, ?it/s]
  File "train.py", line 151, in <module>
    train(config)
  File "train.py", line 133, in train
    trainer.fit(model_module, data_module)
  File "/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 609, in fit
    self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
  File "/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/pytorch_lightning/trainer/call.py", line 36, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 88, in launch
    return function(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _fit_impl
    self._run(model, ckpt_path=self.ckpt_path)
  File "/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1103, in _run
    results = self._run_stage()
  File "/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1182, in _run_stage
    self._run_train()
  File "/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1205, in _run_train
    self.fit_loop.run()
  File "/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/pytorch_lightning/loops/fit_loop.py", line 267, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
  File "/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.on_advance_end()
  File "/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 250, in on_advance_end
    self._run_validation()
  File "/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 308, in _run_validation
    self.val_loop.run()
  File "/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 152, in advance
    dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
  File "/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 137, in advance
    output = self._evaluation_step(**kwargs)
  File "/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 234, in _evaluation_step
    output = self.trainer._call_strategy_hook(hook_name, *kwargs.values())
  File "/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1485, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/pytorch_lightning/strategies/ddp.py", line 359, in validation_step
    return self.model(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 963, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/pytorch_lightning/overrides/base.py", line 110, in forward
    return self._forward_module.validation_step(*inputs, **kwargs)
  File "/home/ubuntu/aymane/donut/lightning_module.py", line 72, in validation_step
    return_attentions=False,
  File "/home/ubuntu/aymane/donut/donut/model.py", line 475, in inference
    output_attentions=return_attentions,
  File "/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/transformers/generation/utils.py", line 1400, in generate
    **model_kwargs,
  File "/home/ubuntu/anaconda3/envs/donut_official/lib/python3.7/site-packages/transformers/generation/utils.py", line 2176, in greedy_search
    model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
TypeError: prepare_inputs_for_inference() got an unexpected keyword argument 'past_key_values'
Epoch 0:  91%|█████████ | 500/550 [02:43<00:16,  3.05it/s, loss=2.6, v_num=cvqa]  

Does anyone know how to fix it?

@emigomez
Copy link
Author

emigomez commented Feb 1, 2023

solved with:

 pip install pytorch-lightning==1.6.4
 pip install transformers==4.11.3
 pip install timm==0.5.4

I thought that with pip install . wasn't needed to install extra packages

@Mohamed-Dhouib
Copy link

Mohamed-Dhouib commented Feb 3, 2023

Pip install is indeed not needed, it's the last transformers version that changed some variables naming. You just need to install transformers version 4.25.1 (pip install transformers==4.25.1) or change input to past_key_values in prepare_inputs_for_inference.

@ChrisDelClea
Copy link

I get the same error and i am not able to fix it with installing the mentioned packages above. Why is that? I implemented it completly as suggested from source.

@benjaminfh
Copy link

benjaminfh commented Dec 3, 2023

Ok. I just wasted so much time on this so I want to try and lay out a canonical answer so others can avoid the same fate (others probably include future me 🤦 ).

Firstly, these requirements are working for me now, with a model I trained with transformers==4.25.1:

timm<=0.6.13
**donut-python @ git+https://github.com/clovaai/donut@4cfcf972560e1a0f26eb3e294c8fc88a0d336626**

Breaking these down:
timm > 0.6.13 will introduce the following error:

NotImplementedError: Make sure _init_weights is implemented for <class 'donut.model.DonutModel'>

donut-python <= 1.0.9 (i.e. via pypi / pip install) will introduce:

TypeError: prepare_inputs_for_inference() got an unexpected keyword argument 'past_key_values'
which is an incompatibility with the required later transformers version (>4.25.1). This is fixed in the master branch of this very repo (https://github.com/clovaai/donut) but hasn't been cut to a new version to pypi.

when using donut-python <= 1.0.9, changing the transformers version will flip you between the _init_weights error for early versions of transformers (#184) and TypeError: prepare_inputs_for_inference() got an unexpected keyword argument 'past_key_values' if using recent versions. NOTE: this is a bit of a red herring, in that until you upgrade donut-python to the github version (as above), you will be stuck in this kafka-esque loop forever.

As far as I can tell, no other versions are make or break, but for the sanity of other readers, here's my full requirements.txt:

absl-py==2.0.0 aiohttp==3.8.6 aiosignal==1.3.1 async-timeout==4.0.3 attrs==23.1.0 cachetools==5.3.2 certifi==2023.7.22 charset-normalizer==3.3.2 click==8.1.7 datasets==2.14.6 dill==0.3.7 donut-python @ git+https://github.com/clovaai/donut@4cfcf972560e1a0f26eb3e294c8fc88a0d336626 filelock==3.13.1 frozenlist==1.4.0 fsspec==2023.10.0 google-auth==2.23.4 google-auth-oauthlib==1.1.0 grpcio==1.59.2 huggingface-hub==0.18.0 idna==3.4 importlib-metadata==6.8.0 Jinja2==3.1.2 joblib==1.3.2 lightning-utilities==0.9.0 Markdown==3.5.1 MarkupSafe==2.1.3 mpmath==1.3.0 multidict==6.0.4 multiprocess==0.70.15 munch==4.0.0 networkx==3.2.1 nltk==3.8.1 numpy==1.26.1 oauthlib==3.2.2 packaging==23.2 pandas==2.1.2 Pillow==10.1.0 protobuf==3.20.1 pyarrow==14.0.0 pyasn1==0.5.0 pyasn1-modules==0.3.0 pyDeprecate==0.3.2 python-dateutil==2.8.2 pytorch-lightning==1.8.5 pytz==2023.3.post1 PyYAML==6.0.1 regex==2023.10.3 requests==2.31.0 requests-oauthlib==1.3.1 rsa==4.9 ruamel.yaml==0.18.5 ruamel.yaml.clib==0.2.8 safetensors==0.4.1 sconf==0.2.5 sentencepiece==0.1.99 six==1.16.0 sympy==1.12 tensorboard==2.15.1 tensorboard-data-server==0.7.2 tensorboardX==2.6.2.2 timm==0.6.13 tokenizers==0.12.1 torch==2.1.1 torchmetrics==1.2.0 torchvision==0.16.1 tqdm==4.66.1 transformers==4.26.1 typing_extensions==4.8.0 tzdata==2023.3 urllib3==2.0.7 Werkzeug==3.0.1 xxhash==3.4.1 yarl==1.9.2 zipp==3.17.0 zss==1.2.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants