An Unofficial PyTorch implementation of Lin et al. ViBERTgrid: A Jointly Trained Multi-Modal 2D Document Representation for Key Information Extraction from Documents. ICDAR, 2021.
To learn more about Visual Information Extraction and Document AI, please refer to Document-AI-Recommendations.
- data: data loaders
- model: ViBERTgrid net architecture
- pipeline: preprocessing, trainer, evaluation metrics
- utils: data visualization, dataset spiltting
- deployment: examples for inference and deployment
- train_*.py: main train scripts
- eval_*.py: evaluation scripts
pip install -r requirements.txt
The following components are required for the processed dataset:
- the original image [bs, 3, h, w]
- label_csv, label file in csv format. Each row should contain text, left-coor, right-coor, top-coor, bot-coor, class type, pos_neg type information of each text segment.
It is worth noting that the labeling of the dataset will strongly affect the final result. The original SROIE dataset (https://rrc.cvc.uab.es/?ch=13&com=tasks) only contains text label of the key information. Coordinates, however, are necessary for constructing the grid, hence we re-labeled the dataset to obtain the coordinates. Unfortunately, our re-labeled data cannot be made public for some reasons. But the model trained using our relabelled data is available here (97+ entity-level F1).
Here's another method for matching the coordinates through regular expression and cosine similarity, referring to (https://github.com/antoinedelplace/Chargrid). The matching result is not satisfying and can only achieve an entity-level F1 of around 60.
You can preprocess the data by following the steps below, or you can just download the preprocessed version from here
- Download the official SROIE dataset here.
0325updated.task1train(626p)
contains images and OCR results of the training set, put the images inimg
folder and txt files inbox
folder.txt
files in0325updatd.task2train(626p)
are key-type labels, put them in thekey
folder.- Run
sroie_data_preprocessing.py
. The argtrain_raw
is the dir to the root of the three folders mentioned above. The argtrain_processed
is the dir to the processed csv labels generated by the sroie_data_preprocessing.py, namedlabel
. Put the img, key and label folders in the same root named train. - Download the raw data of test set from the link provided at the bottom of the official page see here and here. Follow the instructions in 2, but name the folder as test. Put it in the same root with the train folder and name the root as SROIE.
- The processed data should be organized as shown below
. ├── test │ ├── image │ ├── key │ └── label └── train ├── image ├── key └── label
We recommend re-labeling the dataset on your own as it contains around 1k images and will not take up a lot of time, or find out a better solution to match the coordinates.
The dataset can be obtained from (https://github.com/HCIILAB/EPHOIE). An unzip password will be provided after submitting the application form.
EPHOIE provides labels in txt format, you should first convert it into JSON format on your own. Then run the following command:
python ./pipeline/ephoie_data_preprocessing.py
Images and labels can be found here.
The FUNSD dataset contains two subtasks, entity labeling
and entity linking
. The ViBERTgrid model can only perform KIE on the first task, in which the text contents are labeled into 3 key types(header, question, answer). Run the following commands to generate formatted labels.
python ./pipeline/funsd_data_preprocessing.py
First, you need to set up configurations. An example config file example_config.yaml
is provided. Then run the following command. Replace * with SROIE, EPHOIE, or FUNSD.
torchrun --nnodes 1 --nproc_per_node 2 ./train_*.py -c dir_to_config_file.yaml
Scripts for inference are provided in the deployment
folder. run inference_*
to get the VIE result in JSON format.
In the paper, the author applied classification on word-level, which predicts the key type of each word and joins the words that belong to the same class as the final entity-level result.
In fact, ViBERTgrid can work on any data level, like line-level or char-level. Choosing a proper data level may significantly boost the final score. According to our experiment result, Line-level is the most suitable choice for the SROIE dataset, char-level for EPHOIE, and segment-level for FUNSD.
The author of the paper used an ImageNet pre-trained ResNet18-D to initialize the weights of the CNN backbone. Pretrained weights of ResNet-D, however, cannot be found in PyTorch's model zoo. Hence we use an ImageNet pretrained ResNet34 instead.
CNN backbones can be changed by setting different values in the config file, supported backbones are shown below
- resnet_18_fpn
- resnet_34_fpn
- resnet_18_fpn_pretrained
- resnet_34_fpn_pretrained
- resnet_18_D_fpn
- resnet_34_D_fpn
Some words could be labeled with more than one field type tags (similar to the nested named entity recognition task), we design two classifiers to input and perform field type classification for each word
To solve the problem mentioned above, the author designed a complicated two-stage classifier. We found that this classifier does not work well and is hard to fine-tune. Since the multi-label problem does not occur in SROIE, EPHOIE, and FUNSD datasets, we use a one-stage multi-class classifier with multi-class cross-entropy loss to replace the original design.
Experiments show that an additional, independent key information binary classifier may improve the final F1 score. The classifier indicates whether a text segment belongs to key information or not, which may boost the recall metric`.
In our case, the auxiliary semantic segmentation head does not help on both the SROIE and EPHOIE datasets. You can remove this branch by setting the loss_control_lambda
to zero in the configuration file.
The model can directly predict the category of each text-line/char/segment, or predict the BIO tags under the restriction of a CRF layer. We found that the representative ability of ViBERTgrid is good enough and the direct prediction works best. Using BIO tagging with CRF layers is unnecessary and has a negative effect.
Dataset | Configuration | # of Parameters | F1 |
---|---|---|---|
SROIE | original paper, BERT-Base, ResNet18-D-pretrained | 142M | 96.25 |
SROIE | original paper, RoBERTa-Base, ResNet18-D-pretrained | 147M | 96.40 |
SROIE | BERT-Base uncased, ResNet34-pretrained | 151M | 97.16 |
EPHOIE | BERT-Base chinese, ResNet34-pretrained | 145M | 96.55 |
FUNSD | BERT-Base uncased, ResNet34-pretrained | 151M | 87.63 |
- Due to source limitations, I used 2 NVIDIA TITAN X for training, which can only afford a batch size of 4 (2 on each GPU). The loss curve is not stable in this situation and may affect the performance.