From a60b3a5d557ca5e040e8f360709d0136fd6d1645 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Thu, 14 Dec 2017 11:47:32 +0800 Subject: [PATCH 1/8] fix doc of seq_expand_op --- paddle/operators/seq_expand_op.cc | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/paddle/operators/seq_expand_op.cc b/paddle/operators/seq_expand_op.cc index ede9754697429a..8051ddd702f7e3 100644 --- a/paddle/operators/seq_expand_op.cc +++ b/paddle/operators/seq_expand_op.cc @@ -59,7 +59,7 @@ This operator expands input(X) according to LOD of input(Y). Following are cases to better explain how this works: Case 1: -Given 2-level a LoDTensor input(X) +Given a 2-level LoDTensor input(X) X.lod = [[0, 2, 3], [0, 1, 3, 4]] X.data = [a, b, c, d] @@ -76,9 +76,8 @@ then we get 2-level LoDTensor Case 2: -Given a 0-level LoDTensor input(X) +Given a common Tensor input(X) X.data = [a, b, c] - X.lod = NULL X.dims = [3, 1] and input(Y) Y.lod = [[0, 2, 3, 6]] @@ -90,9 +89,8 @@ then we get 1-level LoDTensor Case 3: -Given a 0-level LoDTensor input(X) +Given a common Tensor input(X) X.data = [[a, b], [c, d], [e, f]] - X.lod = NULL X.dims = [3, 2] and input(Y) Y.lod = [[0, 2, 3, 6]] From 579f684661d1badf34957b8a48ffe7d713547ead Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Mon, 15 Jan 2018 15:49:46 +0800 Subject: [PATCH 2/8] Add ctc_greedy_decode_op --- paddle/operators/ctc_greedy_decode_op.cc | 88 +++++++++++ paddle/operators/ctc_greedy_decode_op.cu | 138 ++++++++++++++++++ paddle/operators/ctc_greedy_decode_op.h | 82 +++++++++++ .../v2/fluid/tests/test_ctc_greedy_decode.py | 56 +++++++ 4 files changed, 364 insertions(+) create mode 100644 paddle/operators/ctc_greedy_decode_op.cc create mode 100644 paddle/operators/ctc_greedy_decode_op.cu create mode 100644 paddle/operators/ctc_greedy_decode_op.h create mode 100644 python/paddle/v2/fluid/tests/test_ctc_greedy_decode.py diff --git a/paddle/operators/ctc_greedy_decode_op.cc b/paddle/operators/ctc_greedy_decode_op.cc new file mode 100644 index 00000000000000..3c9b705f7f61c3 --- /dev/null +++ b/paddle/operators/ctc_greedy_decode_op.cc @@ -0,0 +1,88 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/ctc_greedy_decode_op.h" + +namespace paddle { +namespace operators { + +class CTCGreedyDecodeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input of CTCGreedyDecodeOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Output"), + "Output of CTCGreedyDecodeOp should not be null."); + + auto input_dims = ctx->GetInputDim("Input"); + + int sequence_width = + static_cast(framework::product(input_dims) / input_dims[0]); + int blank = ctx->Attrs().Get("blank"); + PADDLE_ENFORCE((blank >= 0) && (blank < sequence_width), + "The value of Attr(blank) should be in interval [0, %d).", + sequence_width); + // TODO(wanghaoshuang): it is tricky to set the wrong dimension here. + ctx->SetOutputDim("Output", {input_dims[0], 1}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Input")->type()), + ctx.device_context()); + } +}; + +class CTCGreedyDecodeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + CTCGreedyDecodeOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Input", + "(LodTensor, default: LoDTensor), the unscaled " + "probabilities of variable-length sequences, which is a 2-D " + "Tensor with LoD information. It's shape is " + "[Lp, num_classes + 1], where Lp is the sum of all input " + "sequences' length and num_classes is the true number of classes " + "(not including the blank label)."); + AddOutput("Output", "(Tensor, default: Tensor), the decode result "); + AddAttr("blank", + "(int, default: 0), the blank label setted in Connectionist " + "Temporal Classification (CTC) op, and it is in the " + "half-opened interval [0, num_classes + 1).") + .SetDefault(0); + AddAttr("merge_repeated", + "(bool, default: true), whether to " + "merge repeated elements between two blanks. ") + .SetDefault(true); + AddComment(R"DOC( +CTCGreedyDecoder is an implementation of the simple best path decoding +algorithm, selecting at each timestep the most likely class at each timestep. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(ctc_greedy_decode, ops::CTCGreedyDecodeOp, + ops::CTCGreedyDecodeOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + ctc_greedy_decode, + ops::CTCGreedyDecodeKernel); diff --git a/paddle/operators/ctc_greedy_decode_op.cu b/paddle/operators/ctc_greedy_decode_op.cu new file mode 100644 index 00000000000000..43a78745aca2f6 --- /dev/null +++ b/paddle/operators/ctc_greedy_decode_op.cu @@ -0,0 +1,138 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include "paddle/operators/ctc_greedy_decode_op.h" +#include "paddle/platform/cuda_helper.h" +#include "paddle/platform/gpu_info.h" + +namespace paddle { +namespace operators { +using platform::PADDLE_CUDA_NUM_THREADS; + +__device__ static float atomicMaxF(float* address, float val) { + int* address_as_i = (int*)address; + int old = *address_as_i, assumed; + do { + assumed = old; + old = ::atomicCAS(address_as_i, assumed, + __float_as_int(::fmaxf(val, __int_as_float(assumed)))); + } while (assumed != old); + return __int_as_float(old); +} + +template +__global__ void ArgmaxCudaKernel(const size_t seq_width, const T* logits, + int* output) { + T local_max_value = 0; + int local_max_index = 0; + __shared__ T max_value; + if (threadIdx.x == 0) { + max_value = 0; + } + __syncthreads(); + + for (int i = threadIdx.x; i < seq_width; i += BlockSize) { + T value = logits[blockIdx.x * seq_width + i]; + if (value > local_max_value) { + local_max_value = value; + local_max_index = i; + } + } + + atomicMaxF(&max_value, local_max_value); + + __syncthreads(); + + if (local_max_value == max_value) { + output[blockIdx.x] = local_max_index; + } +} + +template +__global__ void MergeAndDelCudaKernel(const int64_t num_token, int* tokens, + const size_t num_seq, size_t* lod0, + const int blank, const int merge_repeated, + size_t* out_lod0, int* output) { + int ouput_idx = 0; + out_lod0[0] = 0; + + for (int i = 0; i < num_seq; ++i) { + int pre_token = -1; + for (int j = lod0[i]; j < lod0[i + 1]; ++j) { + if (tokens[j] != blank && !(merge_repeated && tokens[j] == pre_token)) { + output[ouput_idx] = tokens[j]; + ++ouput_idx; + } + pre_token = tokens[j]; + } + out_lod0[i + 1] = ouput_idx; + } +} + +template +class CTCGreedyDecodeOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use CUDAPlace."); + auto* input = ctx.Input("Input"); + auto* output = ctx.Output("Output"); + + const int64_t num_tokens = input->dims()[0]; + const size_t seq_width = input->numel() / num_tokens; + const T* logits = input->data(); + Tensor tmp; + int* tokens = tmp.mutable_data({num_tokens, 1}, ctx.GetPlace()); + // get argmax + // platform::GpuMemsetAsync(args, 0, sizeof(float), stream); + + auto stream = ctx.cuda_device_context().stream(); + ArgmaxCudaKernel<<< + num_tokens, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(seq_width, logits, + tokens); + + const size_t level = 0; + auto input_lod = framework::ToAbsOffset(input->lod()); + const size_t num_seq = input_lod[level].size() - 1; + const int blank = ctx.Attr("blank"); + const int merge_repeated = + static_cast(ctx.Attr("merge_repeated")); + + thrust::device_vector dev_out_lod0(input_lod[level].size()); + size_t* dev_out_lod0_ptr = thrust::raw_pointer_cast(dev_out_lod0.data()); + + int* output_data = + output->mutable_data({num_tokens, 1}, ctx.GetPlace()); + MergeAndDelCudaKernel<<<1, 1, 0, stream>>>( + num_tokens, tokens, num_seq, input_lod[level].data(), blank, + merge_repeated, dev_out_lod0_ptr, output_data); + + thrust::host_vector host_out_lod0(dev_out_lod0.begin(), + dev_out_lod0.end()); + framework::LoD out_lod; + out_lod.push_back(host_out_lod0); + output->set_lod(out_lod); + + output->Resize({static_cast(host_out_lod0.back()), 1}); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_CUDA_KERNEL(ctc_greedy_decode, + paddle::operators::CTCGreedyDecodeOpCUDAKernel); diff --git a/paddle/operators/ctc_greedy_decode_op.h b/paddle/operators/ctc_greedy_decode_op.h new file mode 100644 index 00000000000000..f12ea6c541b699 --- /dev/null +++ b/paddle/operators/ctc_greedy_decode_op.h @@ -0,0 +1,82 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/framework/op_registry.h" +#include "unsupported/Eigen/CXX11/Tensor" +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +template +class CTCGreedyDecodeKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* output = ctx.Output("Output"); + const size_t level = 0; + + auto input_lod = framework::ToAbsOffset(input->lod()); + auto input_dims = input->dims(); + PADDLE_ENFORCE_EQ(input_dims[0], + static_cast(input_lod[level].back()), + "The first dimension of Input(Input) should be equal to " + "the sum of all sequences' lengths."); + + const size_t num_sequences = input_lod[level].size() - 1; + const size_t sequence_width = input->numel() / input_dims[0]; + size_t blank = static_cast(ctx.Attr("blank")); + bool merge_repeated = ctx.Attr("merge_repeated"); + std::vector> pathes(num_sequences); + std::vector output_lod0(1, 0); + + const T* input_data = input->data(); + Eigen::Map< + Eigen::Matrix> + input_mat(const_cast(input_data), input->numel() / sequence_width, + sequence_width); + + size_t max_class_idx; + size_t prev_class_idx = -1; + for (size_t seq_idx = 0; seq_idx < num_sequences; ++seq_idx) { + for (size_t i = input_lod[level][seq_idx]; + i < input_lod[level][seq_idx + 1]; ++i) { + input_mat.row(i).maxCoeff(&max_class_idx); + if (max_class_idx != blank && + !(merge_repeated && max_class_idx == prev_class_idx)) { + pathes[seq_idx].push_back(max_class_idx); + } + prev_class_idx = max_class_idx; + } + output_lod0.push_back(output_lod0.back() + pathes[seq_idx].size()); + } + framework::LoD output_lod; + output_lod.push_back(output_lod0); + output->set_lod(output_lod); + int64_t num_step = static_cast(output_lod0.back()); + int* output_data = output->mutable_data({num_step, 1}, ctx.GetPlace()); + + for (int i = 0; i < num_sequences; ++i) { + memcpy(output_data + output_lod0[i], pathes[i].data(), + sizeof(int) * pathes[i].size()); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_ctc_greedy_decode.py b/python/paddle/v2/fluid/tests/test_ctc_greedy_decode.py new file mode 100644 index 00000000000000..23fceb6dcdbc90 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_ctc_greedy_decode.py @@ -0,0 +1,56 @@ +import sys +import unittest +import numpy as np +from op_test import OpTest +from test_softmax_op import stable_softmax + + +def CTCGreedyDecode(softmax, blank, merge_repeated): + prev_token = -1 + result = [] + for token in np.argmax(softmax, axis=1): + if (token != blank) and not (merge_repeated and token == prev_token): + result.append(token) + return np.array(result).reshape([len(result), 1]) + + +class TestCTCGreedyDecodeOp(OpTest): + def config(self): + self.op_type = "ctc_greedy_decode" + self.batch_size = 4 + self.num_classes = 8 + self.input_lod = [[0, 4, 5, 8, 11]] + self.blank = 7 + self.merge_repeated = True + + def setUp(self): + self.config() + input = np.random.uniform( + 0.1, 1.0, + [self.input_lod[0][-1], self.num_classes]).astype("float32") + softmax = np.apply_along_axis(stable_softmax, 1, input) + output = CTCGreedyDecode(softmax, self.blank, self.merge_repeated) + + self.inputs = {"Input": (softmax, self.input_lod), } + self.outputs = {"Output": output} + self.attrs = { + "blank": self.blank, + "merge_repeated": self.merge_repeated + } + + def test_check_output(self): + self.check_output() + + +class TestCTCGreedyDecodeOpCase1(TestCTCGreedyDecodeOp): + def config(self): + self.op_type = "ctc_greedy_decode" + self.batch_size = 4 + self.num_classes = 1025 + self.input_lod = [[0, 4, 5, 8, 11]] + self.blank = 0 + self.merge_repeated = True + + +if __name__ == "__main__": + unittest.main() From 281e93bcbb3e67996f3b7a2f76df1da0071969db Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Tue, 16 Jan 2018 15:15:10 +0800 Subject: [PATCH 3/8] Remove 'top 1' from CPU and GPU kernel 1. Remove 'top 1'(or argmax) from CPU and GPU kernel 2. Add a new test case 3. Refine doc --- ...c_greedy_decode_op.cc => ctc_decode_op.cc} | 44 ++++++----- ...c_greedy_decode_op.cu => ctc_decode_op.cu} | 77 ++++--------------- ...ctc_greedy_decode_op.h => ctc_decode_op.h} | 34 ++++---- .../paddle/v2/fluid/tests/test_ctc_decode.py | 62 +++++++++++++++ .../v2/fluid/tests/test_ctc_greedy_decode.py | 56 -------------- 5 files changed, 118 insertions(+), 155 deletions(-) rename paddle/operators/{ctc_greedy_decode_op.cc => ctc_decode_op.cc} (67%) rename paddle/operators/{ctc_greedy_decode_op.cu => ctc_decode_op.cu} (60%) rename paddle/operators/{ctc_greedy_decode_op.h => ctc_decode_op.h} (75%) create mode 100644 python/paddle/v2/fluid/tests/test_ctc_decode.py delete mode 100644 python/paddle/v2/fluid/tests/test_ctc_greedy_decode.py diff --git a/paddle/operators/ctc_greedy_decode_op.cc b/paddle/operators/ctc_decode_op.cc similarity index 67% rename from paddle/operators/ctc_greedy_decode_op.cc rename to paddle/operators/ctc_decode_op.cc index 3c9b705f7f61c3..b290b11d1d1e80 100644 --- a/paddle/operators/ctc_greedy_decode_op.cc +++ b/paddle/operators/ctc_decode_op.cc @@ -29,14 +29,8 @@ class CTCGreedyDecodeOp : public framework::OperatorWithKernel { auto input_dims = ctx->GetInputDim("Input"); - int sequence_width = - static_cast(framework::product(input_dims) / input_dims[0]); - int blank = ctx->Attrs().Get("blank"); - PADDLE_ENFORCE((blank >= 0) && (blank < sequence_width), - "The value of Attr(blank) should be in interval [0, %d).", - sequence_width); // TODO(wanghaoshuang): it is tricky to set the wrong dimension here. - ctx->SetOutputDim("Output", {input_dims[0], 1}); + ctx->SetOutputDim("Output", input_dims); } protected: @@ -53,25 +47,37 @@ class CTCGreedyDecodeOpMaker : public framework::OpProtoAndCheckerMaker { CTCGreedyDecodeOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("Input", - "(LodTensor, default: LoDTensor), the unscaled " - "probabilities of variable-length sequences, which is a 2-D " - "Tensor with LoD information. It's shape is " - "[Lp, num_classes + 1], where Lp is the sum of all input " - "sequences' length and num_classes is the true number of classes " - "(not including the blank label)."); - AddOutput("Output", "(Tensor, default: Tensor), the decode result "); + "(LodTensor, default: LoDTensor), Its shape is " + "[Lp, 1], where Lp is the sum of all input sequences' length."); + AddOutput("Output", "(Tensor, default: Tensor), The decode result."); AddAttr("blank", "(int, default: 0), the blank label setted in Connectionist " - "Temporal Classification (CTC) op, and it is in the " - "half-opened interval [0, num_classes + 1).") + "Temporal Classification (CTC) op.") .SetDefault(0); AddAttr("merge_repeated", "(bool, default: true), whether to " "merge repeated elements between two blanks. ") .SetDefault(true); AddComment(R"DOC( -CTCGreedyDecoder is an implementation of the simple best path decoding -algorithm, selecting at each timestep the most likely class at each timestep. +CTCDecoder is used to merge repeated elements between two blanks +and then delete all blanks in sequence. + +Given: + Input.data = [0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6, + 6, 0, 0, 7, 7, 7, 0] + Input.dims = {18, 1} + Input.LoD = [[0, 11, 18]] + +And: + blank = 0 + merge_repeated = True + +Then: + Output.data = [1, 2, 4, 4, 5, 6, + 6, 7] + Output.dims = {8, 1} + Output.LoD = [[0, 6, 8]] + )DOC"); } }; @@ -85,4 +91,4 @@ REGISTER_OPERATOR(ctc_greedy_decode, ops::CTCGreedyDecodeOp, paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL( ctc_greedy_decode, - ops::CTCGreedyDecodeKernel); + ops::CTCGreedyDecodeKernel); diff --git a/paddle/operators/ctc_greedy_decode_op.cu b/paddle/operators/ctc_decode_op.cu similarity index 60% rename from paddle/operators/ctc_greedy_decode_op.cu rename to paddle/operators/ctc_decode_op.cu index 43a78745aca2f6..e9cdad7c26b179 100644 --- a/paddle/operators/ctc_greedy_decode_op.cu +++ b/paddle/operators/ctc_decode_op.cu @@ -16,62 +16,20 @@ limitations under the License. */ #include #include #include "paddle/operators/ctc_greedy_decode_op.h" -#include "paddle/platform/cuda_helper.h" -#include "paddle/platform/gpu_info.h" namespace paddle { namespace operators { -using platform::PADDLE_CUDA_NUM_THREADS; - -__device__ static float atomicMaxF(float* address, float val) { - int* address_as_i = (int*)address; - int old = *address_as_i, assumed; - do { - assumed = old; - old = ::atomicCAS(address_as_i, assumed, - __float_as_int(::fmaxf(val, __int_as_float(assumed)))); - } while (assumed != old); - return __int_as_float(old); -} - -template -__global__ void ArgmaxCudaKernel(const size_t seq_width, const T* logits, - int* output) { - T local_max_value = 0; - int local_max_index = 0; - __shared__ T max_value; - if (threadIdx.x == 0) { - max_value = 0; - } - __syncthreads(); - - for (int i = threadIdx.x; i < seq_width; i += BlockSize) { - T value = logits[blockIdx.x * seq_width + i]; - if (value > local_max_value) { - local_max_value = value; - local_max_index = i; - } - } - - atomicMaxF(&max_value, local_max_value); - - __syncthreads(); - - if (local_max_value == max_value) { - output[blockIdx.x] = local_max_index; - } -} template -__global__ void MergeAndDelCudaKernel(const int64_t num_token, int* tokens, +__global__ void MergeAndDelCudaKernel(const int64_t num_token, const T* tokens, const size_t num_seq, size_t* lod0, const int blank, const int merge_repeated, - size_t* out_lod0, int* output) { + size_t* out_lod0, T* output) { int ouput_idx = 0; out_lod0[0] = 0; for (int i = 0; i < num_seq; ++i) { - int pre_token = -1; + T pre_token = -1; for (int j = lod0[i]; j < lod0[i + 1]; ++j) { if (tokens[j] != blank && !(merge_repeated && tokens[j] == pre_token)) { output[ouput_idx] = tokens[j]; @@ -89,44 +47,39 @@ class CTCGreedyDecodeOpCUDAKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "It must use CUDAPlace."); + const size_t level = 0; auto* input = ctx.Input("Input"); auto* output = ctx.Output("Output"); + auto input_lod = framework::ToAbsOffset(input->lod()); + const T* tokens = input->data(); const int64_t num_tokens = input->dims()[0]; - const size_t seq_width = input->numel() / num_tokens; - const T* logits = input->data(); - Tensor tmp; - int* tokens = tmp.mutable_data({num_tokens, 1}, ctx.GetPlace()); - // get argmax - // platform::GpuMemsetAsync(args, 0, sizeof(float), stream); - - auto stream = ctx.cuda_device_context().stream(); - ArgmaxCudaKernel<<< - num_tokens, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(seq_width, logits, - tokens); - - const size_t level = 0; - auto input_lod = framework::ToAbsOffset(input->lod()); const size_t num_seq = input_lod[level].size() - 1; + const int blank = ctx.Attr("blank"); const int merge_repeated = static_cast(ctx.Attr("merge_repeated")); + // prepare a lod to record lod information while merging elements thrust::device_vector dev_out_lod0(input_lod[level].size()); size_t* dev_out_lod0_ptr = thrust::raw_pointer_cast(dev_out_lod0.data()); - int* output_data = - output->mutable_data({num_tokens, 1}, ctx.GetPlace()); + // merge elements and delete blank + T* output_data = output->mutable_data({num_tokens, 1}, ctx.GetPlace()); + + auto stream = ctx.cuda_device_context().stream(); MergeAndDelCudaKernel<<<1, 1, 0, stream>>>( num_tokens, tokens, num_seq, input_lod[level].data(), blank, merge_repeated, dev_out_lod0_ptr, output_data); + // set output lod thrust::host_vector host_out_lod0(dev_out_lod0.begin(), dev_out_lod0.end()); framework::LoD out_lod; out_lod.push_back(host_out_lod0); output->set_lod(out_lod); + // resize output dims output->Resize({static_cast(host_out_lod0.back()), 1}); } }; @@ -135,4 +88,4 @@ class CTCGreedyDecodeOpCUDAKernel : public framework::OpKernel { } // namespace paddle REGISTER_OP_CUDA_KERNEL(ctc_greedy_decode, - paddle::operators::CTCGreedyDecodeOpCUDAKernel); + paddle::operators::CTCGreedyDecodeOpCUDAKernel); diff --git a/paddle/operators/ctc_greedy_decode_op.h b/paddle/operators/ctc_decode_op.h similarity index 75% rename from paddle/operators/ctc_greedy_decode_op.h rename to paddle/operators/ctc_decode_op.h index f12ea6c541b699..30bb53e157f19e 100644 --- a/paddle/operators/ctc_greedy_decode_op.h +++ b/paddle/operators/ctc_decode_op.h @@ -16,7 +16,6 @@ limitations under the License. */ #include #include "paddle/framework/op_registry.h" -#include "unsupported/Eigen/CXX11/Tensor" namespace paddle { namespace operators { @@ -30,8 +29,9 @@ class CTCGreedyDecodeKernel : public framework::OpKernel { auto* input = ctx.Input("Input"); auto* output = ctx.Output("Output"); const size_t level = 0; - auto input_lod = framework::ToAbsOffset(input->lod()); + + // check input dims and lod auto input_dims = input->dims(); PADDLE_ENFORCE_EQ(input_dims[0], static_cast(input_lod[level].back()), @@ -39,38 +39,36 @@ class CTCGreedyDecodeKernel : public framework::OpKernel { "the sum of all sequences' lengths."); const size_t num_sequences = input_lod[level].size() - 1; - const size_t sequence_width = input->numel() / input_dims[0]; size_t blank = static_cast(ctx.Attr("blank")); bool merge_repeated = ctx.Attr("merge_repeated"); + + // merge repeated tokens and delete blank std::vector> pathes(num_sequences); std::vector output_lod0(1, 0); - const T* input_data = input->data(); - Eigen::Map< - Eigen::Matrix> - input_mat(const_cast(input_data), input->numel() / sequence_width, - sequence_width); - - size_t max_class_idx; - size_t prev_class_idx = -1; for (size_t seq_idx = 0; seq_idx < num_sequences; ++seq_idx) { + T prev_token = -1; for (size_t i = input_lod[level][seq_idx]; i < input_lod[level][seq_idx + 1]; ++i) { - input_mat.row(i).maxCoeff(&max_class_idx); - if (max_class_idx != blank && - !(merge_repeated && max_class_idx == prev_class_idx)) { - pathes[seq_idx].push_back(max_class_idx); + if (input_data[i] != blank && + !(merge_repeated && input_data[i] == prev_token)) { + pathes[seq_idx].push_back(input_data[i]); } - prev_class_idx = max_class_idx; + prev_token = input_data[i]; } output_lod0.push_back(output_lod0.back() + pathes[seq_idx].size()); } + + // set output lod framework::LoD output_lod; output_lod.push_back(output_lod0); output->set_lod(output_lod); - int64_t num_step = static_cast(output_lod0.back()); - int* output_data = output->mutable_data({num_step, 1}, ctx.GetPlace()); + // resize output dims + T* output_data = output->mutable_data( + {static_cast(output_lod0.back()), 1}, ctx.GetPlace()); + + // copy result to output for (int i = 0; i < num_sequences; ++i) { memcpy(output_data + output_lod0[i], pathes[i].data(), sizeof(int) * pathes[i].size()); diff --git a/python/paddle/v2/fluid/tests/test_ctc_decode.py b/python/paddle/v2/fluid/tests/test_ctc_decode.py new file mode 100644 index 00000000000000..3b7486cfb98098 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_ctc_decode.py @@ -0,0 +1,62 @@ +import sys +import unittest +import numpy as np +from op_test import OpTest +from test_softmax_op import stable_softmax + + +def CTCDecode(input, lod, blank, merge_repeated): + lod0 = lod[0] + result = [] + for i in range(len(lod0) - 1): + prev_token = -1 + for j in range(lod0[i], lod0[i + 1]): + token = input[j][0] + if (token != blank) and not (merge_repeated and + token == prev_token): + result.append(token) + prev_token = token + result = np.array(result).reshape([len(result), 1]).astype("int32") + return result + + +class TestCTCDecodeOp(OpTest): + def config(self): + self.op_type = "ctc_greedy_decode" + self.input_lod = [[0, 11, 18]] + self.blank = 0 + self.merge_repeated = False + self.input = np.array( + [0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6, 6, 0, 0, 7, 7, 7, 0]).reshape( + [18, 1]).astype("int32") + + def setUp(self): + self.config() + output = CTCDecode(self.input, self.input_lod, self.blank, + self.merge_repeated) + + self.inputs = {"Input": (self.input, self.input_lod), } + self.outputs = {"Output": output} + self.attrs = { + "blank": self.blank, + "merge_repeated": self.merge_repeated + } + + def test_check_output(self): + self.check_output() + pass + + +class TestCTCDecodeOpCase1(TestCTCDecodeOp): + def config(self): + self.op_type = "ctc_greedy_decode" + self.input_lod = [[0, 11, 18]] + self.blank = 0 + self.merge_repeated = True + self.input = np.array( + [0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6, 6, 0, 0, 7, 7, 7, 0]).reshape( + [18, 1]).astype("int32") + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/fluid/tests/test_ctc_greedy_decode.py b/python/paddle/v2/fluid/tests/test_ctc_greedy_decode.py deleted file mode 100644 index 23fceb6dcdbc90..00000000000000 --- a/python/paddle/v2/fluid/tests/test_ctc_greedy_decode.py +++ /dev/null @@ -1,56 +0,0 @@ -import sys -import unittest -import numpy as np -from op_test import OpTest -from test_softmax_op import stable_softmax - - -def CTCGreedyDecode(softmax, blank, merge_repeated): - prev_token = -1 - result = [] - for token in np.argmax(softmax, axis=1): - if (token != blank) and not (merge_repeated and token == prev_token): - result.append(token) - return np.array(result).reshape([len(result), 1]) - - -class TestCTCGreedyDecodeOp(OpTest): - def config(self): - self.op_type = "ctc_greedy_decode" - self.batch_size = 4 - self.num_classes = 8 - self.input_lod = [[0, 4, 5, 8, 11]] - self.blank = 7 - self.merge_repeated = True - - def setUp(self): - self.config() - input = np.random.uniform( - 0.1, 1.0, - [self.input_lod[0][-1], self.num_classes]).astype("float32") - softmax = np.apply_along_axis(stable_softmax, 1, input) - output = CTCGreedyDecode(softmax, self.blank, self.merge_repeated) - - self.inputs = {"Input": (softmax, self.input_lod), } - self.outputs = {"Output": output} - self.attrs = { - "blank": self.blank, - "merge_repeated": self.merge_repeated - } - - def test_check_output(self): - self.check_output() - - -class TestCTCGreedyDecodeOpCase1(TestCTCGreedyDecodeOp): - def config(self): - self.op_type = "ctc_greedy_decode" - self.batch_size = 4 - self.num_classes = 1025 - self.input_lod = [[0, 4, 5, 8, 11]] - self.blank = 0 - self.merge_repeated = True - - -if __name__ == "__main__": - unittest.main() From 10dd632659012374f827ae0208c05b0eb5c17fb6 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Tue, 16 Jan 2018 15:56:52 +0800 Subject: [PATCH 4/8] Rename 'ctc_greedy_decode' to 'ctc_decode' --- paddle/operators/ctc_decode_op.cc | 18 ++++++++---------- paddle/operators/ctc_decode_op.cu | 8 ++++---- paddle/operators/ctc_decode_op.h | 2 +- .../paddle/v2/fluid/tests/test_ctc_decode.py | 4 ++-- 4 files changed, 15 insertions(+), 17 deletions(-) diff --git a/paddle/operators/ctc_decode_op.cc b/paddle/operators/ctc_decode_op.cc index b290b11d1d1e80..480c9ae133ce1d 100644 --- a/paddle/operators/ctc_decode_op.cc +++ b/paddle/operators/ctc_decode_op.cc @@ -12,20 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/ctc_greedy_decode_op.h" +#include "paddle/operators/ctc_decode_op.h" namespace paddle { namespace operators { -class CTCGreedyDecodeOp : public framework::OperatorWithKernel { +class CTCDecodeOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input of CTCGreedyDecodeOp should not be null."); + "Input of CTCDecodeOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Output"), - "Output of CTCGreedyDecodeOp should not be null."); + "Output of CTCDecodeOp should not be null."); auto input_dims = ctx->GetInputDim("Input"); @@ -42,9 +42,9 @@ class CTCGreedyDecodeOp : public framework::OperatorWithKernel { } }; -class CTCGreedyDecodeOpMaker : public framework::OpProtoAndCheckerMaker { +class CTCDecodeOpMaker : public framework::OpProtoAndCheckerMaker { public: - CTCGreedyDecodeOpMaker(OpProto* proto, OpAttrChecker* op_checker) + CTCDecodeOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("Input", "(LodTensor, default: LoDTensor), Its shape is " @@ -86,9 +86,7 @@ and then delete all blanks in sequence. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(ctc_greedy_decode, ops::CTCGreedyDecodeOp, - ops::CTCGreedyDecodeOpMaker, +REGISTER_OPERATOR(ctc_decode, ops::CTCDecodeOp, ops::CTCDecodeOpMaker, paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL( - ctc_greedy_decode, - ops::CTCGreedyDecodeKernel); + ctc_decode, ops::CTCDecodeKernel); diff --git a/paddle/operators/ctc_decode_op.cu b/paddle/operators/ctc_decode_op.cu index e9cdad7c26b179..b10db100f7fb71 100644 --- a/paddle/operators/ctc_decode_op.cu +++ b/paddle/operators/ctc_decode_op.cu @@ -15,7 +15,7 @@ limitations under the License. */ #include #include #include -#include "paddle/operators/ctc_greedy_decode_op.h" +#include "paddle/operators/ctc_decode_op.h" namespace paddle { namespace operators { @@ -42,7 +42,7 @@ __global__ void MergeAndDelCudaKernel(const int64_t num_token, const T* tokens, } template -class CTCGreedyDecodeOpCUDAKernel : public framework::OpKernel { +class CTCDecodeOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), @@ -87,5 +87,5 @@ class CTCGreedyDecodeOpCUDAKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL(ctc_greedy_decode, - paddle::operators::CTCGreedyDecodeOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(ctc_decode, + paddle::operators::CTCDecodeOpCUDAKernel); diff --git a/paddle/operators/ctc_decode_op.h b/paddle/operators/ctc_decode_op.h index 30bb53e157f19e..bc8dfab9f62153 100644 --- a/paddle/operators/ctc_decode_op.h +++ b/paddle/operators/ctc_decode_op.h @@ -23,7 +23,7 @@ using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; template -class CTCGreedyDecodeKernel : public framework::OpKernel { +class CTCDecodeKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* input = ctx.Input("Input"); diff --git a/python/paddle/v2/fluid/tests/test_ctc_decode.py b/python/paddle/v2/fluid/tests/test_ctc_decode.py index 3b7486cfb98098..6e798a8465c36a 100644 --- a/python/paddle/v2/fluid/tests/test_ctc_decode.py +++ b/python/paddle/v2/fluid/tests/test_ctc_decode.py @@ -22,7 +22,7 @@ def CTCDecode(input, lod, blank, merge_repeated): class TestCTCDecodeOp(OpTest): def config(self): - self.op_type = "ctc_greedy_decode" + self.op_type = "ctc_decode" self.input_lod = [[0, 11, 18]] self.blank = 0 self.merge_repeated = False @@ -49,7 +49,7 @@ def test_check_output(self): class TestCTCDecodeOpCase1(TestCTCDecodeOp): def config(self): - self.op_type = "ctc_greedy_decode" + self.op_type = "ctc_decode" self.input_lod = [[0, 11, 18]] self.blank = 0 self.merge_repeated = True From adcfde3eab274b76029f3efb13ca9b3627273e7a Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 17 Jan 2018 10:08:41 +0800 Subject: [PATCH 5/8] Modify unitest --- python/paddle/v2/fluid/tests/test_ctc_decode.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/v2/fluid/tests/test_ctc_decode.py b/python/paddle/v2/fluid/tests/test_ctc_decode.py index 6e798a8465c36a..1efacab4b3bf75 100644 --- a/python/paddle/v2/fluid/tests/test_ctc_decode.py +++ b/python/paddle/v2/fluid/tests/test_ctc_decode.py @@ -50,12 +50,12 @@ def test_check_output(self): class TestCTCDecodeOpCase1(TestCTCDecodeOp): def config(self): self.op_type = "ctc_decode" - self.input_lod = [[0, 11, 18]] + self.input_lod = [[0, 11, 19]] self.blank = 0 self.merge_repeated = True self.input = np.array( - [0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6, 6, 0, 0, 7, 7, 7, 0]).reshape( - [18, 1]).astype("int32") + [0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6, 6, 0, 0, 7, 7, 7, 0, 0]).reshape( + [19, 1]).astype("int32") if __name__ == "__main__": From 7150289b5cad76d3347a268b54c31e13a0e49f42 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 17 Jan 2018 16:34:38 +0800 Subject: [PATCH 6/8] Refine CPU kernel 1. Allocate memory for output before compute. 2. Rename 'ctc_decode' to 'ctc_align' --- .../{ctc_decode_op.cc => ctc_align_op.cc} | 20 +++++++++---------- .../{ctc_decode_op.cu => ctc_align_op.cu} | 8 ++++---- .../{ctc_decode_op.h => ctc_align_op.h} | 19 +++++++----------- .../{test_ctc_decode.py => test_ctc_align.py} | 14 ++++++------- 4 files changed, 28 insertions(+), 33 deletions(-) rename paddle/operators/{ctc_decode_op.cc => ctc_align_op.cc} (78%) rename paddle/operators/{ctc_decode_op.cu => ctc_align_op.cu} (93%) rename paddle/operators/{ctc_decode_op.h => ctc_align_op.h} (80%) rename python/paddle/v2/fluid/tests/{test_ctc_decode.py => test_ctc_align.py} (82%) diff --git a/paddle/operators/ctc_decode_op.cc b/paddle/operators/ctc_align_op.cc similarity index 78% rename from paddle/operators/ctc_decode_op.cc rename to paddle/operators/ctc_align_op.cc index 480c9ae133ce1d..3fa8d2af7424ff 100644 --- a/paddle/operators/ctc_decode_op.cc +++ b/paddle/operators/ctc_align_op.cc @@ -12,20 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/ctc_decode_op.h" +#include "paddle/operators/ctc_align_op.h" namespace paddle { namespace operators { -class CTCDecodeOp : public framework::OperatorWithKernel { +class CTCAlignOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input of CTCDecodeOp should not be null."); + "Input of CTCAlignOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Output"), - "Output of CTCDecodeOp should not be null."); + "Output of CTCAlignOp should not be null."); auto input_dims = ctx->GetInputDim("Input"); @@ -42,14 +42,14 @@ class CTCDecodeOp : public framework::OperatorWithKernel { } }; -class CTCDecodeOpMaker : public framework::OpProtoAndCheckerMaker { +class CTCAlignOpMaker : public framework::OpProtoAndCheckerMaker { public: - CTCDecodeOpMaker(OpProto* proto, OpAttrChecker* op_checker) + CTCAlignOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("Input", "(LodTensor, default: LoDTensor), Its shape is " "[Lp, 1], where Lp is the sum of all input sequences' length."); - AddOutput("Output", "(Tensor, default: Tensor), The decode result."); + AddOutput("Output", "(Tensor, default: Tensor), The align result."); AddAttr("blank", "(int, default: 0), the blank label setted in Connectionist " "Temporal Classification (CTC) op.") @@ -59,7 +59,7 @@ class CTCDecodeOpMaker : public framework::OpProtoAndCheckerMaker { "merge repeated elements between two blanks. ") .SetDefault(true); AddComment(R"DOC( -CTCDecoder is used to merge repeated elements between two blanks +CTCAlign op is used to merge repeated elements between two blanks and then delete all blanks in sequence. Given: @@ -86,7 +86,7 @@ and then delete all blanks in sequence. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(ctc_decode, ops::CTCDecodeOp, ops::CTCDecodeOpMaker, +REGISTER_OPERATOR(ctc_align, ops::CTCAlignOp, ops::CTCAlignOpMaker, paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL( - ctc_decode, ops::CTCDecodeKernel); + ctc_align, ops::CTCAlignKernel); diff --git a/paddle/operators/ctc_decode_op.cu b/paddle/operators/ctc_align_op.cu similarity index 93% rename from paddle/operators/ctc_decode_op.cu rename to paddle/operators/ctc_align_op.cu index b10db100f7fb71..99e716e989f064 100644 --- a/paddle/operators/ctc_decode_op.cu +++ b/paddle/operators/ctc_align_op.cu @@ -15,7 +15,7 @@ limitations under the License. */ #include #include #include -#include "paddle/operators/ctc_decode_op.h" +#include "paddle/operators/ctc_align_op.h" namespace paddle { namespace operators { @@ -42,7 +42,7 @@ __global__ void MergeAndDelCudaKernel(const int64_t num_token, const T* tokens, } template -class CTCDecodeOpCUDAKernel : public framework::OpKernel { +class CTCAlignOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), @@ -87,5 +87,5 @@ class CTCDecodeOpCUDAKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL(ctc_decode, - paddle::operators::CTCDecodeOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(ctc_align, + paddle::operators::CTCAlignOpCUDAKernel); diff --git a/paddle/operators/ctc_decode_op.h b/paddle/operators/ctc_align_op.h similarity index 80% rename from paddle/operators/ctc_decode_op.h rename to paddle/operators/ctc_align_op.h index bc8dfab9f62153..589413feb3dcbb 100644 --- a/paddle/operators/ctc_decode_op.h +++ b/paddle/operators/ctc_align_op.h @@ -23,7 +23,7 @@ using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; template -class CTCDecodeKernel : public framework::OpKernel { +class CTCAlignKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* input = ctx.Input("Input"); @@ -43,7 +43,8 @@ class CTCDecodeKernel : public framework::OpKernel { bool merge_repeated = ctx.Attr("merge_repeated"); // merge repeated tokens and delete blank - std::vector> pathes(num_sequences); + T* output_data = output->mutable_data(ctx.GetPlace()); + size_t output_idx = 0; std::vector output_lod0(1, 0); const T* input_data = input->data(); for (size_t seq_idx = 0; seq_idx < num_sequences; ++seq_idx) { @@ -52,11 +53,12 @@ class CTCDecodeKernel : public framework::OpKernel { i < input_lod[level][seq_idx + 1]; ++i) { if (input_data[i] != blank && !(merge_repeated && input_data[i] == prev_token)) { - pathes[seq_idx].push_back(input_data[i]); + output_data[output_idx] = input_data[i]; + ++output_idx; } prev_token = input_data[i]; } - output_lod0.push_back(output_lod0.back() + pathes[seq_idx].size()); + output_lod0.push_back(output_idx); } // set output lod @@ -65,14 +67,7 @@ class CTCDecodeKernel : public framework::OpKernel { output->set_lod(output_lod); // resize output dims - T* output_data = output->mutable_data( - {static_cast(output_lod0.back()), 1}, ctx.GetPlace()); - - // copy result to output - for (int i = 0; i < num_sequences; ++i) { - memcpy(output_data + output_lod0[i], pathes[i].data(), - sizeof(int) * pathes[i].size()); - } + output->Resize({static_cast(output_lod0.back()), 1}); } }; diff --git a/python/paddle/v2/fluid/tests/test_ctc_decode.py b/python/paddle/v2/fluid/tests/test_ctc_align.py similarity index 82% rename from python/paddle/v2/fluid/tests/test_ctc_decode.py rename to python/paddle/v2/fluid/tests/test_ctc_align.py index 1efacab4b3bf75..96f45890ee9baf 100644 --- a/python/paddle/v2/fluid/tests/test_ctc_decode.py +++ b/python/paddle/v2/fluid/tests/test_ctc_align.py @@ -5,7 +5,7 @@ from test_softmax_op import stable_softmax -def CTCDecode(input, lod, blank, merge_repeated): +def CTCAlign(input, lod, blank, merge_repeated): lod0 = lod[0] result = [] for i in range(len(lod0) - 1): @@ -20,9 +20,9 @@ def CTCDecode(input, lod, blank, merge_repeated): return result -class TestCTCDecodeOp(OpTest): +class TestCTCAlignOp(OpTest): def config(self): - self.op_type = "ctc_decode" + self.op_type = "ctc_align" self.input_lod = [[0, 11, 18]] self.blank = 0 self.merge_repeated = False @@ -32,8 +32,8 @@ def config(self): def setUp(self): self.config() - output = CTCDecode(self.input, self.input_lod, self.blank, - self.merge_repeated) + output = CTCAlign(self.input, self.input_lod, self.blank, + self.merge_repeated) self.inputs = {"Input": (self.input, self.input_lod), } self.outputs = {"Output": output} @@ -47,9 +47,9 @@ def test_check_output(self): pass -class TestCTCDecodeOpCase1(TestCTCDecodeOp): +class TestCTCAlignOpCase1(TestCTCAlignOp): def config(self): - self.op_type = "ctc_decode" + self.op_type = "ctc_align" self.input_lod = [[0, 11, 19]] self.blank = 0 self.merge_repeated = True From e4695457571ea6bb80b2ebbaadc2fa0551d83af7 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Thu, 18 Jan 2018 09:14:47 +0800 Subject: [PATCH 7/8] Add Copyright to test_ctc_align.py --- python/paddle/v2/fluid/tests/test_ctc_align.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/python/paddle/v2/fluid/tests/test_ctc_align.py b/python/paddle/v2/fluid/tests/test_ctc_align.py index 96f45890ee9baf..5a7c16997c19fe 100644 --- a/python/paddle/v2/fluid/tests/test_ctc_align.py +++ b/python/paddle/v2/fluid/tests/test_ctc_align.py @@ -1,3 +1,17 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + import sys import unittest import numpy as np From 6089b50c4b01f507e0fe7200a68a1972fd1505c0 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Fri, 19 Jan 2018 11:24:04 +0800 Subject: [PATCH 8/8] Registry int64_t kernels --- paddle/operators/ctc_align_op.cc | 3 ++- paddle/operators/ctc_align_op.cu | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/paddle/operators/ctc_align_op.cc b/paddle/operators/ctc_align_op.cc index 3fa8d2af7424ff..eeecbd32127d2c 100644 --- a/paddle/operators/ctc_align_op.cc +++ b/paddle/operators/ctc_align_op.cc @@ -89,4 +89,5 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(ctc_align, ops::CTCAlignOp, ops::CTCAlignOpMaker, paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL( - ctc_align, ops::CTCAlignKernel); + ctc_align, ops::CTCAlignKernel, + ops::CTCAlignKernel); diff --git a/paddle/operators/ctc_align_op.cu b/paddle/operators/ctc_align_op.cu index 99e716e989f064..45635f16745346 100644 --- a/paddle/operators/ctc_align_op.cu +++ b/paddle/operators/ctc_align_op.cu @@ -87,5 +87,5 @@ class CTCAlignOpCUDAKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL(ctc_align, - paddle::operators::CTCAlignOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(ctc_align, paddle::operators::CTCAlignOpCUDAKernel, + paddle::operators::CTCAlignOpCUDAKernel);