Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add greedy CTC evaluator python API #7596

Closed
wanghaoshuang opened this issue Jan 17, 2018 · 0 comments · Fixed by #7655
Closed

Add greedy CTC evaluator python API #7596

wanghaoshuang opened this issue Jan 17, 2018 · 0 comments · Fixed by #7655
Assignees

Comments

@wanghaoshuang
Copy link
Contributor

wanghaoshuang commented Jan 17, 2018

This issue depend on #7527
CTC evaluator = top-k_op + ctc_align_op + edit_distance_op

Test script:

import numpy as np
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
from paddle.v2.fluid import core

x = fluid.layers.data(name='x', shape=[8], dtype='float32')
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
ctc_result = fluid.layers.ctc_greedy_decoder(input=x, blank=0)
edit_distance = fluid.evaluator.EditDistance(input=ctc_result,label=y)
print "step1"

place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
print "step2"
edit_distance.reset(exe)
batch_num = 2
for i in range(batch_num):
    print "step3"
    y_data = np.random.randint(0, 8, [7, 1])
    y_lod = [[0, 2, 4, 7]]
    y_tensor = core.LoDTensor()
    y_tensor.set(y_data, place)
    y_tensor.set_lod(y_lod)

    x_data = np.random.uniform(0.1, 1, [11, 8]).astype("float32")
    x_lod = [[0, 3, 5, 11]]
    x_tensor = core.LoDTensor()
    x_tensor.set(x_data, place)
    x_tensor.set_lod(x_lod)

    cost, = exe.run(fluid.default_main_program(),
                              feed={
                                   'x': x_tensor,
                                   'y': y_tensor
                                },
                                fetch_list=edit_distance.metrics)
    pass_error = edit_distance.eval(exe)
    print "cost: %s" % cost
    print "pass_id=%d; pass_error=%s" % (i, str(pass_error))

pass_error = edit_distance.eval(exe)
print "total_pass_error=%s" % str(pass_error)
@wanghaoshuang wanghaoshuang self-assigned this Jan 17, 2018
@wanghaoshuang wanghaoshuang changed the title Add CTC evaluator python API Add greedy CTC evaluator python API Jan 18, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant