Skip to content

Commit

Permalink
Merge pull request #7655 from wanghaoshuang/ctc_evaluator_py
Browse files Browse the repository at this point in the history
Add python wrapper for  CTC greedy decoder and edit distance evaluator
  • Loading branch information
wanghaoshuang authored Jan 22, 2018
2 parents b156bbc + d9d9be1 commit 44561a2
Show file tree
Hide file tree
Showing 8 changed files with 233 additions and 6 deletions.
10 changes: 10 additions & 0 deletions doc/api/v2/fluid/layers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,16 @@ swish
.. autofunction:: paddle.v2.fluid.layers.swish
:noindex:

edit_distance
---------------
.. autofunction:: paddle.v2.fluid.layers.edit_distance_error
:noindex:

ctc_greedy_decoder
---------------
.. autofunction:: paddle.v2.fluid.layers.ctc_greedy_decoder
:noindex:

l2_normalize
------------
.. autofunction:: paddle.v2.fluid.layers.l2_normalize
Expand Down
1 change: 1 addition & 0 deletions paddle/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ op_library(parallel_do_op DEPS executor)
# Regist multiple Kernel to pybind
if (WITH_GPU)
op_library(conv_op SRCS conv_op.cc conv_op.cu.cc conv_cudnn_op.cu.cc DEPS vol2col)
op_library(edit_distance_op SRCS edit_distance_op.cc edit_distance_op.cu DEPS math_function)
op_library(pool_op SRCS pool_op.cc pool_op.cu.cc pool_cudnn_op.cu.cc DEPS pooling)
op_library(conv_transpose_op SRCS conv_transpose_op.cc conv_transpose_op.cu.cc
conv_transpose_cudnn_op.cu.cc DEPS vol2col)
Expand Down
4 changes: 4 additions & 0 deletions paddle/operators/edit_distance_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class EditDistanceOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("Hyps"), "Input(Hyps) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("Refs"), "Input(Refs) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasOutput("SequenceNum"),
"Output(SequenceNum) shouldn't be null.");
auto hyp_dims = ctx->GetInputDim("Hyps");
auto ref_dims = ctx->GetInputDim("Refs");
PADDLE_ENFORCE(hyp_dims.size() == 2 && hyp_dims[1] == 1,
Expand All @@ -34,6 +36,7 @@ class EditDistanceOp : public framework::OperatorWithKernel {
"Input(Refs) must be a 2-D LoDTensor with the 2nd dimension "
"equal to 1.");
ctx->SetOutputDim("Out", ctx->GetInputDim("Refs"));
ctx->SetOutputDim("SequenceNum", {1});
}

protected:
Expand All @@ -54,6 +57,7 @@ class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Refs",
"(2-D LoDTensor<int64_t>, 2nd dim. equal to 1) "
"The indices for reference strings.");
AddOutput("SequenceNum", "The sequence count of current batch");
AddAttr<bool>("normalized",
"(bool, default false) Indicated whether to normalize "
"the edit distance by the length of reference string.")
Expand Down
9 changes: 8 additions & 1 deletion paddle/operators/edit_distance_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#include <algorithm>
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/platform/cuda_helper.h"
#include "paddle/platform/gpu_info.h"

Expand Down Expand Up @@ -72,6 +73,8 @@ class EditDistanceGPUKernel : public framework::OpKernel<T> {

auto* x1_t = ctx.Input<framework::LoDTensor>("Hyps");
auto* x2_t = ctx.Input<framework::LoDTensor>("Refs");
auto* sequence_num = ctx.Output<framework::Tensor>("SequenceNum");
sequence_num->mutable_data<int64_t>(ctx.GetPlace());

auto normalized = ctx.Attr<bool>("normalized");
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
Expand All @@ -88,7 +91,11 @@ class EditDistanceGPUKernel : public framework::OpKernel<T> {
"Reference string %d is empty.", i);
}

auto num_strs = hyp_lod.size() - 1;
const size_t num_strs = hyp_lod.size() - 1;
math::SetConstant<platform::CUDADeviceContext, int64_t> set_constant;
set_constant(ctx.template device_context<platform::CUDADeviceContext>(),
sequence_num, static_cast<int64_t>(num_strs));

out_t->Resize({static_cast<int64_t>(num_strs), 1});
out_t->mutable_data<T>(ctx.GetPlace());
auto out = out_t->data<T>();
Expand Down
4 changes: 3 additions & 1 deletion paddle/operators/edit_distance_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License. */
#include <algorithm>
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"

namespace paddle {
namespace operators {

Expand All @@ -28,6 +27,8 @@ class EditDistanceKernel : public framework::OpKernel<T> {

auto* x1_t = ctx.Input<framework::LoDTensor>("Hyps");
auto* x2_t = ctx.Input<framework::LoDTensor>("Refs");
auto* sequence_num = ctx.Output<framework::Tensor>("SequenceNum");
int64_t* seq_num_data = sequence_num->mutable_data<int64_t>(ctx.GetPlace());

auto normalized = ctx.Attr<bool>("normalized");

Expand All @@ -41,6 +42,7 @@ class EditDistanceKernel : public framework::OpKernel<T> {
"Reference string %d is empty.", i);
}
auto num_strs = hyp_lod.size() - 1;
*seq_num_data = static_cast<int64_t>(num_strs);

out_t->Resize({static_cast<int64_t>(num_strs), 1});
out_t->mutable_data<float>(ctx.GetPlace());
Expand Down
60 changes: 60 additions & 0 deletions python/paddle/v2/fluid/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,63 @@ def eval(self, executor, eval_program=None):
[precision], dtype='float32'), np.array(
[recall], dtype='float32'), np.array(
[f1_score], dtype='float32')


class EditDistance(Evaluator):
"""
Accumulate edit distance sum and sequence number from mini-batches and
compute the average edit_distance of all batches.
Args:
input: the sequences predicted by network.
label: the target sequences which must has same sequence count
with input.
ignored_tokens(list of int): Tokens that should be removed before
calculating edit distance.
Example:
exe = fluid.executor(place)
distance_evaluator = fluid.Evaluator.EditDistance(input, label)
for epoch in PASS_NUM:
distance_evaluator.reset(exe)
for data in batches:
loss, sum_distance = exe.run(fetch_list=[cost] + distance_evaluator.metrics)
avg_distance = distance_evaluator.eval(exe)
pass_distance = distance_evaluator.eval(exe)
In the above example:
'sum_distance' is the sum of the batch's edit distance.
'avg_distance' is the average of edit distance from the firt batch to the current batch.
'pass_distance' is the average of edit distance from all the pass.
"""

def __init__(self, input, label, ignored_tokens=None, **kwargs):
super(EditDistance, self).__init__("edit_distance", **kwargs)
main_program = self.helper.main_program
if main_program.current_block().idx != 0:
raise ValueError("You can only invoke Evaluator in root block")

self.total_error = self.create_state(
dtype='float32', shape=[1], suffix='total_error')
self.seq_num = self.create_state(
dtype='int64', shape=[1], suffix='seq_num')
error, seq_num = layers.edit_distance(
input=input, label=label, ignored_tokens=ignored_tokens)
#error = layers.cast(x=error, dtype='float32')
sum_error = layers.reduce_sum(error)
layers.sums(input=[self.total_error, sum_error], out=self.total_error)
layers.sums(input=[self.seq_num, seq_num], out=self.seq_num)
self.metrics.append(sum_error)

def eval(self, executor, eval_program=None):
if eval_program is None:
eval_program = Program()
block = eval_program.current_block()
with program_guard(main_program=eval_program):
total_error = _clone_var_(block, self.total_error)
seq_num = _clone_var_(block, self.seq_num)
seq_num = layers.cast(x=seq_num, dtype='float32')
out = layers.elementwise_div(x=total_error, y=seq_num)
return np.array(executor.run(eval_program, fetch_list=[out])[0])
145 changes: 143 additions & 2 deletions python/paddle/v2/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'sequence_expand',
'lstm_unit', 'reduce_sum', 'reduce_mean', 'reduce_max', 'reduce_min',
'sequence_first_step', 'sequence_last_step', 'dropout', 'split',
'l2_normalize', 'matmul', 'warpctc', 'sequence_reshape'
'ctc_greedy_decoder', 'edit_distance', 'l2_normalize', 'matmul', 'warpctc',
'sequence_reshape'
]


Expand Down Expand Up @@ -1866,6 +1867,146 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
return out


def edit_distance(input,
label,
normalized=False,
ignored_tokens=None,
name=None):
"""
EditDistance operator computes the edit distances between a batch of hypothesis strings and their references. Edit distance, also called Levenshtein distance, measures how dissimilar two strings are by counting the minimum number of operations to transform one string into anthor. Here the operations include insertion, deletion, and substitution. For example, given hypothesis string A = "kitten" and reference B = "sitting", the edit distance is 3 for A will be transformed into B at least after two substitutions and one insertion:
"kitten" -> "sitten" -> "sittin" -> "sitting"
Input(Hyps) is a LoDTensor consisting of all the hypothesis strings with the total number denoted by `batch_size`, and the separation is specified by the LoD information. And the `batch_size` reference strings are arranged in order in the same way in the LoDTensor Input(Refs).
Output(Out) contains the `batch_size` results and each stands for the edit stance for a pair of strings respectively. If Attr(normalized) is true, the edit distance will be divided by the length of reference string.
Args:
input(Variable): The indices for hypothesis strings.
label(Variable): The indices for reference strings.
normalized(bool): Indicated whether to normalize the edit distance by the length of reference string.
ignored_tokens(list of int): Tokens that should be removed before calculating edit distance.
Returns:
Variable: sequence-to-sequence edit distance in shape [batch_size, 1].
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[8], dtype='float32')
y = fluid.layers.data(name='y', shape=[7], dtype='float32')
cost = fluid.layers.edit_distance(input=x,label=y)
"""
helper = LayerHelper("edit_distance", **locals())

# remove some tokens from input and labels
if ignored_tokens is not None and len(ignored_tokens) > 0:
erased_input = helper.create_tmp_variable(dtype="int64")
erased_label = helper.create_tmp_variable(dtype="int64")

helper.append_op(
type="sequence_erase",
inputs={"X": [input]},
outputs={"Out": [erased_input]},
attrs={"tokens": ignored_tokens})
input = erased_input

helper.append_op(
type="sequence_erase",
inputs={"X": [label]},
outputs={"Out": [erase_label]},
attrs={"tokens": ignored_tokens})
label = erased_label

# edit distance op
edit_distance_out = helper.create_tmp_variable(dtype="int64")
sequence_num = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="edit_distance",
inputs={"Hyps": [input],
"Refs": [label]},
outputs={"Out": [edit_distance_out],
"SequenceNum": [sequence_num]},
attrs={"normalized": normalized})

return edit_distance_out, sequence_num


def ctc_greedy_decoder(input, blank, name=None):
"""
This op is used to decode sequences by greedy policy by below steps:
1. Get the indexes of max value for each row in input. a.k.a. numpy.argmax(input, axis=0).
2. For each sequence in result of step1, merge repeated tokens between two blanks and delete all blanks.
A simple example as below:
.. code-block:: text
Given:
input.data = [[0.6, 0.1, 0.3, 0.1],
[0.3, 0.2, 0.4, 0.1],
[0.1, 0.5, 0.1, 0.3],
[0.5, 0.1, 0.3, 0.1],
[0.5, 0.1, 0.3, 0.1],
[0.2, 0.2, 0.2, 0.4],
[0.2, 0.2, 0.1, 0.5],
[0.5, 0.1, 0.3, 0.1]]
input.lod = [[0, 4, 8]]
Then:
output.data = [[2],
[1],
[3]]
output.lod = [[0, 2, 3]]
Args:
input(Variable): (LoDTensor<float>), the probabilities of variable-length sequences, which is a 2-D Tensor with LoD information. It's shape is [Lp, num_classes + 1], where Lp is the sum of all input sequences' length and num_classes is the true number of classes. (not including the blank label).
blank(int): the blank label index of Connectionist Temporal Classification (CTC) loss, which is in thehalf-opened interval [0, num_classes + 1).
Returns:
Variable: CTC greedy decode result.
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[8], dtype='float32')
cost = fluid.layers.ctc_greedy_decoder(input=x, blank=0)
"""
helper = LayerHelper("ctc_greedy_decoder", **locals())
# top 1 op
topk_out = helper.create_tmp_variable(dtype=input.dtype)
topk_indices = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="top_k",
inputs={"X": [input]},
outputs={"Out": [topk_out],
"Indices": [topk_indices]},
attrs={"k": 1})

# ctc align op
ctc_out = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="ctc_align",
inputs={"Input": [topk_indices]},
outputs={"Output": [ctc_out]},
attrs={"merge_repeated": True,
"blank": blank})
return ctc_out


def warpctc(input, label, blank=0, norm_by_times=False, **kwargs):
"""
An operator integrating the open source Warp-CTC library
Expand All @@ -1890,7 +2031,7 @@ def warpctc(input, label, blank=0, norm_by_times=False, **kwargs):
Temporal Classification (CTC) loss, which is in the
half-opened interval [0, num_classes + 1).
norm_by_times: (bool, default: false), whether to normalize
the gradients by the number of time-step,which is also the
the gradients by the number of time-step, which is also the
sequence's length. There is no need to normalize the gradients
if warpctc layer was follewed by a mean_op.
Expand Down
6 changes: 4 additions & 2 deletions python/paddle/v2/fluid/tests/test_edit_distance_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def setUp(self):

num_strs = len(x1_lod) - 1
distance = np.zeros((num_strs, 1)).astype("float32")
sequence_num = np.array(2).astype("int64")
for i in range(0, num_strs):
distance[i] = Levenshtein(
hyp=x1[x1_lod[i]:x1_lod[i + 1]],
Expand All @@ -70,7 +71,7 @@ def setUp(self):
distance[i] = distance[i] / len_ref
self.attrs = {'normalized': normalized}
self.inputs = {'Hyps': (x1, [x1_lod]), 'Refs': (x2, [x2_lod])}
self.outputs = {'Out': distance}
self.outputs = {'Out': distance, 'SequenceNum': sequence_num}

def test_check_output(self):
self.check_output()
Expand All @@ -89,6 +90,7 @@ def setUp(self):

num_strs = len(x1_lod) - 1
distance = np.zeros((num_strs, 1)).astype("float32")
sequence_num = np.array(3).astype("int64")
for i in range(0, num_strs):
distance[i] = Levenshtein(
hyp=x1[x1_lod[i]:x1_lod[i + 1]],
Expand All @@ -98,7 +100,7 @@ def setUp(self):
distance[i] = distance[i] / len_ref
self.attrs = {'normalized': normalized}
self.inputs = {'Hyps': (x1, [x1_lod]), 'Refs': (x2, [x2_lod])}
self.outputs = {'Out': distance}
self.outputs = {'Out': distance, 'SequenceNum': sequence_num}

def test_check_output(self):
self.check_output()
Expand Down

0 comments on commit 44561a2

Please sign in to comment.