Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update ctc for application of STR #253

Merged
merged 1 commit into from
Oct 29, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
475 changes: 475 additions & 0 deletions ctc/README.md

Large diffs are not rendered by default.

99 changes: 99 additions & 0 deletions ctc/data_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from __future__ import absolute_import
from __future__ import division

import os
from paddle.v2.image import load_image
import cv2


class AsciiDic(object):
UNK = 0

def __init__(self):
self.dic = {
'<unk>': self.UNK,
}
self.chars = [chr(i) for i in range(40, 171)]
for id, c in enumerate(self.chars):
self.dic[c] = id + 1

def lookup(self, w):
return self.dic.get(w, self.UNK)

def id2word(self):
self.id2word = {}
for key, value in self.dic.items():
self.id2word[value] = key

return self.id2word

def word2ids(self, sent):
'''
transform a word to a list of ids.
@sent: str
'''
return [self.lookup(c) for c in list(sent)]

def size(self):
return len(self.dic)


class ImageDataset(object):
def __init__(self,
train_image_paths_generator,
test_image_paths_generator,
infer_image_paths_generator,
fixed_shape=None,
is_infer=False):
'''
@image_paths_generator: function
return a list of images' paths, called like:

for path in image_paths_generator():
load_image(path)
'''
if is_infer == False:
self.train_filelist = [p for p in train_image_paths_generator]
self.test_filelist = [p for p in test_image_paths_generator]
else:
self.infer_filelist = [p for p in infer_image_paths_generator]

self.fixed_shape = fixed_shape
self.ascii_dic = AsciiDic()

def train(self):
for i, (image, label) in enumerate(self.train_filelist):
yield self.load_image(image), self.ascii_dic.word2ids(label)

def test(self):
for i, (image, label) in enumerate(self.test_filelist):
yield self.load_image(image), self.ascii_dic.word2ids(label)

def infer(self):
for i, (image, label) in enumerate(self.infer_filelist):
yield self.load_image(image), label

def load_image(self, path):
'''
load image and transform to 1-dimention vector
'''
image = load_image(path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# resize all images to a fixed shape

if self.fixed_shape:
image = cv2.resize(
image, self.fixed_shape, interpolation=cv2.INTER_CUBIC)

image = image.flatten() / 255.
return image


def get_file_list(image_file_list):
pwd = os.path.dirname(image_file_list)
with open(image_file_list) as f:
for line in f:
fs = line.strip().split(',')
file = fs[0].strip()
path = os.path.join(pwd, file)
yield path, fs[1][2:-1]
35 changes: 35 additions & 0 deletions ctc/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Contains various CTC decoders."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from itertools import groupby
import numpy as np


def ctc_greedy_decoder(probs_seq, vocabulary):
"""CTC greedy (best path) decoder.
Path consisting of the most probable tokens are further post-processed to
remove consecutive repetitions and all blanks.
:param probs_seq: 2-D list of probabilities over the vocabulary for each
character. Each element is a list of float probabilities
for one character.
:type probs_seq: list
:param vocabulary: Vocabulary list.
:type vocabulary: list
:return: Decoding result string.
:rtype: baseline
"""
# dimension verification
for probs in probs_seq:
if not len(probs) == len(vocabulary) + 1:
raise ValueError("probs_seq dimension mismatchedd with vocabulary")
# argmax to get the best index for each time step
max_index_list = list(np.array(probs_seq).argmax(axis=1))
# remove consecutive duplicate indexes
index_list = [index_group[0] for index_group in groupby(max_index_list)]
# remove blank indexes
blank_index = len(vocabulary)
index_list = [index for index in index_list if index != blank_index]
# convert index list to string
return ''.join([vocabulary[index] for index in index_list])
Binary file added ctc/images/503.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added ctc/images/504.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added ctc/images/505.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added ctc/images/ctc.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added ctc/images/feature_vector.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added ctc/images/transcription.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
56 changes: 56 additions & 0 deletions ctc/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import logging
import argparse
import paddle.v2 as paddle
import gzip
from model import Model
from data_provider import get_file_list, AsciiDic, ImageDataset
from decoder import ctc_greedy_decoder


def infer(inferer, test_batch, labels):
infer_results = inferer.infer(input=test_batch)
num_steps = len(infer_results) // len(test_batch)
probs_split = [
infer_results[i * num_steps:(i + 1) * num_steps]
for i in xrange(0, len(test_batch))
]

results = []
# best path decode
for i, probs in enumerate(probs_split):
output_transcription = ctc_greedy_decoder(
probs_seq=probs, vocabulary=AsciiDic().id2word())
results.append(output_transcription)

for result, label in zip(results, labels):
print("\nOutput Transcription: %s\nTarget Transcription: %s" % (result,
label))


if __name__ == "__main__":
model_path = "model.ctc-pass-1-batch-150-test-10.2607016472.tar.gz"
image_shape = "173,46"
batch_size = 50
infer_file_list = 'data/test_data/Challenge2_Test_Task3_GT.txt'
image_shape = tuple(map(int, image_shape.split(',')))
infer_generator = get_file_list(infer_file_list)

dataset = ImageDataset(None, None, infer_generator, image_shape, True)

paddle.init(use_gpu=True, trainer_count=4)
parameters = paddle.parameters.Parameters.from_tar(gzip.open(model_path))
model = Model(AsciiDic().size(), image_shape, is_infer=True)
inferer = paddle.inference.Inference(
output_layer=model.log_probs, parameters=parameters)

test_batch = []
labels = []
for i, (image, label) in enumerate(dataset.infer()):
test_batch.append([image])
labels.append(label)
if len(test_batch) == batch_size:
infer(inferer, test_batch, labels)
test_batch = []
labels = []
if test_batch:
infer(inferer, test_batch, labels)
127 changes: 127 additions & 0 deletions ctc/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from paddle import v2 as paddle
from paddle.v2 import layer
from paddle.v2 import evaluator
from paddle.v2.activation import Relu, Linear
from paddle.v2.networks import img_conv_group, simple_gru


def conv_groups(input_image, num, with_bn):
'''
a deep CNN.
@input_image: input image
@num: number of CONV filters
@with_bn: whether with batch normal
'''
assert num % 4 == 0

tmp = img_conv_group(
input=input_image,
num_channels=1,
conv_padding=1,
conv_num_filter=[16] * (num / 4),
conv_filter_size=3,
conv_act=Relu(),
conv_with_batchnorm=with_bn,
pool_size=2,
pool_stride=2, )

tmp = img_conv_group(
input=tmp,
conv_padding=1,
conv_num_filter=[32] * (num / 4),
conv_filter_size=3,
conv_act=Relu(),
conv_with_batchnorm=with_bn,
pool_size=2,
pool_stride=2, )

tmp = img_conv_group(
input=tmp,
conv_padding=1,
conv_num_filter=[64] * (num / 4),
conv_filter_size=3,
conv_act=Relu(),
conv_with_batchnorm=with_bn,
pool_size=2,
pool_stride=2, )

tmp = img_conv_group(
input=tmp,
conv_padding=1,
conv_num_filter=[128] * (num / 4),
conv_filter_size=3,
conv_act=Relu(),
conv_with_batchnorm=with_bn,
pool_size=2,
pool_stride=2, )

return tmp


class Model(object):
def __init__(self, num_classes, shape, is_infer=False):
'''
@num_classes: int
size of the character dict
@shape: tuple of 2 int
size of the input images
'''
self.num_classes = num_classes
self.shape = shape
self.is_infer = is_infer
self.image_vector_size = shape[0] * shape[1]

self.__declare_input_layers__()
self.__build_nn__()

def __declare_input_layers__(self):
# image input as a float vector
self.image = layer.data(
name='image',
type=paddle.data_type.dense_vector(self.image_vector_size),
height=self.shape[0],
width=self.shape[1])

# label input as a ID list
if self.is_infer == False:
self.label = layer.data(
name='label',
type=paddle.data_type.integer_value_sequence(self.num_classes))

def __build_nn__(self):
# CNN output image features, 128 float matrixes
conv_features = conv_groups(self.image, 8, True)

# cutting CNN output into a sequence of feature vectors, which are
# 1 pixel wide and 11 pixel high.
sliced_feature = layer.block_expand(
input=conv_features,
num_channels=128,
stride_x=1,
stride_y=1,
block_x=1,
block_y=11)

# RNNs to capture sequence information forwards and backwards.
gru_forward = simple_gru(input=sliced_feature, size=128, act=Relu())
gru_backward = simple_gru(
input=sliced_feature, size=128, act=Relu(), reverse=True)

# map each step of RNN to character distribution.
self.output = layer.fc(
input=[gru_forward, gru_backward],
size=self.num_classes + 1,
act=Linear())

self.log_probs = paddle.layer.mixed(
input=paddle.layer.identity_projection(input=self.output),
act=paddle.activation.Softmax())

# warp CTC to calculate cost for a CTC task.
if self.is_infer == False:
self.cost = layer.warp_ctc(
input=self.output,
label=self.label,
size=self.num_classes + 1,
norm_by_times=True,
blank=self.num_classes)
Loading