diff --git a/paddle/fluid/operators/gather_tree_op.cc b/paddle/fluid/operators/gather_tree_op.cc index 830134e57e0e72..2868c3697eda19 100644 --- a/paddle/fluid/operators/gather_tree_op.cc +++ b/paddle/fluid/operators/gather_tree_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/gather_tree_op.h" +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { @@ -73,5 +73,3 @@ selected ids. namespace ops = paddle::operators; REGISTER_OPERATOR(gather_tree, ops::GatherTreeOp, ops::GatherTreeOpMaker); -REGISTER_OP_CPU_KERNEL(gather_tree, ops::GatherTreeOpKernel, - ops::GatherTreeOpKernel); diff --git a/paddle/fluid/operators/gather_tree_op.cu b/paddle/fluid/operators/gather_tree_op.cu deleted file mode 100644 index 829682764a674d..00000000000000 --- a/paddle/fluid/operators/gather_tree_op.cu +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/gather_tree_op.h" - -namespace paddle { -namespace operators { - -template -__global__ void GatherTree(const T *ids_data, const T *parents_data, - T *out_data, const int64_t max_length, - const int64_t batch_size, const int64_t beam_size) { - CUDA_KERNEL_LOOP(i, batch_size * beam_size) { - int batch = i / beam_size; - int beam = i % beam_size; - auto idx = - (max_length - 1) * batch_size * beam_size + batch * beam_size + beam; - out_data[idx] = ids_data[idx]; - auto parent = parents_data[idx]; - for (int step = max_length - 2; step >= 0; step--) { - idx = step * batch_size * beam_size + batch * beam_size; - out_data[idx + beam] = ids_data[idx + parent]; - parent = parents_data[idx + parent]; - } - } -} - -template -class GatherTreeOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - auto *ids = ctx.Input("Ids"); - auto *parents = ctx.Input("Parents"); - auto *out = ctx.Output("Out"); - - const auto *ids_data = ids->data(); - const auto *parents_data = parents->data(); - auto *out_data = out->mutable_data(ctx.GetPlace()); - - PADDLE_ENFORCE_NOT_NULL( - ids_data, platform::errors::InvalidArgument( - "Input(Ids) of gather_tree should not be null.")); - - PADDLE_ENFORCE_NOT_NULL( - parents_data, platform::errors::InvalidArgument( - "Input(Parents) of gather_tree should not be null.")); - - auto &ids_dims = ids->dims(); - int64_t max_length = ids_dims[0]; - int64_t batch_size = ids_dims[1]; - int64_t beam_size = ids_dims[2]; - - auto &dev_ctx = ctx.cuda_device_context(); - - const int block = 512; - int max_threads = - std::min(static_cast(dev_ctx.GetMaxPhysicalThreadCount()), - batch_size * beam_size); - const int grid = std::max(max_threads / block, 1); - GatherTree<<>>(ids_data, parents_data, out_data, max_length, - batch_size, beam_size); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OP_CUDA_KERNEL(gather_tree, ops::GatherTreeOpCUDAKernel, - ops::GatherTreeOpCUDAKernel); diff --git a/paddle/fluid/operators/gather_tree_op.h b/paddle/fluid/operators/gather_tree_op.h deleted file mode 100644 index e035a30e7954fe..00000000000000 --- a/paddle/fluid/operators/gather_tree_op.h +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class GatherTreeOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - auto *ids = ctx.Input("Ids"); - auto *parents = ctx.Input("Parents"); - auto *out = ctx.Output("Out"); - - const auto *ids_data = ids->data(); - const auto *parents_data = parents->data(); - auto *out_data = out->mutable_data(ctx.GetPlace()); - - auto &ids_dims = ids->dims(); - auto max_length = ids_dims[0]; - auto batch_size = ids_dims[1]; - auto beam_size = ids_dims[2]; - - PADDLE_ENFORCE_NOT_NULL( - ids_data, platform::errors::InvalidArgument( - "Input(Ids) of gather_tree should not be null.")); - - PADDLE_ENFORCE_NOT_NULL( - parents_data, platform::errors::InvalidArgument( - "Input(Parents) of gather_tree should not be null.")); - - for (int batch = 0; batch < batch_size; batch++) { - for (int beam = 0; beam < beam_size; beam++) { - auto idx = (max_length - 1) * batch_size * beam_size + - batch * beam_size + beam; - out_data[idx] = ids_data[idx]; - auto parent = parents_data[idx]; - for (int step = max_length - 2; step >= 0; step--) { - idx = step * batch_size * beam_size + batch * beam_size; - out_data[idx + beam] = ids_data[idx + parent]; - parent = parents_data[idx + parent]; - } - } - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/reduce_ops/reduce_prod_op.cc b/paddle/fluid/operators/reduce_ops/reduce_prod_op.cc index 50df75d9ad3fd7..eb745ab9c56c5b 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_prod_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_prod_op.cc @@ -27,15 +27,7 @@ class CPUDeviceContext; } // namespace paddle REGISTER_REDUCE_OP(reduce_prod); -REGISTER_OP_CPU_KERNEL(reduce_prod, - ops::ReduceKernel, - ops::ReduceKernel, - ops::ReduceKernel, - ops::ReduceKernel); + REGISTER_OP_CPU_KERNEL(reduce_prod_grad, ops::ReduceGradKernel, diff --git a/paddle/fluid/operators/reduce_ops/reduce_prod_op.h b/paddle/fluid/operators/reduce_ops/reduce_prod_op.h index 103e108e4bda1c..60dedf8d6ffb07 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_prod_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_prod_op.h @@ -19,13 +19,6 @@ namespace paddle { namespace operators { -struct ProdFunctor { - template - void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { - y->device(place) = x->prod(dim); - } -}; - struct ProdGradFunctor { template diff --git a/paddle/phi/kernels/cpu/gather_tree_kernel.cc b/paddle/phi/kernels/cpu/gather_tree_kernel.cc new file mode 100644 index 00000000000000..25fb870d851f67 --- /dev/null +++ b/paddle/phi/kernels/cpu/gather_tree_kernel.cc @@ -0,0 +1,62 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gather_tree_kernel.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void GatherTreeKernel(const Context &dev_ctx, + const DenseTensor &ids, + const DenseTensor &parents, + DenseTensor *out) { + const auto *ids_data = ids.data(); + const auto *parents_data = parents.data(); + + T *out_data = dev_ctx.template Alloc(out); + + auto &ids_dims = ids.dims(); + auto max_length = ids_dims[0]; + auto batch_size = ids_dims[1]; + auto beam_size = ids_dims[2]; + + PADDLE_ENFORCE_NOT_NULL(ids_data, + phi::errors::InvalidArgument( + "Input(Ids) of gather_tree should not be null.")); + + PADDLE_ENFORCE_NOT_NULL( + parents_data, + phi::errors::InvalidArgument( + "Input(Parents) of gather_tree should not be null.")); + + for (int batch = 0; batch < batch_size; batch++) { + for (int beam = 0; beam < beam_size; beam++) { + auto idx = + (max_length - 1) * batch_size * beam_size + batch * beam_size + beam; + out_data[idx] = ids_data[idx]; + auto parent = parents_data[idx]; + for (int step = max_length - 2; step >= 0; step--) { + idx = step * batch_size * beam_size + batch * beam_size; + out_data[idx + beam] = ids_data[idx + parent]; + parent = parents_data[idx + parent]; + } + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + gather_tree, CPU, ALL_LAYOUT, phi::GatherTreeKernel, int, int64_t) {} diff --git a/paddle/phi/kernels/cpu/reduce_prod_kernel.cc b/paddle/phi/kernels/cpu/reduce_prod_kernel.cc new file mode 100644 index 00000000000000..cf0179124ebdfc --- /dev/null +++ b/paddle/phi/kernels/cpu/reduce_prod_kernel.cc @@ -0,0 +1,44 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/reduce_prod_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cpu/reduce.h" +#include "paddle/phi/kernels/funcs/reduce_functor.h" + +namespace phi { + +template +void ReduceProdKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* out) { + auto out_dtype = x.dtype(); + phi::Reduce( + dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); +} + +} // namespace phi + +PD_REGISTER_KERNEL(reduce_prod, + CPU, + ALL_LAYOUT, + phi::ReduceProdKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/funcs/reduce_functor.h b/paddle/phi/kernels/funcs/reduce_functor.h index ce8e095e8ac6c2..aebd155ac59cb2 100644 --- a/paddle/phi/kernels/funcs/reduce_functor.h +++ b/paddle/phi/kernels/funcs/reduce_functor.h @@ -33,5 +33,13 @@ struct MeanFunctor { } }; +//////// Prod Functor /////// +struct ProdFunctor { + template + void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { + y->device(place) = x->prod(dim); + } +}; + } // namespace funcs } // namespace phi diff --git a/paddle/fluid/operators/reduce_ops/reduce_prod_op.cu b/paddle/phi/kernels/gather_tree_kernel.h similarity index 51% rename from paddle/fluid/operators/reduce_ops/reduce_prod_op.cu rename to paddle/phi/kernels/gather_tree_kernel.h index 2de647df8b182b..e5a1a684daef09 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_prod_op.cu +++ b/paddle/phi/kernels/gather_tree_kernel.h @@ -1,4 +1,4 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,12 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" -#include "paddle/fluid/operators/reduce_ops/reduce_prod_op.h" +#pragma once -REGISTER_OP_CUDA_KERNEL( - reduce_prod, - ops::ReduceCudaKernel, - ops::ReduceCudaKernel, - ops::ReduceCudaKernel, - ops::ReduceCudaKernel); +#include "paddle/phi/core/dense_tensor.h" +namespace phi { + +template +void GatherTreeKernel(const Context &dev_ctx, + const DenseTensor &ids, + const DenseTensor &parents, + DenseTensor *out); + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/gather_tree_kernel.cu b/paddle/phi/kernels/gpu/gather_tree_kernel.cu new file mode 100644 index 00000000000000..a9e73ec37c8ed5 --- /dev/null +++ b/paddle/phi/kernels/gpu/gather_tree_kernel.cu @@ -0,0 +1,79 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/gather_tree_kernel.h" + +namespace phi { + +template +__global__ void GatherTree(const T *ids_data, + const T *parents_data, + T *out_data, + const int64_t max_length, + const int64_t batch_size, + const int64_t beam_size) { + CUDA_KERNEL_LOOP(i, batch_size * beam_size) { + int batch = i / beam_size; + int beam = i % beam_size; + auto idx = + (max_length - 1) * batch_size * beam_size + batch * beam_size + beam; + out_data[idx] = ids_data[idx]; + auto parent = parents_data[idx]; + for (int step = max_length - 2; step >= 0; step--) { + idx = step * batch_size * beam_size + batch * beam_size; + out_data[idx + beam] = ids_data[idx + parent]; + parent = parents_data[idx + parent]; + } + } +} + +template +void GatherTreeKernel(const Context &dev_ctx, + const DenseTensor &ids, + const DenseTensor &parents, + DenseTensor *out) { + const auto *ids_data = ids.data(); + const auto *parents_data = parents.data(); + T *out_data = dev_ctx.template Alloc(out); + + PADDLE_ENFORCE_NOT_NULL(ids_data, + phi::errors::InvalidArgument( + "Input(Ids) of gather_tree should not be null.")); + + PADDLE_ENFORCE_NOT_NULL( + parents_data, + phi::errors::InvalidArgument( + "Input(Parents) of gather_tree should not be null.")); + + auto &ids_dims = ids.dims(); + int64_t max_length = ids_dims[0]; + int64_t batch_size = ids_dims[1]; + int64_t beam_size = ids_dims[2]; + + const int block = 512; + int max_threads = + std::min(static_cast(dev_ctx.GetMaxPhysicalThreadCount()), + batch_size * beam_size); + const int grid = std::max(max_threads / block, 1); + GatherTree<<>>( + ids_data, parents_data, out_data, max_length, batch_size, beam_size); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + gather_tree, GPU, ALL_LAYOUT, phi::GatherTreeKernel, int, int64_t) {} diff --git a/paddle/phi/kernels/gpu/reduce_prod_kernel.cu b/paddle/phi/kernels/gpu/reduce_prod_kernel.cu new file mode 100644 index 00000000000000..14084d0f4f3c6f --- /dev/null +++ b/paddle/phi/kernels/gpu/reduce_prod_kernel.cu @@ -0,0 +1,43 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/gpu/reduce.h" +#include "paddle/phi/kernels/reduce_prod_kernel.h" + +namespace phi { + +template +void ReduceProdKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* out) { + auto out_dtype = x.dtype(); + phi::Reduce( + dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); +} + +} // namespace phi + +PD_REGISTER_KERNEL(reduce_prod, + GPU, + ALL_LAYOUT, + phi::ReduceProdKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/reduce_prod_kernel.h b/paddle/phi/kernels/reduce_prod_kernel.h new file mode 100644 index 00000000000000..5e92b6c4db14e7 --- /dev/null +++ b/paddle/phi/kernels/reduce_prod_kernel.h @@ -0,0 +1,29 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void ReduceProdKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/reduce_sig.cc b/paddle/phi/ops/compat/reduce_sig.cc index 74704671f8b5d2..097502d64c3883 100644 --- a/paddle/phi/ops/compat/reduce_sig.cc +++ b/paddle/phi/ops/compat/reduce_sig.cc @@ -43,6 +43,11 @@ KernelSignature ReduceMeanOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("unregistered", {}, {}, {}); } +KernelSignature ReduceProdOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature( + "reduce_prod", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"}); +} + } // namespace phi PD_REGISTER_BASE_KERNEL_NAME(reduce_sum, sum); @@ -50,3 +55,4 @@ PD_REGISTER_BASE_KERNEL_NAME(reduce_mean, mean); PD_REGISTER_ARG_MAPPING_FN(reduce_sum, phi::ReduceSumOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(reduce_mean, phi::ReduceMeanOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(reduce_prod, phi::ReduceProdOpArgumentMapping);