Skip to content

Commit

Permalink
Refine ctc model code for English dataset. (#991)
Browse files Browse the repository at this point in the history
* Refine code for English dataset.
1. Remove a pooling layer.
2. Change classes_num to 94.
3. Modify some arguments in ctc_train.py
4. Add learning rate decay policy.

* Fix readme.

* Fix README.

* Remove consine decay.

* Remove eval.sh
  • Loading branch information
wanghaoshuang authored Jun 21, 2018
1 parent 2cb27d0 commit db1edc2
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 69 deletions.
47 changes: 23 additions & 24 deletions fluid/ocr_recognition/README.md
Original file line number Diff line number Diff line change
@@ -1,32 +1,31 @@


运行本目录下的程序示例需要使用PaddlePaddle develop最新版本。如果您的PaddlePaddle安装版本低于此要求,请按照安装文档中的说明更新PaddlePaddle安装版本
运行本目录下的程序示例需要使用PaddlePaddle develop最新版本。如果您的PaddlePaddle安装版本低于此要求,请按照[安装文档](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html)中的说明更新PaddlePaddle安装版本

# Optical Character Recognition
## 代码结构
```
├── ctc_reader.py # 下载、读取、处理数据。
├── crnn_ctc_model.py # 定义了训练网络、预测网络和evaluate网络。
├── ctc_train.py # 用于模型的训练。
├── infer.py # 加载训练好的模型文件,对新数据进行预测。
├── eval.py # 评估模型在指定数据集上的效果。
└── utils.py # 定义通用的函数。
```

这里将介绍如何在PaddlePaddle Fluid下使用CRNN-CTC 和 CRNN-Attention模型对图片中的文字内容进行识别。

## 1. CRNN-CTC
## 简介

本章的任务是识别含有单行汉语字符图片,首先采用卷积将图片转为特征图, 然后使用`im2sequence op`将特征图转为序列,通过`双向GRU`学习到序列特征。训练过程选用的损失函数为CTC(Connectionist Temporal Classification) loss,最终的评估指标为样本级别的错误率。

本路径下各个文件的作用如下:

- **ctc_reader.py :** 下载、读取、处理数据。提供方法`train()``test()` 分别产生训练集和测试集的数据迭代器。
- **crnn_ctc_model.py :** 在该脚本中定义了训练网络、预测网络和evaluate网络。
- **ctc_train.py :** 用于模型的训练,可通过命令`python train.py --help` 获得使用方法。
- **infer.py :** 加载训练好的模型文件,对新数据进行预测。可通过命令`python infer.py --help` 获得使用方法。
- **eval.py :** 评估模型在指定数据集上的效果。可通过命令`python infer.py --help` 获得使用方法。
- **utility.py :** 实现的一些通用方法,包括参数配置、tensor的构造等。


### 1.1 数据
## 数据

数据的下载和简单预处理都在`ctc_reader.py`中实现。

#### 1.1.1 数据格式
### 数据示例

我们使用的训练和测试数据如`图1`所示,每张图片包含单行不定长的中文字符串,这些图片都是经过检测算法进行预框选处理的。
我们使用的训练和测试数据如`图1`所示,每张图片包含单行不定长的英文字符串,这些图片都是经过检测算法进行预框选处理的。

<p align="center">
<img src="images/demo.jpg" width="620" hspace='10'/> <br/>
Expand All @@ -35,12 +34,12 @@

在训练集中,每张图片对应的label是汉字在词典中的索引。 `图1` 对应的label如下所示:
```
3835,8371,7191,2369,6876,4162,1938,168,1517,4590,3793
80,84,68,82,83,72,78,77,68,67
```
在上边这个label中,`3835` 表示字符‘两’的索引,`4590` 表示中文字符逗号的索引
在上边这个label中,`80` 表示字符`Q`的索引,`67` 表示英文字符`D`的索引


#### 1.1.2 数据准备
### 数据准备

**A. 训练集**

Expand Down Expand Up @@ -105,7 +104,9 @@ data/test_images/00003.jpg

第三种:从stdin读入一张图片的path,然后进行一次inference.

#### 1.2 训练
## 模型训练与预测

### 训练

使用默认数据在GPU单卡上训练:

Expand All @@ -121,7 +122,7 @@ env CUDA_VISIABLE_DEVICES=0,1,2,3 python ctc_train.py --parallel=True

执行`python ctc_train.py --help`可查看更多使用方式和参数详细说明。

图2为使用默认参数和默认数据集训练的收敛曲线,其中横坐标轴为训练迭代次数,纵轴为样本级错误率。其中,蓝线为训练集上的样本错误率,红线为测试集上的样本错误率。在45轮迭代训练中,测试集上最低错误率为第60轮的21.11%.
图2为使用默认参数和默认数据集训练的收敛曲线,其中横坐标轴为训练迭代次数,纵轴为样本级错误率。其中,蓝线为训练集上的样本错误率,红线为测试集上的样本错误率。在60轮迭代训练中,测试集上最低错误率为第32轮的22.0%.

<p align="center">
<img src="images/train.jpg" width="620" hspace='10'/> <br/>
Expand All @@ -130,7 +131,7 @@ env CUDA_VISIABLE_DEVICES=0,1,2,3 python ctc_train.py --parallel=True



### 1.3 评估
## 测试

通过以下命令调用评估脚本用指定数据集对模型进行评估:

Expand All @@ -144,7 +145,7 @@ env CUDA_VISIBLE_DEVICE=0 python eval.py \
执行`python ctc_train.py --help`可查看参数详细说明。


### 1.4 预测
### 预测

从标准输入读取一张图片的路径,并对齐进行预测:

Expand Down Expand Up @@ -176,5 +177,3 @@ env CUDA_VISIBLE_DEVICE=0 python infer.py \
--model_path="models/model_00044_15000" \
--input_images_list="data/test.list"
```

>注意:因为版权原因,我们暂时停止提供中文数据集的下载和使用服务,你通过`ctc_reader.py`自动下载的数据将是含有30W图片的英文数据集。在英文数据集上的训练结果会稍后发布。
47 changes: 26 additions & 21 deletions fluid/ocr_recognition/crnn_ctc_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import paddle.fluid as fluid
from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
from paddle.fluid.initializer import init_on_cpu
import math


def conv_bn_pool(input,
Expand All @@ -8,7 +11,8 @@ def conv_bn_pool(input,
param=None,
bias=None,
param_0=None,
is_test=False):
is_test=False,
pooling=True):
tmp = input
for i in xrange(group):
tmp = fluid.layers.conv2d(
Expand All @@ -19,32 +23,25 @@ def conv_bn_pool(input,
param_attr=param if param_0 is None else param_0,
act=None, # LinearActivation
use_cudnn=True)
#tmp = fluid.layers.Print(tmp)
tmp = fluid.layers.batch_norm(
input=tmp,
act=act,
param_attr=param,
bias_attr=bias,
is_test=is_test)
tmp = fluid.layers.pool2d(
input=tmp,
pool_size=2,
pool_type='max',
pool_stride=2,
use_cudnn=True,
ceil_mode=True)
if pooling:
tmp = fluid.layers.pool2d(
input=tmp,
pool_size=2,
pool_type='max',
pool_stride=2,
use_cudnn=True,
ceil_mode=True)

return tmp


def ocr_convs(input,
num,
with_bn,
regularizer=None,
gradient_clip=None,
is_test=False):
assert (num % 4 == 0)

def ocr_convs(input, regularizer=None, gradient_clip=None, is_test=False):
b = fluid.ParamAttr(
regularizer=regularizer,
gradient_clip=gradient_clip,
Expand All @@ -63,7 +60,8 @@ def ocr_convs(input,

tmp = conv_bn_pool(tmp, 2, [32, 32], param=w1, bias=b, is_test=is_test)
tmp = conv_bn_pool(tmp, 2, [64, 64], param=w1, bias=b, is_test=is_test)
tmp = conv_bn_pool(tmp, 2, [128, 128], param=w1, bias=b, is_test=is_test)
tmp = conv_bn_pool(
tmp, 2, [128, 128], param=w1, bias=b, is_test=is_test, pooling=False)
return tmp


Expand All @@ -75,8 +73,6 @@ def encoder_net(images,
is_test=False):
conv_features = ocr_convs(
images,
8,
True,
regularizer=regularizer,
gradient_clip=gradient_clip,
is_test=is_test)
Expand Down Expand Up @@ -143,6 +139,7 @@ def ctc_train_net(images, label, args, num_classes):
L2_RATE = 0.0004
LR = 1.0e-3
MOMENTUM = 0.9
learning_rate_decay = None
regularizer = fluid.regularizer.L2Decay(L2_RATE)

fc_out = encoder_net(images, num_classes, regularizer=regularizer)
Expand All @@ -155,7 +152,15 @@ def ctc_train_net(images, label, args, num_classes):
error_evaluator = fluid.evaluator.EditDistance(
input=decoded_out, label=casted_label)
inference_program = fluid.default_main_program().clone(for_test=True)
optimizer = fluid.optimizer.Momentum(learning_rate=LR, momentum=MOMENTUM)
if learning_rate_decay == "piecewise_decay":
learning_rate = fluid.layers.piecewise_decay([
args.total_step / 4, args.total_step / 2, args.total_step * 3 / 4
], [LR, LR * 0.1, LR * 0.01, LR * 0.001])
else:
learning_rate = LR

optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate, momentum=MOMENTUM)
_, params_grads = optimizer.minimize(sum_cost)
model_average = None
if args.average_window > 0:
Expand Down
2 changes: 1 addition & 1 deletion fluid/ocr_recognition/ctc_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from paddle.v2.image import load_image
import paddle.v2 as paddle

NUM_CLASSES = 10784
NUM_CLASSES = 95
DATA_SHAPE = [1, 48, 512]

DATA_MD5 = "7256b1d5420d8c3e74815196e58cdad5"
Expand Down
47 changes: 25 additions & 22 deletions fluid/ocr_recognition/ctc_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('pass_num', int, 100, "Number of training epochs.")
add_arg('total_step', int, 720000, "Number of training iterations.")
add_arg('log_period', int, 1000, "Log period.")
add_arg('save_model_period', int, 15000, "Save model period. '-1' means never saving the model.")
add_arg('eval_period', int, 15000, "Evaluate period. '-1' means never evaluating the model.")
add_arg('save_model_dir', str, "./models", "The directory the model to be saved to.")
add_arg('init_model', str, None, "The init model file of directory.")
add_arg('use_gpu', bool, True, "Whether use GPU to train.")
add_arg('min_average_window',int, 10000, "Min average window.")
add_arg('max_average_window',int, 15625, "Max average window. It is proposed to be set as the number of minibatch in a pass.")
add_arg('max_average_window',int, 12500, "Max average window. It is proposed to be set as the number of minibatch in a pass.")
add_arg('average_window', float, 0.15, "Average window.")
add_arg('parallel', bool, False, "Whether use parallel training.")
# yapf: enable
Expand Down Expand Up @@ -90,54 +90,57 @@ def train_one_batch(data):
results = [result[0] for result in results]
return results

def test(pass_id, batch_id):
def test(iter_num):
error_evaluator.reset(exe)
for data in test_reader():
exe.run(inference_program, feed=get_feeder_data(data, place))
_, test_seq_error = error_evaluator.eval(exe)
print "\nTime: %s; Pass[%d]-batch[%d]; Test seq error: %s.\n" % (
time.time(), pass_id, batch_id, str(test_seq_error[0]))
print "\nTime: %s; Iter[%d]; Test seq error: %s.\n" % (
time.time(), iter_num, str(test_seq_error[0]))

def save_model(args, exe, pass_id, batch_id):
filename = "model_%05d_%d" % (pass_id, batch_id)
def save_model(args, exe, iter_num):
filename = "model_%05d" % iter_num
fluid.io.save_params(
exe, dirname=args.save_model_dir, filename=filename)
print "Saved model to: %s/%s." % (args.save_model_dir, filename)

for pass_id in range(args.pass_num):
batch_id = 1
iter_num = 0
while True:
total_loss = 0.0
total_seq_error = 0.0
# train a pass
for data in train_reader():
iter_num += 1
if iter_num > args.total_step:
return
results = train_one_batch(data)
total_loss += results[0]
total_seq_error += results[2]
# training log
if batch_id % args.log_period == 0:
print "\nTime: %s; Pass[%d]-batch[%d]; Avg Warp-CTC loss: %s; Avg seq err: %s" % (
time.time(), pass_id, batch_id,
total_loss / (batch_id * args.batch_size),
total_seq_error / (batch_id * args.batch_size))
if iter_num % args.log_period == 0:
print "\nTime: %s; Iter[%d]; Avg Warp-CTC loss: %.3f; Avg seq err: %.3f" % (
time.time(), iter_num,
total_loss / (args.log_period * args.batch_size),
total_seq_error / (args.log_period * args.batch_size))
sys.stdout.flush()
total_loss = 0.0
total_seq_error = 0.0

# evaluate
if batch_id % args.eval_period == 0:
if iter_num % args.eval_period == 0:
if model_average:
with model_average.apply(exe):
test(pass_id, batch_id)
test(iter_num)
else:
test(pass_id, batch_d)
test(iter_num)

# save model
if batch_id % args.save_model_period == 0:
if iter_num % args.save_model_period == 0:
if model_average:
with model_average.apply(exe):
save_model(args, exe, pass_id, batch_id)
save_model(args, exe, iter_num)
else:
save_model(args, exe, pass_id, batch_id)

batch_id += 1
save_model(args, exe, iter_num)


def main():
Expand Down
2 changes: 1 addition & 1 deletion fluid/ocr_recognition/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def evaluate(args, eval=ctc_eval, data_reader=ctc_reader):

# prepare environment
place = fluid.CPUPlace()
if use_gpu:
if args.use_gpu:
place = fluid.CUDAPlace(0)

exe = fluid.Executor(place)
Expand Down
Binary file modified fluid/ocr_recognition/images/demo.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified fluid/ocr_recognition/images/train.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit db1edc2

Please sign in to comment.