本项目是进行图片到latex的翻译任务
在本项目中,使用的是典型的Encoder-Decoder模型
- Encoder: 带Global Context(GC) Block的ResNet-31:8,4 times downsampling
- Decoder: 3层Transformer Decoder
网络结构为:
本项目主要需要如下的环境依赖,具体安装方式见快速开始!
- coming soon
-
创建新环境
conda create -n im2latex python=3.7 conda activate im2latex # 安装nltk pip install nltk
-
安装torch1.10.0+cu113
# install torch1.10.0+cu113 pip3 install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
-
安装 mmcv-full-1.4.0。点击 here 查看更多细节。
# install mmcv-full-1.4.0 with torch version 1.10.0 cuda_version 11.3 pip install mmcv-full==1.4.0 -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.10.0/index.html
如果遇到网络问题无法下载,可选择手动下载mmcv1.4.0安装包,选择合适的版本进行安装,示例安装命令:
pip install mmcv_full-1.4.0-cp38-cp38-manylinux1_x86_64.whl
-
确保在项目文件夹下(image2latex),安装 mmdetection。点击 here 查看更多细节。
# We embed mmdetection-2.11.0 source code into this project. # You can cd and install it (recommend). cd ./mmdetection-2.11.0 pip install -v -e .
-
安装mmocr. 点击here 查看更多细节。
# install mmocr cd {Path to image2latex} pip install -v -e .
-
ResNet31withGCB + 3*Transformer Decoder
sh im2latex/im2latex_resnet31withGCB.sh
训练过程中的日志文件和checkpoint将会保存在 expr_result/im2latex_res31 中
-
ResNet31withGCB + 3*Transformer Decoder
sh im2latex/im2latex_res31_infer.sh
-
批量预测(batch inference)
python im2latex/im2latex_infer.py
python im2latex/eval_score.py
这里是模型在验证集(1000张图片)上的BLEU@4得分
Models | BLEU@4 | EM |
---|---|---|
ResNet31withGCB + 3*Transformer Decoder | 0.9139 | 0.4749 |