A simple pytorch implementation for an OCR model. The model is using a Transformer encoder-decoder architecture in a seq2seq fashion.
clone the repository
git clone https://github.com/dali92002/OCR-TR
cd OCR-TR
Create the following environment named vit with Anaconda. Then, Activate it.
conda env create -f environment.yml
conda activate vit
For this task we will create a synthetic data that simulate the handwritten text. I choosed to create the dataset from the EMNIST dataset (digits+characters). I created 80000 images for training, 10000 for validation and 10000 for testing. The images are composed of randomly concatenated characters with a size between 3 and 10.
The code of preparing the dataset can be found in the file prepare_data.py , to execute it, use the following command:
python prepare_data.py --train_words 100000 --valid_words 10000 --test_words 10000
This will generate random words in the folder ./data/words/ and the transcription of each word in the files ./data/train.txt, ./data/valid.txt and ./data/test.txt
In each txt file, there will be in each line the image name and its transcription, separated by a space.
After creating the data, we can train the model using this command
python train.py --data_path ./data/ --img_height 32 --img_width 256 --train_type htr_Augm --batch_size 64 --vit_patch_size 8
Here I specified to use data augmentation for a better training, also I set the image sizes and the vit patch size to be 8x8. You can however use your custom configurations, check Config.py.
During training there will be a validation in each epoch, the best weights will be saved in a folder named ./weights/ and the predictions will be saved in a folder named ./pred_logs/
To test the model, run the following command. It will recognize the testing data using the trained model, you should specify which model you want to use by profiding its path (here I am using ./weights/best-seq2seq_htr_Augm_32_256_8.pt that will be created if you launched the training already):
python test.py --data_path ./data/ --img_height 32 --img_width 256 --train_type htr_Augm --batch_size 64 --vit_patch_size 8 --test_model ./weights/best-seq2seq_htr_Augm_32_256_8.pt
I trained a model already, you can dowload the weights from here and use it directly to test.
After rnning the testing you will get the predictions in the folder ./pred_logs/ as well as the CER and WER.