基于paddle框架的Attention on Attention for Image Captioning实现
本项目基于paddle复现Attention on Attention for Image Captioning中所提出的Attention on Attention
模型。该模型在传统的self-attention
注意力机制的基础上,添加了gate
机制以过滤和query
不相关的attention
信息。同时,作者还引入multi-head attention
用于建模不同目标之间的关系。
注: AI Studio项目地址: https://aistudio.baidu.com/aistudio/projectdetail/3290242.
您可以使用AI Studio平台在线运行该项目!
论文:
- [1] L. Huang, W. Wang, J. Chen, X. Wei, "Attention on Attention for Image Captioning", ICCV, 2019.
参考项目:
所有指标均为模型在COCO2014的测试集评估而得
指标 | BlEU-1 | BlEU-2 | BlEU-3 | BlEU-4 | METEOR | ROUGE-L | CIDEr-D | SPICE |
---|---|---|---|---|---|---|---|---|
原论文 | 0.805 | 0.652 | 0.510 | 0.391 | 0.290 | 0.589 | 1.289 | 0.227 |
复现精度 | 0.802 | 0.648 | 0.504 | 0.385 | 0.286 | 0.585 | 1.271 | 0.222 |
本项目所使用的数据集为COCO2014。该数据集共包含123287张图像,每张图像对应5个标题。训练集、验证集和测试集分别为113287、5000、5000张图像及其对应的标题。本项目使用预提取的bottom-up
特征,可以从这里下载得到(我们提供了脚本下载该数据集的标题以及图像特征,见download_dataset.sh)。
-
硬件:CPU、GPU ( > 11G )
-
软件:
- Python 3.8
- Java 1.8.0
- PaddlePaddle == 2.1.0
# clone this repo
git clone https://github.com/fuqianya/AoANet-Paddle.git --recursive
cd AoANet-Paddle
pip install -r requirements.txt
# 下载数据集及特征
bash ./download_dataset.sh
# 下载与计算评价指标相关的文件
bash ./coco-caption/get_google_word2vec_model.sh
bash ./coco-caption/get_stanford_models.sh
python prepro.py
训练过程过程分为两步(详情见论文3.3节):
-
Training with Cross Entropy (XE) Loss
bash ./train_xe.sh
-
CIDEr-D Score Optimization
bash ./train_rl.sh
-
测试
train_xe
阶段的模型python eval.py --model log/log_aoa/model.pdparams --infos_path log/log_aoa/infos_aoa.pkl --num_images -1 --language_eval 1 --beam_size 2 --batch_size 100 --split test
-
测试
train_rl
阶段的模型python eval.py --model log/log_aoa_rl/model.pdparams --infos_path log/log_aoa_rl/infos_aoa.pkl --num_images -1 --language_eval 1 --beam_size 2 --batch_size 100 --split test
模型下载: 谷歌云盘
将下载的模型权重以及训练信息放到log
目录下, 运行step6
的指令进行测试。
我们提供一个Demo样例,详情见demo.ipynb
├── cider # 计算评价指标工具
├── coco-caption # 计算评价指标工具
├── config
│ └── config.py # 模型的参数设置
├── data # 预处理的数据
├── log # 存储训练模型及历史信息
├── model
│ └── AoAModel.py # 定义模型结构
│ └── dataloader.py # 加载训练数据
│ └── loss.py # 定义损失函数
├── utils
│ └── eval_utils.py # 测试工具
│ └── utils.py # 其他工具
├── download_dataset.sh # 数据集下载脚本
├── prepro.py # 数据预处理
├── train.py # 训练主函数
├── eval.py # 测试主函数
├── train_xe.sh # 训练脚本
├── train_rl.sh # 训练脚本
└── requirement.txt # 依赖包
模型、训练的所有参数信息都在config.py
中进行了详细注释,详情见config/config.py
。
关于模型的其他信息,可以参考下表:
信息 | 说明 |
---|---|
发布者 | fuqianya |
时间 | 2021.08 |
框架版本 | Paddle 2.1.0 |
应用场景 | 多模态 |
支持硬件 | GPU、CPU |
下载链接 | 预训练模型 | 训练日志 |