keras复现人群数量估计网络"CNN-based Cascaded Multi-task Learning of High-level Prior and Density Estimation for Crowd Counting"。 本工程的实现主要参考crowdcount-cascaded-mtl和keras-mcnn 在ShanghaiTech数据集上训练和测试效果如下:
| | MAE | MSE |
----------------------------
| Part_A | 115.57 | 179.82 |
----------------------------
| Part_B | 26.30 | 48.78 |
-
Clone
git clone https://github.com/embracesource-cv-com/keras-crowdcounting-cmtl
-
安装依赖库
cd keras-crowdcounting-cmtl pip install -r requirements.txt
-
创建数据存放目录$ORIGIN_DATA_PATH
mkdir /opt/dataset/crowd_counting/shanghaitech/original
-
将
part_A_final
和part_B_final
存放到$ORIGIN_DATA_PATH目录下 -
生成测试集的ground truth文件
python create_gt_test_set_shtech.py [A or B] # Part_A or Part_B
生成好的ground-truth文件将会保存在$TEST_GT_PATH/test_data/ground_truth_csv目录下
-
生成训练集和验证集
python create_training_set_shtech.py [A or B]
生成好的数据保存将会在$TRAIN_PATH、$TRAIN_GT_PATH、$VAL_PATH、$VAL_GT_PATH目录下
a)下载训练模型
cmtl-A.235.h5 提取码:prxi、cmtl-B.210.h5 提取码:7if7
b) 如下命令分别测试A和B
python test.py --dataset A --weight_path /tmp/cmtl-A.235.h5 --output_dir /tmp/ctml_A
python test.py --dataset B --weight_path /tmp/cmtl-B.210.h5 --output_dir /tmp/ctml_B
如果你想自己训练模型,很简单:
python train.py [A or B]