diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 280f24bdd6fa6..6dcdae7691310 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -24,7 +24,6 @@ add_subdirectory(optimizers) add_subdirectory(reduce_ops) add_subdirectory(sequence_ops) add_subdirectory(string) -add_subdirectory(prim_ops) if(WITH_DISTRIBUTE) diff --git a/paddle/fluid/operators/prim_ops/CMakeLists.txt b/paddle/fluid/operators/prim_ops/CMakeLists.txt deleted file mode 100644 index 7a1278219bb6d..0000000000000 --- a/paddle/fluid/operators/prim_ops/CMakeLists.txt +++ /dev/null @@ -1,6 +0,0 @@ -include(operators) -if(WITH_UNITY_BUILD) - # Load Unity Build rules for operators in paddle/fluid/operators/prim_ops. - include(unity_build_rule.cmake) -endif() -register_operators() diff --git a/paddle/fluid/operators/prim_ops/abs_p_op.cc b/paddle/fluid/operators/prim_ops/abs_p_op.cc deleted file mode 100644 index 87b5243d6afe7..0000000000000 --- a/paddle/fluid/operators/prim_ops/abs_p_op.cc +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace operators { -class AbsPrimOp : public framework::OperatorBase { - public: - AbsPrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator abs_p should not be executed directly")); - } -}; - -class AbsPrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of abs_p op."); - AddOutput("Y", "(Tensor), The output tensor of abs_p op."); - AddComment(R"DOC(Autograd primitive abs_p operator.)DOC"); - } -}; - -class AbsPrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; - framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; - framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); - PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape()); - } -}; - -class AbsPrimOpVarTypeInference - : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_name = Input(ctx, "X")[0]; - auto y_name = Output(ctx, "Y")[0]; - SetType(ctx, y_name, GetType(ctx, x_name)); - SetDataType(ctx, y_name, GetDataType(ctx, x_name)); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(abs_p, - paddle::operators::AbsPrimOp, - paddle::operators::AbsPrimOpMaker, - paddle::operators::AbsPrimOpShapeInference, - paddle::operators::AbsPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/add_p_op.cc b/paddle/fluid/operators/prim_ops/add_p_op.cc deleted file mode 100644 index 7fbbdf136929c..0000000000000 --- a/paddle/fluid/operators/prim_ops/add_p_op.cc +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace framework { -class InferShapeContext; -class VarDesc; -} // namespace framework -} // namespace paddle - -namespace paddle { -namespace operators { -class AddPrimOp : public framework::OperatorBase { - public: - AddPrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator add_p should not be executed directly")); - } -}; - -class AddPrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of add_p op."); - AddInput("Y", "(Tensor), The input tensor of add_p op."); - AddOutput("Z", "(Tensor), The output tensor of add_p op."); - AddComment(R"DOC( -Autograd primitive add_p operator. -)DOC"); - } -}; - -class AddPrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; - framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0]; - framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0]; - - framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); - framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr); - auto x_shape = x_var->GetShape(); - auto y_shape = y_var->GetShape(); - size_t x_rank = x_shape.size(); - size_t y_rank = y_shape.size(); - PADDLE_ENFORCE_EQ(x_rank, - y_rank, - platform::errors::InvalidArgument( - "The dimensions of two input tensor should be same, " - "but get %d and %d", - x_rank, - y_rank)); - for (size_t i = 0; i < x_rank; ++i) { - PADDLE_ENFORCE_EQ( - x_shape[i], - y_shape[i], - platform::errors::InvalidArgument( - "The shape of two input tensor at dimension %d should be same, " - "but get %d and %d", - i, - x_shape[i], - y_shape[i])); - } - - PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape); - } -}; - -class AddPrimOpVarTypeInference - : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_name = Input(ctx, "X")[0]; - auto y_name = Input(ctx, "Y")[0]; - auto z_name = Output(ctx, "Z")[0]; - auto x_type = GetType(ctx, x_name); - auto y_type = GetType(ctx, y_name); - auto x_dtype = GetDataType(ctx, x_name); - auto y_dtype = GetDataType(ctx, y_name); - PADDLE_ENFORCE_EQ(x_type, - y_type, - platform::errors::InvalidArgument( - "The type of two input tensor should be same, " - "but get %d and %d", - x_type, - y_type)); - PADDLE_ENFORCE_EQ(x_dtype, - y_dtype, - platform::errors::InvalidArgument( - "The datatype of two input tensor should be same, " - "but get %d and %d", - x_dtype, - y_dtype)); - - SetType(ctx, z_name, x_type); - SetDataType(ctx, z_name, x_dtype); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(add_p, - paddle::operators::AddPrimOp, - paddle::operators::AddPrimOpMaker, - paddle::operators::AddPrimOpShapeInference, - paddle::operators::AddPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/broadcast_p_op.cc b/paddle/fluid/operators/prim_ops/broadcast_p_op.cc deleted file mode 100644 index d2c391f7a9bc6..0000000000000 --- a/paddle/fluid/operators/prim_ops/broadcast_p_op.cc +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace framework { -class InferShapeContext; -class VarDesc; -} // namespace framework -} // namespace paddle - -namespace paddle { -namespace operators { -class BroadcastPrimOp : public framework::OperatorBase { - public: - BroadcastPrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator broadcast_p should not be executed directly")); - } -}; - -class BroadcastPrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of broadcast_p op."); - AddOutput("Y", "(Tensor), The output tensor of broadcast_p op."); - AddAttr>( - "shape", - "(std::vector) Target shape of broadcast_p operator."); - AddComment(R"DOC( -Autograd primitive broadcast_p operator. -)DOC"); - } -}; - -static void CheckShapeValid(const std::vector &x_shape, - const std::vector &target_shape) { - size_t x_rank = x_shape.size(); - size_t target_rank = target_shape.size(); - PADDLE_ENFORCE_GE(target_rank, - x_rank, - platform::errors::InvalidArgument( - "The rank of target shape should be greater than or " - "equal to input tensor's dimensions, " - "but received %d and %d", - target_rank, - x_rank)); - std::vector::const_iterator it = target_shape.begin(); - for (size_t i = 0; i < x_rank; i++, it++) { - if (x_shape[i] != 1) { - it = std::find(it, target_shape.end(), x_shape[i]); - } - PADDLE_ENFORCE_EQ( - it != target_shape.end(), - true, - platform::errors::InvalidArgument( - "Invalid shape, can not broadcast input tensor into target shape," - "the first dismatching shape %d is shape of input tensor at " - "dimension %d", - x_shape[i], - i)); - } -} - -class BroadcastPrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; - framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; - framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); - auto x_shape = x_var->GetShape(); - auto target_shape = ctx->Attrs().Get>("shape"); - CheckShapeValid(x_shape, target_shape); - PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(target_shape); - } -}; - -class BroadcastPrimOpVarTypeInference - : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_name = Input(ctx, "X")[0]; - auto y_name = Output(ctx, "Y")[0]; - SetType(ctx, y_name, GetType(ctx, x_name)); - SetDataType(ctx, y_name, GetDataType(ctx, x_name)); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(broadcast_p, - paddle::operators::BroadcastPrimOp, - paddle::operators::BroadcastPrimOpMaker, - paddle::operators::BroadcastPrimOpShapeInference, - paddle::operators::BroadcastPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/cast_p_op.cc b/paddle/fluid/operators/prim_ops/cast_p_op.cc deleted file mode 100644 index ead6cc53ceea7..0000000000000 --- a/paddle/fluid/operators/prim_ops/cast_p_op.cc +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace framework { -class InferShapeContext; -class VarDesc; -} // namespace framework -} // namespace paddle - -namespace paddle { -namespace operators { -class CastPrimOp : public framework::OperatorBase { - public: - CastPrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator cast_p should not be executed directly")); - } -}; - -class CastPrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of cast_p op."); - AddOutput("Y", "(Tensor), The output tensor of cast_p op."); - AddAttr("dtype", "output data type"); - AddComment(R"DOC(Autograd primitive cast_p operator.)DOC"); - } -}; - -class CastPrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; - framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; - framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); - PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape()); - } -}; - -class CastPrimOpVarTypeInference - : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto out_type = static_cast( - PADDLE_GET_CONST(int, ctx->GetAttr("dtype"))); - ctx->SetOutputDataType("Y", out_type); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(cast_p, - paddle::operators::CastPrimOp, - paddle::operators::CastPrimOpMaker, - paddle::operators::CastPrimOpShapeInference, - paddle::operators::CastPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/concat_p_op.cc b/paddle/fluid/operators/prim_ops/concat_p_op.cc deleted file mode 100644 index 6b8d6c0a3322a..0000000000000 --- a/paddle/fluid/operators/prim_ops/concat_p_op.cc +++ /dev/null @@ -1,148 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace framework { -class InferShapeContext; -class VarDesc; -} // namespace framework -} // namespace paddle - -namespace paddle { -namespace operators { -class ConcatPrimOp : public framework::OperatorBase { - public: - ConcatPrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator concat_p should not be executed directly")); - } -}; - -class ConcatPrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("XS", "(Tensor), The input tensors of concat_p op.") - .AsDuplicable(); - AddOutput("Y", "(Tensor), The output tensor of concat_p op."); - AddAttr("axis", "(int64_t), The axis along which to concat."); - AddComment(R"DOC( -Autograd primitive concat_p operator. -)DOC"); - } -}; - -class ConcatPrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - auto x_var_ptrs = ctx->GetInputVarPtrs("XS"); - framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; - auto axis = ctx->Attrs().Get("axis"); - int64_t cnt_along_axis = 0; - framework::VarDesc *first_x_var = - PADDLE_GET(framework::VarDesc *, x_var_ptrs[0]); - auto first_x_shape = first_x_var->GetShape(); - cnt_along_axis += first_x_shape[axis]; - size_t first_x_rank = first_x_shape.size(); - for (size_t i = 1; i < x_var_ptrs.size(); ++i) { - framework::VarDesc *x_var = - PADDLE_GET(framework::VarDesc *, x_var_ptrs[i]); - auto x_shape = x_var->GetShape(); - cnt_along_axis += x_shape[axis]; - size_t x_rank = x_shape.size(); - PADDLE_ENFORCE_EQ( - x_rank, - first_x_rank, - platform::errors::InvalidArgument("The dimensions of %d input tensor " - "should be same as the dimensions " - "of 1st input tensor's, " - "but get %d and %d", - i + 1, - x_rank, - first_x_rank)); - for (size_t j = 0; j < x_rank; ++j) { - if (j != size_t(axis)) { - PADDLE_ENFORCE_EQ(x_shape[j], - first_x_shape[j], - platform::errors::InvalidArgument( - "The shape of %d input tensor at dimension %d " - "should be same as the 1st input tensor's, " - "but get %d and %d", - i + 1, - j, - x_shape[j], - first_x_shape[j])); - } - } - } - - std::vector y_shape(first_x_shape); - y_shape[axis] = cnt_along_axis; - PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(y_shape); - } -}; - -class ConcatPrimOpVarTypeInference - : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_names = Input(ctx, "XS"); - auto y_name = Output(ctx, "Y")[0]; - auto first_x_name = x_names[0]; - auto first_x_type = GetType(ctx, first_x_name); - auto first_x_dtype = GetDataType(ctx, first_x_name); - for (size_t i = 1; i < x_names.size(); ++i) { - auto x_name = x_names[i]; - auto x_type = GetType(ctx, x_name); - auto x_dtype = GetDataType(ctx, x_name); - PADDLE_ENFORCE_EQ(x_type, - first_x_type, - platform::errors::InvalidArgument( - "The type of %d input tensor should be same as the " - "first input tensor's, " - "but get %d and %d", - i + 1, - x_type, - first_x_type)); - PADDLE_ENFORCE_EQ(x_dtype, - first_x_dtype, - platform::errors::InvalidArgument( - "The datatype of %d input tensor should be same as " - "the first input tensor's, " - "but get %d and %d", - i + 1, - x_dtype, - first_x_dtype)); - } - SetType(ctx, y_name, GetType(ctx, first_x_name)); - SetDataType(ctx, y_name, GetDataType(ctx, first_x_name)); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(concat_p, - paddle::operators::ConcatPrimOp, - paddle::operators::ConcatPrimOpMaker, - paddle::operators::ConcatPrimOpShapeInference, - paddle::operators::ConcatPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/cos_p_op.cc b/paddle/fluid/operators/prim_ops/cos_p_op.cc deleted file mode 100644 index c8acc30ba6107..0000000000000 --- a/paddle/fluid/operators/prim_ops/cos_p_op.cc +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace operators { -class CosPrimOp : public framework::OperatorBase { - public: - CosPrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator cos_p should not be executed directly")); - } -}; - -class CosPrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of cos_p op."); - AddOutput("Y", "(Tensor), The output tensor of cos_p op."); - AddComment(R"DOC(Autograd primitive cos_p operator.)DOC"); - } -}; - -class CosPrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; - framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; - framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); - PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape()); - } -}; - -class CosPrimOpVarTypeInference - : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_name = Input(ctx, "X")[0]; - auto y_name = Output(ctx, "Y")[0]; - SetType(ctx, y_name, GetType(ctx, x_name)); - SetDataType(ctx, y_name, GetDataType(ctx, x_name)); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(cos_p, - paddle::operators::CosPrimOp, - paddle::operators::CosPrimOpMaker, - paddle::operators::CosPrimOpShapeInference, - paddle::operators::CosPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/div_p_op.cc b/paddle/fluid/operators/prim_ops/div_p_op.cc deleted file mode 100644 index c046c63b8abad..0000000000000 --- a/paddle/fluid/operators/prim_ops/div_p_op.cc +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace framework { -class InferShapeContext; -class VarDesc; -} // namespace framework -} // namespace paddle - -namespace paddle { -namespace operators { -class DivPrimOp : public framework::OperatorBase { - public: - DivPrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator div_p should not be executed directly")); - } -}; - -class DivPrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of div_p op."); - AddInput("Y", "(Tensor), The input tensor of div_p op."); - AddOutput("Z", "(Tensor), The output tensor of div_p op."); - AddComment(R"DOC( -Autograd primitive div_p operator. -)DOC"); - } -}; - -class DivPrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; - framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0]; - framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0]; - - framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); - framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr); - auto x_shape = x_var->GetShape(); - auto y_shape = y_var->GetShape(); - size_t x_rank = x_shape.size(); - size_t y_rank = y_shape.size(); - PADDLE_ENFORCE_EQ(x_rank, - y_rank, - platform::errors::InvalidArgument( - "The dimensions of two input tensor should be same, " - "but get %d and %d", - x_rank, - y_rank)); - for (size_t i = 0; i < x_rank; ++i) { - PADDLE_ENFORCE_EQ( - x_shape[i], - y_shape[i], - platform::errors::InvalidArgument( - "The shape of two input tensor at dimension %d should be same, " - "but get %d and %d", - i, - x_shape[i], - y_shape[i])); - } - - PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape); - } -}; - -class DivPrimOpVarTypeInference - : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_name = Input(ctx, "X")[0]; - auto y_name = Input(ctx, "Y")[0]; - auto z_name = Output(ctx, "Z")[0]; - auto x_type = GetType(ctx, x_name); - auto y_type = GetType(ctx, y_name); - auto x_dtype = GetDataType(ctx, x_name); - auto y_dtype = GetDataType(ctx, y_name); - PADDLE_ENFORCE_EQ(x_type, - y_type, - platform::errors::InvalidArgument( - "The type of two input tensor should be same, " - "but get %d and %d", - x_type, - y_type)); - PADDLE_ENFORCE_EQ(x_dtype, - y_dtype, - platform::errors::InvalidArgument( - "The datatype of two input tensor should be same, " - "but get %d and %d", - x_dtype, - y_dtype)); - - SetType(ctx, z_name, x_type); - SetDataType(ctx, z_name, x_dtype); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(div_p, - paddle::operators::DivPrimOp, - paddle::operators::DivPrimOpMaker, - paddle::operators::DivPrimOpShapeInference, - paddle::operators::DivPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/eq_p_op.cc b/paddle/fluid/operators/prim_ops/eq_p_op.cc deleted file mode 100644 index 389fd548677d6..0000000000000 --- a/paddle/fluid/operators/prim_ops/eq_p_op.cc +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace operators { -class EqPrimOp : public framework::OperatorBase { - public: - EqPrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator eq_p should not be executed directly")); - } -}; - -class EqPrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of eq_p op."); - AddInput("Y", "(Tensor), The input tensor of eq_p op."); - AddOutput("Z", "(Tensor), The output tensor of eq_p op."); - AddComment(R"DOC( -Autograd primitive eq_p operator. -)DOC"); - } -}; - -class EqPrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; - framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0]; - framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0]; - - framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); - framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr); - auto x_shape = x_var->GetShape(); - auto y_shape = y_var->GetShape(); - size_t x_rank = x_shape.size(); - size_t y_rank = y_shape.size(); - PADDLE_ENFORCE_EQ(x_rank, - y_rank, - platform::errors::InvalidArgument( - "The dimensions of two input tensor should be same, " - "but get %d and %d", - x_rank, - y_rank)); - for (size_t i = 0; i < x_rank; ++i) { - PADDLE_ENFORCE_EQ( - x_shape[i], - y_shape[i], - platform::errors::InvalidArgument( - "The shape of two input tensor at dimension %d should be same, " - "but get %d and %d", - i, - x_shape[i], - y_shape[i])); - } - - PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape); - } -}; - -class EqPrimOpVarTypeInference : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_name = Input(ctx, "X")[0]; - auto y_name = Input(ctx, "Y")[0]; - auto z_name = Output(ctx, "Z")[0]; - auto x_type = GetType(ctx, x_name); - auto y_type = GetType(ctx, y_name); - auto x_dtype = GetDataType(ctx, x_name); - auto y_dtype = GetDataType(ctx, y_name); - PADDLE_ENFORCE_EQ(x_type, - y_type, - platform::errors::InvalidArgument( - "The type of two input tensor should be same, " - "but get %d and %d", - x_type, - y_type)); - PADDLE_ENFORCE_EQ(x_dtype, - y_dtype, - platform::errors::InvalidArgument( - "The datatype of two input tensor should be same, " - "but get %d and %d", - x_dtype, - y_dtype)); - - SetType(ctx, z_name, x_type); - SetDataType(ctx, z_name, framework::proto::VarType::BOOL); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(eq_p, - paddle::operators::EqPrimOp, - paddle::operators::EqPrimOpMaker, - paddle::operators::EqPrimOpShapeInference, - paddle::operators::EqPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/erf_p_op.cc b/paddle/fluid/operators/prim_ops/erf_p_op.cc deleted file mode 100644 index 95bbeadfd6798..0000000000000 --- a/paddle/fluid/operators/prim_ops/erf_p_op.cc +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace framework { -class InferShapeContext; -class VarDesc; -} // namespace framework -} // namespace paddle - -namespace paddle { -namespace operators { -class ErfPrimOp : public framework::OperatorBase { - public: - ErfPrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator erf_p should not be executed directly")); - } -}; - -class ErfPrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of erf_p op."); - AddOutput("Y", "(Tensor), The output tensor of erf_p op."); - AddComment(R"DOC(Autograd primitive erf_p operator.)DOC"); - } -}; - -class ErfPrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; - framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; - framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); - PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape()); - } -}; - -class ErfPrimOpVarTypeInference - : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_name = Input(ctx, "X")[0]; - auto y_name = Output(ctx, "Y")[0]; - SetType(ctx, y_name, GetType(ctx, x_name)); - SetDataType(ctx, y_name, GetDataType(ctx, x_name)); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(erf_p, - paddle::operators::ErfPrimOp, - paddle::operators::ErfPrimOpMaker, - paddle::operators::ErfPrimOpShapeInference, - paddle::operators::ErfPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/exp_p_op.cc b/paddle/fluid/operators/prim_ops/exp_p_op.cc deleted file mode 100644 index 220ed7672ab25..0000000000000 --- a/paddle/fluid/operators/prim_ops/exp_p_op.cc +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace operators { -class ExpPrimOp : public framework::OperatorBase { - public: - ExpPrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator exp_p should not be executed directly")); - } -}; - -class ExpPrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of exp_p op."); - AddOutput("Y", "(Tensor), The output tensor of exp_p op."); - AddComment(R"DOC(Autograd primitive exp_p operator.)DOC"); - } -}; - -class ExpPrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; - framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; - framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); - PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape()); - } -}; - -class ExpPrimOpVarTypeInference - : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_name = Input(ctx, "X")[0]; - auto y_name = Output(ctx, "Y")[0]; - SetType(ctx, y_name, GetType(ctx, x_name)); - SetDataType(ctx, y_name, GetDataType(ctx, x_name)); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(exp_p, - paddle::operators::ExpPrimOp, - paddle::operators::ExpPrimOpMaker, - paddle::operators::ExpPrimOpShapeInference, - paddle::operators::ExpPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/fill_constant_p_op.cc b/paddle/fluid/operators/prim_ops/fill_constant_p_op.cc deleted file mode 100644 index a570ccd1cecba..0000000000000 --- a/paddle/fluid/operators/prim_ops/fill_constant_p_op.cc +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace framework { -class InferShapeContext; -class VarDesc; -} // namespace framework -} // namespace paddle - -namespace paddle { -namespace operators { -class FillConstantPrimOp : public framework::OperatorBase { - public: - FillConstantPrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator fill_constant_p should not be executed directly")); - } -}; - -class FillConstantPrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddOutput("Y", "(Tensor), The output tensor of fill_constant_p op."); - AddAttr("value", "(float) The value of output tensor."); - AddAttr>( - "shape", "(std::vector) The shape of output tensor."); - AddAttr("dtype", "(int) The dtype of output tensor."); - AddComment(R"DOC( -Autograd primitive fill_constant_p operator. -)DOC"); - } -}; - -class FillConstantPrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; - auto shape = ctx->Attrs().Get>("shape"); - PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(shape); - } -}; - -class FillConstantPrimOpVarTypeInference - : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto y_name = Output(ctx, "Y")[0]; - auto data_type = static_cast( - PADDLE_GET_CONST(int, ctx->GetAttr("dtype"))); - SetDataType(ctx, y_name, data_type); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(fill_constant_p, - paddle::operators::FillConstantPrimOp, - paddle::operators::FillConstantPrimOpMaker, - paddle::operators::FillConstantPrimOpShapeInference, - paddle::operators::FillConstantPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/gather_p_op.cc b/paddle/fluid/operators/prim_ops/gather_p_op.cc deleted file mode 100644 index 23d8349f22eee..0000000000000 --- a/paddle/fluid/operators/prim_ops/gather_p_op.cc +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace framework { -class InferShapeContext; -class VarDesc; -} // namespace framework -} // namespace paddle - -namespace paddle { -namespace operators { -class GatherPrimOp : public framework::OperatorBase { - public: - GatherPrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator gather_p should not be executed directly")); - } -}; - -class GatherPrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of gather_p op."); - AddInput("IndexTensor", - "(Tensor), The index tensor of gather_p op, which is a 1D tensor.") - .AsDispensable(); - AddOutput("Y", "(Tensor), The output tensor of gather_p op."); - AddAttr("axis", "(int64_t), The axis along which to gather."); - AddAttr>( - "index", "(std::vector) The index of gather_p op") - .SetDefault({0}); - AddComment(R"DOC( -Autograd primitive gather_p operator. -)DOC"); - } -}; - -class GatherPrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; - framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; - int64_t num_index = 0; - if (ctx->HasInput("IndexTensor")) { - framework::InferShapeVarPtr index_var_ptr = - ctx->GetInputVarPtrs("IndexTensor")[0]; - framework::VarDesc *index_var = - PADDLE_GET(framework::VarDesc *, index_var_ptr); - auto index_shape = index_var->GetShape(); - PADDLE_ENFORCE_EQ(index_shape.size(), - 1, - platform::errors::InvalidArgument( - "The index tensor should be a 1D tensor," - "but get rank %d", - index_shape.size())); - num_index = index_shape[0]; - } else { - num_index = static_cast( - ctx->Attrs().Get>("index").size()); - } - auto axis = ctx->Attrs().Get("axis"); - - framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); - auto x_shape = x_var->GetShape(); - x_shape[axis] = num_index; - - PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_shape); - } -}; - -class GatherPrimOpVarTypeInference - : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_name = Input(ctx, "X")[0]; - auto y_name = Output(ctx, "Y")[0]; - if (ctx->HasInput("IndexTensor")) { - auto index_name = Input(ctx, "IndexTensor")[0]; - auto index_dtype = GetDataType(ctx, index_name); - PADDLE_ENFORCE_EQ( - index_dtype, - framework::proto::VarType_Type_INT32, - platform::errors::InvalidArgument( - "The datatype of input tensor should be VarType_Type_INT32(%d), " - "but get %d", - framework::proto::VarType_Type_INT32, - index_dtype)); - } - SetType(ctx, y_name, GetType(ctx, x_name)); - SetDataType(ctx, y_name, GetDataType(ctx, x_name)); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(gather_p, - paddle::operators::GatherPrimOp, - paddle::operators::GatherPrimOpMaker, - paddle::operators::GatherPrimOpShapeInference, - paddle::operators::GatherPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/ge_p_op.cc b/paddle/fluid/operators/prim_ops/ge_p_op.cc deleted file mode 100644 index 20a6496158611..0000000000000 --- a/paddle/fluid/operators/prim_ops/ge_p_op.cc +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace operators { -class GePrimOp : public framework::OperatorBase { - public: - GePrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator ge_p should not be executed directly")); - } -}; - -class GePrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of ge_p op."); - AddInput("Y", "(Tensor), The input tensor of ge_p op."); - AddOutput("Z", "(Tensor), The output tensor of ge_p op."); - AddComment(R"DOC( -Autograd primitive ge_p operator. -)DOC"); - } -}; - -class GePrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; - framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0]; - framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0]; - - framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); - framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr); - auto x_shape = x_var->GetShape(); - auto y_shape = y_var->GetShape(); - size_t x_rank = x_shape.size(); - size_t y_rank = y_shape.size(); - PADDLE_ENFORCE_EQ(x_rank, - y_rank, - platform::errors::InvalidArgument( - "The dimensions of two input tensor should be same, " - "but get %d and %d", - x_rank, - y_rank)); - for (size_t i = 0; i < x_rank; ++i) { - PADDLE_ENFORCE_EQ( - x_shape[i], - y_shape[i], - platform::errors::InvalidArgument( - "The shape of two input tensor at dimension %d should be same, " - "but get %d and %d", - i, - x_shape[i], - y_shape[i])); - } - - PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape); - } -}; - -class GePrimOpVarTypeInference : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_name = Input(ctx, "X")[0]; - auto y_name = Input(ctx, "Y")[0]; - auto z_name = Output(ctx, "Z")[0]; - auto x_type = GetType(ctx, x_name); - auto y_type = GetType(ctx, y_name); - auto x_dtype = GetDataType(ctx, x_name); - auto y_dtype = GetDataType(ctx, y_name); - PADDLE_ENFORCE_EQ(x_type, - y_type, - platform::errors::InvalidArgument( - "The type of two input tensor should be same, " - "but get %d and %d", - x_type, - y_type)); - PADDLE_ENFORCE_EQ(x_dtype, - y_dtype, - platform::errors::InvalidArgument( - "The datatype of two input tensor should be same, " - "but get %d and %d", - x_dtype, - y_dtype)); - - SetType(ctx, z_name, x_type); - SetDataType(ctx, z_name, framework::proto::VarType::BOOL); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(ge_p, - paddle::operators::GePrimOp, - paddle::operators::GePrimOpMaker, - paddle::operators::GePrimOpShapeInference, - paddle::operators::GePrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/gt_p_op.cc b/paddle/fluid/operators/prim_ops/gt_p_op.cc deleted file mode 100644 index 01e8c1612cc43..0000000000000 --- a/paddle/fluid/operators/prim_ops/gt_p_op.cc +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace operators { -class GtPrimOp : public framework::OperatorBase { - public: - GtPrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator gt_p should not be executed directly")); - } -}; - -class GtPrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of gt_p op."); - AddInput("Y", "(Tensor), The input tensor of gt_p op."); - AddOutput("Z", "(Tensor), The output tensor of gt_p op."); - AddComment(R"DOC( -Autograd primitive gt_p operator. -)DOC"); - } -}; - -class GtPrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; - framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0]; - framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0]; - - framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); - framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr); - auto x_shape = x_var->GetShape(); - auto y_shape = y_var->GetShape(); - size_t x_rank = x_shape.size(); - size_t y_rank = y_shape.size(); - PADDLE_ENFORCE_EQ(x_rank, - y_rank, - platform::errors::InvalidArgument( - "The dimensions of two input tensor should be same, " - "but get %d and %d", - x_rank, - y_rank)); - for (size_t i = 0; i < x_rank; ++i) { - PADDLE_ENFORCE_EQ( - x_shape[i], - y_shape[i], - platform::errors::InvalidArgument( - "The shape of two input tensor at dimension %d should be same, " - "but get %d and %d", - i, - x_shape[i], - y_shape[i])); - } - - PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape); - } -}; - -class GtPrimOpVarTypeInference : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_name = Input(ctx, "X")[0]; - auto y_name = Input(ctx, "Y")[0]; - auto z_name = Output(ctx, "Z")[0]; - auto x_type = GetType(ctx, x_name); - auto y_type = GetType(ctx, y_name); - auto x_dtype = GetDataType(ctx, x_name); - auto y_dtype = GetDataType(ctx, y_name); - PADDLE_ENFORCE_EQ(x_type, - y_type, - platform::errors::InvalidArgument( - "The type of two input tensor should be same, " - "but get %d and %d", - x_type, - y_type)); - PADDLE_ENFORCE_EQ(x_dtype, - y_dtype, - platform::errors::InvalidArgument( - "The datatype of two input tensor should be same, " - "but get %d and %d", - x_dtype, - y_dtype)); - - SetType(ctx, z_name, x_type); - SetDataType(ctx, z_name, framework::proto::VarType::BOOL); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(gt_p, - paddle::operators::GtPrimOp, - paddle::operators::GtPrimOpMaker, - paddle::operators::GtPrimOpShapeInference, - paddle::operators::GtPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/log_p_op.cc b/paddle/fluid/operators/prim_ops/log_p_op.cc deleted file mode 100644 index d077510fd5c46..0000000000000 --- a/paddle/fluid/operators/prim_ops/log_p_op.cc +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace operators { -class LogPrimOp : public framework::OperatorBase { - public: - LogPrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator log_p should not be executed directly")); - } -}; - -class LogPrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of log_p op."); - AddOutput("Y", "(Tensor), The output tensor of log_p op."); - AddComment(R"DOC( -Autograd primitive log_p operator. -)DOC"); - } -}; - -class LogPrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; - framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; - - framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); - - PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape()); - } -}; - -class LogPrimOpVarTypeInference - : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_name = Input(ctx, "X")[0]; - auto y_name = Output(ctx, "Y")[0]; - SetType(ctx, y_name, GetType(ctx, x_name)); - SetDataType(ctx, y_name, GetDataType(ctx, x_name)); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(log_p, - paddle::operators::LogPrimOp, - paddle::operators::LogPrimOpMaker, - paddle::operators::LogPrimOpShapeInference, - paddle::operators::LogPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/matmul_p_op.cc b/paddle/fluid/operators/prim_ops/matmul_p_op.cc deleted file mode 100644 index 6a53dda16f71c..0000000000000 --- a/paddle/fluid/operators/prim_ops/matmul_p_op.cc +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace framework { -class InferShapeContext; -class VarDesc; -} // namespace framework -} // namespace paddle - -namespace paddle { -namespace operators { -class MatmulPrimOp : public framework::OperatorBase { - public: - MatmulPrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator matmul_p should not be executed directly")); - } -}; - -class MatmulPrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of matmul_p op."); - AddInput("Y", "(Tensor), The input tensor of matmul_p op."); - AddOutput("Z", "(Tensor), The output tensor of matmul_p op."); - AddComment(R"DOC( -Autograd primitive matmul_p operator. -)DOC"); - } -}; - -class MatmulPrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; - framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0]; - framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0]; - - framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); - framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr); - auto x_shape = x_var->GetShape(); - auto y_shape = y_var->GetShape(); - size_t x_rank = x_shape.size(); - size_t y_rank = y_shape.size(); - PADDLE_ENFORCE_EQ(x_rank, - y_rank, - platform::errors::InvalidArgument( - "The two input tensor's dimension should be equal" - "But received first input tensor's dimension is %d, " - "and another input tensor's dimension is %d", - x_rank, - y_rank)); - - PADDLE_ENFORCE_EQ(x_rank == 2 || x_rank == 3, - true, - platform::errors::InvalidArgument( - "The input tensor's dimension should be 2 or 3" - "But received input tensor's dimension is %d", - x_rank)); - - PADDLE_ENFORCE_EQ( - x_shape[x_rank - 1], - y_shape[y_rank - 2], - platform::errors::InvalidArgument( - "Invalid shape for matmul, the last dimension of first input and " - "the penultimate dimension for the second input should be same." - "But received %d and %d.", - x_shape[x_rank - 1], - y_shape[y_rank - 2])); - if (x_rank == 2) { - std::vector z_shape{x_shape[x_rank - 2], y_shape[y_rank - 1]}; - PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(z_shape); - } else { - PADDLE_ENFORCE_EQ(x_shape[0], - y_shape[0], - platform::errors::InvalidArgument( - "Invalid shape for matmul when input tensor's " - "dimension is 3, the first dimension of first " - "input and the second input should be same." - "But received %d and %d.", - x_shape[0], - y_shape[0])); - - std::vector z_shape{ - x_shape[0], x_shape[x_rank - 2], y_shape[y_rank - 1]}; - PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(z_shape); - } - } -}; - -class MatmulPrimOpVarTypeInference - : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_name = Input(ctx, "X")[0]; - auto y_name = Input(ctx, "Y")[0]; - auto z_name = Output(ctx, "Z")[0]; - auto x_type = GetType(ctx, x_name); - auto y_type = GetType(ctx, y_name); - auto x_dtype = GetDataType(ctx, x_name); - auto y_dtype = GetDataType(ctx, y_name); - PADDLE_ENFORCE_EQ(x_type, - y_type, - platform::errors::InvalidArgument( - "The type of two input tensor should be same, " - "but get %d and %d", - x_type, - y_type)); - PADDLE_ENFORCE_EQ(x_dtype, - y_dtype, - platform::errors::InvalidArgument( - "The datatype of two input tensor should be same, " - "but get %d and %d", - x_dtype, - y_dtype)); - - SetType(ctx, z_name, x_type); - SetDataType(ctx, z_name, x_dtype); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(matmul_p, - paddle::operators::MatmulPrimOp, - paddle::operators::MatmulPrimOpMaker, - paddle::operators::MatmulPrimOpShapeInference, - paddle::operators::MatmulPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/max_p_op.cc b/paddle/fluid/operators/prim_ops/max_p_op.cc deleted file mode 100644 index 782925b748eac..0000000000000 --- a/paddle/fluid/operators/prim_ops/max_p_op.cc +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace operators { -class MaxPrimOp : public framework::OperatorBase { - public: - MaxPrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator max_p should not be executed directly")); - } -}; - -class MaxPrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of max_p op."); - AddInput("Y", "(Tensor), The input tensor of max_p op."); - AddOutput("Z", "(Tensor), The output tensor of max_p op."); - AddComment(R"DOC( -Autograd primitive max_p operator. -)DOC"); - } -}; - -class MaxPrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; - framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0]; - framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0]; - - framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); - framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr); - auto x_shape = x_var->GetShape(); - auto y_shape = y_var->GetShape(); - size_t x_rank = x_shape.size(); - size_t y_rank = y_shape.size(); - PADDLE_ENFORCE_EQ(x_rank, - y_rank, - platform::errors::InvalidArgument( - "The dimensions of two input tensor should be same, " - "but get %d and %d", - x_rank, - y_rank)); - for (size_t i = 0; i < x_rank; ++i) { - PADDLE_ENFORCE_EQ( - x_shape[i], - y_shape[i], - platform::errors::InvalidArgument( - "The shape of two input tensor at dimension %d should be same, " - "but get %d and %d", - i, - x_shape[i], - y_shape[i])); - } - - PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape); - } -}; - -class MaxPrimOpVarTypeInference - : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_name = Input(ctx, "X")[0]; - auto y_name = Input(ctx, "Y")[0]; - auto z_name = Output(ctx, "Z")[0]; - auto x_type = GetType(ctx, x_name); - auto y_type = GetType(ctx, y_name); - auto x_dtype = GetDataType(ctx, x_name); - auto y_dtype = GetDataType(ctx, y_name); - PADDLE_ENFORCE_EQ(x_type, - y_type, - platform::errors::InvalidArgument( - "The type of two input tensor should be same, " - "but get %d and %d", - x_type, - y_type)); - PADDLE_ENFORCE_EQ(x_dtype, - y_dtype, - platform::errors::InvalidArgument( - "The datatype of two input tensor should be same, " - "but get %d and %d", - x_dtype, - y_dtype)); - - SetType(ctx, z_name, x_type); - SetDataType(ctx, z_name, x_dtype); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(max_p, - paddle::operators::MaxPrimOp, - paddle::operators::MaxPrimOpMaker, - paddle::operators::MaxPrimOpShapeInference, - paddle::operators::MaxPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/mul_p_op.cc b/paddle/fluid/operators/prim_ops/mul_p_op.cc deleted file mode 100644 index fd655e887be90..0000000000000 --- a/paddle/fluid/operators/prim_ops/mul_p_op.cc +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace framework { -class InferShapeContext; -class VarDesc; -} // namespace framework -} // namespace paddle - -namespace paddle { -namespace operators { -class MulPrimOp : public framework::OperatorBase { - public: - MulPrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator mul_p should not be executed directly")); - } -}; - -class MulPrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of mul_p op."); - AddInput("Y", "(Tensor), The input tensor of mul_p op."); - AddOutput("Z", "(Tensor), The output tensor of mul_p op."); - AddComment(R"DOC( -Autograd primitive mul_p operator. -)DOC"); - } -}; - -class MulPrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; - framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0]; - framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0]; - - framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); - framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr); - auto x_shape = x_var->GetShape(); - auto y_shape = y_var->GetShape(); - size_t x_rank = x_shape.size(); - size_t y_rank = y_shape.size(); - PADDLE_ENFORCE_EQ(x_rank, - y_rank, - platform::errors::InvalidArgument( - "The dimensions of two input tensor should be same, " - "but get %d and %d", - x_rank, - y_rank)); - for (size_t i = 0; i < x_rank; ++i) { - PADDLE_ENFORCE_EQ( - x_shape[i], - y_shape[i], - platform::errors::InvalidArgument( - "The shape of two input tensor at dimension %d should be same, " - "but get %d and %d", - i, - x_shape[i], - y_shape[i])); - } - - PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape); - } -}; - -class MulPrimOpVarTypeInference - : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_name = Input(ctx, "X")[0]; - auto y_name = Input(ctx, "Y")[0]; - auto z_name = Output(ctx, "Z")[0]; - auto x_type = GetType(ctx, x_name); - auto y_type = GetType(ctx, y_name); - auto x_dtype = GetDataType(ctx, x_name); - auto y_dtype = GetDataType(ctx, y_name); - PADDLE_ENFORCE_EQ(x_type, - y_type, - platform::errors::InvalidArgument( - "The type of two input tensor should be same, " - "but get %d and %d", - x_type, - y_type)); - PADDLE_ENFORCE_EQ(x_dtype, - y_dtype, - platform::errors::InvalidArgument( - "The datatype of two input tensor should be same, " - "but get %d and %d", - x_dtype, - y_dtype)); - - SetType(ctx, z_name, x_type); - SetDataType(ctx, z_name, x_dtype); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(mul_p, - paddle::operators::MulPrimOp, - paddle::operators::MulPrimOpMaker, - paddle::operators::MulPrimOpShapeInference, - paddle::operators::MulPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/ne_p_op.cc b/paddle/fluid/operators/prim_ops/ne_p_op.cc deleted file mode 100644 index 0d65d1a7e33d9..0000000000000 --- a/paddle/fluid/operators/prim_ops/ne_p_op.cc +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace operators { -class NePrimOp : public framework::OperatorBase { - public: - NePrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator ne_p should not be executed directly")); - } -}; - -class NePrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of ne_p op."); - AddInput("Y", "(Tensor), The input tensor of ne_p op."); - AddOutput("Z", "(Tensor), The output tensor of ne_p op."); - AddComment(R"DOC( -Autograd primitive ne_p operator. -)DOC"); - } -}; - -class NePrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; - framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0]; - framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0]; - - framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); - framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr); - auto x_shape = x_var->GetShape(); - auto y_shape = y_var->GetShape(); - size_t x_rank = x_shape.size(); - size_t y_rank = y_shape.size(); - PADDLE_ENFORCE_EQ(x_rank, - y_rank, - platform::errors::InvalidArgument( - "The dimensions of two input tensor should be same, " - "but get %d and %d", - x_rank, - y_rank)); - for (size_t i = 0; i < x_rank; ++i) { - PADDLE_ENFORCE_EQ( - x_shape[i], - y_shape[i], - platform::errors::InvalidArgument( - "The shape of two input tensor at dimension %d should be same, " - "but get %d and %d", - i, - x_shape[i], - y_shape[i])); - } - - PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape); - } -}; - -class NePrimOpVarTypeInference : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_name = Input(ctx, "X")[0]; - auto y_name = Input(ctx, "Y")[0]; - auto z_name = Output(ctx, "Z")[0]; - auto x_type = GetType(ctx, x_name); - auto y_type = GetType(ctx, y_name); - auto x_dtype = GetDataType(ctx, x_name); - auto y_dtype = GetDataType(ctx, y_name); - PADDLE_ENFORCE_EQ(x_type, - y_type, - platform::errors::InvalidArgument( - "The type of two input tensor should be same, " - "but get %d and %d", - x_type, - y_type)); - PADDLE_ENFORCE_EQ(x_dtype, - y_dtype, - platform::errors::InvalidArgument( - "The datatype of two input tensor should be same, " - "but get %d and %d", - x_dtype, - y_dtype)); - - SetType(ctx, z_name, x_type); - SetDataType(ctx, z_name, framework::proto::VarType::BOOL); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(ne_p, - paddle::operators::NePrimOp, - paddle::operators::NePrimOpMaker, - paddle::operators::NePrimOpShapeInference, - paddle::operators::NePrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/pow_p_op.cc b/paddle/fluid/operators/prim_ops/pow_p_op.cc deleted file mode 100644 index 50e625a328e58..0000000000000 --- a/paddle/fluid/operators/prim_ops/pow_p_op.cc +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace operators { -class PowPrimOp : public framework::OperatorBase { - public: - PowPrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator pow_p should not be executed directly")); - } -}; - -class PowPrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The base of pow_p op."); - AddInput("Y", "(Tensor), The exponents of pow_p op."); - AddOutput("Z", "(Tensor), The output tensor of pow_p op."); - AddComment(R"DOC( -Autograd primitive pow_p operator. -)DOC"); - } -}; - -class PowPrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; - framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0]; - framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0]; - - framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); - framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr); - auto x_shape = x_var->GetShape(); - auto y_shape = y_var->GetShape(); - size_t x_rank = x_shape.size(); - size_t y_rank = y_shape.size(); - - PADDLE_ENFORCE_EQ(x_rank, - y_rank, - platform::errors::InvalidArgument( - "The dimensions of two input tensor should be same, " - "but get %d and %d", - x_rank, - y_rank)); - for (size_t i = 0; i < x_rank; ++i) { - PADDLE_ENFORCE_EQ( - x_shape[i], - y_shape[i], - platform::errors::InvalidArgument( - "The shape of two input tensor at dimension %d should be same, " - "but get %d and %d", - i, - x_shape[i], - y_shape[i])); - } - - PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape); - } -}; - -class PowPrimOpVarTypeInference - : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_name = Input(ctx, "X")[0]; - auto y_name = Input(ctx, "Y")[0]; - auto z_name = Output(ctx, "Z")[0]; - auto x_type = GetType(ctx, x_name); - auto y_type = GetType(ctx, y_name); - auto x_dtype = GetDataType(ctx, x_name); - - PADDLE_ENFORCE_EQ(x_type, - y_type, - platform::errors::InvalidArgument( - "The type of two input tensor should be same, " - "but get %d and %d", - x_type, - y_type)); - - SetType(ctx, z_name, x_type); - SetDataType(ctx, z_name, x_dtype); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(pow_p, - paddle::operators::PowPrimOp, - paddle::operators::PowPrimOpMaker, - paddle::operators::PowPrimOpShapeInference, - paddle::operators::PowPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/reduce_sum_p_op.cc b/paddle/fluid/operators/prim_ops/reduce_sum_p_op.cc deleted file mode 100644 index dbb33a98b108c..0000000000000 --- a/paddle/fluid/operators/prim_ops/reduce_sum_p_op.cc +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace framework { -class InferShapeContext; -class VarDesc; -} // namespace framework -} // namespace paddle - -namespace paddle { -namespace operators { -class ReduceSumPrimOp : public framework::OperatorBase { - public: - ReduceSumPrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator reduce_sum_p should not be executed directly")); - } -}; - -class ReduceSumPrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of reduce_sum_p op."); - AddOutput("Y", "(Tensor), The output tensor of reduce_sum_p op."); - AddAttr>( - "axis", - "(std::vector) The axis along which to reduce on. Must be in " - "range [-rank(input), rank(input)]. If `axis[i] < 0`, the axis[i] to " - "reduce is `rank + axis[i]`."); - AddAttr("keepdim", - "(bool, default false) " - "If true, retain the reduced axis with length 1.") - .SetDefault(false); - AddComment(R"DOC( -Autograd primitive reduce_sum_p operator. -)DOC"); - } -}; - -class ReduceSumPrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; - framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; - framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); - auto x_shape = x_var->GetShape(); - auto axis = ctx->Attrs().Get>("axis"); - auto keepdim = ctx->Attrs().Get("keepdim"); - if (keepdim) { - for (auto item : axis) { - x_shape[item] = 1; - } - } else { - const int kDelFlag = -2; - for (auto item : axis) { - x_shape[item] = kDelFlag; - } - x_shape.erase(remove(x_shape.begin(), x_shape.end(), kDelFlag), - x_shape.end()); - } - - PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_shape); - } -}; - -class ReduceSumPrimOpVarTypeInference - : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_name = Input(ctx, "X")[0]; - auto y_name = Output(ctx, "Y")[0]; - SetType(ctx, y_name, GetType(ctx, x_name)); - SetDataType(ctx, y_name, GetDataType(ctx, x_name)); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(reduce_sum_p, - paddle::operators::ReduceSumPrimOp, - paddle::operators::ReduceSumPrimOpMaker, - paddle::operators::ReduceSumPrimOpShapeInference, - paddle::operators::ReduceSumPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/reshape_p_op.cc b/paddle/fluid/operators/prim_ops/reshape_p_op.cc deleted file mode 100644 index 8137dfd629b01..0000000000000 --- a/paddle/fluid/operators/prim_ops/reshape_p_op.cc +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace framework { -class InferShapeContext; -class VarDesc; -} // namespace framework -} // namespace paddle - -namespace paddle { -namespace operators { -class ReshapePrimOp : public framework::OperatorBase { - public: - ReshapePrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator reshape_p should not be executed directly")); - } -}; - -class ReshapePrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of reshape_p op."); - AddOutput("Y", "(Tensor), The output tensor of reshape_p op."); - AddAttr>( - "shape", "(std::vector) Target shape of reshape_p operator."); - AddComment(R"DOC( -Autograd primitive reshape_p operator. -)DOC"); - } -}; - -static int64_t product(const std::vector &shape) { - int64_t rslt = 1; - for (auto item : shape) { - rslt *= item; - } - return rslt; -} - -class ReshapePrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; - framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; - framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); - auto x_shape = x_var->GetShape(); - auto shape = ctx->Attrs().Get>("shape"); - PADDLE_ENFORCE_EQ(product(x_shape), - product(shape), - platform::errors::InvalidArgument( - "The input tensor can't be reshaped to target shape, " - "the input tensor has %d elements but target shape " - "contains %d elements", - product(x_shape), - product(shape))); - PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(shape); - } -}; - -class ReshapePrimOpVarTypeInference - : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_name = Input(ctx, "X")[0]; - auto y_name = Output(ctx, "Y")[0]; - SetType(ctx, y_name, GetType(ctx, x_name)); - SetDataType(ctx, y_name, GetDataType(ctx, x_name)); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(reshape_p, - paddle::operators::ReshapePrimOp, - paddle::operators::ReshapePrimOpMaker, - paddle::operators::ReshapePrimOpShapeInference, - paddle::operators::ReshapePrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/rsqrt_p_op.cc b/paddle/fluid/operators/prim_ops/rsqrt_p_op.cc deleted file mode 100644 index d2401c6d4e40f..0000000000000 --- a/paddle/fluid/operators/prim_ops/rsqrt_p_op.cc +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace framework { -class InferShapeContext; -class VarDesc; -} // namespace framework -} // namespace paddle - -namespace paddle { -namespace operators { -class RsqrtPrimOp : public framework::OperatorBase { - public: - RsqrtPrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator rsqrt_p should not be executed directly")); - } -}; - -class RsqrtPrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of rsqrt_p op."); - AddOutput("Y", "(Tensor), The output tensor of rsqrt_p op."); - AddComment(R"DOC( -Autograd primitive rsqrt_p operator. -)DOC"); - } -}; - -class RsqrtPrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; - framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; - - framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); - - PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape()); - } -}; - -class RsqrtPrimOpVarTypeInference - : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_name = Input(ctx, "X")[0]; - auto y_name = Output(ctx, "Y")[0]; - SetType(ctx, y_name, GetType(ctx, x_name)); - SetDataType(ctx, y_name, GetDataType(ctx, x_name)); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(rsqrt_p, - paddle::operators::RsqrtPrimOp, - paddle::operators::RsqrtPrimOpMaker, - paddle::operators::RsqrtPrimOpShapeInference, - paddle::operators::RsqrtPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/scatter_add_p_op.cc b/paddle/fluid/operators/prim_ops/scatter_add_p_op.cc deleted file mode 100644 index 2b116d5224073..0000000000000 --- a/paddle/fluid/operators/prim_ops/scatter_add_p_op.cc +++ /dev/null @@ -1,176 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace framework { -class InferShapeContext; -class VarDesc; -} // namespace framework -} // namespace paddle - -namespace paddle { -namespace operators { -class ScatterAddPrimOp : public framework::OperatorBase { - public: - ScatterAddPrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator scatter_add_p should not be executed directly")); - } -}; - -class ScatterAddPrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The tensor to apply scatter rule and add on."); - AddInput("Y", "(Tensor), The source tensor of scatter_add_p op."); - AddInput( - "IndexTensor", - "(Tensor), The index tensor of scatter_add_p op, which is a 1D tensor.") - .AsDispensable(); - AddOutput("Z", "(Tensor), The output tensor of scatter_add_p op."); - AddAttr("axis", - "(int64_t), The axis along which to scatter and add."); - AddAttr>( - "index", "(std::vector) The index of scatter_add_p op") - .SetDefault({0}); - AddComment(R"DOC( -Autograd primitive scatter_add_p operator. -)DOC"); - } -}; - -class ScatterAddPrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; - framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0]; - framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0]; - int64_t num_index = 0; - if (ctx->HasInput("IndexTensor")) { - framework::InferShapeVarPtr index_var_ptr = - ctx->GetInputVarPtrs("IndexTensor")[0]; - framework::VarDesc *index_var = - PADDLE_GET(framework::VarDesc *, index_var_ptr); - auto index_shape = index_var->GetShape(); - PADDLE_ENFORCE_EQ(index_shape.size(), - 1, - platform::errors::InvalidArgument( - "The index tensor should be a 1D tensor," - "but get rank %d", - index_shape.size())); - num_index = index_shape[0]; - } else { - num_index = static_cast( - ctx->Attrs().Get>("index").size()); - } - auto axis = ctx->Attrs().Get("axis"); - framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); - framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr); - auto x_shape = x_var->GetShape(); - auto y_shape = y_var->GetShape(); - size_t x_rank = x_shape.size(); - size_t y_rank = y_shape.size(); - PADDLE_ENFORCE_EQ(x_rank, - y_rank, - platform::errors::InvalidArgument( - "The dimensions of two input tensor should be same, " - "but get %d and %d", - x_rank, - y_rank)); - PADDLE_ENFORCE_EQ(y_shape[axis], - num_index, - platform::errors::InvalidArgument( - "The shape of source input tensor at scatter axis " - "should be equal to num_index, " - "but get %d and %d", - y_shape[axis], - num_index)); - for (size_t i = 0; i < x_rank; ++i) { - if (i != size_t(axis)) { - PADDLE_ENFORCE_EQ( - x_shape[i], - y_shape[i], - platform::errors::InvalidArgument( - "The shape of two input tensor at dimension %d should be same, " - "but get %d and %d", - i, - x_rank, - y_rank)); - } - } - - PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape); - } -}; - -class ScatterAddPrimOpVarTypeInference - : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_name = Input(ctx, "X")[0]; - auto y_name = Input(ctx, "Y")[0]; - auto z_name = Output(ctx, "Z")[0]; - auto x_type = GetType(ctx, x_name); - auto y_type = GetType(ctx, y_name); - auto x_dtype = GetDataType(ctx, x_name); - auto y_dtype = GetDataType(ctx, y_name); - PADDLE_ENFORCE_EQ(x_type, - y_type, - platform::errors::InvalidArgument( - "The type of two input tensor should be same, " - "but get %d and %d", - x_type, - y_type)); - PADDLE_ENFORCE_EQ(x_dtype, - y_dtype, - platform::errors::InvalidArgument( - "The datatype of two input tensor should be same, " - "but get %d and %d", - x_dtype, - y_dtype)); - - if (ctx->HasInput("IndexTensor")) { - auto index_name = Input(ctx, "IndexTensor")[0]; - auto index_dtype = GetDataType(ctx, index_name); - PADDLE_ENFORCE_EQ( - index_dtype, - framework::proto::VarType_Type_INT32, - platform::errors::InvalidArgument( - "The datatype of input tensor should be VarType_Type_INT32(%d), " - "but get %d", - framework::proto::VarType_Type_INT32, - index_dtype)); - } - SetType(ctx, z_name, GetType(ctx, x_name)); - SetDataType(ctx, z_name, GetDataType(ctx, x_name)); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(scatter_add_p, - paddle::operators::ScatterAddPrimOp, - paddle::operators::ScatterAddPrimOpMaker, - paddle::operators::ScatterAddPrimOpShapeInference, - paddle::operators::ScatterAddPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/select_p_op.cc b/paddle/fluid/operators/prim_ops/select_p_op.cc deleted file mode 100644 index 69253da41d7d2..0000000000000 --- a/paddle/fluid/operators/prim_ops/select_p_op.cc +++ /dev/null @@ -1,153 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace operators { -class SelectPrimOp : public framework::OperatorBase { - public: - SelectPrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator select_p should not be executed directly")); - } -}; - -class SelectPrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("Condition", "(Tensor), The input condition of select_p op."); - AddInput("X", "(Tensor), The input tensor of select_p op."); - AddInput("Y", "(Tensor), The input tensor of select_p op."); - AddOutput("Z", "(Tensor), The output tensor of select_p op."); - AddComment(R"DOC( -Autograd primitive select_p operator. -)DOC"); - } -}; - -class SelectPrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr condition_var_ptr = - ctx->GetInputVarPtrs("Condition")[0]; - framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; - framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0]; - framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0]; - - framework::VarDesc *condition_var = - PADDLE_GET(framework::VarDesc *, condition_var_ptr); - framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); - framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr); - - auto condition_shape = condition_var->GetShape(); - auto x_shape = x_var->GetShape(); - auto y_shape = y_var->GetShape(); - - size_t condition_rank = condition_shape.size(); - size_t x_rank = x_shape.size(); - size_t y_rank = y_shape.size(); - - PADDLE_ENFORCE_EQ( - condition_rank, - x_rank, - platform::errors::InvalidArgument( - "The dimensions of condtion and Inputs(X) should be same, " - "but get %d and %d", - condition_rank, - x_rank)); - PADDLE_ENFORCE_EQ( - x_rank, - y_rank, - platform::errors::InvalidArgument( - "The dimensions of Inputs(X) and Inputs(Y) should be same, " - "but get %d and %d", - x_rank, - y_rank)); - for (size_t i = 0; i < condition_rank; ++i) { - PADDLE_ENFORCE_EQ(condition_shape[i], - x_shape[i], - platform::errors::InvalidArgument( - "The shape of condition and Inputs(X) at dimension " - "%d should be same, " - "but get %d and %d", - i, - condition_shape[i], - x_shape[i])); - } - for (size_t i = 0; i < x_rank; ++i) { - PADDLE_ENFORCE_EQ(x_shape[i], - y_shape[i], - platform::errors::InvalidArgument( - "The shape of Inputs(X) and Inputs(Y) at dimension " - "%d should be same, " - "but get %d and %d", - i, - x_shape[i], - y_shape[i])); - } - - PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(condition_shape); - } -}; - -class SelectPrimOpVarTypeInference - : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_name = Input(ctx, "X")[0]; - auto y_name = Input(ctx, "Y")[0]; - auto z_name = Output(ctx, "Z")[0]; - - auto x_type = GetType(ctx, x_name); - auto y_type = GetType(ctx, y_name); - - auto x_dtype = GetDataType(ctx, x_name); - auto y_dtype = GetDataType(ctx, y_name); - - PADDLE_ENFORCE_EQ(x_type, - y_type, - platform::errors::InvalidArgument( - "The type of two input tensor should be same, " - "but get %d and %d", - x_type, - y_type)); - PADDLE_ENFORCE_EQ(x_dtype, - y_dtype, - platform::errors::InvalidArgument( - "The datatype of two input tensor should be same, " - "but get %d and %d", - x_dtype, - y_dtype)); - - SetType(ctx, z_name, x_type); - SetDataType(ctx, z_name, x_dtype); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(select_p, - paddle::operators::SelectPrimOp, - paddle::operators::SelectPrimOpMaker, - paddle::operators::SelectPrimOpShapeInference, - paddle::operators::SelectPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/sin_p_op.cc b/paddle/fluid/operators/prim_ops/sin_p_op.cc deleted file mode 100644 index 95b413acc77af..0000000000000 --- a/paddle/fluid/operators/prim_ops/sin_p_op.cc +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace operators { -class SinPrimOp : public framework::OperatorBase { - public: - SinPrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator sin_p should not be executed directly")); - } -}; - -class SinPrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of sin_p op."); - AddOutput("Y", "(Tensor), The output tensor of sin_p op."); - AddComment(R"DOC(Autograd primitive sin_p operator.)DOC"); - } -}; - -class SinPrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; - framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; - framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); - PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape()); - } -}; - -class SinPrimOpVarTypeInference - : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_name = Input(ctx, "X")[0]; - auto y_name = Output(ctx, "Y")[0]; - SetType(ctx, y_name, GetType(ctx, x_name)); - SetDataType(ctx, y_name, GetDataType(ctx, x_name)); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(sin_p, - paddle::operators::SinPrimOp, - paddle::operators::SinPrimOpMaker, - paddle::operators::SinPrimOpShapeInference, - paddle::operators::SinPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/slice_assign_p_op.cc b/paddle/fluid/operators/prim_ops/slice_assign_p_op.cc deleted file mode 100644 index 9485d621aa5d4..0000000000000 --- a/paddle/fluid/operators/prim_ops/slice_assign_p_op.cc +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace framework { -class InferShapeContext; -class VarDesc; -} // namespace framework -} // namespace paddle - -namespace paddle { -namespace operators { -class SliceAssignPrimOp : public framework::OperatorBase { - public: - SliceAssignPrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator slice_assign_p should not be executed directly")); - } -}; - -class SliceAssignPrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The tensor to slice from and assign on."); - AddInput("Y", "(Tensor), The source tensor of slice_assign_p op."); - AddOutput("Z", "(Tensor), The output tensor of slice_assign_p op."); - AddAttr>( - "axis", "(std::vector), The axis along which to gather."); - AddAttr>( - "starts", - "(std::vector) The slice starts of slice_assign_p op"); - AddAttr>( - "ends", "(std::vector) The slice ends of slice_assign_p op"); - AddAttr>( - "strides", - "(std::vector) The slice strides of slice_assign_p op"); - AddComment(R"DOC( -Autograd primitive slice_assign_p operator. -)DOC"); - } -}; - -class SliceAssignPrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; - framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0]; - framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0]; - framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); - framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr); - auto x_shape = x_var->GetShape(); - auto y_shape = y_var->GetShape(); - size_t x_rank = x_shape.size(); - size_t y_rank = y_shape.size(); - auto axis = ctx->Attrs().Get>("axis"); - auto starts = ctx->Attrs().Get>("starts"); - auto ends = ctx->Attrs().Get>("ends"); - auto strides = ctx->Attrs().Get>("strides"); - PADDLE_ENFORCE_EQ( - starts.size(), - axis.size(), - platform::errors::InvalidArgument( - "Number of starts attribute and axis attribute should be same, " - "but get %d and %d", - starts.size(), - axis.size())); - PADDLE_ENFORCE_EQ( - ends.size(), - axis.size(), - platform::errors::InvalidArgument( - "Number of ends attribute and axis attribute should be same, " - "but get %d and %d", - ends.size(), - axis.size())); - PADDLE_ENFORCE_EQ( - strides.size(), - axis.size(), - platform::errors::InvalidArgument( - "Number of strides attribute and axis attribute should be same, " - "but get %d and %d", - strides.size(), - axis.size())); - PADDLE_ENFORCE_EQ(x_rank, - y_rank, - platform::errors::InvalidArgument( - "The dimensions of two input tensor should be same, " - "but get %d and %d", - x_rank, - y_rank)); - std::vector y_target_shape(x_shape); - for (size_t i = 0; i < axis.size(); ++i) { - y_target_shape[axis[i]] = - (ends[i] - starts[i] + strides[i] - 1) / strides[i]; - } - for (size_t i = 0; i < x_rank; ++i) { - PADDLE_ENFORCE_EQ(y_target_shape[i], - y_shape[i], - platform::errors::InvalidArgument( - "The shape of source tensor of slice_assign_p op " - "at dimension %d should be %d, " - "but get %d", - i, - y_target_shape[i], - y_shape[i])); - } - PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape); - } -}; - -class SliceAssignPrimOpVarTypeInference - : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_name = Input(ctx, "X")[0]; - auto y_name = Input(ctx, "Y")[0]; - auto z_name = Output(ctx, "Z")[0]; - auto x_type = GetType(ctx, x_name); - auto y_type = GetType(ctx, y_name); - auto x_dtype = GetDataType(ctx, x_name); - auto y_dtype = GetDataType(ctx, y_name); - PADDLE_ENFORCE_EQ(x_type, - y_type, - platform::errors::InvalidArgument( - "The type of two input tensor should be same, " - "but get %d and %d", - x_type, - y_type)); - PADDLE_ENFORCE_EQ(x_dtype, - y_dtype, - platform::errors::InvalidArgument( - "The datatype of two input tensor should be same, " - "but get %d and %d", - x_dtype, - y_dtype)); - - SetType(ctx, z_name, GetType(ctx, x_name)); - SetDataType(ctx, z_name, GetDataType(ctx, x_name)); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(slice_assign_p, - paddle::operators::SliceAssignPrimOp, - paddle::operators::SliceAssignPrimOpMaker, - paddle::operators::SliceAssignPrimOpShapeInference, - paddle::operators::SliceAssignPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/slice_select_p_op.cc b/paddle/fluid/operators/prim_ops/slice_select_p_op.cc deleted file mode 100644 index dd2242368b739..0000000000000 --- a/paddle/fluid/operators/prim_ops/slice_select_p_op.cc +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace framework { -class InferShapeContext; -class VarDesc; -} // namespace framework -} // namespace paddle - -namespace paddle { -namespace operators { -class SliceSelectPrimOp : public framework::OperatorBase { - public: - SliceSelectPrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator slice_select_p should not be executed directly")); - } -}; - -class SliceSelectPrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of slice_select_p op."); - AddOutput("Y", "(Tensor), The output tensor of slice_select_p op."); - AddAttr>( - "axis", "(std::vector), The axis along which to gather."); - AddAttr>( - "starts", - "(std::vector) The slice starts of slice_select_p op"); - AddAttr>( - "ends", "(std::vector) The slice ends of slice_select_p op"); - AddAttr>( - "strides", - "(std::vector) The slice strides of slice_select_p op"); - AddComment(R"DOC( -Autograd primitive slice_select_p operator. -)DOC"); - } -}; - -class SliceSelectPrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; - framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; - framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); - auto x_shape = x_var->GetShape(); - auto axis = ctx->Attrs().Get>("axis"); - auto starts = ctx->Attrs().Get>("starts"); - auto ends = ctx->Attrs().Get>("ends"); - auto strides = ctx->Attrs().Get>("strides"); - PADDLE_ENFORCE_EQ( - starts.size(), - axis.size(), - platform::errors::InvalidArgument( - "Number of starts attribute and axis attribute should be same, " - "but get %d and %d", - starts.size(), - axis.size())); - PADDLE_ENFORCE_EQ( - ends.size(), - axis.size(), - platform::errors::InvalidArgument( - "Number of ends attribute and axis attribute should be same, " - "but get %d and %d", - ends.size(), - axis.size())); - PADDLE_ENFORCE_EQ( - strides.size(), - axis.size(), - platform::errors::InvalidArgument( - "Number of strides attribute and axis attribute should be same, " - "but get %d and %d", - strides.size(), - axis.size())); - for (size_t i = 0; i < axis.size(); ++i) { - x_shape[axis[i]] = (ends[i] - starts[i] + strides[i] - 1) / strides[i]; - } - PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_shape); - } -}; - -class SliceSelectPrimOpVarTypeInference - : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_name = Input(ctx, "X")[0]; - auto y_name = Output(ctx, "Y")[0]; - SetType(ctx, y_name, GetType(ctx, x_name)); - SetDataType(ctx, y_name, GetDataType(ctx, x_name)); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(slice_select_p, - paddle::operators::SliceSelectPrimOp, - paddle::operators::SliceSelectPrimOpMaker, - paddle::operators::SliceSelectPrimOpShapeInference, - paddle::operators::SliceSelectPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/split_p_op.cc b/paddle/fluid/operators/prim_ops/split_p_op.cc deleted file mode 100644 index bc0f8b8a31cda..0000000000000 --- a/paddle/fluid/operators/prim_ops/split_p_op.cc +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace framework { -class InferShapeContext; -class VarDesc; -} // namespace framework -} // namespace paddle - -namespace paddle { -namespace operators { -class SplitPrimOp : public framework::OperatorBase { - public: - SplitPrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator split_p should not be executed directly")); - } -}; - -class SplitPrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of split_p op."); - AddOutput("YS", "(Tensor), The output tensors of split_p op.") - .AsDuplicable(); - AddAttr("axis", "(int64_t), The axis along which to split."); - AddAttr>( - "num_or_sections", - "(std::vector) If num_or_sections has only one element, then " - "num_or_sections indicates the number of equal sized sub-Tensors that " - "the input will be divided into. If num_or_sections has more then one " - "element, the length of it indicates the number of sub-Tensors and the " - "elements in it indicate the sizes of sub-Tensors' dimension orderly. " - "The length of the vector must not be larger than the input's size of " - "specified axis."); - AddComment(R"DOC( -Autograd primitive split_p operator. -)DOC"); - } -}; - -class SplitPrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; - auto y_var_ptrs = ctx->GetOutputVarPtrs("YS"); - framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); - auto x_shape = x_var->GetShape(); - auto axis = ctx->Attrs().Get("axis"); - auto num_or_sections = - ctx->Attrs().Get>("num_or_sections"); - std::vector y_shape(x_shape); - if (num_or_sections.size() == 1) { - PADDLE_ENFORCE_EQ(x_shape[axis] % num_or_sections[0], - 0, - platform::errors::InvalidArgument( - "The input tensor can't be devided equally into %d " - "parts equally along axis %d", - num_or_sections[0], - axis)); - y_shape[axis] = x_shape[axis] / num_or_sections[0]; - for (size_t i = 0; i < size_t(num_or_sections[0]); ++i) { - PADDLE_GET(framework::VarDesc *, y_var_ptrs[i])->SetShape(y_shape); - } - } else { - int64_t cnt_along_axis = 0; - for (size_t i = 0; i < num_or_sections.size(); ++i) { - y_shape[axis] = num_or_sections[i]; - cnt_along_axis += num_or_sections[i]; - PADDLE_GET(framework::VarDesc *, y_var_ptrs[i])->SetShape(y_shape); - } - PADDLE_ENFORCE_EQ( - x_shape[axis], - cnt_along_axis, - platform::errors::InvalidArgument( - "The input tensor has %d elements along axis %d, thus can't be " - "devided into %d tensor with %d elements totally.", - x_shape[axis], - axis, - num_or_sections.size(), - cnt_along_axis)); - } - } -}; - -class SplitPrimOpVarTypeInference - : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_name = Input(ctx, "X")[0]; - auto y_names = Output(ctx, "YS"); - for (auto const &y_name : y_names) { - SetType(ctx, y_name, GetType(ctx, x_name)); - SetDataType(ctx, y_name, GetDataType(ctx, x_name)); - } - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(split_p, - paddle::operators::SplitPrimOp, - paddle::operators::SplitPrimOpMaker, - paddle::operators::SplitPrimOpShapeInference, - paddle::operators::SplitPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/sqrt_p_op.cc b/paddle/fluid/operators/prim_ops/sqrt_p_op.cc deleted file mode 100644 index caebfd388f68f..0000000000000 --- a/paddle/fluid/operators/prim_ops/sqrt_p_op.cc +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace framework { -class InferShapeContext; -class VarDesc; -} // namespace framework -} // namespace paddle - -namespace paddle { -namespace operators { -class SqrtPrimOp : public framework::OperatorBase { - public: - SqrtPrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator sqrt_p should not be executed directly")); - } -}; - -class SqrtPrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of sqrt_p op."); - AddOutput("Y", "(Tensor), The output tensor of sqrt_p op."); - AddComment(R"DOC( -Autograd primitive sqrt_p operator. -)DOC"); - } -}; - -class SqrtPrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; - framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; - - framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); - - PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape()); - } -}; - -class SqrtPrimOpVarTypeInference - : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_name = Input(ctx, "X")[0]; - auto y_name = Output(ctx, "Y")[0]; - SetType(ctx, y_name, GetType(ctx, x_name)); - SetDataType(ctx, y_name, GetDataType(ctx, x_name)); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(sqrt_p, - paddle::operators::SqrtPrimOp, - paddle::operators::SqrtPrimOpMaker, - paddle::operators::SqrtPrimOpShapeInference, - paddle::operators::SqrtPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/sub_p_op.cc b/paddle/fluid/operators/prim_ops/sub_p_op.cc deleted file mode 100644 index 4497978093f4f..0000000000000 --- a/paddle/fluid/operators/prim_ops/sub_p_op.cc +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace framework { -class InferShapeContext; -class VarDesc; -} // namespace framework -} // namespace paddle - -namespace paddle { -namespace operators { -class SubPrimOp : public framework::OperatorBase { - public: - SubPrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator sub_p should not be executed directly")); - } -}; - -class SubPrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of sub_p op."); - AddInput("Y", "(Tensor), The input tensor of sub_p op."); - AddOutput("Z", "(Tensor), The output tensor of sub_p op."); - AddComment(R"DOC( -Autograd primitive sub_p operator. -)DOC"); - } -}; - -class SubPrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; - framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0]; - framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0]; - - framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); - framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr); - auto x_shape = x_var->GetShape(); - auto y_shape = y_var->GetShape(); - size_t x_rank = x_shape.size(); - size_t y_rank = y_shape.size(); - PADDLE_ENFORCE_EQ(x_rank, - y_rank, - platform::errors::InvalidArgument( - "The dimensions of two input tensor should be same, " - "but get %d and %d", - x_rank, - y_rank)); - for (size_t i = 0; i < x_rank; ++i) { - PADDLE_ENFORCE_EQ( - x_shape[i], - y_shape[i], - platform::errors::InvalidArgument( - "The shape of two input tensor at dimension %d should be same, " - "but get %d and %d", - i, - x_shape[i], - y_shape[i])); - } - - PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape); - } -}; - -class SubPrimOpVarTypeInference - : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_name = Input(ctx, "X")[0]; - auto y_name = Input(ctx, "Y")[0]; - auto z_name = Output(ctx, "Z")[0]; - auto x_type = GetType(ctx, x_name); - auto y_type = GetType(ctx, y_name); - auto x_dtype = GetDataType(ctx, x_name); - auto y_dtype = GetDataType(ctx, y_name); - PADDLE_ENFORCE_EQ(x_type, - y_type, - platform::errors::InvalidArgument( - "The type of two input tensor should be same, " - "but get %d and %d", - x_type, - y_type)); - PADDLE_ENFORCE_EQ(x_dtype, - y_dtype, - platform::errors::InvalidArgument( - "The datatype of two input tensor should be same, " - "but get %d and %d", - x_dtype, - y_dtype)); - - SetType(ctx, z_name, x_type); - SetDataType(ctx, z_name, x_dtype); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(sub_p, - paddle::operators::SubPrimOp, - paddle::operators::SubPrimOpMaker, - paddle::operators::SubPrimOpShapeInference, - paddle::operators::SubPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/tanh_p_op.cc b/paddle/fluid/operators/prim_ops/tanh_p_op.cc deleted file mode 100644 index 042394aa15068..0000000000000 --- a/paddle/fluid/operators/prim_ops/tanh_p_op.cc +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace framework { -class InferShapeContext; -class VarDesc; -} // namespace framework -} // namespace paddle - -namespace paddle { -namespace operators { -class TanhPrimOp : public framework::OperatorBase { - public: - TanhPrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator tanh_p should not be executed directly")); - } -}; - -class TanhPrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of tanh_p op."); - AddOutput("Y", "(Tensor), The output tensor of tanh_p op."); - AddComment(R"DOC( -Autograd primitive tanh_p operator. -)DOC"); - } -}; - -class TanhPrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; - framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; - - framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); - - PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape()); - } -}; - -class TanhPrimOpVarTypeInference - : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_name = Input(ctx, "X")[0]; - auto y_name = Output(ctx, "Y")[0]; - SetType(ctx, y_name, GetType(ctx, x_name)); - SetDataType(ctx, y_name, GetDataType(ctx, x_name)); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(tanh_p, - paddle::operators::TanhPrimOp, - paddle::operators::TanhPrimOpMaker, - paddle::operators::TanhPrimOpShapeInference, - paddle::operators::TanhPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/transpose_p_op.cc b/paddle/fluid/operators/prim_ops/transpose_p_op.cc deleted file mode 100644 index cb76f81ef0901..0000000000000 --- a/paddle/fluid/operators/prim_ops/transpose_p_op.cc +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace framework { -class InferShapeContext; -class VarDesc; -} // namespace framework -} // namespace paddle - -namespace paddle { -namespace operators { -class TransposePrimOp : public framework::OperatorBase { - public: - TransposePrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator transpose_p should not be executed directly")); - } -}; - -class TransposePrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of transpose_p op."); - AddOutput("Y", "(Tensor), The output tensor of transpose_p op."); - AddAttr>("axis", - "(std::vector) Tanspose axis."); - AddComment(R"DOC( -Autograd primitive transpose_p operator. -)DOC"); - } -}; - -class TransposePrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; - framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; - framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); - auto x_shape = x_var->GetShape(); - auto axis = ctx->Attrs().Get>("axis"); - size_t x_rank = x_shape.size(); - size_t axis_size = axis.size(); - PADDLE_ENFORCE_EQ(x_rank, - axis_size, - platform::errors::InvalidArgument( - "The input tensor's dimension " - "should be equal to the axis's size. " - "But received input tensor's dimension is %d, " - "axis's size is %d", - x_rank, - axis_size)); - - std::vector count(axis_size, 0); - for (size_t i = 0; i < axis_size; i++) { - PADDLE_ENFORCE_GE(axis[i], - 0, - platform::errors::InvalidArgument( - "The axis should be greater than or equal to 0." - "But received %d of axis[%d]", - axis[i], - i)); - - PADDLE_ENFORCE_EQ( - axis[i] < static_cast(axis_size) && ++count[axis[i]] == 1, - true, - platform::errors::InvalidArgument( - "Each element of Attribute axis should " - "be a unique value range from 0 to (dims - 1), " - "where the dims is the axis's size, " - "unique value means this axis value can appear only once. " - "But received axis[%d] is %d, axis_size is %d, " - "count[axis[%d]] is %d", - i, - axis[i], - axis_size, - i, - count[axis[i]])); - } - std::vector y_shape(axis_size); - for (size_t i = 0; i < axis_size; i++) { - y_shape[i] = x_shape[axis[i]]; - } - PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(y_shape); - } -}; - -class TransposePrimOpVarTypeInference - : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_name = Input(ctx, "X")[0]; - auto y_name = Output(ctx, "Y")[0]; - SetType(ctx, y_name, GetType(ctx, x_name)); - SetDataType(ctx, y_name, GetDataType(ctx, x_name)); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(transpose_p, - paddle::operators::TransposePrimOp, - paddle::operators::TransposePrimOpMaker, - paddle::operators::TransposePrimOpShapeInference, - paddle::operators::TransposePrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/uniform_random_p_op.cc b/paddle/fluid/operators/prim_ops/uniform_random_p_op.cc deleted file mode 100644 index 3a06459d33798..0000000000000 --- a/paddle/fluid/operators/prim_ops/uniform_random_p_op.cc +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace framework { -class InferShapeContext; -class VarDesc; -} // namespace framework -} // namespace paddle - -namespace paddle { -namespace operators { -class UniformRandomPrimOp : public framework::OperatorBase { - public: - UniformRandomPrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : framework::OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator uniform_randrom_p should not be executed directly")); - } -}; - -class UniformRandomPrimOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddOutput("Out", "(Tensor), The output tensor of uniform_random_p op."); - AddAttr>("shape", "The shape of the output tensor") - .SetDefault({}); - AddAttr("min", "Minimum value of uniform_random_p. [default -1.0]."); - AddAttr("max", "Maximun value of uniform_random_p. [default 1.0]."); - AddAttr("seed", - "Random seed used for generating samples. " - "0 means use a seed generated by the system." - "Note that if seed is not 0, this operator will always " - "generate the same random numbers every time. "); - AddAttr("dtype", "Output tensor data type. "); - AddComment(R"DOC( -Autograd primitive uniform_random_p operator. -)DOC"); - } -}; - -class UniformRandomPrimOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Out")[0]; - auto shape = ctx->Attrs().Get>("shape"); - PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(shape); - } -}; - -class UniformRandomPrimOpVarTypeInference - : public framework::StaticGraphVarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto y_name = Output(ctx, "Out")[0]; - auto data_type = static_cast( - PADDLE_GET_CONST(int, ctx->GetAttr("dtype"))); - SetDataType(ctx, y_name, data_type); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(uniform_random_p, - paddle::operators::UniformRandomPrimOp, - paddle::operators::UniformRandomPrimOpMaker, - paddle::operators::UniformRandomPrimOpShapeInference, - paddle::operators::UniformRandomPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/unity_build_rule.cmake b/paddle/fluid/operators/prim_ops/unity_build_rule.cmake deleted file mode 100644 index 73340d33c1091..0000000000000 --- a/paddle/fluid/operators/prim_ops/unity_build_rule.cmake +++ /dev/null @@ -1,19 +0,0 @@ -register_unity_group( - cc - reshape_p_op.cc - broadcast_p_op.cc - transpose_p_op.cc - split_p_op.cc - concat_p_op.cc - slice_select_p_op.cc - slice_assign_p_op.cc - gather_p_op.cc - scatter_add_p_op.cc - add_p_op.cc - sub_p_op.cc - mul_p_op.cc - div_p_op.cc - sqrt_p_op.cc - tanh_p_op.cc - matmul_p_op.cc - fill_constant_p_op.cc) diff --git a/python/paddle/incubate/autograd/primops.py b/python/paddle/incubate/autograd/primops.py deleted file mode 100644 index 0633ba28ce57f..0000000000000 --- a/python/paddle/incubate/autograd/primops.py +++ /dev/null @@ -1,536 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import functools -import operator - -import paddle -from paddle.base.layer_helper import LayerHelper - -from .primreg import REGISTER_FN - - -def _simple_unop(helper): - optype = helper.layer_type - x, out = tuple(map(helper.kwargs.get, ('x', 'out'))) - if out is None: - out = helper.create_variable_for_type_inference(dtype=x.dtype) - - helper.append_op(type=optype, inputs={'X': x}, outputs={'Y': out}, attrs={}) - return out - - -def _simple_binop(helper): - optype = helper.layer_type - x, y, out = tuple(map(helper.kwargs.get, ('x', 'y', 'out'))) - if out is None: - out = helper.create_variable_for_type_inference(dtype=x.dtype) - - helper.append_op( - type=optype, inputs={'X': x, 'Y': y}, outputs={'Z': out}, attrs={} - ) - return out - - -def _manipulation_unop(helper): - optype = helper.layer_type - x, out = tuple(map(helper.kwargs.get, ('x', 'out'))) - - attrs = { - k: helper.kwargs[k] - for k in ('shape', 'axis', 'index') - if k in helper.kwargs - } - - if out is None: - out = helper.create_variable_for_type_inference(dtype=x.dtype) - - helper.append_op( - type=optype, inputs={'X': x}, outputs={'Y': out}, attrs=attrs - ) - return out - - -# Each primitive op is given a Python constructor for sake of convenience. -def fill_const(value, shape, dtype, out=None): - attrs = {'value': value, 'shape': shape, 'dtype': dtype} - helper = LayerHelper('fill_constant_p', **locals()) - if out is None: - out = helper.create_variable_for_type_inference(dtype) - helper.append_op(type=helper.layer_type, outputs={'Y': out}, attrs=attrs) - return out - - -def neg(x, out=None): - zero = fill_const(0.0, x.shape, x.dtype) - return sub(zero, x) - - -def set_value(x, y, axis, starts, ends, strides, out): - assert x is out, "x and out should be the same Tensor in set_value" - attrs = {'axes': axis, 'starts': starts, 'ends': ends, 'steps': strides} - helper = LayerHelper('set_value', **locals()) - helper.append_op( - type=helper.layer_type, - inputs={'Input': x, 'ValueTensor': y}, - outputs={'Out': out}, - attrs=attrs, - ) - return out - - -def mean(x, axis=None, keepdim=False): - axes = axis or tuple(range(0, len(x.shape))) - sum = reduce_sum(x, axis=axes, keepdim=keepdim) - norm = fill_const( - shape=sum.shape, - value=functools.reduce(operator.mul, [x.shape[axis] for axis in axes]), - dtype=sum.dtype, - ) - return div(sum, norm) - - -def ones(shape, dtype): - return fill_const(1, shape, dtype) - - -def zeros(shape, dtype): - return fill_const(0, shape, dtype) - - -def batch_norm( - x, - axis, - gamma, - beta, - run_mean, - run_var, - eps=1e-5, - momentum=0.9, - use_run_stat=False, - reserve_space=None, -): - """batch normalizer. - - Args: - x (Tensor): A tensor to be normalized. - axis (int): The features axis. - gamma (Tensor): The scale factor. - beta (float): The shift factor. - run_mean (Tensor): Running mean. - run_var (Tensor): Running variance. - eps (float, optional): A value added to the denominator for numerical - stability. Defaults to 1e-5. - momentum (float, optional): The value used for the running_mean and - running_var computation. Can be set to None for cumulative moving - average (i.e. simple average). Defaults to 0.9. - use_run_stat (bool, optional): Whether or not using running statistics. - Defaults to False. - """ - reduce_axes = tuple(i for i in range(len(x.shape)) if i != axis) - stats_shape = tuple( - 1 if i in reduce_axes else s for i, s in enumerate(x.shape) - ) - - batch_mean = zeros(run_mean.shape, run_mean.dtype) - batch_var = zeros(run_var.shape, run_var.dtype) - - if not use_run_stat: - batch_mean = mean(x, reduce_axes, keepdim=True) - batch_var = mean( - square(sub(x, broadcast(batch_mean, x.shape))), - reduce_axes, - keepdim=True, - ) - x_hat = div( - sub(x, broadcast(batch_mean, x.shape)), - sqrt( - add( - broadcast(batch_var, x.shape), - fill_const(eps, x.shape, batch_var.dtype), - ) - ), - ) - - momentum = fill_const(momentum, run_mean.shape, run_mean.dtype) - run_mean = add( - mul(momentum, run_mean), - mul( - sub(ones(run_mean.shape, run_mean.dtype), momentum), - reshape(batch_mean, run_mean.shape), - ), - ) - run_var = add( - mul(momentum, run_var), - mul( - sub(ones(run_var.shape, run_var.dtype), momentum), - reshape(batch_var, run_var.shape), - ), - ) - else: - x_hat = div( - sub(x, broadcast(reshape(run_mean, stats_shape), x.shape)), - sqrt( - add( - broadcast(reshape(run_var, stats_shape), x.shape), - fill_const(eps, x.shape, x.dtype), - ) - ), - ) - y = add( - mul(broadcast(reshape(gamma, stats_shape), x_hat.shape), x_hat), - broadcast(reshape(beta, stats_shape), x_hat.shape), - ) - - if reserve_space: - return run_mean, reserve_space, batch_mean, batch_var, run_var, y - else: - return run_mean, batch_mean, batch_var, run_var, y - - -def square(x): - return pow(x, fill_const(2.0, x.shape, x.dtype)) - - -@REGISTER_FN('add_p', 'X', 'Y', 'Z') -def add(x, y, out=None): - return _simple_binop(LayerHelper('add_p', **locals())) - - -@REGISTER_FN('sub_p', 'X', 'Y', 'Z') -def sub(x, y, out=None): - return _simple_binop(LayerHelper('sub_p', **locals())) - - -@REGISTER_FN('mul_p', 'X', 'Y', 'Z') -def mul(x, y, out=None): - return _simple_binop(LayerHelper('mul_p', **locals())) - - -@REGISTER_FN('div_p', 'X', 'Y', 'Z') -def div(x, y, out=None): - return _simple_binop(LayerHelper('div_p', **locals())) - - -@REGISTER_FN('sqrt_p', 'X', 'Y') -def sqrt(x, out=None): - return _simple_unop(LayerHelper('sqrt_p', **locals())) - - -@REGISTER_FN('tanh_p', 'X', 'Y') -def tanh(x, out=None): - return _simple_unop(LayerHelper('tanh_p', **locals())) - - -@REGISTER_FN('sin_p', 'X', 'Y') -def sin(x, out=None): - return _simple_unop(LayerHelper('sin_p', **locals())) - - -@REGISTER_FN('cos_p', 'X', 'Y') -def cos(x, out=None): - return _simple_unop(LayerHelper('cos_p', **locals())) - - -@REGISTER_FN('exp_p', 'X', 'Y') -def exp(x, out=None): - return _simple_unop(LayerHelper('exp_p', **locals())) - - -@REGISTER_FN('abs_p', 'X', 'Y') -def abs(x, out=None): - return _simple_unop(LayerHelper('abs_p', **locals())) - - -@REGISTER_FN('reshape_p', 'X', 'Y') -def reshape(x, shape, out=None): - return _manipulation_unop(LayerHelper('reshape_p', **locals())) - - -@REGISTER_FN('broadcast_p', 'X', 'Y') -def broadcast(x, shape, out=None): - return _manipulation_unop(LayerHelper('broadcast_p', **locals())) - - -@REGISTER_FN('transpose_p', 'X', 'Y') -def transpose(x, axis=None, out=None): - return _manipulation_unop(LayerHelper('transpose_p', **locals())) - - -@REGISTER_FN('split_p', 'X', 'YS') -def split(x, num_or_sections, axis=0, outs=None): - if isinstance(num_or_sections, (list, tuple)): - n = len(num_or_sections) - else: - if not isinstance(num_or_sections, int): - raise TypeError( - f'num_or_sections must be int, but got {type(num_or_sections)}.' - ) - n = num_or_sections - - attrs = {'num_or_sections': num_or_sections, 'axis': axis} - - helper = LayerHelper('split_p', **locals()) - if outs is None: - outs = [ - helper.create_variable_for_type_inference(dtype=x.dtype) - for i in range(n) - ] - helper.append_op( - type=helper.layer_type, - inputs={'X': x}, - outputs={'YS': outs}, - attrs=attrs, - ) - return outs - - -@REGISTER_FN('concat_p', 'XS', 'Y') -def concat(xs, axis=0, out=None): - if isinstance(xs, paddle.base.framework.Variable): - xs = [xs] - attrs = {'axis': axis} - helper = LayerHelper('concat_p', **locals()) - if out is None: - out = helper.create_variable_for_type_inference(dtype=xs[0].dtype) - helper.append_op( - type=helper.layer_type, - inputs={'XS': xs}, - outputs={'Y': out}, - attrs=attrs, - ) - return out - - -@REGISTER_FN('reduce_sum_p', 'X', 'Y') -def reduce_sum(x, axis=None, keepdim=False, out=None): - axes = axis or tuple(range(0, len(x.shape))) - axes = (axes,) if isinstance(axes, int) else axes - if not isinstance(axis, (tuple, list)): - raise TypeError(f'axis must be tuple or list, but got {type(axis)}') - if not isinstance(keepdim, bool): - raise TypeError(f'keepdim must be bool, but got {type(keepdim)}') - - attrs = {'axis': axis, 'keepdim': keepdim} - helper = LayerHelper('reduce_sum_p', **locals()) - if out is None: - out = helper.create_variable_for_type_inference(dtype=x.dtype) - - helper.append_op( - type=helper.layer_type, inputs={'X': x}, outputs={'Y': out}, attrs=attrs - ) - return out - - -@REGISTER_FN('matmul_p', 'X', 'Y', 'Z') -def matmul(x, y, out=None): - return _simple_binop(LayerHelper('matmul_p', **locals())) - - -@REGISTER_FN('slice_select_p', 'X', 'Y') -def slice_select(x, axis, starts, ends, strides, out=None): - if not isinstance(axis, (list, tuple)): - raise TypeError( - f'Argument type error. `axis` is supposed to be list or' - f' tuple but found {type(axis)}.' - ) - if not isinstance(starts, (list, tuple)): - raise TypeError( - f'Argument type error. `starts` is supposed to be list or' - f' tuple but found {type(starts)}.' - ) - if not isinstance(ends, (list, tuple)): - raise TypeError( - f'Argument type error. `ends` is supposed to be list or' - f' tuple but found {type(ends)}.' - ) - assert len(axis) == len(starts) == len(ends) == len(strides), ( - f'len(axis), len(starts), len(ends) and len(strides) should be equal, ' - f'but len(axis)={len(axis)}, len(starts)={len(starts)}, ' - f'len(ends)={len(ends)} and len(strides)={len(strides)}' - ) - - attrs = {'axis': axis, 'starts': starts, 'ends': ends, 'strides': strides} - helper = LayerHelper('slice_select_p', **locals()) - if out is None: - out = helper.create_variable_for_type_inference(dtype=x.dtype) - helper.append_op( - type=helper.layer_type, inputs={'X': x}, outputs={'Y': out}, attrs=attrs - ) - return out - - -@REGISTER_FN('slice_assign_p', 'X', 'Y', 'Z') -def slice_assign(x, y, axis, starts, ends, strides, out=None): - assert len(starts) == len(ends) == len(strides) == len(axis), ( - f'len(starts), len(ends), len(strides) and len(axis) should be equal, ' - f'but len(starts)={len(starts)}, len(ends)={len(ends)}, ' - f'len(strides)={len(strides)} and len(axis)={len(axis)}' - ) - assert len(y.shape) == len(x.shape), ( - f'len(y.shape) should be equal to len(x.shape), ' - f'but len(y.shape)={len(y.shape)} and len(x.shape)={len(x.shape)}.' - ) - - attrs = {'axis': axis, 'starts': starts, 'ends': ends, 'strides': strides} - helper = LayerHelper('slice_assign_p', **locals()) - if out is None: - out = helper.create_variable_for_type_inference(dtype=x.dtype) - helper.append_op( - type=helper.layer_type, - inputs={'X': x, 'Y': y}, - outputs={'Z': out}, - attrs=attrs, - ) - return out - - -@REGISTER_FN('gather_p', 'X', 'IndexTensor', 'Y') -def gather(x, indextensor, axis, out=None): - attrs = {'axis': axis} - helper = LayerHelper('gather_p', **locals()) - if out is None: - out = helper.create_variable_for_type_inference(dtype=x.dtype) - helper.append_op( - type=helper.layer_type, - inputs={'X': x, 'IndexTensor': indextensor}, - outputs={'Y': out}, - attrs=attrs, - ) - return out - - -@REGISTER_FN('scatter_add_p', 'X', 'Y', 'IndexTensor', 'Z') -def scatter_add(x, y, indextensor, axis, out=None): - assert len(x.shape) == len(y.shape), ( - f'len(x.shape) should be equal to len(y.shape), ' - f'but len(x.shape)={len(x.shape)} and len(y.shape)={len(y.shape)}.' - ) - assert ( - len(indextensor.shape) == 1 - ), f'len(indextensor.shape) must be equal to 1, but got {len(indextensor.shape)}.' - assert y.shape[axis] == indextensor.shape[0], ( - f'y.shape[axis] should be equal to indextensor.shape[0], ' - f'but y.shape[axis]={y.shape[axis]} and ' - f'indextensor.shape[0]={indextensor.shape[0]}.' - ) - attrs = {'axis': axis} - helper = LayerHelper('scatter_add_p', **locals()) - if out is None: - out = helper.create_variable_for_type_inference(dtype=x.dtype) - helper.append_op( - type=helper.layer_type, - inputs={'X': x, 'Y': y, 'IndexTensor': indextensor}, - outputs={'Z': out}, - attrs=attrs, - ) - return out - - -@REGISTER_FN('log_p', 'X', 'Y') -def log(x, out=None): - return _simple_unop(LayerHelper('log_p', **locals())) - - -@REGISTER_FN('select_p', 'Condition', 'X', 'Y', 'Z') -def select(cond, x, y, out=None): - if len(cond.shape) != len(x.shape): - raise ValueError( - f"len(cond.shape) should be equal to len(x.shape), but len(cond.shape)={len(cond.shape)} and len(x.shape)={len(x.shape)}." - ) - - if len(x.shape) != len(y.shape): - raise ValueError( - f"len(x.shape) should be equal to len(y.shape), but len(x.shape)={len(x.shape)} and len(y.shape)={len(y.shape)}." - ) - - helper = LayerHelper('select_p', **locals()) - if out is None: - out = helper.create_variable_for_type_inference(dtype=x.dtype) - helper.append_op( - type=helper.layer_type, - inputs={'Condition': cond, 'X': x, 'Y': y}, - outputs={'Z': out}, - ) - return out - - -@REGISTER_FN('eq_p', 'X', 'Y', 'Z') -def eq(x, y, out=None): - return _simple_binop(LayerHelper('eq_p', **locals())) - - -@REGISTER_FN('gt_p', 'X', 'Y', 'Z') -def gt(x, y, out=None): - return _simple_binop(LayerHelper('gt_p', **locals())) - - -@REGISTER_FN('ge_p', 'X', 'Y', 'Z') -def ge(x, y, out=None): - return _simple_binop(LayerHelper('ge_p', **locals())) - - -@REGISTER_FN('ne_p', 'X', 'Y', 'Z') -def ne(x, y, out=None): - return _simple_binop(LayerHelper('ne_p', **locals())) - - -@REGISTER_FN('pow_p', 'X', 'Y', 'Z') -def pow(x, y, out=None): - return _simple_binop(LayerHelper('pow_p', **locals())) - - -@REGISTER_FN('max_p', 'X', 'Y', 'Z') -def max(x, y, out=None): - return _simple_binop(LayerHelper('max_p', **locals())) - - -@REGISTER_FN('erf_p', 'X', 'Y') -def erf(x, out=None): - return _simple_unop(LayerHelper('erf_p', **locals())) - - -@REGISTER_FN('cast_p', 'X', 'Y') -def cast(x, dtype, out=None): - helper = LayerHelper('cast_p', **locals()) - if out is None: - out = helper.create_variable_for_type_inference(dtype) - helper.append_op( - type=helper.layer_type, - inputs={'X': x}, - outputs={'Y': out}, - attrs={'dtype': dtype}, - ) - return out - - -@REGISTER_FN('rsqrt_p', 'X', 'Y') -def rsqrt(x, out=None): - return _simple_unop(LayerHelper('rsqrt_p', **locals())) - - -@REGISTER_FN('uniform_random_p', 'Out') -def uniform_random(dtype, min_value, max_value, seed, shape=None, out=None): - attrs = { - 'shape': shape, - 'dtype': dtype, - 'min': min_value, - 'max': max_value, - 'seed': seed, - } - helper = LayerHelper('uniform_random_p', **locals()) - if out is None: - out = helper.create_variable_for_type_inference(dtype) - helper.append_op(type=helper.layer_type, outputs={'Out': out}, attrs=attrs) - return out diff --git a/python/paddle/incubate/autograd/primrules.py b/python/paddle/incubate/autograd/primrules.py index a0ff9d3471c1a..28792863325a3 100644 --- a/python/paddle/incubate/autograd/primrules.py +++ b/python/paddle/incubate/autograd/primrules.py @@ -11,62 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import functools -import math -import operator -import typing -import paddle - -from . import primops -from .primops import ( - add, - broadcast, - concat, - cos, - div, - eq, - erf, - exp, - fill_const, - gather, - ge, - gt, - log, - matmul, - mul, - ne, - neg, - reduce_sum, - reshape, - rsqrt, - scatter_add, - select, - set_value, - sin, - slice_assign, - slice_select, - split, - sqrt, - sub, - tanh, - transpose, - uniform_random, -) from .primreg import ( - REGISTER_JVP, - REGISTER_ORIG2PRIM, - REGISTER_PRIM2ORIG, - REGISTER_TRANSPOSE, lookup_fn, lookup_jvp, lookup_orig2prim, lookup_prim2orig, lookup_transpose, - op_position_inputs, - op_position_output, ) -from .utils import INT_DTYPE_2_STRING, get_output_var_list def _orig2prim(op, *args): @@ -125,1224 +77,3 @@ def linear_jvp(op, *args, **kwargs): slice p_norm """ - - -@REGISTER_ORIG2PRIM('elementwise_add') -def elementwise_add_orig2prim(op, x, y): - if x.shape != y.shape: - y = broadcast(y, shape=x.shape) - return add(x, y) - - -@REGISTER_ORIG2PRIM('elementwise_sub') -def elementwise_sub_orig2prim(op, x, y): - if x.shape != y.shape: - y = broadcast(y, shape=x.shape) - return sub(x, y) - - -@REGISTER_ORIG2PRIM('elementwise_mul') -def elementwise_mul_orig2prim(op, x, y): - if x.shape != y.shape: - y = broadcast(y, shape=x.shape) - return mul(x, y) - - -@REGISTER_ORIG2PRIM('elementwise_div') -def elementwise_div_orig2prim(op, x, y): - if x.shape != y.shape: - y = broadcast(y, shape=x.shape) - return primops.div(x, y) - - -@REGISTER_ORIG2PRIM('tanh') -def tanh_orig2prim(op, x): - return tanh(x) - - -@REGISTER_ORIG2PRIM('sin') -def sin_orig2prim(op, x): - return sin(x) - - -@REGISTER_ORIG2PRIM('cos') -def cos_orig2prim(op, x): - return cos(x) - - -@REGISTER_ORIG2PRIM('exp') -def exp_orig2prim(op, x): - return exp(x) - - -@REGISTER_ORIG2PRIM('erf') -def erf_orig2prim(op, x): - return erf(x) - - -@REGISTER_ORIG2PRIM('abs') -def abs_orig2prim(op, x): - return primops.abs(x) - - -@REGISTER_ORIG2PRIM('log') -def log_orig2prim(op, x): - return log(x) - - -@REGISTER_ORIG2PRIM('fill_zeros_like') -def fill_zeros_like_orig2prim(op, x): - return fill_const(value=0.0, shape=x.shape, dtype=x.dtype) - - -@REGISTER_ORIG2PRIM('fill_any_like') -def fill_any_like_orig2prim(op, x): - if op.attr('dtype') == -1: - return fill_const(value=op.attr('value'), shape=x.shape, dtype=x.dtype) - return fill_const( - value=op.attr('value'), - shape=x.shape, - dtype=paddle.dtype(op.attr('dtype')), - ) - - -@REGISTER_ORIG2PRIM('fill_constant') -def fill_const_orig2prim( - op, shape_tensor=None, shape_tensor_list=None, value_tensor=None -): - if shape_tensor or shape_tensor_list or value_tensor: - raise TypeError( - 'fill_const_orig2prim currently not support Tensor input of shape and value.' - ) - return fill_const( - value=op.attr('value'), - shape=op.attr('shape'), - dtype=paddle.dtype(op.attr('dtype')), - ) - - -@REGISTER_ORIG2PRIM('sum') -def sum_orig2prim(op, xs): - x0 = xs[0] - for x in xs[1:]: - x0 = add(x0, x) - return x0 - - -@REGISTER_ORIG2PRIM('index_select') -def index_select_orig2prim(op, index_t, x): - return gather(x, indextensor=index_t, axis=op.attr('dim')) - - -@REGISTER_ORIG2PRIM('scale') -def scale_orig2prim(op, scale_t, x): - if scale_t is None: - scale_t = fill_const( - shape=x.shape, dtype=x.dtype, value=op.attr('scale') - ) - bias_t = fill_const(shape=x.shape, dtype=x.dtype, value=op.attr('bias')) - if op.attr('bias_after_scale'): - return add(mul(x, scale_t), bias_t) - else: - return mul(add(x, bias_t), scale_t) - - -@REGISTER_ORIG2PRIM('assign') -def assign_orig2prim(op, x): - zero_t = fill_const(shape=x.shape, dtype=x.dtype, value=0.0) - return add(x, zero_t) - - -@REGISTER_ORIG2PRIM('sqrt') -def sqrt_orig2prim(op, x): - return sqrt(x) - - -@REGISTER_ORIG2PRIM('rsqrt') -def rsqrt_orig2prim(op, x): - return rsqrt(x) - - -@REGISTER_ORIG2PRIM('matmul_v2') -def matmul_v2_orig2prim(op, x, y): - def trans(shape): - ret = list(range(len(shape))) - ret[-1], ret[-2] = ret[-2], ret[-1] - return ret - - assert ( - len(x.shape) < 4 and len(y.shape) < 4 - ), 'Do not support multi batchsize dimensions currently.' - - if len(x.shape) == 1: - x = broadcast(x, shape=[1, x.shape[0]]) - if len(y.shape) == 1: - y = broadcast(y, shape=[y.shape[0], 1]) - if op.attr('trans_x'): - x = transpose(x, axis=trans(x.shape)) - if op.attr('trans_y'): - y = transpose(y, axis=trans(y.shape)) - return matmul(x, y) - - -# NOTE(lml): The second output of reshape2 Xshape, which is only used in reshape2_grad, is meaningless in new autograd mechanism, thus we use a zero tensor instead. -@REGISTER_ORIG2PRIM('reshape2') -def reshape2_orig2prim(op, shape_t, shape_tl, x): - assert ( - shape_t is None - ), 'Can not lower reshape2 into prim ops with shapetensor.' - assert ( - shape_tl is None - ), 'Can not lower reshape2 into prim ops with shapetensorlist.' - y, xshape = get_output_var_list(op) - return reshape(x, shape=y.shape), fill_const( - shape=xshape.shape, dtype=xshape.dtype, value=0.0 - ) - - -@REGISTER_ORIG2PRIM('concat') -def concat_orig2prim(op, axis_t, xs): - assert axis_t is None, 'Can not lower concat into prim ops with axistensor.' - return concat(xs, axis=op.attr('axis')) - - -@REGISTER_ORIG2PRIM('slice') -def slice_orig2prim(op, ends_t, ends_tl, x, starts_t, starts_tl): - assert ( - starts_t is None - ), 'Can not lower concat into prim ops with startstensor.' - assert ends_t is None, 'Can not lower concat into prim ops with endstensor.' - assert ( - starts_tl is None - ), 'Can not lower concat into prim ops with startstensorlist.' - assert ( - ends_tl is None - ), 'Can not lower concat into prim ops with endstensorlist.' - starts = op.attr('starts') - ends = op.attr('ends') - strides = [1 for _ in starts] - axis = op.attr('axes') - y = slice_select(x, starts=starts, ends=ends, strides=strides, axis=axis) - if op.attr('decrease_axis'): - y = reshape(y, shape=get_output_var_list(op)[0].shape) - return y - - -@REGISTER_ORIG2PRIM('sigmoid') -def sigmoid_orig2prim(op, x): - return div( - fill_const(value=1.0, shape=x.shape, dtype=x.dtype), - (add(fill_const(value=1.0, shape=x.shape, dtype=x.dtype), exp(neg(x)))), - ) - - -@REGISTER_ORIG2PRIM('p_norm') -def p_norm_orig2prim(op, x): - def num_el(shape): - n = 1 - for s in shape: - n = n * s - return n - - assert op.attr( - 'asvector' - ), 'Only support lower pnorm when asvector=True currently' - if len(x.shape) > 1: - x = reshape(x, shape=[num_el(x.shape)]) - - if abs(op.attr('porder') - 2.0) < 1e-5: - return sqrt(reduce_sum(mul(x, x), axis=[0])) - elif abs(op.attr('porder') - 1.0) < 1e-5: - return reduce_sum(primops.abs(x), axis=[0]) - else: - raise RuntimeError('Only support lower l2/l1 norm currently') - - -@REGISTER_ORIG2PRIM('cast') -def cast_orig2prim(op, x): - return primops.cast(x, paddle.dtype(op.attr('out_dtype'))) - - -# TODO: support broadcast -@REGISTER_ORIG2PRIM('where') -def select_orig2prim(op, condition, x, y): - return select(condition, x, y) - - -@REGISTER_ORIG2PRIM('equal') -def equal_orig2prim(op, x, y): - if x.shape != y.shape: - y = broadcast(y, shape=x.shape) - return eq(x, y) - - -@REGISTER_ORIG2PRIM('not_equal') -def ne_orig2prim(op, x, y): - if x.shape != y.shape: - y = broadcast(y, shape=x.shape) - return ne(x, y) - - -@REGISTER_ORIG2PRIM('greater_than') -def gt_orig2prim(op, x, y): - if x.shape != y.shape: - y = broadcast(y, shape=x.shape) - return gt(x, y) - - -@REGISTER_ORIG2PRIM('greater_equal') -def ge_orig2prim(op, x, y): - if x.shape != y.shape: - y = broadcast(y, shape=x.shape) - return ge(x, y) - - -# paddle.pow API use "elementwise_pow" operator when y is a Tensor. -@REGISTER_ORIG2PRIM('elementwise_pow') -def elementwise_pow_orig2prim(op, x, y): - if x.shape != y.shape: - y = broadcast(y, shape=x.shape) - z = primops.pow(x, y) - return z - - -# paddle.pow API use "pow" operator when y is a scalar. -@REGISTER_ORIG2PRIM('pow') -def pow_orig2prim(op, x, y): - # x is factorTensor defined in paddle phi op. Currently it is None. - return primops.pow(y, fill_const(op.attr('factor'), y.shape, y.dtype)) - - -@REGISTER_ORIG2PRIM('square') -def square_orig2prim(op, x): - return primops.square(x) - - -@REGISTER_ORIG2PRIM('elementwise_max') -def elementwise_max_orig2prim(op, x, y): - if x.shape != y.shape: - y = broadcast(y, shape=x.shape) - return primops.max(x, y) - - -@REGISTER_ORIG2PRIM('gelu') -def gelu_orig2prim(op, x): - if op.attr('approximate'): - cdf = mul( - fill_const(0.5, x.shape, x.dtype), - add( - fill_const(1.0, x.shape, x.dtype), - tanh( - mul( - fill_const(math.sqrt(2 / math.pi), x.shape, x.dtype), - add( - x, - mul( - fill_const(0.044715, x.shape, x.dtype), - primops.pow( - x, fill_const(3.0, x.shape, x.dtype) - ), - ), - ), - ) - ), - ), - ) - return mul(x, cdf) - else: - return mul( - mul(fill_const(0.5, x.shape, x.dtype), x), - add( - fill_const(1.0, x.shape, x.dtype), - erf(mul(x, fill_const(1 / math.sqrt(2.0), x.shape, x.dtype))), - ), - ) - - -@REGISTER_ORIG2PRIM('uniform_random') -def uniform_random_orig2prim(op, shape_t, shape_tl): - if shape_t or shape_tl: - raise TypeError( - 'uniform_random_orig2prim currently not support ShapeTensor input or ShapeTensorList input.' - ) - min_value = op.attr('min') - max_value = op.attr('max') - seed = op.attr('seed') - dtype = paddle.dtype(op.attr('dtype')) - shape = op.attr('shape') - return uniform_random(dtype, min_value, max_value, seed, shape=shape) - - -@REGISTER_ORIG2PRIM('reduce_sum') -def reduce_sum_orig2prim(op, x): - axes = ( - tuple(range(0, len(x.shape))) - if op.attr('reduce_all') - else op.attr('dim') - ) - return reduce_sum(x, axis=axes, keepdim=op.attr('keep_dim')) - - -@REGISTER_ORIG2PRIM('reduce_mean') -def reduce_mean_orig2prim(op, x): - axes = ( - tuple(range(0, len(x.shape))) - if op.attr('reduce_all') - else op.attr('dim') - ) - return primops.mean(x, axes, op.attr('keep_dim')) - - -@REGISTER_ORIG2PRIM('batch_norm') -def batch_norm_orig2prim( - op, bias, run_mean, momentum_tensor, scale, run_var, x -): - momentum = op.attr('momentum') - eps = op.attr('epsilon') - is_test = op.attr('is_test') - data_layout = op.attr('data_layout') - use_global_stats = op.attr('use_global_stats') - trainable_statistics = op.attr('trainable_statistics') - reserve_space = ( - None if len(op.output_names) == 5 else get_output_var_list(op)[1] - ) - - feature_axis = ( - 1 if data_layout in ('NC', 'NCL', 'NCHW', 'NCHWD') else len(x.shape) - 1 - ) - use_run_stat = (is_test and (not trainable_statistics)) or use_global_stats - - return primops.batch_norm( - x, - feature_axis, - scale, - bias, - run_mean, - run_var, - eps=eps, - momentum=momentum, - use_run_stat=use_run_stat, - reserve_space=reserve_space, - ) - - -@REGISTER_ORIG2PRIM('size') -def size_orig2prim(op, x): - return fill_const(functools.reduce(operator.mul, x.shape), (), paddle.int64) - - -# Register prim2orig lower rules -@REGISTER_PRIM2ORIG('add_p') -def add_prim2orig(op, x, y): - return paddle.add(x, y) - - -@REGISTER_PRIM2ORIG('sub_p') -def sub_prim2orig(op, x, y): - return paddle.subtract(x, y) - - -@REGISTER_PRIM2ORIG('rsqrt_p') -def rsqrt_prim2orig(op, x): - return paddle.rsqrt(x) - - -@REGISTER_PRIM2ORIG('mul_p') -def mul_prim2orig(op, x, y): - return paddle.multiply(x, y) - - -@REGISTER_PRIM2ORIG('div_p') -def div_prim2orig(op, x, y): - return paddle.divide(x, y) - - -@REGISTER_PRIM2ORIG('sqrt_p') -def sqrt_prim2orig(op, x): - return paddle.sqrt(x) - - -@REGISTER_PRIM2ORIG('tanh_p') -def tanh_prim2orig(op, x): - return paddle.tanh(x) - - -@REGISTER_PRIM2ORIG('sin_p') -def sin_prim2orig(op, x): - return paddle.sin(x) - - -@REGISTER_PRIM2ORIG('cos_p') -def cos_prim2orig(op, x): - return paddle.cos(x) - - -@REGISTER_PRIM2ORIG('exp_p') -def exp_prim2orig(op, x): - return paddle.exp(x) - - -@REGISTER_PRIM2ORIG('erf_p') -def erf_prim2orig(op, x): - return paddle.erf(x) - - -@REGISTER_PRIM2ORIG('abs_p') -def abs_prim2orig(op, x): - return paddle.abs(x) - - -@REGISTER_PRIM2ORIG('log_p') -def log_prim2orig(op, x): - return paddle.log(x) - - -@REGISTER_PRIM2ORIG('reshape_p') -def reshape_prim2orig(op, x): - return paddle.reshape(x, shape=op.attr('shape')) - - -@REGISTER_PRIM2ORIG('broadcast_p') -def broadcast_prim2orig(op, x): - return paddle.broadcast_to(x, shape=op.attr('shape')) - - -@REGISTER_PRIM2ORIG('transpose_p') -def transpose_prim2orig(op, x): - return paddle.transpose(x, perm=op.attr('axis')) - - -@REGISTER_PRIM2ORIG('split_p') -def split_prim2orig(op, x): - num_or_sections = op.attr('num_or_sections') - if len(num_or_sections) == 1: - num_or_sections = num_or_sections[0] - return paddle.split( - x, num_or_sections=num_or_sections, axis=op.attr('axis') - ) - - -@REGISTER_PRIM2ORIG('concat_p') -def concat_prim2orig(op, xs): - return paddle.concat(xs, axis=op.attr('axis')) - - -@REGISTER_PRIM2ORIG('reduce_sum_p') -def reduce_prim2orig(op, x): - return paddle.sum(x, axis=op.attr('axis'), keepdim=op.attr('keepdim')) - - -@REGISTER_PRIM2ORIG('matmul_p') -def matmul_prim2orig(op, x, y): - return paddle.matmul(x, y) - - -@REGISTER_PRIM2ORIG('slice_select_p') -def slice_select_prim2orig(op, x): - return paddle.strided_slice( - x, - axes=op.attr('axis'), - starts=op.attr('starts'), - ends=op.attr('ends'), - strides=op.attr('strides'), - ) - - -@REGISTER_PRIM2ORIG('slice_assign_p') -def slice_assign_prim2orig(op, x, y): - x_copy = paddle.assign(x) - return set_value( - x_copy, - y, - axis=op.attr('axis'), - starts=op.attr('starts'), - ends=op.attr('ends'), - strides=op.attr('strides'), - out=x_copy, - ) - - -@REGISTER_PRIM2ORIG('gather_p') -def gather_prim2orig(op, index_t, x): - return paddle.gather(x, index_t, axis=op.attr('axis')) - - -@REGISTER_PRIM2ORIG('scatter_add_p') -def scatter_add_prim2orig(op, index_t, x, y): - assert op.attr('axis') == 0, 'Only support axis==0 currently' - zeros = paddle.zeros_like(x=x, dtype=x.dtype) - tmp = paddle.scatter(x=zeros, index=index_t, updates=y, overwrite=False) - return paddle.add(x, tmp) - - -@REGISTER_PRIM2ORIG('fill_constant_p') -def fill_constant_prim2orig(op): - return paddle.full( - shape=op.attr('shape'), - fill_value=op.attr('value'), - dtype=INT_DTYPE_2_STRING[op.attr('dtype')], - ) - - -@REGISTER_PRIM2ORIG('bernoulli_p') -def bernoulli_prim2orig(op): - t = paddle.full( - shape=op.attr('shape'), - fill_value=op.attr('p'), - dtype=INT_DTYPE_2_STRING[op.attr('dtype')], - ) - return paddle.bernoulli(t) - - -@REGISTER_PRIM2ORIG('uniform_random_p') -def uniform_random_prim2orig(op): - return paddle.uniform( - shape=op.attr('shape'), - dtype=INT_DTYPE_2_STRING[op.attr('dtype')], - min=op.attr('min'), - max=op.attr('max'), - seed=op.attr('seed'), - ) - - -@REGISTER_PRIM2ORIG('select_p') -def select_prim2orig(op, condition, x, y): - return paddle.where(condition, x, y) - - -@REGISTER_PRIM2ORIG('eq_p') -def eq_prim2orig(op, x, y): - return paddle.equal(x, y) - - -@REGISTER_PRIM2ORIG('gt_p') -def gt_prim2orig(op, x, y): - return paddle.greater_than(x, y) - - -@REGISTER_PRIM2ORIG('ge_p') -def ge_prim2orig(op, x, y): - return paddle.greater_equal(x, y) - - -@REGISTER_PRIM2ORIG('ne_p') -def ne_prim2orig(op, x, y): - return paddle.not_equal(x, y) - - -@REGISTER_PRIM2ORIG('pow_p') -def pow_prim2orig(op, x, y): - return paddle.pow(x, y) - - -@REGISTER_PRIM2ORIG('max_p') -def max_prim2orig(op, x, y): - return paddle.maximum(x, y) - - -@REGISTER_PRIM2ORIG('cast_p') -def cast_prim2orig(op, x): - return paddle.cast(x, paddle.dtype(op.attr('dtype'))) - - -# Register linearize rules -@REGISTER_JVP('add_p') -def add_jvp(op, x_dot, y_dot): - if x_dot is None: - return y_dot - elif y_dot is None: - return x_dot - else: - return linear_jvp(op, x_dot, y_dot) - - -@REGISTER_JVP('sub_p') -def sub_jvp(op, x_dot, y_dot): - if x_dot is None: - return neg(y_dot) - elif y_dot is None: - return x_dot - else: - return linear_jvp(op, x_dot, y_dot) - - -@REGISTER_JVP('mul_p') -def mul_jvp(op, x_dot, y_dot): - if x_dot is None and y_dot is None: - return None - x, y = op_position_inputs(op) - if x_dot is None: - return mul(x, y_dot) - elif y_dot is None: - return mul(x_dot, y) - else: - t1, t2 = mul(x_dot, y), mul(x, y_dot) - z_dot = add(t1, t2) - return z_dot - - -@REGISTER_JVP('div_p') -def div_jvp(op, x_dot, y_dot): - if x_dot is None and y_dot is None: - return None - x, y = op_position_inputs(op) - if y_dot is None: - return div(x_dot, y) - elif x_dot is None: - return neg(div(mul(x, y_dot), mul(y, y))) - else: - t1 = div(x_dot, y) - t2 = div(mul(x, y_dot), mul(y, y)) - return sub(t1, t2) - - -@REGISTER_JVP('sqrt_p') -def sqrt_jvp(op, x_dot): - if x_dot is None: - return None - y = op_position_output(op) - c2 = fill_const(value=2.0, shape=y.shape, dtype=y.dtype) - y_dot = div(x_dot, mul(c2, y)) - return y_dot - - -@REGISTER_JVP('tanh_p') -def tanh_jvp(op, x_dot): - if x_dot is None: - return None - y = op_position_output(op) - c1 = fill_const(value=1.0, shape=y.shape, dtype=y.dtype) - y_dot = mul(x_dot, sub(c1, mul(y, y))) - return y_dot - - -@REGISTER_JVP('sin_p') -def sin_jvp(op, x_dot): - if x_dot is None: - return None - (x,) = op_position_inputs(op) - return mul(x_dot, cos(x)) - - -@REGISTER_JVP('cos_p') -def cos_jvp(op, x_dot): - if x_dot is None: - return None - (x,) = op_position_inputs(op) - return mul(x_dot, neg(sin(x))) - - -@REGISTER_JVP('exp_p') -def exp_jvp(op, x_dot): - if x_dot is None: - return None - y = op_position_output(op) - return mul(x_dot, y) - - -@REGISTER_JVP('erf_p') -def erf_jvp(op, x_dot): - if x_dot is None: - return None - (x,) = op_position_inputs(op) - return mul( - fill_const(2.0 / math.sqrt(math.pi), x.shape, x.dtype), - mul(x_dot, exp(neg(primops.pow(x, fill_const(2.0, x.shape, x.dtype))))), - ) - - -@REGISTER_JVP('abs_p') -def abs_jvp(op, x_dot): - if x_dot is None: - return None - (x,) = op_position_inputs(op) - return select(ge(x, fill_const(0.0, x.shape, x.dtype)), x_dot, neg(x_dot)) - - -@REGISTER_JVP('log_p') -def log_jvp(op, x_dot): - if x_dot is None: - return None - (x,) = op_position_inputs(op) - return div(x_dot, x) - - -@REGISTER_JVP('reshape_p') -def reshape_jvp(op, x_dot): - if x_dot is None: - return None - shape = op.attr('shape') - return linear_jvp(op, x_dot, shape=shape) - - -@REGISTER_JVP('broadcast_p') -def broadcast_jvp(op, x_dot): - if x_dot is None: - return None - shape = op.attr('shape') - return linear_jvp(op, x_dot, shape=shape) - - -@REGISTER_JVP('transpose_p') -def transpose_jvp(op, x_dot): - if x_dot is None: - return None - axis = op.attr('axis') - return linear_jvp(op, x_dot, axis=axis) - - -@REGISTER_JVP('split_p') -def split_jvp(op, x_dot): - if x_dot is None: - return None - num_or_sections = op.attr('num_or_sections') - axis = op.attr('axis') - return linear_jvp(op, x_dot, num_or_sections=num_or_sections, axis=axis) - - -@REGISTER_JVP('concat_p') -def concat_jvp(op, xs_dot): - if xs_dot is None: - return None - axis = op.attr('axis') - return linear_jvp(op, xs_dot, axis=axis) - - -@REGISTER_JVP('reduce_sum_p') -def reduce_sum_jvp(op, x_dot): - if x_dot is None: - return None - axis = op.attr('axis') - keepdim = op.attr('keepdim') - return linear_jvp(op, x_dot, axis=axis, keepdim=keepdim) - - -@REGISTER_JVP('matmul_p') -def matmul_jvp(op, x_dot, y_dot): - if x_dot is None and y_dot is None: - return None - x, y = op_position_inputs(op) - if x_dot is None: - return matmul(x, y_dot) - elif y_dot is None: - return matmul(x_dot, y) - else: - t1 = matmul(x, y_dot) - t2 = matmul(x_dot, y) - return add(t1, t2) - - -@REGISTER_JVP('slice_select_p') -def slice_select_jvp(op, x_dot): - if x_dot is None: - return x_dot - axis = op.attr('axis') - starts = op.attr('starts') - ends = op.attr('ends') - strides = op.attr('strides') - return linear_jvp( - op, x_dot, axis=axis, starts=starts, ends=ends, strides=strides - ) - - -@REGISTER_JVP('slice_assign_p') -def slice_assign_jvp(op, x_dot, y_dot): - x, y = op_position_inputs(op) - assert ( - x_dot is not None or y_dot is not None - ), "x_dot and y_dot can't be None at the same time. " - axis = op.attr('axis') - starts = op.attr('starts') - ends = op.attr('ends') - strides = op.attr('strides') - if x_dot is None: - return linear_jvp( - op, - fill_const(value=0.0, shape=x.shape, dtype=x.dtype), - y_dot, - axis=axis, - starts=starts, - ends=ends, - strides=strides, - ) - elif y_dot is None: - return linear_jvp( - op, - x_dot, - fill_const(value=0.0, shape=y.shape, dtype=y.dtype), - axis=axis, - starts=starts, - ends=ends, - strides=strides, - ) - return add( - linear_jvp( - op, - fill_const(value=0.0, shape=x.shape, dtype=x.dtype), - y_dot, - axis=axis, - starts=starts, - ends=ends, - strides=strides, - ), - linear_jvp( - op, - x_dot, - fill_const(value=0.0, shape=y.shape, dtype=y.dtype), - axis=axis, - starts=starts, - ends=ends, - strides=strides, - ), - ) - - -@REGISTER_JVP('gather_p') -def gather_jvp(op, x_dot, indextensor): - if x_dot is None: - return None - _, indextensor = op_position_inputs(op) - axis = op.attr('axis') - return linear_jvp(op, x_dot, indextensor, axis=axis) - - -@REGISTER_JVP('scatter_add_p') -def scatter_add_jvp(op, x_dot, y_dot): - if x_dot is None: - return None - _, _, indextensor = op_position_inputs(op) - axis = op.attr('axis') - return linear_jvp(op, x_dot, y_dot, indextensor, axis=axis) - - -@REGISTER_JVP('select_p') -def select_jvp(op, cond_dot, x_dot, y_dot): - if x_dot is None and y_dot is None: - return None - - cond, x, y = op_position_inputs(op) - if x_dot is None: - x_dot = fill_const(value=0.0, shape=y.shape, dtype=y.dtype) - if y_dot is None: - y_dot = fill_const(value=0.0, shape=y.shape, dtype=y.dtype) - return select(cond, x_dot, y_dot) - - -@REGISTER_JVP('eq_p') -def eq_jvp(op, x_dot, y_dot): - if x_dot is None and y_dot is None: - return None - x, _ = op_position_inputs(op) - z_dot = fill_const(value=0.0, shape=x.shape, dtype=x.dtype) - return z_dot - - -@REGISTER_JVP('gt_p') -def gt_jvp(op, x_dot, y_dot): - if x_dot is None and y_dot is None: - return None - x, _ = op_position_inputs(op) - z_dot = fill_const(value=0.0, shape=x.shape, dtype=x.dtype) - return z_dot - - -@REGISTER_JVP('ge_p') -def ge_jvp(op, x_dot, y_dot): - if x_dot is None and y_dot is None: - return None - x, _ = op_position_inputs(op) - z_dot = fill_const(value=0.0, shape=x.shape, dtype=x.dtype) - return z_dot - - -@REGISTER_JVP('ne_p') -def ne_jvp(op, x_dot, y_dot): - if x_dot is None and y_dot is None: - return None - x, _ = op_position_inputs(op) - z_dot = fill_const(value=0.0, shape=x.shape, dtype=x.dtype) - return z_dot - - -@REGISTER_JVP('pow_p') -def pow_jvp(op, x_dot, y_dot): - def _compute_t1(x, y): - zero_y = fill_const(value=0.0, shape=y.shape, dtype=y.dtype) - one_y = fill_const(value=1.0, shape=y.shape, dtype=y.dtype) - - cond = eq(y, zero_y) - new_y = select(cond, one_y, sub(y, one_y)) - t1 = mul(x_dot, mul(y, primops.pow(x, new_y))) - return t1 - - if x_dot is None and y_dot is None: - return None - x, y = op_position_inputs(op) - z = op_position_output(op) - - if y_dot is None: - return _compute_t1(x, y) - elif x_dot is None: - return mul(y_dot, mul(log(x), z)) - else: - t1, t2 = _compute_t1(x, y), mul(y_dot, mul(log(x), z)) - z_dot = add(t1, t2) - return z_dot - - -@REGISTER_JVP('max_p') -def max_jvp(op, x_dot, y_dot): - if x_dot is None and y_dot is None: - return None - - x, y = op_position_inputs(op) - z = op_position_output(op) - z_zeros = fill_const(value=0.0, shape=z.shape, dtype=z.dtype) - - # To make the grad of max_p consistent with paddle.maximum when x==y, - # we just let z_dot = y_dot when compute z_dot to y and x==y, - # instead of using balance_eq like Jax. - if y_dot is None: - return select(eq(y, z), z_zeros, x_dot) - elif x_dot is None: - return select(eq(y, z), y_dot, z_zeros) - else: - return select(eq(y, z), y_dot, x_dot) - - -@REGISTER_JVP('cast_p') -def cast_jvp(op, x_dot): - y = op_position_output(op) - return primops.cast(x_dot, y.dtype) - - -@REGISTER_JVP('rsqrt_p') -def rsqrt_jvp(op, x_dot): - if x_dot is None: - return None - y = op_position_output(op) - x = op_position_inputs(op) - c2 = fill_const(value=-2.0, shape=y.shape, dtype=y.dtype) - y_dot = mul(x_dot, div(div(y, x), c2)) - return y_dot - - -# Register transpose rules - - -@REGISTER_TRANSPOSE('add_p') -def add_transpose(op, check_dot, z_bar): - x, y = op_position_inputs(op) - assert check_dot(x) or check_dot(y), ( - f'(check_dot(x) or check_dot(y)) must be True, ' - f'but check_dot(x)={check_dot(x)} and check_dot(y)={check_dot(y)}.' - ) - x_bar = z_bar if check_dot(x) else None - y_bar = z_bar if check_dot(y) else None - return x_bar, y_bar - - -@REGISTER_TRANSPOSE('sub_p') -def sub_transpose(op, check_dot, z_bar): - x, y = op_position_inputs(op) - assert check_dot(x) or check_dot(y), ( - f'(check_dot(x) or check_dot(y)) must be True, ' - f'but check_dot(x)={check_dot(x)} and check_dot(y)={check_dot(y)}.' - ) - x_bar = z_bar if check_dot(x) else None - y_bar = neg(z_bar) if check_dot(y) else None - return x_bar, y_bar - - -@REGISTER_TRANSPOSE('mul_p') -def mul_transpose(op, check_dot, z_bar): - x, y = op_position_inputs(op) - assert check_dot(x) ^ check_dot(y), ( - f'(check_dot(x) ^ check_dot(y)) must be True, ' - f'but check_dot(x)={check_dot(x)} and check_dot(y)={check_dot(y)}.' - ) - if check_dot(x): - return mul(z_bar, y), None - else: - return None, mul(x, z_bar) - - -@REGISTER_TRANSPOSE('div_p') -def div_transpose(op, check_dot, z_bar): - x, y = op_position_inputs(op) - assert not check_dot(y), 'check_dot(y) must be False' - x_bar = div(z_bar, y) if check_dot(x) else None - return x_bar, None - - -@REGISTER_TRANSPOSE('reshape_p') -def reshape_transpose(op, check_dot, y_bar): - (x,) = op_position_inputs(op) - assert check_dot(x), 'check_dot(x) must be True' - return reshape(y_bar, shape=x.shape) - - -@REGISTER_TRANSPOSE('broadcast_p') -def broadcast_transpose(op, check_dot, y_bar): - (x,) = op_position_inputs(op) - assert check_dot(x), 'check_dot(x) must be True' - bat = len(y_bar.shape) - len(x.shape) - axis = list(range(bat)) - keepdim = [(bat + i) for i, s in enumerate(x.shape) if s == 1] - axis += keepdim - # TODO: Change it. keepdim boolean - out = reduce_sum(y_bar, axis=axis, keepdim=False) - return reshape(out, x.shape) - - -@REGISTER_TRANSPOSE('transpose_p') -def transpose_transpose(op, check_dot, y_bar): - (x,) = op_position_inputs(op) - assert check_dot(x), 'check_dot(x) must be True' - axis = op.attr('axis') - reordered = sorted((k, i) for i, k in enumerate(axis)) - axis = [i for k, i in reordered] - return transpose(y_bar, axis=axis) - - -@REGISTER_TRANSPOSE('split_p') -def split_transpose(op, check_dot, ys_bar): - (x,) = op_position_inputs(op) - assert check_dot(x), 'check_dot(x) must be True' - return concat(ys_bar, axis=op.attr('axis')) - - -@REGISTER_TRANSPOSE('concat_p') -def concat_transpose(op, check_dot, y_bar): - (xs,) = op_position_inputs(op) - if not isinstance(xs, typing.Sequence): - xs = [xs] - for x in xs: - assert check_dot(x), 'check_dot(x) must be True' - axis = op.attr('axis') - sections = [x.shape[axis] for x in xs] - if len(sections) == 1: - return y_bar - return split(y_bar, num_or_sections=sections, axis=axis) - - -@REGISTER_TRANSPOSE('reduce_sum_p') -def reduce_sum_transpose(op, check_dot, y_bar): - (x,) = op_position_inputs(op) - assert check_dot(x), 'check_dot(x) must be True' - axes = op.attr('axis') - shape = tuple(1 if i in axes else size for i, size in enumerate(x.shape)) - t = reshape(y_bar, shape=shape) - return broadcast(t, shape=x.shape) - - -@REGISTER_TRANSPOSE('matmul_p') -def matmul_transpose(op, check_dot, z_bar): - x, y = op_position_inputs(op) - assert check_dot(x) ^ check_dot(y), ( - f'(check_dot(x) ^ check_dot(y)) must be True, ' - f'but check_dot(x)={check_dot(x)} and check_dot(y)={check_dot(y)}.' - ) - # TODO: replace it. this is hacky - axis = [1, 0] if len(x.shape) == 2 else [0, 2, 1] - if check_dot(x): - return matmul(z_bar, transpose(y, axis=axis)), None - else: - return None, matmul(transpose(x, axis=axis), z_bar) - - -@REGISTER_TRANSPOSE('slice_select_p') -def slice_select_transpose(op, check_dot, y_bar): - (x,) = op_position_inputs(op) - assert check_dot(x), 'check_dot(x) must be True' - zeros = fill_const(value=0.0, shape=x.shape, dtype=x.dtype) - axis = op.attr('axis') - starts = op.attr('starts') - ends = op.attr('ends') - strides = op.attr('strides') - return slice_assign( - zeros, y_bar, axis=axis, starts=starts, ends=ends, strides=strides - ) - - -@REGISTER_TRANSPOSE('slice_assign_p') -def slice_assign_transpose(op, check_dot, z_bar): - x, y = op_position_inputs(op) - assert check_dot(x) ^ check_dot(y), ( - f'(check_dot(x) ^ check_dot(y)) must be True, ' - f'but check_dot(x)={check_dot(x)} and check_dot(y)={check_dot(y)}.' - ) - zeros = fill_const(value=0.0, shape=y.shape, dtype=y.dtype) - axis = op.attr('axis') - starts = op.attr('starts') - ends = op.attr('ends') - strides = op.attr('strides') - if check_dot(x): - return ( - slice_assign( - z_bar, - zeros, - axis=axis, - starts=starts, - ends=ends, - strides=strides, - ), - None, - ) - return None, slice_select( - z_bar, axis=axis, starts=starts, ends=ends, strides=strides - ) - - -@REGISTER_TRANSPOSE('gather_p') -def gather_transpose(op, check_dot, y_bar): - x, indextensor = op_position_inputs(op) - assert check_dot(x), 'check_dot(x) must be True' - axis = op.attr('axis') - zeros = fill_const(0.0, x.shape, x.dtype) - x_bar = scatter_add(zeros, y_bar, indextensor, axis=axis) - indextensor_bar = None - return x_bar, indextensor_bar - - -@REGISTER_TRANSPOSE('scatter_add_p') -def scatter_add_transpose(op, check_dot, z_bar): - x, y, indextensor = op_position_inputs(op) - assert check_dot(x) and check_dot(y), ( - f'(check_dot(x) and check_dot(y)) must be True, ' - f'but check_dot(x)={check_dot(x)} and check_dot(y)={check_dot(y)}.' - ) - axis = op.attr('axis') - zeros = fill_const(value=0.0, shape=y.shape, dtype=y.dtype) - x_bar = scatter_add(z_bar, zeros, indextensor, axis=axis) - y_bar = gather(z_bar, indextensor, axis=axis) - indextensor_bar = None - return x_bar, y_bar, indextensor_bar - - -@REGISTER_TRANSPOSE('select_p') -def select_transpose(op, check_dot, z_bar): - cond, x, y = op_position_inputs(op) - assert check_dot(cond) or check_dot(x) or check_dot(y), ( - f'check_dot(cond) ^ (check_dot(x) ^ check_dot(y)) must be True, ' - f'but check_dot(cond)={check_dot(cond)}, check_dot(x)={check_dot(x)} and check_dot(y)={check_dot(y)}.' - ) - - zeros_x = fill_const(value=0.0, shape=x.shape, dtype=x.dtype) - zeros_y = fill_const(value=0.0, shape=y.shape, dtype=y.dtype) - - cond_bar = ( - fill_const(value=0.0, shape=y.shape, dtype=cond.dtype) - if check_dot(cond) - else None - ) - x_bar = select(cond, z_bar, zeros_x) if check_dot(x) else None - y_bar = select(cond, zeros_y, z_bar) if check_dot(y) else None - - return cond_bar, x_bar, y_bar - - -@REGISTER_TRANSPOSE('cast_p') -def cast_transpose(op, check_dot, y_bar): - (x,) = op_position_inputs(op) - return primops.cast(y_bar, x.dtype) diff --git a/python/paddle/incubate/autograd/primx.py b/python/paddle/incubate/autograd/primx.py index 8ffd36031991a..901e23a649974 100644 --- a/python/paddle/incubate/autograd/primx.py +++ b/python/paddle/incubate/autograd/primx.py @@ -23,17 +23,13 @@ from paddle.incubate.autograd.utils import as_tensors from .composite_rules import _composite -from .primops import add, fill_const from .primreg import ( lookup_composite, lookup_orig2prim, lookup_prim2orig, - op_position_inputs, - op_position_output, ) -from .primrules import _jvp, _orig2prim, _prim2orig, _transpose +from .primrules import _orig2prim, _prim2orig from .utils import ( - flatten, flatten_and_remove_none, get_input_var_list, get_output_var_list, @@ -262,178 +258,6 @@ def dot2bar_rec(self, dots): bars = [self.dot2bar_rec(dot) for dot in dots] return bars - def linearize(self, xs, ys, xs_dot=None): - """Performs the linearization transform, a.k.a, forward mode AD - transform, on a primitive lowered program. - - Args: - xs: a list of input variables - ys: a list of output variables - xs_dot: optional, a list of gradient input variables. The list size - must be equal to `len(xs)`. The shape and dtype of each element - must be the same as in `xs` - - Returns: - (xs_dot, ys_dot): a tuple of two lists. `xs_dot` is the list of - gradient inputs of the resulting linearized program. `ys_dot` is - the list gradient outputs of the resulting linearized program - - """ - if xs_dot is None: - xs_dot = [fill_const(1.0, shape=x.shape, dtype=x.dtype) for x in xs] - self.add_vars(xs_dot) - else: - assert len(xs) == len(xs_dot), ( - f'len(xs) should be equal to len(xs_dot), ' - f'but len(xs)={len(xs)} and len(xs_dot)={len(xs_dot)}' - ) - - for x, dot in zip(xs, xs_dot): - assert x.dtype == dot.dtype, ( - f'x.dtype should be equal to dot.dtype, ' - f'but x.dtype={x.dtype} and dot.dtype={dot.dtype}' - ) - assert x.shape == dot.shape, ( - f'x.shape should be equal to dot.shape, ' - f'but x.shape={x.shape} and dot.shape={dot.shape}' - ) - self.var2dot.add(x, dot) - - path, unused_xs, _ = topo_path(xs, ys, self.block) - - # No need to track unused inputs - for x in unused_xs: - self.var2dot.delete(x) - - for op in path: - # An input var may not be on the input-output path, which implies - # there may be None's in `ins_dot`. In this case we place - # the original input in the position of the otherwise forward - # gradient. - ins = op_position_inputs(op) - jvp_ins = self.var2dot_rec(ins) - # apply op's forward ad rule - outs_dot = _jvp(op, *jvp_ins) - self.add_vars_rec(outs_dot) - outs = op_position_output(op) - self.var2dot.add_rec(outs, outs_dot) - - ys_dot = [self.var2dot.lookup(y) for y in ys] - return xs_dot, ys_dot - - def transpose(self, ys_dot, xs_dot, ys_bar=None, retain_fwd=False): - """Performs the transpose transform, a.k.a, reverse mode AD - transform, on a linearized primitive program. - - Note, `transpose` is supposed to be used in couple with `linearize`. - - Args: - ys_dot: a list of outputs of the linearized program. - xs_dot: a list of inputs of the linearized program. - ys_bar: optional, a list of inputs of the resulting transposed - program. The list size must be equal to `len(ys_dot)`. The shape - and dtype of each element must be the same as in `ys_dot` - - Returns: - (ys_bar, xs_bar): a tuple of two lists. `ys_bar` is the list of - inputs of the resulting transposed program. `xs_bar` is - the list outputs of the resulting transposed program - - """ - assert all(v is not None for v in xs_dot), '`xs_dot` includes None.' - assert all(v is not None for v in ys_dot), '`ys_dot` includes None.' - - if ys_bar is None: - ys_bar = [] - for y in ys_dot: - ys_bar.append(fill_const(1.0, shape=y.shape, dtype=y.dtype)) - self.add_vars(ys_bar) - else: - assert len(ys_dot) == len(ys_bar), ( - f'len(ys_dot) should be equal to len(ys_bar), ' - f'but len(ys_dot)={len(ys_dot)} and len(ys_bar)={len(ys_bar)}' - ) - for y_dot, y_bar in zip(ys_dot, ys_bar): - assert y_dot.shape == y_bar.shape, ( - f'y_dot.shape should be equal to y_bar.shape, ' - f'but y_dot.shape={y_dot.shape} and y_bar.shape={y_bar.shape}' - ) - assert y_dot.dtype == y_bar.dtype, ( - f'y_dot.dtype should be equal to y_bar.dtype, ' - f'but y_dot.dtype={y_dot.dtype} and y_bar.dtype={y_bar.dtype}' - ) - - for dot, bar in zip(ys_dot, ys_bar): - self.dot2bar.add(dot, bar) - - # find all the relevant forward gradients - path, unused_xs_dot, _ = topo_path(xs_dot, ys_dot, self.block) - - # No need to track unused inputs - for dot in unused_xs_dot: - self.dot2bar.delete(dot) - - dotvars = output_vars_on_path(path) - dotvars.update((id(var), var) for var in xs_dot) - - is_dot = lambda v: id(v) in dotvars - - for op in reversed(path): - out = op_position_output(op) - out_bar_rec = self.dot2bar_rec(out) - ins_bar_rec = _transpose(op, is_dot, out_bar_rec) - - # TODO(Tongxin): this is hacky. Tuple implies the Transpose rule - # returns multiple entities. There should be better ways to handle - # outputs. - if isinstance(ins_bar_rec, tuple): - ins_bar_rec = list(ins_bar_rec) - else: - ins_bar_rec = [ins_bar_rec] - self.add_vars_rec(ins_bar_rec) - - ins_bar = flatten(ins_bar_rec) - ins = flatten(op_position_inputs(op)) - assert len(ins) == len(ins_bar), ( - f'len(ins) should be equal to len(ins_bar), ' - f'but len(ins)={len(ins)} and len(ins_bar)={len(ins_bar)}' - ) - - for dot, bar in zip(ins, ins_bar): - if bar is not None: - # aggregate gradient - grad = self.dot2bar.lookup(dot) - if grad is None: - self.dot2bar.add(dot, bar) - else: - grad = add(grad, bar) - self.add_vars([grad]) - self.dot2bar.add(dot, grad) - - xs_bar = [self.dot2bar.lookup(x) for x in xs_dot] - - if not retain_fwd and len(path) > 0: - vars_to_remove = set() - for op in path: - vars_to_remove.update( - flatten_and_remove_none(get_output_var_list(op)) - ) - - op_indexes = [] - - block = self.block - for i, op in enumerate(block.ops): - if op in path: - op_indexes.append(i) - path.pop(0) - if len(path) == 0: - break - - self.erase_ops(op_indexes) - self.erase_dots(vars_to_remove) - - return ys_bar, xs_bar - # TODO(lml): supporting control flow, nested blocks, and block other than current block of main program. def _lower(block, reverse, blacklist): diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index f9a7214cf9321..f0b04b0efc441 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -262,7 +262,6 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_comp_cost MODULES test_comp_cost) py_test_modules(test_base_cost MODULES test_base_cost) py_test_modules(test_dist_context MODULES test_dist_context) - py_test_modules(test_prim_dist_op MODULES test_prim_dist_op) py_test_modules(test_to_static MODULES test_to_static) py_test_modules(test_dist_op_cost MODULES test_dist_op_cost) py_test_modules(test_cluster_v2 MODULES test_cluster_v2) diff --git a/test/auto_parallel/test_prim_dist_op.py b/test/auto_parallel/test_prim_dist_op.py deleted file mode 100644 index 99e12b2099874..0000000000000 --- a/test/auto_parallel/test_prim_dist_op.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import paddle -from paddle.base.layer_helper import LayerHelper -from paddle.distributed.auto_parallel.static.completion import Completer -from paddle.distributed.auto_parallel.static.dist_context import ( - DistributedContext, - get_default_distributed_context, -) -from paddle.distributed.auto_parallel.static.partitioner import Partitioner -from paddle.distributed.auto_parallel.static.utils import set_var_dist_attr -from paddle.distributed.fleet import auto -from paddle.incubate.autograd import enable_prim - -paddle.enable_static() -enable_prim() -nranks = 2 -rank = 0 - - -class TestPrimDistOp(unittest.TestCase): - def setUp(self): - self.main_program = paddle.static.Program() - self.startup_program = paddle.static.Program() - self.layer_help = LayerHelper('TestPrimDistOp') - - with paddle.static.program_guard( - self.main_program, self.startup_program - ): - self.init_prog() - - def init_prog(self): - # block = self.main_program.global_block() - # block = self.main_program.global_block() - self.w = self.layer_help.create_parameter( - dtype="float", shape=[20], attr=None - ) - self.w_grad = paddle.static.data( - name='w_grad', shape=[20], dtype='float' - ) - self.tmp1 = paddle.static.data(name='tmp1', shape=[20], dtype='float') - self.tmp2 = paddle.static.data(name='tmp2', shape=[20], dtype='float') - self.batch_reduced = paddle.static.data( - name='batch_reduced', shape=[], dtype='float' - ) - self.attrs = {} - - default_dist_context = get_default_distributed_context() - _global_process_mesh = auto.ProcessMesh(list(range(nranks))) - tensor_dist_attr = set_var_dist_attr( - default_dist_context, - self.tmp1, - [-1], - _global_process_mesh, - mark_annotated=True, - ) - tensor_dist_attr = set_var_dist_attr( - default_dist_context, - self.tmp1, - [-1], - _global_process_mesh, - mark_annotated=True, - ) - - op = self.layer_help.append_op( - type="add_p", - inputs={'X': self.tmp1, 'Y': self.w}, - outputs={'Z': self.w_grad}, - attrs=self.attrs, - ) - - op = self.layer_help.append_op( - type="reduce_sum_p", - inputs={'X': self.tmp2}, - outputs={'Y': self.batch_reduced}, - attrs={"axis": [0]}, - ) - - def test_loss_and_grad_allreduce(self): - dist_context = DistributedContext( - self.main_program, self.startup_program - ) - completer = Completer(dist_context) - completer.complete_prim_annotation(self.main_program) - dist_context.block_state.parse_forward_blocks(self.main_program) - dist_context.block_state.parse_backward_blocks(self.main_program) - dist_context.grads_params = {} - dist_context.grads_params[self.w_grad.name] = self.w.name - dist_context.synced_gradient = set() - dist_context.data_parallel_group = list(range(nranks)) - partitioner = Partitioner(dist_context, rank) - dist_main_prog, dist_startup_prog, _ = partitioner.partition( - self.main_program, self.startup_program, [(self.w, self.w_grad)] - ) - ops = dist_main_prog.global_block().ops - self.assertTrue(ops[1].type == "c_allreduce_sum") - self.assertTrue(ops[3].type == "c_allreduce_sum") - - -if __name__ == "__main__": - unittest.main() diff --git a/test/autograd/CMakeLists.txt b/test/autograd/CMakeLists.txt index 592517cb8e3da..9bdb0b88daf63 100644 --- a/test/autograd/CMakeLists.txt +++ b/test/autograd/CMakeLists.txt @@ -8,7 +8,6 @@ set(GC_ENVS FLAGS_eager_delete_tensor_gb=0.0) if(WIN32) # TODO: Fix these unittests failed on Windows list(REMOVE_ITEM TEST_OPS test_autograd_functional_prim) - list(REMOVE_ITEM TEST_OPS test_primapi) endif() foreach(TEST_OP ${TEST_OPS}) @@ -21,5 +20,4 @@ set_tests_properties(test_autograd_functional_static PROPERTIES TIMEOUT 160) set_tests_properties(test_minimize PROPERTIES TIMEOUT 60) if(NOT WIN32) set_tests_properties(test_autograd_functional_prim PROPERTIES TIMEOUT 60) - set_tests_properties(test_primapi PROPERTIES TIMEOUT 60) endif() diff --git a/test/autograd/test_jvp_and_transpose.py b/test/autograd/test_jvp_and_transpose.py deleted file mode 100644 index b37fd4e201a4e..0000000000000 --- a/test/autograd/test_jvp_and_transpose.py +++ /dev/null @@ -1,1336 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import paddle -from paddle.base.layer_helper import LayerHelper -from paddle.incubate.autograd.primrules import _jvp, _transpose - -paddle.enable_static() - - -# --------------------- Test linearize rules ----------------------- # -class TestAddPJVPAndTranspose(unittest.TestCase): - def setUp(self): - self.main_program = paddle.static.Program() - self.startup_program = paddle.static.Program() - self.layer_help = LayerHelper('TestPrim2Orig') - - with paddle.static.program_guard( - self.main_program, self.startup_program - ): - self.init_data() - - def init_data(self): - # Set prim op - self.op_type = 'add_p' - X = paddle.static.data(name='X', shape=[2, 2], dtype='float') - Y = paddle.static.data(name='Y', shape=[2, 2], dtype='float') - self.prim_input = {'X': X, 'Y': Y} - self.prim_output = { - 'Z': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = {} - - # Set JVP - X_DOT = paddle.static.data(name='X_DOT', shape=[2, 2], dtype='float') - Y_DOT = paddle.static.data(name='Y_DOT', shape=[2, 2], dtype='float') - self.jvp_args = (X_DOT, Y_DOT) - self.jvp_out_shape_map = {0: self.prim_output['Z']} - - # Set transpose - check_dot = lambda v: True - Z_BAR = paddle.static.data(name='Z_BAR', shape=[2, 2], dtype='float') - self.transpose_args = (check_dot, Z_BAR) - self.transpose_out_shape_map = {0: X, 1: Y} - - self.all_ops = [ - # prim op: - 'add_p', - # jvp op: - 'add_p', - # transpose op: - ] - - def test_op(self): - with paddle.static.program_guard( - self.main_program, self.startup_program - ): - op = self.layer_help.append_op( - type=self.op_type, - inputs=self.prim_input, - outputs=self.prim_output, - attrs=self.prim_attrs, - ) - - jvp_out = _jvp(op, *self.jvp_args) - jvp_out = paddle.utils.flatten(jvp_out) - for k, v in self.jvp_out_shape_map.items(): - self.assertEqual(jvp_out[k].shape, v.shape) - - # Some prim ops dont have transpose rule - if hasattr(self, 'transpose_args'): - transpose_out = _transpose(op, *self.transpose_args) - transpose_out = paddle.utils.flatten(transpose_out) - for k, v in self.transpose_out_shape_map.items(): - self.assertEqual(transpose_out[k].shape, v.shape) - - all_ops = [op.type for op in self.main_program.block(0).ops] - self.assertEqual(sorted(all_ops), sorted(self.all_ops)) - - -class TestSubPJVPAndTranspose(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'sub_p' - X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') - Y = paddle.static.data(name='Y', shape=[5, 6], dtype='int64') - self.prim_input = {'X': X, 'Y': Y} - self.prim_output = { - 'Z': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = {} - - # Set JVP - X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64') - Y_DOT = paddle.static.data(name='Y_DOT', shape=[5, 6], dtype='int64') - self.jvp_args = (X_DOT, Y_DOT) - self.jvp_out_shape_map = {0: self.prim_output['Z']} - - # Set transpose - check_dot = lambda v: True - Z_BAR = paddle.static.data(name='Z_BAR', shape=[5, 6], dtype='int64') - self.transpose_args = (check_dot, Z_BAR) - self.transpose_out_shape_map = {0: X, 1: Y} - - self.all_ops = [ - # prim op: - 'sub_p', - # jvp op: - 'sub_p', - # transpose op: - 'fill_constant_p', - 'sub_p', - ] - - -class TestMulPJVPAndTranspose(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'mul_p' - X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') - Y = paddle.static.data(name='Y', shape=[5, 6], dtype='int64') - self.prim_input = {'X': X, 'Y': Y} - self.prim_output = { - 'Z': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = {} - - # Set JVP - X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64') - Y_DOT = paddle.static.data(name='Y_DOT', shape=[5, 6], dtype='int64') - self.jvp_args = (X_DOT, Y_DOT) - self.jvp_out_shape_map = {0: self.prim_output['Z']} - - # Set transpose - check_dot = lambda v: v is X - Z_BAR = paddle.static.data(name='Z_BAR', shape=[5, 6], dtype='int64') - self.transpose_args = (check_dot, Z_BAR) - self.transpose_out_shape_map = { - 0: X, - } - - self.all_ops = [ - # prim op: - 'mul_p', - # jvp op: - 'mul_p', - 'mul_p', - 'add_p', - # transpose op: - 'mul_p', - ] - - -class TestDivPJVPAndTranspose(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'div_p' - X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') - Y = paddle.static.data(name='Y', shape=[5, 6], dtype='int64') - self.prim_input = {'X': X, 'Y': Y} - self.prim_output = { - 'Z': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = {} - - # Set JVP - X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64') - Y_DOT = paddle.static.data(name='Y_DOT', shape=[5, 6], dtype='int64') - self.jvp_args = (X_DOT, Y_DOT) - self.jvp_out_shape_map = {0: self.prim_output['Z']} - - # Set transpose - check_dot = lambda v: v is X - Z_BAR = paddle.static.data(name='Z_BAR', shape=[5, 6], dtype='int64') - self.transpose_args = (check_dot, Z_BAR) - self.transpose_out_shape_map = { - 0: X, - } - - self.all_ops = [ - # prim op: - 'div_p', - # jvp op: - 'div_p', - 'div_p', - 'mul_p', - 'mul_p', - 'sub_p', - # transpose op: - 'div_p', - ] - - -class TestSqrtPJVPAndTranspose(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'sqrt_p' - X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') - self.prim_input = { - 'X': X, - } - self.prim_output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = {} - - # Set JVP - X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64') - self.jvp_args = (X_DOT,) - self.jvp_out_shape_map = {0: self.prim_output['Y']} - - self.all_ops = [ - # prim op: - 'sqrt_p', - # jvp op: - 'div_p', - 'mul_p', - 'fill_constant_p', - # 'sqrt_p', - # transpose op: - ] - - -class TestRSqrtPJVPAndTranspose(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'rsqrt_p' - X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') - self.prim_input = { - 'X': X, - } - self.prim_output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = {} - - # Set JVP - X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64') - self.jvp_args = (X_DOT,) - self.jvp_out_shape_map = {0: self.prim_output['Y']} - - self.all_ops = [ - # prim op: - 'rsqrt_p', - # jvp op: - 'div_p', - 'div_p', - 'mul_p', - 'fill_constant_p', - # 'sqrt_p', - # transpose op: - ] - - -class TestTanhPJVPAndTranspose(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'tanh_p' - X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') - self.prim_input = { - 'X': X, - } - self.prim_output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = {} - - # Set JVP - X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64') - self.jvp_args = (X_DOT,) - self.jvp_out_shape_map = {0: self.prim_output['Y']} - - self.all_ops = [ - # prim op: - 'tanh_p', - # jvp op: - 'mul_p', - 'sub_p', - 'fill_constant_p', - 'mul_p', - # transpose op: - ] - - -class TestSinPJVPAndTranspose(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'sin_p' - X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') - self.prim_input = { - 'X': X, - } - self.prim_output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = {} - - # Set JVP - X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64') - self.jvp_args = (X_DOT,) - self.jvp_out_shape_map = {0: self.prim_output['Y']} - - self.all_ops = [ - # prim op: - 'sin_p', - # jvp op: - 'mul_p', - 'cos_p', - # transpose op: - ] - - -class TestCosPJVPAndTranspose(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'cos_p' - X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') - self.prim_input = { - 'X': X, - } - self.prim_output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = {} - - # Set JVP - X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64') - self.jvp_args = (X_DOT,) - self.jvp_out_shape_map = {0: self.prim_output['Y']} - - self.all_ops = [ - # prim op: - 'cos_p', - # jvp op: - 'mul_p', - 'sin_p', - 'fill_constant_p', - 'sub_p' - # transpose op: - ] - - -class TestExpPJVPAndTranspose(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'exp_p' - X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') - self.prim_input = { - 'X': X, - } - self.prim_output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = {} - - # Set JVP - X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64') - self.jvp_args = (X_DOT,) - self.jvp_out_shape_map = {0: self.prim_output['Y']} - - self.all_ops = [ - # prim op: - 'exp_p', - # jvp op: - 'mul_p', - # transpose op: - ] - - -class TestErfPJVPAndTranspose(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'erf_p' - X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') - self.prim_input = { - 'X': X, - } - self.prim_output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = {} - - # Set JVP - X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64') - self.jvp_args = (X_DOT,) - self.jvp_out_shape_map = {0: self.prim_output['Y']} - - self.all_ops = [ - # prim op: - 'erf_p', - # jvp op: - 'exp_p', - 'fill_constant_p', - 'fill_constant_p', - 'fill_constant_p', - 'mul_p', - 'mul_p', - 'pow_p', - 'sub_p', - # transpose op: - ] - - -class TestAbsPJVPAndTranspose(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'abs_p' - X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') - self.prim_input = { - 'X': X, - } - self.prim_output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = {} - - # Set JVP - X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64') - self.jvp_args = (X_DOT,) - self.jvp_out_shape_map = {0: self.prim_output['Y']} - - self.all_ops = [ - # prim op: - 'abs_p', - # jvp op: - 'select_p', - 'ge_p', - 'fill_constant_p', - 'fill_constant_p', - 'sub_p', - # transpose op: - ] - - -class TestCastPJVPAndTranspose(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'cast_p' - X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') - self.prim_input = { - 'X': X, - } - self.prim_output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = {'dtype': paddle.float64} - - # Set JVP - X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64') - self.jvp_args = (X_DOT,) - self.jvp_out_shape_map = {0: self.prim_output['Y']} - - # Set transpose - check_dot = lambda v: True - Y_BAR = paddle.static.data(name='Y_BAR', shape=[5, 6], dtype='float') - self.transpose_args = (check_dot, Y_BAR) - self.transpose_out_shape_map = {0: X} - - self.all_ops = [ - # prim op: - 'cast_p', - # jvp op: - 'cast_p', - # transpose op: - 'cast_p', - ] - - -class TestLogPJVPAndTranspose(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'log_p' - X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') - self.prim_input = { - 'X': X, - } - self.prim_output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = {} - - # Set JVP - X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64') - self.jvp_args = (X_DOT,) - self.jvp_out_shape_map = {0: self.prim_output['Y']} - - self.all_ops = [ - # prim op: - 'log_p', - # jvp op: - 'div_p', - # transpose op: - ] - - -class TestReshapePJVPAndTranspose(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'reshape_p' - X = paddle.static.data(name='X', shape=[8, 8], dtype='int64') - self.prim_input = { - 'X': X, - } - self.prim_output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = {'shape': [2, 32]} - - # Set JVP - X_DOT = paddle.static.data(name='X_DOT', shape=[8, 8], dtype='int64') - self.jvp_args = (X_DOT,) - self.jvp_out_shape_map = {0: self.prim_output['Y']} - - # Set transpose - check_dot = lambda v: v is X - Y_BAR = paddle.static.data(name='Y_BAR', shape=[2, 32], dtype='int64') - self.transpose_args = (check_dot, Y_BAR) - self.transpose_out_shape_map = { - 0: X, - } - - self.all_ops = [ - # prim op: - 'reshape_p', - # jvp op: - 'reshape_p', - # transpose op: - 'reshape_p', - ] - - -class TestBroadcastPJVPAndTranspose(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'broadcast_p' - X = paddle.static.data(name='X', shape=[10, 1], dtype='int64') - self.prim_input = { - 'X': X, - } - self.prim_output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = {'shape': [2, 10, 7]} - - # Set JVP - X_DOT = paddle.static.data(name='X_DOT', shape=[10, 7], dtype='int64') - self.jvp_args = (X_DOT,) - self.jvp_out_shape_map = {0: self.prim_output['Y']} - - # Set transpose - check_dot = lambda v: v is X - Y_BAR = paddle.static.data( - name='Y_BAR', shape=[2, 10, 7], dtype='int64' - ) - self.transpose_args = (check_dot, Y_BAR) - self.transpose_out_shape_map = { - 0: X, - } - - self.all_ops = [ - # prim op: - 'broadcast_p', - # jvp op: - 'broadcast_p', - # transpose op: - 'reduce_sum_p', - 'reshape_p', - ] - - -class TestTransposePJVPAndTranspose(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'transpose_p' - X = paddle.static.data(name='X', shape=[2, 3, 4, 5], dtype='int64') - self.prim_input = { - 'X': X, - } - self.prim_output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = {'axis': [0, 2, 3, 1]} - - # Set JVP - X_DOT = paddle.static.data( - name='X_DOT', shape=[2, 3, 4, 5], dtype='int64' - ) - self.jvp_args = (X_DOT,) - self.jvp_out_shape_map = {0: self.prim_output['Y']} - - # Set transpose - check_dot = lambda v: v is X - Y_BAR = paddle.static.data( - name='Y_BAR', shape=[2, 4, 5, 3], dtype='int64' - ) - self.transpose_args = (check_dot, Y_BAR) - self.transpose_out_shape_map = { - 0: X, - } - - self.all_ops = [ - # prim op: - 'transpose_p', - # jvp op: - 'transpose_p', - # transpose op: - 'transpose_p', - ] - - -class TestSplitPJVPAndTranspose(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'split_p' - X = paddle.static.data(name='X', shape=[2, 7, 10], dtype='int64') - self.prim_input = { - 'X': X, - } - self.prim_output = { - 'YS': [ - self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - for i in range(4) - ] - } - self.prim_attrs = {'num_or_sections': [2, 3, 4, 1], 'axis': 2} - - # Set JVP - X_DOT = paddle.static.data( - name='X_DOT', shape=[2, 7, 10], dtype='int64' - ) - self.jvp_args = (X_DOT,) - self.jvp_out_shape_map = { - 0: self.prim_output['YS'][0], - 1: self.prim_output['YS'][1], - 2: self.prim_output['YS'][2], - 3: self.prim_output['YS'][3], - } - - # Set transpose - check_dot = lambda v: v is X - YS_BAR = [ - paddle.static.data(name='Y_BAR1', shape=[2, 7, 2], dtype='int64'), - paddle.static.data(name='Y_BAR2', shape=[2, 7, 3], dtype='int64'), - paddle.static.data(name='Y_BAR3', shape=[2, 7, 4], dtype='int64'), - paddle.static.data(name='Y_BAR4', shape=[2, 7, 1], dtype='int64'), - ] - self.transpose_args = (check_dot, YS_BAR) - self.transpose_out_shape_map = { - 0: X, - } - - self.all_ops = [ - # prim op: - 'split_p', - # jvp op: - 'split_p', - # transpose op: - 'concat_p', - ] - - -class TestConcatPJVPAndTranspose(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'concat_p' - X = paddle.static.data(name='X', shape=[3, 9, 5], dtype='float64') - Y = paddle.static.data(name='Y', shape=[3, 2, 5], dtype='float64') - Z = paddle.static.data(name='Z', shape=[3, 3, 5], dtype='float64') - self.prim_input = { - 'XS': [X, Y, Z], - } - self.prim_output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = {'axis': 1} - - # Set JVP - XS_DOT = [ - paddle.static.data(name='X_DOT1', shape=[3, 9, 5], dtype='float64'), - paddle.static.data(name='X_DOT2', shape=[3, 2, 5], dtype='float64'), - paddle.static.data(name='X_DOT3', shape=[3, 3, 5], dtype='float64'), - ] - self.jvp_args = (XS_DOT,) - self.jvp_out_shape_map = {0: self.prim_output['Y']} - - # Set transpose - check_dot = lambda v: v is X or v is Y or v is Z - Y_BAR = paddle.static.data( - name='Y_BAR', shape=[3, 14, 5], dtype='float64' - ) - self.transpose_args = (check_dot, Y_BAR) - self.transpose_out_shape_map = { - 0: X, - 1: Y, - 2: Z, - } - - self.all_ops = [ - # prim op: - 'concat_p', - # jvp op: - 'concat_p', - # transpose op: - 'split_p', - ] - - -class TestReduceSumPJVPAndTranspose(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'reduce_sum_p' - X = paddle.static.data(name='X', shape=[2, 3, 4, 5], dtype='float64') - self.prim_input = {'X': X} - self.prim_output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = {'axis': [2], 'keepdim': False} - - # Set JVP - X_DOT = paddle.static.data( - name='X_DOT1', shape=[2, 3, 4, 5], dtype='float64' - ) - self.jvp_args = (X_DOT,) - self.jvp_out_shape_map = {0: self.prim_output['Y']} - - # Set transpose - check_dot = lambda v: v is X - Y_BAR = paddle.static.data( - name='Y_BAR', shape=[2, 3, 5], dtype='float64' - ) - self.transpose_args = (check_dot, Y_BAR) - self.transpose_out_shape_map = { - 0: X, - } - - self.all_ops = [ - # prim op: - 'reduce_sum_p', - # jvp op: - 'reduce_sum_p', - # transpose op: - 'reshape_p', - 'broadcast_p', - ] - - -class TestMatmulPJVPAndTranspose(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'matmul_p' - X = paddle.static.data(name='X', shape=[2, 3], dtype='float64') - Y = paddle.static.data(name='Y', shape=[3, 4], dtype='float64') - self.prim_input = {'X': X, 'Y': Y} - self.prim_output = { - 'Z': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = {} - - # Set JVP - X_DOT = paddle.static.data(name='X_DOT', shape=[2, 3], dtype='float64') - Y_DOT = paddle.static.data(name='Y_DOT', shape=[3, 4], dtype='float64') - self.jvp_args = (X_DOT, Y_DOT) - self.jvp_out_shape_map = {0: self.prim_output['Z']} - - # Set transpose - check_dot = lambda v: v is X - Z_BAR = paddle.static.data(name='Z_BAR', shape=[2, 4], dtype='float64') - self.transpose_args = (check_dot, Z_BAR) - self.transpose_out_shape_map = { - 0: X, - } - - self.all_ops = [ - # prim op: - 'matmul_p', - # jvp op: - 'matmul_p', - 'matmul_p', - 'add_p', - # transpose op: - 'matmul_p', - 'transpose_p', - ] - - -class TestSliceSelectPJVPAndTranspose(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'slice_select_p' - X = paddle.static.data(name='X', shape=[3, 20], dtype='float64') - self.prim_input = { - 'X': X, - } - self.prim_output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = { - 'axis': [1], - 'starts': [0], - 'ends': [20], - 'strides': [2], - } - - # Set JVP - X_DOT = paddle.static.data(name='X_DOT', shape=[3, 20], dtype='float64') - self.jvp_args = (X_DOT,) - self.jvp_out_shape_map = {0: self.prim_output['Y']} - - # Set transpose - check_dot = lambda v: v is X - Y_BAR = paddle.static.data(name='Y_BAR', shape=[3, 10], dtype='float64') - self.transpose_args = (check_dot, Y_BAR) - self.transpose_out_shape_map = { - 0: X, - } - - self.all_ops = [ - # prim op: - 'slice_select_p', - # jvp op: - 'slice_select_p', - # transpose op: - 'slice_assign_p', - 'fill_constant_p', - ] - - -class TestSliceAssignPJVPAndTranspose1(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'slice_assign_p' - X = paddle.static.data(name='X', shape=[3, 20], dtype='float64') - Y = paddle.static.data(name='Y', shape=[3, 5], dtype='float64') - self.prim_input = {'X': X, 'Y': Y} - self.prim_output = { - 'Z': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = { - 'axis': [1], - 'starts': [0], - 'ends': [10], - 'strides': [2], - } - - # Set JVP - X_DOT = paddle.static.data(name='X_DOT', shape=[3, 20], dtype='float64') - Y_DOT = paddle.static.data(name='Y_DOT', shape=[3, 5], dtype='float64') - self.jvp_args = (X_DOT, Y_DOT) - self.jvp_out_shape_map = {0: self.prim_output['Z']} - - # Set transpose - check_dot = lambda v: v is X - Z_BAR = paddle.static.data(name='Z_BAR', shape=[3, 20], dtype='float64') - self.transpose_args = (check_dot, Z_BAR) - self.transpose_out_shape_map = {0: X} - - self.all_ops = [ - # prim op: - 'slice_assign_p', - # jvp op: - 'slice_assign_p', - "slice_assign_p", - "add_p", - "fill_constant_p", - "fill_constant_p", - # transpose op: - 'slice_assign_p', - 'fill_constant_p', - ] - - -class TestSliceAssignPJVPAndTranspose2(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'slice_assign_p' - X = paddle.static.data(name='X', shape=[3, 20], dtype='float64') - Y = paddle.static.data(name='Y', shape=[3, 5], dtype='float64') - self.prim_input = {'X': X, 'Y': Y} - self.prim_output = { - 'Z': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = { - 'axis': [1], - 'starts': [0], - 'ends': [10], - 'strides': [2], - } - - # Set JVP - Y_DOT = paddle.static.data(name='Y_DOT', shape=[3, 5], dtype='float64') - self.jvp_args = (None, Y_DOT) - self.jvp_out_shape_map = {0: self.prim_output['Z']} - - # Set transpose - check_dot = lambda v: v is Y - Z_BAR = paddle.static.data(name='Z_BAR', shape=[3, 20], dtype='float64') - self.transpose_args = (check_dot, Z_BAR) - self.transpose_out_shape_map = {1: Y} - - self.all_ops = [ - # prim op: - 'slice_assign_p', - # jvp op: - 'slice_assign_p', - "fill_constant_p", - # transpose op: - 'slice_select_p', - 'fill_constant_p', - ] - - -class TestSliceAssignPJVPAndTranspose3(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'slice_assign_p' - X = paddle.static.data(name='X', shape=[3, 20], dtype='float64') - Y = paddle.static.data(name='Y', shape=[3, 5], dtype='float64') - self.prim_input = {'X': X, 'Y': Y} - self.prim_output = { - 'Z': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = { - 'axis': [1], - 'starts': [0], - 'ends': [10], - 'strides': [2], - } - - # Set JVP - X_DOT = paddle.static.data(name='X_DOT', shape=[3, 20], dtype='float64') - self.jvp_args = (X_DOT, None) - self.jvp_out_shape_map = {0: self.prim_output['Z']} - - # Set transpose - check_dot = lambda v: v is X - Z_BAR = paddle.static.data(name='Z_BAR', shape=[3, 20], dtype='float64') - self.transpose_args = (check_dot, Z_BAR) - self.transpose_out_shape_map = {0: X} - - self.all_ops = [ - # prim op: - 'slice_assign_p', - # jvp op: - 'slice_assign_p', - "fill_constant_p", - # transpose op: - 'slice_assign_p', - 'fill_constant_p', - ] - - -class TestGatherPJVPAndTranspose(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'gather_p' - X = paddle.static.data(name='X', shape=[9, 5], dtype='float64') - IndexTensor = paddle.static.data( - name='IndexTensor', shape=[3], dtype='int32' - ) - self.prim_input = {'X': X, 'IndexTensor': IndexTensor} - self.prim_output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = {'axis': 1} - - # Set JVP - X_DOT = paddle.static.data(name='X_DOT', shape=[9, 5], dtype='float64') - self.jvp_args = ( - X_DOT, - IndexTensor, - ) - self.jvp_out_shape_map = {0: self.prim_output['Y']} - - # Set transpose - check_dot = lambda v: v is X - Y_BAR = paddle.static.data(name='Y_BAR', shape=[9, 3], dtype='float64') - self.transpose_args = (check_dot, Y_BAR) - self.transpose_out_shape_map = { - 0: X, - } - - self.all_ops = [ - # prim op: - 'gather_p', - # jvp op: - 'gather_p', - # transpose op: - 'scatter_add_p', - 'fill_constant_p', - ] - - -class TestScatterAddPJVPAndTranspose(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'scatter_add_p' - X = paddle.static.data(name='X', shape=[9, 5], dtype='float64') - Y = paddle.static.data(name='Y', shape=[9, 3], dtype='float64') - IndexTensor = paddle.static.data( - name='IndexTensor', shape=[3], dtype='int32' - ) - self.prim_input = {'X': X, 'Y': Y, 'IndexTensor': IndexTensor} - self.prim_output = { - 'Z': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = {'axis': 1} - - # Set JVP - X_DOT = paddle.static.data(name='X_DOT', shape=[9, 5], dtype='float64') - Y_DOT = paddle.static.data(name='Y_DOT', shape=[9, 3], dtype='float64') - self.jvp_args = (X_DOT, Y_DOT) - self.jvp_out_shape_map = {0: self.prim_output['Z']} - - # Set transpose - check_dot = lambda v: v is X or v is Y - Z_BAR = paddle.static.data(name='Z_BAR', shape=[9, 5], dtype='float64') - self.transpose_args = (check_dot, Z_BAR) - self.transpose_out_shape_map = {0: X, 1: Y} - - self.all_ops = [ - # prim op: - 'scatter_add_p', - # jvp op: - 'scatter_add_p', - # transpose op: - 'scatter_add_p', - 'fill_constant_p', - 'gather_p', - ] - - -class TestSelectPJVPAndTranspose(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'select_p' - Cond = paddle.static.data(name='Condition', shape=[9, 5], dtype='bool') - X = paddle.static.data(name='X', shape=[9, 5], dtype='float64') - Y = paddle.static.data(name='Y', shape=[9, 5], dtype='float64') - - self.prim_input = {'Condition': Cond, 'X': X, 'Y': Y} - self.prim_output = { - 'Z': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = {} - - # Set JVP - Cond_DOT = paddle.static.data( - name='Cond_DOT', shape=[9, 5], dtype='float64' - ) - X_DOT = paddle.static.data(name='X_DOT', shape=[9, 5], dtype='float64') - Y_DOT = paddle.static.data(name='Y_DOT', shape=[9, 5], dtype='float64') - self.jvp_args = (Cond_DOT, X_DOT, Y_DOT) - self.jvp_out_shape_map = {0: self.prim_output['Z']} - - # Set transpose - check_dot = lambda v: True - Z_BAR = paddle.static.data(name='Z_BAR', shape=[9, 5], dtype='float64') - self.transpose_args = (check_dot, Z_BAR) - self.transpose_out_shape_map = {0: X, 1: Y} - - self.all_ops = [ - # prim op: - 'select_p', - # jvp op: - 'select_p', - # transpose op: - 'fill_constant_p', - 'fill_constant_p', - 'fill_constant_p', - 'select_p', - 'select_p', - ] - - -class TestEqPJVPAndTranspose(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'eq_p' - X = paddle.static.data(name='X', shape=[4, 5], dtype='float64') - Y = paddle.static.data(name='Y', shape=[4, 5], dtype='float64') - - self.prim_input = {'X': X, 'Y': Y} - self.prim_output = { - 'Z': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = {} - - # Set JVP - X_DOT = paddle.static.data(name='X_DOT', shape=[4, 5], dtype='float64') - Y_DOT = paddle.static.data(name='Y_DOT', shape=[4, 5], dtype='float64') - self.jvp_args = (X_DOT, Y_DOT) - self.jvp_out_shape_map = {0: self.prim_output['Z']} - - self.all_ops = [ - # prim op: - 'eq_p', - # jvp op: - 'fill_constant_p', - # transpose op: - ] - - -class TestGtPJVPAndTranspose(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'gt_p' - X = paddle.static.data(name='X', shape=[4, 5], dtype='float64') - Y = paddle.static.data(name='Y', shape=[4, 5], dtype='float64') - - self.prim_input = {'X': X, 'Y': Y} - self.prim_output = { - 'Z': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = {} - - # Set JVP - X_DOT = paddle.static.data(name='X_DOT', shape=[4, 5], dtype='float64') - Y_DOT = paddle.static.data(name='Y_DOT', shape=[4, 5], dtype='float64') - self.jvp_args = (X_DOT, Y_DOT) - self.jvp_out_shape_map = {0: self.prim_output['Z']} - - self.all_ops = [ - # prim op: - 'gt_p', - # jvp op: - 'fill_constant_p', - # transpose op: - ] - - -class TestGePJVPAndTranspose(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'ge_p' - X = paddle.static.data(name='X', shape=[4, 5], dtype='float64') - Y = paddle.static.data(name='Y', shape=[4, 5], dtype='float64') - - self.prim_input = {'X': X, 'Y': Y} - self.prim_output = { - 'Z': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = {} - - # Set JVP - X_DOT = paddle.static.data(name='X_DOT', shape=[4, 5], dtype='float64') - Y_DOT = paddle.static.data(name='Y_DOT', shape=[4, 5], dtype='float64') - self.jvp_args = (X_DOT, Y_DOT) - self.jvp_out_shape_map = {0: self.prim_output['Z']} - - self.all_ops = [ - # prim op: - 'ge_p', - # jvp op: - 'fill_constant_p', - # transpose op: - ] - - -class TestNePJVPAndTranspose(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'ne_p' - X = paddle.static.data(name='X', shape=[4, 5], dtype='float64') - Y = paddle.static.data(name='Y', shape=[4, 5], dtype='float64') - - self.prim_input = {'X': X, 'Y': Y} - self.prim_output = { - 'Z': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = {} - - # Set JVP - X_DOT = paddle.static.data(name='X_DOT', shape=[4, 5], dtype='float64') - Y_DOT = paddle.static.data(name='Y_DOT', shape=[4, 5], dtype='float64') - self.jvp_args = (X_DOT, Y_DOT) - self.jvp_out_shape_map = {0: self.prim_output['Z']} - - self.all_ops = [ - # prim op: - 'ne_p', - # jvp op: - 'fill_constant_p', - # transpose op: - ] - - -class TestPowPJVPAndTranspose(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'pow_p' - X = paddle.static.data(name='X', shape=[5, 6], dtype='float32') - Y = paddle.static.data(name='Y', shape=[5, 6], dtype='float32') - self.prim_input = {'X': X, 'Y': Y} - self.prim_output = { - 'Z': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = {} - - # Set JVP - X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='float32') - Y_DOT = paddle.static.data(name='Y_DOT', shape=[5, 6], dtype='float32') - self.jvp_args = (X_DOT, Y_DOT) - self.jvp_out_shape_map = {0: self.prim_output['Z']} - - self.all_ops = [ - # prim op: - 'pow_p', - # jvp op: - 'fill_constant_p', - 'fill_constant_p', - 'eq_p', - 'select_p', - 'sub_p', - 'mul_p', - 'mul_p', - 'pow_p', - 'mul_p', - 'mul_p', - 'log_p', - 'add_p' - # transpose op: - ] - - -class TestMaxPJVPAndTranspose(TestAddPJVPAndTranspose): - def init_data(self): - # Set prim op - self.op_type = 'max_p' - X = paddle.static.data(name='X', shape=[5, 6], dtype='float32') - Y = paddle.static.data(name='Y', shape=[5, 6], dtype='float32') - self.prim_input = {'X': X, 'Y': Y} - self.prim_output = { - 'Z': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.prim_attrs = {} - - # Set JVP - X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='float32') - Y_DOT = paddle.static.data(name='Y_DOT', shape=[5, 6], dtype='float32') - self.jvp_args = (X_DOT, Y_DOT) - self.jvp_out_shape_map = {0: self.prim_output['Z']} - - self.all_ops = [ - # prim op: - 'max_p', - # jvp op: - 'fill_constant_p', - 'eq_p', - 'select_p', - # transpose op: - ] - - -if __name__ == '__main__': - unittest.main() diff --git a/test/autograd/test_orig2prim.py b/test/autograd/test_orig2prim.py deleted file mode 100644 index 4767cc29d8fa2..0000000000000 --- a/test/autograd/test_orig2prim.py +++ /dev/null @@ -1,1087 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import paddle -from paddle.base.layer_helper import LayerHelper -from paddle.incubate.autograd.primrules import _orig2prim - -paddle.enable_static() - - -# ----------------------- Test orig2prim rules ---------------------------- # -class TestElementWiseAddOrig2Prim(unittest.TestCase): - def setUp(self): - self.main_program = paddle.static.Program() - self.startup_program = paddle.static.Program() - self.layer_help = LayerHelper('TestOrig2Prim') - - with paddle.static.program_guard( - self.main_program, self.startup_program - ): - self.init_data() - - def init_data(self): - self.op_type = 'elementwise_add' - X = paddle.static.data(name='X', shape=[2, 2], dtype='float') - Y = paddle.static.data(name='Y', shape=[2, 2], dtype='float') - - self.input = {'X': X, 'Y': Y} - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.orig2prim_args = (X, Y) - self.all_ops = ['elementwise_add', 'add_p'] - # { prim_op_output_index: orig_op_output_var } - self.out_map = {0: self.output['Out']} - - def test_op(self): - with paddle.static.program_guard( - self.main_program, self.startup_program - ): - op = self.layer_help.append_op( - type=self.op_type, - inputs=self.input, - outputs=self.output, - attrs=self.attrs, - ) - - prim_out = _orig2prim(op, *self.orig2prim_args) - all_ops = [op.type for op in self.main_program.block(0).ops] - - self.assertEqual(sorted(all_ops), sorted(self.all_ops)) - prim_out = paddle.utils.flatten(prim_out) - for k, v in self.out_map.items(): - self.assertEqual(prim_out[k].shape, v.shape) - - -class TestSqrtOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'sqrt' - X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') - - self.input = { - 'X': X, - } - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.orig2prim_args = (X,) - self.all_ops = ['sqrt', 'sqrt_p'] - # { prim_op_output_index: orig_op_output_var } - self.out_map = {0: self.output['Out']} - - -class TestElementWiseMulOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'elementwise_mul' - X = paddle.static.data(name='X', shape=[8, 8], dtype='float') - Y = paddle.static.data(name='Y', shape=[8, 8], dtype='float') - - self.input = {'X': X, 'Y': Y} - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.orig2prim_args = (X, Y) - self.all_ops = ['elementwise_mul', 'mul_p'] - # { prim_op_output_index: orig_op_output_var } - self.out_map = {0: self.output['Out']} - - -class TestElementWiseDivOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'elementwise_div' - X = paddle.static.data(name='X', shape=[8, 8], dtype='float') - Y = paddle.static.data(name='Y', shape=[8, 8], dtype='float') - - self.input = {'X': X, 'Y': Y} - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.orig2prim_args = (X, Y) - self.all_ops = ['elementwise_div', 'div_p'] - # { prim_op_output_index: orig_op_output_var } - self.out_map = {0: self.output['Out']} - - -class TestMatmulV2Orig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'matmul_v2' - X = paddle.static.data(name='X', shape=[3, 4], dtype='float') - Y = paddle.static.data(name='Y', shape=[4, 3], dtype='float') - - self.input = {'X': X, 'Y': Y} - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {'trans_x': True, 'trans_y': True} - - self.orig2prim_args = (X, Y) - self.all_ops = ['matmul_v2', 'transpose_p', 'transpose_p', 'matmul_p'] - self.out_map = {0: self.output['Out']} - - -class TestTanhOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'tanh' - X = paddle.static.data(name='X', shape=[3, 4], dtype='float') - - self.input = { - 'X': X, - } - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.orig2prim_args = (X,) - self.all_ops = ['tanh', 'tanh_p'] - self.out_map = {0: self.output['Out']} - - -class TestSinOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'sin' - X = paddle.static.data(name='X', shape=[3, 4], dtype='float') - - self.input = { - 'X': X, - } - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.orig2prim_args = (X,) - self.all_ops = ['sin', 'sin_p'] - self.out_map = {0: self.output['Out']} - - -class TestCosOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'cos' - X = paddle.static.data(name='X', shape=[3, 4], dtype='float') - - self.input = { - 'X': X, - } - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.orig2prim_args = (X,) - self.all_ops = ['cos', 'cos_p'] - self.out_map = {0: self.output['Out']} - - -class TestExpOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'exp' - X = paddle.static.data(name='X', shape=[3, 4], dtype='float') - - self.input = { - 'X': X, - } - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.orig2prim_args = (X,) - self.all_ops = ['exp', 'exp_p'] - self.out_map = {0: self.output['Out']} - - -class TestErfOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'erf' - X = paddle.static.data(name='X', shape=[3, 4], dtype='float') - - self.input = { - 'X': X, - } - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.orig2prim_args = (X,) - self.all_ops = ['erf', 'erf_p'] - self.out_map = {0: self.output['Out']} - - -class TestAbsOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'abs' - X = paddle.static.data(name='X', shape=[3, 4], dtype='float') - - self.input = { - 'X': X, - } - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.orig2prim_args = (X,) - self.all_ops = ['abs', 'abs_p'] - self.out_map = {0: self.output['Out']} - - -class TestLogOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'log' - X = paddle.static.data(name='X', shape=[3, 4], dtype='float') - - self.input = { - 'X': X, - } - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.orig2prim_args = (X,) - self.all_ops = ['log', 'log_p'] - self.out_map = {0: self.output['Out']} - - -class TestReshape2Orig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'reshape2' - X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') - - self.input = { - 'X': X, - } - self.output = { - 'Out': X, - 'XShape': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ), - } - self.attrs = {'shape': [6, 5]} - - self.orig2prim_args = ( - None, - None, - X, - ) - self.all_ops = ['reshape2', 'reshape_p', 'fill_constant_p'] - # Do not check XShape - self.out_map = {0: self.output['Out']} - - -class TestConcatOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'concat' - X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') - Y = paddle.static.data(name='Y', shape=[3, 6], dtype='int64') - - self.input = { - 'X': [X, Y], - } - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {'axis': 0} - - self.orig2prim_args = ( - None, - (X, Y), - ) - self.all_ops = ['concat', 'concat_p'] - self.out_map = {0: self.output['Out']} - - -class TestSliceOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'slice' - X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') - - self.input = { - 'Input': X, - } - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = { - 'axes': [0], - 'starts': [1], - 'ends': [4], - } - - self.orig2prim_args = (None, None, X, None, None) - self.all_ops = ['slice', 'slice_select_p'] - self.out_map = {0: self.output['Out']} - - -class TestFillZerosLikeOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'fill_zeros_like' - X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') - - self.input = { - 'X': X, - } - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.orig2prim_args = (X,) - self.all_ops = ['fill_zeros_like', 'fill_constant_p'] - self.out_map = {0: self.output['Out']} - - -class TestFillAnyLikeOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'fill_any_like' - X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') - - self.input = { - 'X': X, - } - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.orig2prim_args = (X,) - self.all_ops = ['fill_any_like', 'fill_constant_p'] - self.out_map = {0: self.output['Out']} - - -class TestFillAnyLikeOrig2Prim2(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'fill_any_like' - X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') - - self.input = { - 'X': X, - } - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {'dtype': paddle.float32, 'value': 5} - - self.orig2prim_args = (X,) - self.all_ops = ['fill_any_like', 'fill_constant_p'] - self.out_map = {0: self.output['Out']} - - -class TestSumOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'sum' - X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') - Y = paddle.static.data(name='Y', shape=[5, 6], dtype='int64') - - self.input = {'X': X, 'Y': Y} - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.orig2prim_args = ((X, Y),) - self.all_ops = ['sum', 'add_p'] - self.out_map = {0: self.output['Out']} - - -class TestPNormOrig2Prim1(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'p_norm' - X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') - - self.input = { - 'X': X, - } - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = { - 'porder': 1, - 'asvector': True, - } - - self.orig2prim_args = (X,) - self.all_ops = [ - 'p_norm', - 'reshape_p', - 'abs_p', - 'reduce_sum_p', - ] - self.out_map = {0: self.output['Out']} - - -class TestPNormOrig2Prim2(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'p_norm' - X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') - - self.input = { - 'X': X, - } - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = { - 'porder': 2, - 'asvector': True, - } - - self.orig2prim_args = (X,) - self.all_ops = [ - 'p_norm', - 'reshape_p', - 'sqrt_p', - 'reduce_sum_p', - 'mul_p', - ] - self.out_map = {0: self.output['Out']} - - -class TestIndexSelectOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'index_select' - X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') - Index = paddle.static.data(name='Index', shape=[2], dtype='int32') - - self.input = {'X': X, 'Index': Index} - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = { - 'dim': 0, - } - - self.orig2prim_args = ( - Index, - X, - ) - self.all_ops = ['index_select', 'gather_p'] - self.out_map = {0: self.output['Out']} - - -class TestElementwiseSubOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'elementwise_sub' - X = paddle.static.data(name='X', shape=[5, 6], dtype='int32') - Y = paddle.static.data(name='Y', shape=[6], dtype='int32') - - self.input = {'X': X, 'Y': Y} - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = { - 'dim': 0, - } - - self.orig2prim_args = ( - X, - Y, - ) - self.all_ops = ['elementwise_sub', 'broadcast_p', 'sub_p'] - self.out_map = {0: self.output['Out']} - - -class TestScaleOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'scale' - X = paddle.static.data(name='X', shape=[10, 7], dtype='int32') - - self.input = { - 'X': X, - } - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {'scale': 2.0, 'bias': 1.0, 'bias_after_scale': True} - - self.orig2prim_args = ( - None, - X, - ) - self.all_ops = [ - 'scale', - 'fill_constant_p', - 'fill_constant_p', - 'mul_p', - 'add_p', - ] - self.out_map = {0: self.output['Out']} - - -class TestAssignOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'assign' - X = paddle.static.data(name='X', shape=[10, 7], dtype='int32') - - self.input = { - 'X': X, - } - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.orig2prim_args = (X,) - self.all_ops = ['assign', 'fill_constant_p', 'add_p'] - self.out_map = {0: self.output['Out']} - - -class TestWhereOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'where' - Cond = paddle.static.data(name='Condition', shape=[5, 6], dtype='bool') - X = paddle.static.data(name='X', shape=[5, 6], dtype='float32') - Y = paddle.static.data(name='Y', shape=[5, 6], dtype='float32') - - self.input = {'Condition': Cond, 'X': X, 'Y': Y} - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - self.orig2prim_args = (Cond, X, Y) - self.all_ops = ['where', 'select_p'] - self.out_map = {0: self.output['Out']} - - -class TestEqualOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'equal' - X = paddle.static.data(name='X', shape=[5, 8], dtype='float') - Y = paddle.static.data(name='Y', shape=[5, 8], dtype='float') - - self.input = {'X': X, 'Y': Y} - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype='bool' - ) - } - self.attrs = {} - self.orig2prim_args = (X, Y) - self.all_ops = ['equal', 'eq_p'] - # { prim_op_output_index: orig_op_output_var } - self.out_map = {0: self.output['Out']} - - -class TestNeOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'not_equal' - X = paddle.static.data(name='X', shape=[5, 8], dtype='float') - Y = paddle.static.data(name='Y', shape=[5, 8], dtype='float') - - self.input = {'X': X, 'Y': Y} - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype='bool' - ) - } - self.attrs = {} - self.orig2prim_args = (X, Y) - self.all_ops = ['not_equal', 'ne_p'] - # { prim_op_output_index: orig_op_output_var } - self.out_map = {0: self.output['Out']} - - -class TestGtOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'greater_than' - X = paddle.static.data(name='X', shape=[5, 8], dtype='float') - Y = paddle.static.data(name='Y', shape=[5, 8], dtype='float') - - self.input = {'X': X, 'Y': Y} - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype='bool' - ) - } - self.attrs = {} - self.orig2prim_args = (X, Y) - self.all_ops = ['greater_than', 'gt_p'] - # { prim_op_output_index: orig_op_output_var } - self.out_map = {0: self.output['Out']} - - -class TestGeOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'greater_equal' - X = paddle.static.data(name='X', shape=[5, 8], dtype='float') - Y = paddle.static.data(name='Y', shape=[5, 8], dtype='float') - - self.input = {'X': X, 'Y': Y} - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype='bool' - ) - } - self.attrs = {} - self.orig2prim_args = (X, Y) - self.all_ops = ['greater_equal', 'ge_p'] - # { prim_op_output_index: orig_op_output_var } - self.out_map = {0: self.output['Out']} - - -class TestPowOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'elementwise_pow' - X = paddle.static.data(name='X', shape=[5, 8], dtype='float') - Y = paddle.static.data(name='Y', shape=[5, 8], dtype='float') - - self.input = {'X': X, 'Y': Y} - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.orig2prim_args = (X, Y) - self.all_ops = ['elementwise_pow', 'pow_p'] - # { prim_op_output_index: orig_op_output_var } - self.out_map = {0: self.output['Out']} - - -class TestMaxOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'elementwise_max' - X = paddle.static.data(name='X', shape=[5, 8], dtype='float') - Y = paddle.static.data(name='Y', shape=[5, 8], dtype='float') - - self.input = {'X': X, 'Y': Y} - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.orig2prim_args = (X, Y) - self.all_ops = ['elementwise_max', 'max_p'] - # { prim_op_output_index: orig_op_output_var } - self.out_map = {0: self.output['Out']} - - -class TestGeluOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'gelu' - X = paddle.static.data(name='X', shape=[5, 8], dtype='float') - - self.input = {'X': X} - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {'approximate': False} - - self.orig2prim_args = (X,) - self.all_ops = [ - 'gelu', - 'add_p', - 'erf_p', - 'fill_constant_p', - 'fill_constant_p', - 'fill_constant_p', - 'mul_p', - 'mul_p', - 'mul_p', - ] - # { prim_op_output_index: orig_op_output_var } - self.out_map = {0: self.output['Out']} - - -class TestGeluApproximateOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'gelu' - X = paddle.static.data(name='X', shape=[5, 8], dtype='float') - - self.input = {'X': X} - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {'approximate': True} - - self.orig2prim_args = (X,) - self.all_ops = [ - 'add_p', - 'add_p', - 'fill_constant_p', - 'fill_constant_p', - 'fill_constant_p', - 'fill_constant_p', - 'fill_constant_p', - 'gelu', - 'mul_p', - 'mul_p', - 'mul_p', - 'mul_p', - 'pow_p', - 'tanh_p', - ] - # { prim_op_output_index: orig_op_output_var } - self.out_map = {0: self.output['Out']} - - -class TestReduceSumOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'reduce_sum' - - X = paddle.static.data(name='X', shape=[5, 8], dtype='float') - - self.input = {'X': X} - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {'axis': [0, 1], 'keep_dim': False} - - self.orig2prim_args = (X,) - self.all_ops = ['reduce_sum', 'reduce_sum_p'] - # { prim_op_output_index: orig_op_output_var } - self.out_map = {0: self.output['Out']} - - -class TestReduceMeanOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'reduce_mean' - X = paddle.static.data(name='X', shape=[5, 8], dtype='float') - - self.input = {'X': X} - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {'axis': [0, 1], 'keep_dim': False} - - self.orig2prim_args = (X,) - self.all_ops = [ - 'reduce_mean', - 'reduce_sum_p', - 'fill_constant_p', - 'div_p', - ] - # { prim_op_output_index: orig_op_output_var } - self.out_map = {0: self.output['Out']} - - -class TestSizeOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'size' - X = paddle.static.data(name='X', shape=[5, 8], dtype='float') - - self.input = {'Input': X} - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=paddle.int64 - ) - } - self.attrs = {} - self.orig2prim_args = (X,) - self.all_ops = ['size', 'fill_constant_p'] - # { prim_op_output_index: orig_op_output_var } - self.out_map = {0: self.output['Out']} - - -class TestCastOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'cast' - X = paddle.static.data(name='X', shape=[5, 8], dtype='float') - - self.input = {'X': X} - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {'in_dtype': X.dtype, 'out_dtype': paddle.float64} - self.orig2prim_args = (X,) - self.all_ops = ['cast', 'cast_p'] - # { prim_op_output_index: orig_op_output_var } - self.out_map = {0: self.output['Out']} - - -class TestPowScalarOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'pow' - X = paddle.static.data(name='X', shape=[5, 8], dtype='float') - - self.input = {'X': X} - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {'factor': 2.0} - self.orig2prim_args = (None, X) - self.all_ops = ['pow', 'pow_p', 'fill_constant_p'] - # { prim_op_output_index: orig_op_output_var } - self.out_map = {0: self.output['Out']} - - -class TestSquareOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'square' - X = paddle.static.data(name='X', shape=[5, 8], dtype='float') - - self.input = {'X': X} - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - self.orig2prim_args = (X,) - self.all_ops = ['square', 'pow_p', 'fill_constant_p'] - # { prim_op_output_index: orig_op_output_var } - self.out_map = {0: self.output['Out']} - - -class TestRSqrtOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'rsqrt' - X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') - - self.input = { - 'X': X, - } - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.orig2prim_args = (X,) - self.all_ops = ['rsqrt', 'rsqrt_p'] - # { prim_op_output_index: orig_op_output_var } - self.out_map = {0: self.output['Out']} - - -class TestBatchnormOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'batch_norm' - x = paddle.static.data(name='X', shape=[5, 8], dtype='float') - m = paddle.static.data(name='Mean', shape=[8], dtype='float') - v = paddle.static.data(name='Variance', shape=[8], dtype='float') - w = paddle.static.data(name='Scale', shape=[8], dtype='float') - b = paddle.static.data(name='Bias', shape=[8], dtype='float') - - self.input = { - "X": [x], - "Scale": [w], - "Bias": [b], - "Mean": [m], - "Variance": [v], - } - saved_variance = self.layer_help.create_variable_for_type_inference( - dtype=x.dtype, stop_gradient=True - ) - batch_norm_out = self.layer_help.create_variable_for_type_inference( - x.dtype - ) - saved_mean = self.layer_help.create_variable_for_type_inference( - dtype=x.dtype, stop_gradient=True - ) - self.output = { - "Y": [batch_norm_out], - "MeanOut": [m], - "VarianceOut": [v], - "SavedMean": [saved_mean], - "SavedVariance": [saved_variance], - } - - self.attrs = { - "momentum": 0.9, - "epsilon": 1e-5, - "is_test": False, - "data_layout": 'NCHW', - "use_mkldnn": False, - "fuse_with_relu": False, - "use_global_stats": False, - "trainable_statistics": False, - } - self.orig2prim_args = (b, m, None, w, v, x) - self.all_ops = [ - 'add_p', - 'add_p', - 'add_p', - 'add_p', - 'batch_norm', - 'broadcast_p', - 'broadcast_p', - 'broadcast_p', - 'broadcast_p', - 'broadcast_p', - 'div_p', - 'div_p', - 'div_p', - 'fill_constant_p', - 'fill_constant_p', - 'fill_constant_p', - 'fill_constant_p', - 'fill_constant_p', - 'fill_constant_p', - 'fill_constant_p', - 'fill_constant_p', - 'fill_constant_p', - 'mul_p', - 'mul_p', - 'mul_p', - 'mul_p', - 'mul_p', - 'pow_p', - 'reduce_sum_p', - 'reduce_sum_p', - 'reshape_p', - 'reshape_p', - 'reshape_p', - 'reshape_p', - 'sqrt_p', - 'sub_p', - 'sub_p', - 'sub_p', - 'sub_p', - ] - # { prim_op_output_index: orig_op_output_var } - self.out_map = {} - - -class TestFillConstantOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'fill_constant' - - self.attrs = {'value': 1.0, 'shape': (2, 3), 'dtype': paddle.float32} - self.input = {} - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=paddle.float32 - ) - } - - self.orig2prim_args = (None, None, None) - self.all_ops = ['fill_constant', 'fill_constant_p'] - # { prim_op_output_index: orig_op_output_var } - self.out_map = {0: self.output['Out']} - - -class TestUniformRandomOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'uniform_random' - self.input = {} - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=paddle.float32 - ) - } - self.attrs = {'shape': [1, 2]} - - self.orig2prim_args = (None, None) - self.all_ops = ['uniform_random', 'uniform_random_p'] - self.out_map = {0: self.output['Out']} - - -class TestSigmoidOrig2Prim(TestElementWiseAddOrig2Prim): - def init_data(self): - self.op_type = 'sigmoid' - X = paddle.static.data(name='X', shape=[3], dtype='float32') - - self.attrs = {} - self.input = {'X': X} - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=paddle.float32 - ) - } - - self.orig2prim_args = (X,) - self.all_ops = [ - 'sigmoid', - 'div_p', - 'fill_constant_p', - 'add_p', - 'fill_constant_p', - 'exp_p', - 'fill_constant_p', - 'sub_p', - ] - self.out_map = {0: self.output['Out']} - - -if __name__ == '__main__': - unittest.main() diff --git a/test/autograd/test_prim2orig.py b/test/autograd/test_prim2orig.py deleted file mode 100644 index 7dfd2c79068c5..0000000000000 --- a/test/autograd/test_prim2orig.py +++ /dev/null @@ -1,744 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import paddle -from paddle.base.layer_helper import LayerHelper -from paddle.incubate.autograd.primrules import _prim2orig - -paddle.enable_static() - - -# ------------------------ Test prim2orig rules ---------------------------- # -class TestAddPPrim2Orig(unittest.TestCase): - def setUp(self): - self.main_program = paddle.static.Program() - self.startup_program = paddle.static.Program() - self.layer_help = LayerHelper('TestPrim2Orig') - - with paddle.static.program_guard( - self.main_program, self.startup_program - ): - self.init_data() - - def init_data(self): - self.op_type = 'add_p' - X = paddle.static.data(name='X', shape=[2, 2], dtype='float') - Y = paddle.static.data(name='Y', shape=[2, 2], dtype='float') - - self.input = {'X': X, 'Y': Y} - self.output = { - 'Z': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.prim2orig_args = (X, Y) - self.all_ops = ['add_p', 'elementwise_add'] - # { prim_op_output_var: origin_op_out_index } - self.out_map = {self.output['Z']: 0} - - def test_op(self): - with paddle.static.program_guard( - self.main_program, self.startup_program - ): - op = self.layer_help.append_op( - type=self.op_type, - inputs=self.input, - outputs=self.output, - attrs=self.attrs, - ) - - orig_out = _prim2orig(op, *self.prim2orig_args) - all_ops = [op.type for op in self.main_program.block(0).ops] - self.assertEqual(sorted(all_ops), sorted(self.all_ops)) - orig_out = paddle.utils.flatten(orig_out) - for k, v in self.out_map.items(): - self.assertEqual(k.shape, orig_out[v].shape) - - -class TestSubPPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'sub_p' - X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') - Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64') - - self.input = {'X': X, 'Y': Y} - self.output = { - 'Z': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.prim2orig_args = (X, Y) - self.all_ops = ['sub_p', 'elementwise_sub'] - self.out_map = {self.output['Z']: 0} - - -class TestMulPPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'mul_p' - X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') - Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64') - - self.input = {'X': X, 'Y': Y} - self.output = { - 'Z': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.prim2orig_args = (X, Y) - self.all_ops = ['mul_p', 'elementwise_mul'] - self.out_map = {self.output['Z']: 0} - - -class TestDivPPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'div_p' - X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') - Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64') - - self.input = {'X': X, 'Y': Y} - self.output = { - 'Z': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.prim2orig_args = (X, Y) - self.all_ops = ['div_p', 'elementwise_div'] - self.out_map = {self.output['Z']: 0} - - -class TestSqrtPPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'sqrt_p' - X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') - - self.input = { - 'X': X, - } - self.output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.prim2orig_args = (X,) - self.all_ops = ['sqrt_p', 'sqrt'] - self.out_map = {self.output['Y']: 0} - - -class TestTanhPPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'tanh_p' - X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') - - self.input = { - 'X': X, - } - self.output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.prim2orig_args = (X,) - self.all_ops = ['tanh_p', 'tanh'] - self.out_map = {self.output['Y']: 0} - - -class TestSinPPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'sin_p' - X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') - - self.input = { - 'X': X, - } - self.output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.prim2orig_args = (X,) - self.all_ops = ['sin_p', 'sin'] - self.out_map = {self.output['Y']: 0} - - -class TestCosPPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'cos_p' - X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') - - self.input = { - 'X': X, - } - self.output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.prim2orig_args = (X,) - self.all_ops = ['cos_p', 'cos'] - self.out_map = {self.output['Y']: 0} - - -class TestExpPPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'exp_p' - X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') - - self.input = { - 'X': X, - } - self.output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.prim2orig_args = (X,) - self.all_ops = ['exp_p', 'exp'] - self.out_map = {self.output['Y']: 0} - - -class TestErfPPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'erf_p' - X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') - - self.input = { - 'X': X, - } - self.output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.prim2orig_args = (X,) - self.all_ops = ['erf_p', 'erf'] - self.out_map = {self.output['Y']: 0} - - -class TestAbsPPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'abs_p' - X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') - - self.input = { - 'X': X, - } - self.output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.prim2orig_args = (X,) - self.all_ops = ['abs_p', 'abs'] - self.out_map = {self.output['Y']: 0} - - -class TestLogPPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'log_p' - X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') - - self.input = { - 'X': X, - } - self.output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.prim2orig_args = (X,) - self.all_ops = ['log_p', 'log'] - self.out_map = {self.output['Y']: 0} - - -class TestReshapePPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'reshape_p' - X = paddle.static.data(name='X', shape=[2, 8], dtype='float64') - - self.input = { - 'X': X, - } - self.output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {'shape': [4, 4]} - - self.prim2orig_args = (X,) - self.all_ops = ['reshape_p', 'reshape2'] - self.out_map = {self.output['Y']: 0} - - -class TestBroadcastPPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'broadcast_p' - X = paddle.static.data(name='X', shape=[2, 8], dtype='float64') - - self.input = { - 'X': X, - } - self.output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {'shape': [10, 2, 8]} - - self.prim2orig_args = (X,) - self.all_ops = ['broadcast_p', 'expand_v2'] - self.out_map = {self.output['Y']: 0} - - -class TestTransposePPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'transpose_p' - X = paddle.static.data(name='X', shape=[7, 8, 9, 10], dtype='float64') - - self.input = { - 'X': X, - } - self.output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {'axis': [1, 2, 0, 3]} - - self.prim2orig_args = (X,) - self.all_ops = ['transpose_p', 'transpose2'] - self.out_map = {self.output['Y']: 0} - - -class TestSplitPPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'split_p' - X = paddle.static.data(name='X', shape=[3, 9, 5], dtype='float64') - - self.input = { - 'X': X, - } - self.output = { - 'YS': [ - self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - for i in range(3) - ] - } - self.attrs = {'num_or_sections': [2, 3, 4], 'axis': 1} - - self.prim2orig_args = (X,) - self.all_ops = ['split_p', 'split'] - self.out_map = { - self.output['YS'][0]: 0, - self.output['YS'][1]: 1, - self.output['YS'][2]: 2, - } - - -class TestConcatPPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'concat_p' - X = paddle.static.data(name='X', shape=[3, 9, 5], dtype='float64') - Y = paddle.static.data(name='Y', shape=[2, 9, 5], dtype='float64') - Z = paddle.static.data(name='Z', shape=[1, 9, 5], dtype='float64') - - self.input = { - 'XS': [X, Y, Z], - } - self.output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {'axis': 0} - - self.prim2orig_args = ((X, Y, Z),) - self.all_ops = ['concat_p', 'concat'] - self.out_map = {self.output['Y']: 0} - - -class TestReducePPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'reduce_sum_p' - X = paddle.static.data(name='X', shape=[3, 9, 5], dtype='float64') - - self.input = {'X': X} - self.output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {'axis': [1], 'keepdim': True} - - self.prim2orig_args = (X,) - self.all_ops = ['reduce_sum_p', 'reduce_sum'] - self.out_map = {self.output['Y']: 0} - - -class TestMatmulPPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'matmul_p' - X = paddle.static.data(name='X', shape=[9, 5], dtype='float64') - Y = paddle.static.data(name='Y', shape=[5, 9], dtype='float64') - - self.input = {'X': X, 'Y': Y} - self.output = { - 'Z': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.prim2orig_args = (X, Y) - self.all_ops = ['matmul_p', 'matmul_v2'] - self.out_map = {self.output['Z']: 0} - - -class TestSliceSelectPPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'slice_select_p' - X = paddle.static.data(name='X', shape=[9, 5], dtype='float64') - - self.input = { - 'X': X, - } - self.output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {'axis': [0], 'starts': [1], 'ends': [8], 'strides': [2]} - - self.prim2orig_args = (X,) - self.all_ops = ['slice_select_p', 'strided_slice'] - self.out_map = {self.output['Y']: 0} - - -class TestSliceAssignPPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'slice_assign_p' - X = paddle.static.data(name='X', shape=[9, 5], dtype='float64') - Y = paddle.static.data(name='Y', shape=[9, 3], dtype='float64') - - self.input = {'X': X, 'Y': Y} - self.output = { - 'Z': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {'axis': [1], 'starts': [0], 'ends': [3], 'strides': [1]} - - self.prim2orig_args = (X, Y) - self.all_ops = ['slice_assign_p', 'assign', 'set_value'] - self.out_map = {self.output['Z']: 0} - - -class TestGatherPPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'gather_p' - X = paddle.static.data(name='X', shape=[9, 5], dtype='float64') - IndexTensor = paddle.static.data( - name='IndexTensor', shape=[3], dtype='int32' - ) - - self.input = {'X': X, 'IndexTensor': IndexTensor} - self.output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = { - 'axis': 0, - } - - self.prim2orig_args = ( - IndexTensor, - X, - ) - self.all_ops = ['gather_p', 'gather'] - self.out_map = {self.output['Y']: 0} - - -class TestScatterAddPPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'scatter_add_p' - X = paddle.static.data(name='X', shape=[9, 5], dtype='float64') - Y = paddle.static.data(name='Y', shape=[3, 5], dtype='float64') - IndexTensor = paddle.static.data( - name='IndexTensor', shape=[3], dtype='int32' - ) - - self.input = {'X': X, 'Y': Y, 'IndexTensor': IndexTensor} - self.output = { - 'Z': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = { - 'axis': 0, - } - - self.prim2orig_args = (IndexTensor, X, Y) - self.all_ops = [ - 'scatter_add_p', - 'fill_any_like', - 'scatter', - 'elementwise_add', - ] - self.out_map = {self.output['Z']: 0} - - -class TestFillConstantPPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'fill_constant_p' - - self.input = {} - self.output = { - 'Y': self.layer_help.create_variable_for_type_inference( - paddle.int32 - ) - } - self.attrs = {'value': 10, 'shape': [5, 5], 'dtype': paddle.int32} - - self.prim2orig_args = () - self.all_ops = ['fill_constant_p', 'fill_constant'] - self.out_map = {self.output['Y']: 0} - - -class TestSelectPPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'select_p' - Cond = paddle.static.data(name='Condition', shape=[5, 6], dtype='bool') - X = paddle.static.data(name='X', shape=[5, 6], dtype='float32') - Y = paddle.static.data(name='Y', shape=[5, 6], dtype='float32') - - self.input = {'Condition': Cond, 'X': X, 'Y': Y} - self.output = { - 'Z': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - self.prim2orig_args = (Cond, X, Y) - self.all_ops = ['select_p', 'where'] - self.out_map = {self.output['Z']: 0} - - -class TestEqPPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'eq_p' - X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') - Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64') - - self.input = {'X': X, 'Y': Y} - self.output = { - 'Z': self.layer_help.create_variable_for_type_inference( - dtype='bool' - ) - } - self.attrs = {} - - self.prim2orig_args = (X, Y) - self.all_ops = ['eq_p', 'equal'] - self.out_map = {self.output['Z']: 0} - - -class TestNePPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'ne_p' - X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') - Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64') - - self.input = {'X': X, 'Y': Y} - self.output = { - 'Z': self.layer_help.create_variable_for_type_inference( - dtype='bool' - ) - } - self.attrs = {} - - self.prim2orig_args = (X, Y) - self.all_ops = ['ne_p', 'not_equal'] - self.out_map = {self.output['Z']: 0} - - -class TestGtPPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'gt_p' - X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') - Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64') - - self.input = {'X': X, 'Y': Y} - self.output = { - 'Z': self.layer_help.create_variable_for_type_inference( - dtype='bool' - ) - } - self.attrs = {} - - self.prim2orig_args = (X, Y) - self.all_ops = ['gt_p', 'greater_than'] - self.out_map = {self.output['Z']: 0} - - -class TestGePPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'ge_p' - X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') - Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64') - - self.input = {'X': X, 'Y': Y} - self.output = { - 'Z': self.layer_help.create_variable_for_type_inference( - dtype='bool' - ) - } - self.attrs = {} - - self.prim2orig_args = (X, Y) - self.all_ops = ['ge_p', 'greater_equal'] - self.out_map = {self.output['Z']: 0} - - -class TestPowPPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'pow_p' - X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') - Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64') - - self.input = {'X': X, 'Y': Y} - self.output = { - 'Z': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.prim2orig_args = (X, Y) - self.all_ops = ['pow_p', 'elementwise_pow'] - self.out_map = {self.output['Z']: 0} - - -class TestMaxPPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'max_p' - X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') - Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64') - - self.input = {'X': X, 'Y': Y} - self.output = { - 'Z': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.prim2orig_args = (X, Y) - self.all_ops = ['max_p', 'elementwise_max'] - self.out_map = {self.output['Z']: 0} - - -class TestCastPPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'cast_p' - X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') - - self.input = { - 'X': X, - } - self.output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {'dtype': paddle.int64} - - self.prim2orig_args = (X,) - self.all_ops = ['cast_p', 'cast'] - self.out_map = {self.output['Y']: 0} - - -class TestRsqrtPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'rsqrt_p' - X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') - - self.input = { - 'X': X, - } - self.output = { - 'Y': self.layer_help.create_variable_for_type_inference( - dtype=X.dtype - ) - } - self.attrs = {} - - self.prim2orig_args = (X,) - self.all_ops = ['rsqrt_p', 'rsqrt'] - self.out_map = {self.output['Y']: 0} - - -class TestUniformRandomPrim2Orig(TestAddPPrim2Orig): - def init_data(self): - self.op_type = 'uniform_random_p' - - self.input = {} - self.output = { - 'Out': self.layer_help.create_variable_for_type_inference( - dtype=paddle.float64 - ) - } - self.attrs = { - 'shape': [1, 2, 3], - 'min': -1.0, - 'max': 1.0, - 'seed': 0, - 'dtype': paddle.float64, - } - - self.prim2orig_args = () - self.all_ops = ['uniform_random_p', 'uniform_random'] - self.out_map = {self.output['Out']: 0} - - -if __name__ == '__main__': - unittest.main() diff --git a/test/autograd/test_primapi.py b/test/autograd/test_primapi.py deleted file mode 100644 index 7bbe4e4476046..0000000000000 --- a/test/autograd/test_primapi.py +++ /dev/null @@ -1,1097 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import typing -import unittest - -import autograd -import autograd.numpy as anp -import autograd.scipy as ascipy -import config -import numpy as np -import parameterized as param -import utils - -import paddle -from paddle.base import core -from paddle.incubate.autograd import primapi, primx - - -@utils.place(config.DEVICES) -@utils.parameterize( - (utils.TEST_CASE_NAME, 'fun', 'xs', 'dtype'), - ( - ( - 'uniform_random', - lambda: paddle.uniform( - [1, 2, 3], dtype='float32', min=0, max=1.0, seed=1 - ), - (), - 'int32', - ), - ( - 'sigmoid', - paddle.nn.functional.sigmoid, - ( - np.random.rand( - 5, - ), - ), - 'float32', - ), - ), -) -class TestFowardApi(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.xs = tuple(x.astype(cls.dtype) for x in cls.xs) - - def setUp(self): - paddle.enable_static() - paddle.incubate.autograd.enable_prim() - - def tearDown(self): - paddle.incubate.autograd.disable_prim() - paddle.disable_static() - - def test_grad(self): - def expected(): - paddle.incubate.autograd.disable_prim() - sp = paddle.static.Program() - mp = paddle.static.Program() - with paddle.static.program_guard(mp, sp): - feed, static_xs = utils.gen_static_inputs_and_feed( - self.xs, stop_gradient=False - ) - out = self.fun(*static_xs) - exe = paddle.static.Executor() - exe.run(sp) - out = exe.run(mp, feed=feed, fetch_list=out) - paddle.incubate.autograd.enable_prim() - return out - - def actual(): - paddle.incubate.autograd.enable_prim() - sp = paddle.static.Program() - mp = paddle.static.Program() - with paddle.static.program_guard(mp, sp): - feed, static_xs = utils.gen_static_inputs_and_feed( - self.xs, stop_gradient=False - ) - out = self.fun(*static_xs) - primx.orig2prim(mp.block(0)) - primx.prim2orig(mp.block(0)) - exe = paddle.static.Executor() - exe.run(sp) - out = exe.run(mp, feed=feed, fetch_list=out) - paddle.incubate.autograd.disable_prim() - return out - - expected = expected() - actual = actual() - self.assertEqual(type(actual), type(expected)) - for i, j in zip(actual, expected): - np.testing.assert_allclose(i, j, rtol=1e-6) - - -@utils.place(config.DEVICES) -@utils.parameterize( - (utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'dtype'), - ( - ( - 'dropout', - paddle.nn.functional.dropout, - (np.random.rand(5000, 5000),), - None, - 'float32', - ), - ), -) -class TestDropoutGrad(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.xs = tuple(x.astype(cls.dtype) for x in cls.xs) - cls._rtol = ( - config.TOLERANCE.get(str(cls.dtype)) - .get("first_order_grad") - .get("rtol") - ) - cls._atol = ( - config.TOLERANCE.get(str(cls.dtype)) - .get("first_order_grad") - .get("atol") - ) - - def setUp(self): - paddle.enable_static() - paddle.incubate.autograd.enable_prim() - - def tearDown(self): - paddle.incubate.autograd.disable_prim() - paddle.disable_static() - - def test_grad(self): - def expected(): - paddle.incubate.autograd.disable_prim() - sp = paddle.static.Program() - mp = paddle.static.Program() - with paddle.static.program_guard(mp, sp): - feed, static_xs, static_v = utils.gen_static_data_and_feed( - self.xs, self.v, stop_gradient=False - ) - _, ys_grad = paddle.incubate.autograd.vjp( - self.fun, static_xs, static_v - ) - exe = paddle.static.Executor() - exe.run(sp) - out = exe.run(mp, feed=feed, fetch_list=ys_grad) - paddle.incubate.autograd.enable_prim() - return out - - def actual(): - paddle.incubate.autograd.enable_prim() - sp = paddle.static.Program() - mp = paddle.static.Program() - with paddle.static.program_guard(mp, sp): - feed, static_xs, static_v = utils.gen_static_data_and_feed( - self.xs, self.v, stop_gradient=False - ) - ys = ( - self.fun(*static_xs) - if isinstance(static_xs, typing.Sequence) - else self.fun(static_xs) - ) - ys_grad = paddle.incubate.autograd.grad(ys, static_xs, static_v) - paddle.incubate.autograd.prim2orig(mp.block(0)) - exe = paddle.static.Executor() - exe.run(sp) - out = exe.run(mp, feed=feed, fetch_list=ys_grad) - paddle.incubate.autograd.disable_prim() - return out - - expected = expected() - actual = actual() - self.assertEqual(type(actual), type(expected)) - for i, j in zip(actual, expected): - np.testing.assert_allclose(np.sum(i), np.sum(j), rtol=1e-1) - - -@utils.place(config.DEVICES) -@utils.parameterize( - (utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'dtype'), - ( - ( - 'matmul', - paddle.matmul, - (np.random.rand(2, 3), np.random.rand(3, 2)), - None, - 'float32', - ), - ), -) -class TestWithoutProgramGuard(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.xs = tuple(x.astype(cls.dtype) for x in cls.xs) - cls._rtol = ( - config.TOLERANCE.get(str(cls.dtype)) - .get("first_order_grad") - .get("rtol") - ) - cls._atol = ( - config.TOLERANCE.get(str(cls.dtype)) - .get("first_order_grad") - .get("atol") - ) - - def setUp(self): - paddle.enable_static() - paddle.incubate.autograd.enable_prim() - - def tearDown(self): - paddle.incubate.autograd.disable_prim() - paddle.disable_static() - - def test_forward_grad_without_program_guard(self): - def with_program_guard(): - paddle.incubate.autograd.enable_prim() - sp = paddle.static.Program() - mp = paddle.static.Program() - with paddle.static.program_guard(mp, sp): - feed, static_xs, static_v = utils.gen_static_data_and_feed( - self.xs, self.v, stop_gradient=False - ) - ys = ( - self.fun(*static_xs) - if isinstance(static_xs, typing.Sequence) - else self.fun(static_xs) - ) - ys_grad = paddle.incubate.autograd.forward_grad( - ys, static_xs, static_v - ) - paddle.incubate.autograd.prim2orig(mp.block(0)) - exe = paddle.static.Executor() - exe.run(sp) - out = exe.run(mp, feed=feed, fetch_list=ys_grad) - paddle.incubate.autograd.disable_prim() - return out - - def without_program_guard(): - paddle.incubate.autograd.enable_prim() - feed, static_xs, static_v = utils.gen_static_data_and_feed( - self.xs, self.v, stop_gradient=False - ) - ys = ( - self.fun(*static_xs) - if isinstance(static_xs, typing.Sequence) - else self.fun(static_xs) - ) - ys_grad = paddle.incubate.autograd.forward_grad( - ys, static_xs, static_v - ) - sp = paddle.base.framework.default_startup_program() - mp = paddle.base.framework.default_main_program() - exe = paddle.static.Executor() - exe.run(sp) - out = exe.run(mp, feed=feed, fetch_list=ys_grad) - paddle.incubate.autograd.disable_prim() - return out - - expected = with_program_guard() - actual = without_program_guard() - self.assertEqual(type(actual), type(expected)) - np.testing.assert_allclose( - np.concatenate(actual), - np.concatenate(expected), - rtol=self._rtol, - atol=self._atol, - ) - - def test_grad_without_program_guard(self): - def with_program_guard(): - paddle.incubate.autograd.enable_prim() - sp = paddle.static.Program() - mp = paddle.static.Program() - with paddle.static.program_guard(mp, sp): - feed, static_xs, static_v = utils.gen_static_data_and_feed( - self.xs, self.v, stop_gradient=False - ) - ys = ( - self.fun(*static_xs) - if isinstance(static_xs, typing.Sequence) - else self.fun(static_xs) - ) - xs_grad = paddle.incubate.autograd.grad(ys, static_xs, static_v) - paddle.incubate.autograd.prim2orig(mp.block(0)) - exe = paddle.static.Executor() - exe.run(sp) - out = exe.run(mp, feed=feed, fetch_list=xs_grad) - paddle.incubate.autograd.disable_prim() - return out - - def without_program_guard(): - paddle.incubate.autograd.enable_prim() - feed, static_xs, static_v = utils.gen_static_data_and_feed( - self.xs, self.v, stop_gradient=False - ) - ys = ( - self.fun(*static_xs) - if isinstance(static_xs, typing.Sequence) - else self.fun(static_xs) - ) - xs_grad = paddle.incubate.autograd.grad(ys, static_xs, static_v) - sp = paddle.base.framework.default_startup_program() - mp = paddle.base.framework.default_main_program() - exe = paddle.static.Executor() - exe.run(sp) - out = exe.run(mp, feed=feed, fetch_list=xs_grad) - paddle.incubate.autograd.disable_prim() - return out - - expected = with_program_guard() - actual = without_program_guard() - for i, j in zip(actual, expected): - self.assertEqual(type(i), type(j)) - np.testing.assert_allclose( - np.concatenate(i), - np.concatenate(j), - rtol=self._rtol, - atol=self._atol, - ) - - -@utils.place(config.DEVICES) -@utils.parameterize( - (utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'dtype'), - ( - ( - 'matmul', - paddle.matmul, - (np.random.rand(2, 3), np.random.rand(3, 2)), - None, - 'float32', - ), - ( - 'multiply', - paddle.multiply, - (np.random.rand(2, 3), np.random.rand(2, 3)), - None, - 'float64', - ), - ( - 'add', - paddle.add, - (np.random.rand(2, 3), np.random.rand(2, 3)), - None, - 'float32', - ), - ( - 'input_not_sequence', - paddle.tanh, - (np.random.rand(5, 5),), - None, - 'float64', - ), - ( - 'input_gradients_not_none', - paddle.matmul, - (np.random.rand(3, 3), np.random.rand(3, 3)), - (np.random.rand(3, 3), np.random.rand(3, 3)), - 'float64', - ), - ('log', paddle.log, (np.random.rand(3, 4),), None, 'float32'), - ( - 'abs', - paddle.abs, - (np.random.uniform(-10, 10, (10, 10)),), - None, - 'float32', - ), - ('rsqrt', paddle.rsqrt, (np.random.rand(100, 200),), None, 'float32'), - ( - 'sigmoid', - paddle.nn.functional.sigmoid, - ( - np.random.rand( - 5, - ), - ), - None, - 'float32', - ), - ), -) -# paddle.where, paddle.pow, paddle.maximum has no double grad definition, -# can not compute forward grad use double trick -class TestForwardGrad(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.xs = tuple(x.astype(cls.dtype) for x in cls.xs) - cls._rtol = ( - config.TOLERANCE.get(str(cls.dtype)) - .get("first_order_grad") - .get("rtol") - ) - cls._atol = ( - config.TOLERANCE.get(str(cls.dtype)) - .get("first_order_grad") - .get("atol") - ) - - def setUp(self): - paddle.enable_static() - paddle.incubate.autograd.enable_prim() - - def tearDown(self): - paddle.incubate.autograd.disable_prim() - paddle.disable_static() - - def test_forward_grad(self): - def expected(): - paddle.incubate.autograd.disable_prim() - sp = paddle.static.Program() - mp = paddle.static.Program() - with paddle.static.program_guard(mp, sp): - feed, static_xs, static_v = utils.gen_static_data_and_feed( - self.xs, self.v, stop_gradient=False - ) - _, ys_grad = paddle.incubate.autograd.jvp( - self.fun, static_xs, static_v - ) - exe = paddle.static.Executor() - exe.run(sp) - out = exe.run(mp, feed=feed, fetch_list=ys_grad) - paddle.incubate.autograd.enable_prim() - return out - - def actual(): - paddle.incubate.autograd.enable_prim() - sp = paddle.static.Program() - mp = paddle.static.Program() - with paddle.static.program_guard(mp, sp): - feed, static_xs, static_v = utils.gen_static_data_and_feed( - self.xs, self.v, stop_gradient=False - ) - ys = ( - self.fun(*static_xs) - if isinstance(static_xs, typing.Sequence) - else self.fun(static_xs) - ) - ys_grad = paddle.incubate.autograd.forward_grad( - ys, static_xs, static_v - ) - paddle.incubate.autograd.prim2orig(mp.block(0)) - exe = paddle.static.Executor() - exe.run(sp) - out = exe.run(mp, feed=feed, fetch_list=ys_grad) - paddle.incubate.autograd.disable_prim() - return out - - actual = actual() - expected = expected() - self.assertEqual(type(actual), type(expected)) - np.testing.assert_allclose( - np.concatenate(actual), - np.concatenate(expected), - rtol=self._rtol, - atol=self._atol, - ) - - def test_prim_disabled(self): - paddle.incubate.autograd.disable_prim() - sp = paddle.static.Program() - mp = paddle.static.Program() - with self.assertRaises(RuntimeError): - with paddle.static.program_guard(mp, sp): - feed, static_xs, static_v = utils.gen_static_data_and_feed( - self.xs, self.v, stop_gradient=False - ) - ys = ( - self.fun(*static_xs) - if isinstance(static_xs, typing.Sequence) - else self.fun(static_xs) - ) - ys_grad = paddle.incubate.autograd.forward_grad( - ys, static_xs, static_v - ) - paddle.incubate.autograd.prim2orig(mp.block(0)) - exe = paddle.static.Executor() - exe.run(sp) - exe.run(mp, feed=feed, fetch_list=ys_grad) - paddle.incubate.autograd.enable_prim() - - def test_illegal_param(self): - paddle.incubate.autograd.enable_prim() - with self.assertRaises(TypeError): - paddle.incubate.autograd.forward_grad( - 1, paddle.static.data('inputs', shape=[1]) - ) - - with self.assertRaises(TypeError): - paddle.incubate.autograd.forward_grad( - paddle.static.data('targets', shape=[1]), 1 - ) - paddle.incubate.autograd.disable_prim() - - -where_wrap = lambda x, y: paddle.where(paddle.eye(3, 4) == 1, x, y) - - -@utils.place(config.DEVICES) -@utils.parameterize( - (utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'dtype'), - ( - ( - 'matmul', - paddle.matmul, - (np.random.rand(2, 3), np.random.rand(3, 2)), - None, - 'float32', - ), - ( - 'multiply', - paddle.multiply, - (np.random.rand(2, 3), np.random.rand(2, 3)), - None, - 'float64', - ), - ( - 'div', - paddle.divide, - (np.random.rand(2, 3), np.random.rand(2, 3)), - None, - 'float64', - ), - ( - 'add', - paddle.add, - (np.random.rand(2, 3), np.random.rand(2, 3)), - None, - 'float32', - ), - ( - 'input_not_sequence', - paddle.tanh, - (np.random.rand(5, 5),), - None, - 'float64', - ), - ( - 'input_gradients_not_none', - paddle.matmul, - (np.random.rand(3, 3), np.random.rand(3, 3)), - (np.random.rand(3, 3),), - 'float64', - ), - ('sin', paddle.sin, (np.random.rand(100, 200),), None, 'float32'), - ('rsqrt', paddle.rsqrt, (np.random.rand(100, 200),), None, 'float32'), - ('cos', paddle.cos, (np.random.rand(200, 90),), None, 'float32'), - ('exp', paddle.exp, (np.random.rand(299, 320),), None, 'float32'), - # In where op, grad of condition computed by paddle.static.gradients is None, - # and paddle.incubate.autograd.grad will replace None with zeros while transpose - # will just return None because cond_dot is unused, that is a diff. - ( - 'select', - where_wrap, - (np.random.rand(3, 4), np.random.rand(3, 4)), - None, - 'float32', - ), - # pow_p and pow has diff when compute z_dot of 0^0 - ( - 'pow', - paddle.pow, - (np.array([1, 2, 3]), np.array([0, 2, 7])), - None, - 'float32', - ), - # To make max_p consistent with paddle.maximum, be sure x.grad = 0 and y.grad = 1 when x==y. - ( - 'max', - paddle.maximum, - ( - np.array([1, 2, 3]), - np.array([2, 2, 2]), - ), - None, - 'float32', - ), - ('erf', paddle.erf, (np.random.rand(300, 288),), None, 'float32'), - ( - 'gelu', - paddle.nn.functional.gelu, - (np.random.rand(200, 189),), - None, - 'float32', - ), - ( - 'gelu_approximate', - lambda x: paddle.nn.functional.gelu(x, True), - (np.random.rand(200, 189),), - None, - 'float32', - ), - ('sum', paddle.sum, (np.random.rand(200, 345),), None, 'float32'), - ( - 'sigmoid', - paddle.nn.functional.sigmoid, - ( - np.random.rand( - 5, - ), - ), - None, - 'float32', - ), - ( - 'sum_with_axis', - lambda x: paddle.sum(x, axis=1), - (np.random.rand(200, 345),), - None, - 'float32', - ), - ( - 'sum_with_keepdim', - lambda x: paddle.sum(x, keepdim=True), - (np.random.rand(200, 345),), - None, - 'float32', - ), - ('mean', paddle.mean, (np.random.rand(200, 345),), None, 'float32'), - ( - 'mean_with_axis', - lambda x: paddle.mean(x, axis=1), - (np.random.rand(200, 345),), - None, - 'float32', - ), - ( - 'mean_with_keepdim', - lambda x: paddle.mean(x, keepdim=True), - (np.random.rand(200, 345),), - None, - 'float32', - ), - ( - 'mean_with_axis_keepdim', - lambda x: paddle.mean(x, axis=0, keepdim=True), - (np.random.rand(200, 345),), - None, - 'float32', - ), - ( - 'abs', - paddle.abs, - (np.random.uniform(-10, 10, (200, 345)),), - None, - 'float32', - ), - ( - 'cast_float', - lambda x: paddle.cast(x, paddle.float64), - (np.random.rand(10, 20),), - None, - 'float32', - ), - ( - 'cast_int', - lambda x: paddle.cast(x, paddle.int32), - (np.random.rand(10, 20),), - None, - 'float32', - ), - ('square', paddle.square, (np.random.rand(100),), None, 'float32'), - ( - 'pow_scalar', - lambda x: paddle.pow(x, 2), - (np.random.rand(20, 30),), - None, - 'float32', - ), - ( - 'var', - lambda x: paddle.var(x, unbiased=False), - (np.random.rand(200, 324),), - None, - 'float32', - ), - ( - 'var_with_axis', - lambda x: paddle.var(x, axis=1, unbiased=False), - (np.random.rand(10, 20, 30),), - None, - 'float32', - ), - ( - 'var_with_keepdim', - lambda x: paddle.var(x, axis=1, keepdim=True, unbiased=False), - (np.random.rand(10, 20, 30),), - None, - 'float32', - ), - ( - 'bn', - lambda x, w, b: paddle.nn.functional.batch_norm( - x, paddle.ones((10,)), paddle.ones((10,)), w, b - ), - (np.random.rand(10, 10), np.random.rand(10), np.random.rand(10)), - None, - 'float32', - ), - ( - 'bn_train', - lambda x, w, b: paddle.nn.functional.batch_norm( - x, paddle.ones((10,)), paddle.ones((10,)), w, b, training=True - ), - (np.random.rand(10, 10), np.random.rand(10), np.random.rand(10)), - None, - 'float32', - ), - ( - 'bn_nhwc', - lambda x, w, b: paddle.nn.functional.batch_norm( - x, - paddle.ones((10,)) + 1, - paddle.ones((10,)), - w, - b, - training=True, - data_format='NHWC', - ), - (np.random.rand(10, 10), np.random.rand(10), np.random.rand(10)), - None, - 'float32', - ), - ( - 'bn_global_stat', - lambda x, w, b: paddle.nn.functional.batch_norm( - x, - paddle.ones((10,)) + 3.2, - paddle.ones((10,)) + 6.7, - w, - b, - training=True, - data_format='NHWC', - use_global_stats=True, - ), - (np.random.rand(10, 10), np.random.rand(10), np.random.rand(10)), - None, - 'float32', - ), - ), -) -class TestGrad(unittest.TestCase): - def setUp(self): - paddle.enable_static() - paddle.incubate.autograd.enable_prim() - - def tearDown(self): - paddle.incubate.autograd.disable_prim() - paddle.disable_static() - - @classmethod - def setUpClass(cls): - cls.xs = tuple(x.astype(cls.dtype) for x in cls.xs) - cls._rtol = ( - config.TOLERANCE.get(str(cls.dtype)) - .get("first_order_grad") - .get("rtol") - ) - cls._atol = ( - config.TOLERANCE.get(str(cls.dtype)) - .get("first_order_grad") - .get("atol") - ) - - def test_grad(self): - def expected(): - paddle.incubate.autograd.disable_prim() - sp = paddle.static.Program() - mp = paddle.static.Program() - with paddle.static.program_guard(mp, sp): - feed, static_xs, static_v = utils.gen_static_data_and_feed( - self.xs, self.v, stop_gradient=False - ) - _, ys_grad = paddle.incubate.autograd.vjp( - self.fun, static_xs, static_v - ) - exe = paddle.static.Executor() - exe.run(sp) - out = exe.run(mp, feed=feed, fetch_list=ys_grad) - paddle.incubate.autograd.enable_prim() - return out - - def actual(): - paddle.incubate.autograd.enable_prim() - sp = paddle.static.Program() - mp = paddle.static.Program() - with paddle.static.program_guard(mp, sp): - feed, static_xs, static_v = utils.gen_static_data_and_feed( - self.xs, self.v, stop_gradient=False - ) - ys = ( - self.fun(*static_xs) - if isinstance(static_xs, typing.Sequence) - else self.fun(static_xs) - ) - ys_grad = paddle.incubate.autograd.grad(ys, static_xs, static_v) - paddle.incubate.autograd.prim2orig(mp.block(0)) - exe = paddle.static.Executor() - exe.run(sp) - out = exe.run(mp, feed=feed, fetch_list=ys_grad) - paddle.incubate.autograd.disable_prim() - return out - - actual = actual() - expected = expected() - self.assertEqual(type(actual), type(expected)) - for i, j in zip(actual, expected): - np.testing.assert_allclose(i, j, rtol=self._rtol, atol=self._atol) - - def test_illegal_param(self): - paddle.incubate.autograd.enable_prim() - with self.assertRaises(TypeError): - paddle.incubate.autograd.grad( - 1, paddle.static.data('inputs', shape=[1]) - ) - - with self.assertRaises(TypeError): - paddle.incubate.autograd.grad( - paddle.static.data('targets', shape=[1]), 1 - ) - paddle.incubate.autograd.disable_prim() - - def test_disable_prim(self): - def expected(): - paddle.incubate.autograd.disable_prim() - sp = paddle.static.Program() - mp = paddle.static.Program() - with paddle.static.program_guard(mp, sp): - feed, static_xs, static_v = utils.gen_static_data_and_feed( - self.xs, self.v, stop_gradient=False - ) - ys = ( - self.fun(*static_xs) - if isinstance(static_xs, typing.Sequence) - else self.fun(static_xs) - ) - ys_grad = paddle.incubate.autograd.grad(ys, static_xs, static_v) - exe = paddle.static.Executor() - exe.run(sp) - out = exe.run(mp, feed=feed, fetch_list=ys_grad) - paddle.incubate.autograd.enable_prim() - return out - - def actual(): - paddle.incubate.autograd.disable_prim() - sp = paddle.static.Program() - mp = paddle.static.Program() - with paddle.static.program_guard(mp, sp): - feed, static_xs, static_v = utils.gen_static_data_and_feed( - self.xs, self.v, stop_gradient=False - ) - ys = ( - self.fun(*static_xs) - if isinstance(static_xs, typing.Sequence) - else self.fun(static_xs) - ) - ys_grad = paddle.static.gradients(ys, static_xs, static_v) - exe = paddle.static.Executor() - exe.run(sp) - out = exe.run(mp, feed=feed, fetch_list=ys_grad) - paddle.incubate.autograd.enable_prim() - return out - - actual = actual() - expected = expected() - self.assertEqual(type(actual), type(expected)) - for i, j in zip(actual, expected): - np.testing.assert_allclose(i, j, rtol=self._rtol, atol=self._atol) - - -def multiply_pd(x): - x2 = paddle.multiply(x, x) - x3 = paddle.multiply(x2, x2) - x4 = paddle.multiply(x3, x) - return x4 - - -multiply_ag = lambda xs: xs[0] * xs[0] * xs[0] * xs[0] * xs[0] -sin_ag = lambda xs: anp.sin(xs[0]) -cos_ag = lambda xs: anp.cos(xs[0]) -exp_ag = lambda xs: anp.exp(xs[0]) -pow_ag = lambda xs: xs[0] ** xs[1] -log_ag = lambda xs: anp.log(xs[0]) -erf_ag = lambda xs: ascipy.special.erf(xs[0]) -sigmoid_ag = lambda xs: 1.0 / (1 + anp.exp(-xs[0])) - - -def gelu_ag(x, approximate=False): - if approximate: - sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype) - cdf = 0.5 * (1.0 + anp.tanh(sqrt_2_over_pi * (x + 0.044715 * (x**3)))) - return x * cdf - else: - return x * (ascipy.special.erf(x / np.sqrt(2)) + 1) / 2 - - -@utils.place(config.DEVICES) -@utils.parameterize( - (utils.TEST_CASE_NAME, 'fun_pd', 'fun_ag', 'xs', 'v', 'dtype'), - ( - ( - 'multiply', - multiply_pd, - multiply_ag, - (np.random.rand(3, 5),), - None, - 'float32', - ), - ('sin', paddle.sin, sin_ag, (np.random.rand(2, 3),), None, 'float32'), - ('cos', paddle.cos, cos_ag, (np.random.rand(3, 4),), None, 'float32'), - ('exp', paddle.exp, exp_ag, (np.random.rand(2, 3),), None, 'float32'), - ( - 'pow', - paddle.pow, - pow_ag, - (np.random.rand(2, 3), np.random.rand(2, 3)), - None, - 'float32', - ), - ('log', paddle.log, log_ag, (np.random.rand(3, 8),), None, 'float32'), - ( - 'erf', - paddle.erf, - erf_ag, - (np.random.rand(100, 200),), - None, - 'float32', - ), - ( - 'gelu', - paddle.nn.functional.gelu, - lambda xs: gelu_ag(xs[0]), - (np.random.rand(10, 20, 30),), - None, - 'float32', - ), - ( - 'gelu_approximate', - lambda x: paddle.nn.functional.gelu(x, approximate=True), - lambda xs: gelu_ag(xs[0], approximate=True), - (np.random.rand(10, 20, 30),), - None, - 'float32', - ), - ( - 'sigmoid', - paddle.nn.functional.sigmoid, - sigmoid_ag, - (np.random.rand(10, 20),), - None, - 'float32', - ), - ), -) -class TestGradWithHigherOrder(unittest.TestCase): - def setUp(self): - paddle.enable_static() - paddle.incubate.autograd.enable_prim() - - def tearDown(self): - paddle.incubate.autograd.disable_prim() - paddle.disable_static() - - @classmethod - def setUpClass(cls): - cls.xs = tuple(x.astype(cls.dtype) for x in cls.xs) - cls._rtol = ( - config.TOLERANCE.get(str(cls.dtype)) - .get("first_order_grad") - .get("rtol") - ) - cls._atol = ( - config.TOLERANCE.get(str(cls.dtype)) - .get("first_order_grad") - .get("atol") - ) - - def test_grad(self): - def expected(): - egrad = autograd.elementwise_grad - grad_3 = egrad(egrad(egrad(self.fun_ag)))(self.xs) - grad_4 = egrad(egrad(egrad(egrad(self.fun_ag))))(self.xs) - grad_5 = egrad(egrad(egrad(egrad(egrad(self.fun_ag)))))(self.xs) - # the output of egrad is tuple - return list(grad_3 + grad_4 + grad_5) - - def actual(): - paddle_grad = paddle.incubate.autograd.grad - paddle.incubate.autograd.enable_prim() - main = paddle.static.Program() - startup = paddle.static.Program() - with paddle.static.program_guard(main, startup): - feed, static_xs, static_v = utils.gen_static_data_and_feed( - self.xs, self.v, stop_gradient=False - ) - ys = ( - self.fun_pd(*static_xs) - if isinstance(static_xs, typing.Sequence) - else self.fun_pd(static_xs) - ) - - grad1 = paddle_grad(ys, static_xs, static_v) - grad2 = paddle_grad(grad1, static_xs, static_v) - grad3 = paddle_grad(grad2, static_xs, static_v) - grad4 = paddle_grad(grad3, static_xs, static_v) - grad5 = paddle_grad(grad4, static_xs, static_v) - paddle.incubate.autograd.prim2orig() - - fetch_list = [grad3, grad4, grad5] - - place = paddle.CPUPlace() - if paddle.device.is_compiled_with_cuda(): - place = paddle.CUDAPlace(0) - exe = paddle.static.Executor(place) - exe.run(startup) - outs = exe.run(main, feed=feed, fetch_list=fetch_list) - paddle.incubate.autograd.disable_prim() - return outs - - actual = actual() - expected = expected() - self.assertEqual(type(actual), type(expected)) - for i, j in zip(actual, expected): - np.testing.assert_allclose(i, j, rtol=self._rtol, atol=self._atol) - - -class TestToPrim(unittest.TestCase): - def setUp(self): - paddle.enable_static() - core._set_prim_forward_enabled(True) - - def tearDown(self): - core._set_prim_forward_enabled(False) - paddle.disable_static() - - @param.parameterized.expand((({'dropout'},),)) - def test_blacklist(self, blacklist): - program = paddle.static.Program() - with paddle.static.program_guard(program): - paddle.nn.functional.softmax( - paddle.nn.functional.dropout(paddle.rand((1,))) - ) - primapi.to_prim(program.blocks, blacklist=blacklist) - ops = tuple(op.type for op in program.block(0).ops) - self.assertTrue(all(tuple(op in ops for op in blacklist))) - - @param.parameterized.expand((({'dropout'},),)) - def test_whitelist(self, whitelist): - program = paddle.static.Program() - with paddle.static.program_guard(program): - paddle.nn.functional.softmax( - paddle.nn.functional.dropout(paddle.rand((1,))) - ) - primapi.to_prim(program.blocks, whitelist=whitelist) - ops = tuple(op.type for op in program.block(0).ops) - self.assertTrue(all(tuple(op not in ops for op in whitelist))) - - @param.parameterized.expand((({'softmax'}, {'softmax', 'dropout'}),)) - def test_both_not_empty(self, blacklist, whitelist): - program = paddle.static.Program() - with paddle.static.program_guard(program): - paddle.nn.functional.softmax( - paddle.nn.functional.dropout(paddle.rand((1,))) - ) - primapi.to_prim( - program.blocks, blacklist=blacklist, whitelist=whitelist - ) - ops = tuple(op.type for op in program.block(0).ops) - self.assertTrue(all(tuple(op in ops for op in blacklist))) - - @param.parameterized.expand(((('dropout',), 'softmax'),)) - def test_type_error(self, blacklist, whitelist): - program = paddle.static.Program() - with paddle.static.program_guard(program): - paddle.nn.functional.softmax( - paddle.nn.functional.dropout(paddle.rand((1,))) - ) - with self.assertRaises(TypeError): - primapi.to_prim( - program.blocks, blacklist=blacklist, whitelist=whitelist - ) - - -if __name__ == '__main__': - unittest.main() diff --git a/test/autograd/test_primops.py b/test/autograd/test_primops.py deleted file mode 100644 index 9a20dd377b2c5..0000000000000 --- a/test/autograd/test_primops.py +++ /dev/null @@ -1,235 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import sys -import unittest -import uuid - -sys.path.insert(0, '.') - -import config -import numpy as np -import utils -from numpy.random import randint, randn - -import paddle -from paddle.incubate.autograd import primops - -paddle.enable_static() - - -@utils.place(config.DEVICES) -@utils.parameterize( - ( - utils.TEST_CASE_NAME, - 'op', - 'args', - 'kwargs', - 'expected_shape', - 'expected_dtype', - ), - ( - ('add', primops.add, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'), - ('sub', primops.sub, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'), - ('mul', primops.mul, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'), - ('div', primops.div, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'), - ('sub', primops.sub, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'), - ('sqrt', primops.sqrt, randn(2, 3), {}, (2, 3), 'float64'), - ('tanh', primops.tanh, randn(2, 3), {}, (2, 3), 'float64'), - ('sin', primops.sin, randn(2, 3), {}, (2, 3), 'float64'), - ('cos', primops.cos, randn(2, 3), {}, (2, 3), 'float64'), - ('exp', primops.exp, randn(2, 3), {}, (2, 3), 'float64'), - ('erf', primops.erf, randn(2, 3), {}, (2, 3), 'float64'), - ('abs', primops.abs, randn(2, 3), {}, (2, 3), 'float64'), - ('log', primops.log, randn(2, 3), {}, (2, 3), 'float64'), - ( - 'cast', - primops.cast, - randn(2, 3), - {'dtype': paddle.int64}, - (2, 3), - 'int64', - ), - ( - 'reshape', - primops.reshape, - randn(2, 3), - {'shape': (3, 2)}, - (3, 2), - 'float64', - ), - ( - 'broadcast', - primops.broadcast, - randn(2), - {'shape': (3, 2)}, - (3, 2), - 'float64', - ), - ( - 'transpose', - primops.transpose, - randn(2, 3), - {'axis': (1, 0)}, - (3, 2), - 'float64', - ), - ( - 'concat_axis0', - primops.concat, - ((randn(2, 3), randn(2, 3)),), - {'axis': 0}, - (4, 3), - 'float64', - ), - ( - 'concat_axis1', - primops.concat, - ((randn(2, 3), randn(2, 3)),), - {'axis': 1}, - (2, 6), - 'float64', - ), - ( - 'reduce_axis1', - primops.reduce_sum, - randn(2, 3), - {'axis': (1,)}, - (2,), - 'float64', - ), - ( - 'reduce_axis01', - primops.reduce_sum, - randn(2, 3), - {'axis': (0, 1)}, - (), - 'float64', - ), - ( - 'split', - primops.split, - randn(2, 3), - {'num_or_sections': [1, 2], 'axis': 1}, - ((2, 1), (2, 2)), - ('float64', 'float64'), - ), - ( - 'matmul', - primops.matmul, - (randn(2, 3), randn(3, 2)), - {}, - (2, 2), - 'float64', - ), - ( - 'slice_select', - primops.slice_select, - randn(3, 2), - {'axis': [0], 'starts': [0], 'ends': [2], 'strides': [1]}, - (2, 2), - 'float64', - ), - ( - 'slice_assign', - primops.slice_assign, - (randn(2, 3), randn(2, 2)), - {'axis': [1], 'starts': [1], 'ends': [3], 'strides': [1]}, - (2, 3), - 'float64', - ), - ( - 'gather', - primops.gather, - (randn(3, 2), randint(0, 2, (5,), np.int32)), - {'axis': 0}, - (5, 2), - 'float64', - ), - ( - 'scatter_add', - primops.scatter_add, - (randn(3, 2), randn(5, 2), randint(0, 2, (5,), np.int32)), - {'axis': 0}, - (3, 2), - 'float64', - ), - ( - 'fill_const', - primops.fill_const, - (), - {'value': 10, 'shape': (3, 2), 'dtype': paddle.float32}, - (3, 2), - 'float32', - ), - ('neg', primops.neg, randn(2, 3), {}, (2, 3), 'float64'), - ( - 'select', - primops.select, - (randn(2, 3) > 0, randn(2, 3), randn(2, 3)), - {}, - (2, 3), - 'float64', - ), - ('eq', primops.eq, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'bool'), - ('ne', primops.ne, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'bool'), - ('gt', primops.gt, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'bool'), - ('ge', primops.ge, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'bool'), - ('pow', primops.pow, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'), - ('max', primops.max, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'), - ), -) -class TestPrimops(unittest.TestCase): - @classmethod - def setUpClass(cls): - paddle.enable_static() - - @classmethod - def tearDownClass(cls): - paddle.disable_static() - - def test_prim_ops(self): - program = paddle.static.Program() - with paddle.static.program_guard(program): - args = self._as_tuple(self.args) - args = self.arr2var(args) - results = self.op(*args, **self.kwargs) - results = self._as_tuple(results) - expected_shape = self._as_tuple(self.expected_shape) - expected_dtype = self._as_tuple(self.expected_dtype) - - for r, shape, dtype in zip(results, expected_shape, expected_dtype): - self.assertEqual(r.shape, shape) - self.assertEqual(str(r.dtype).split('.')[1], dtype) - - def arr2var(self, arr): - """convert numpy ndarray to paddle Variable recursively.""" - return [ - paddle.static.data(f'x{uuid.uuid4()}', v.shape, v.dtype) - if isinstance(v, np.ndarray) - else self.arr2var(v) - for v in arr - ] - - def _as_tuple(self, input): - if isinstance(input, (tuple, list)) and len(input) == 0: - return input - if not isinstance(input, (tuple, list)) or all( - isinstance(i, int) for i in input - ): - return (input,) - return input - - -if __name__ == '__main__': - unittest.main() diff --git a/test/autograd/test_transform.py b/test/autograd/test_transform.py deleted file mode 100644 index 6116c0b5b490c..0000000000000 --- a/test/autograd/test_transform.py +++ /dev/null @@ -1,484 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import paddle -from paddle.incubate.autograd.primx import Transform, orig2prim, prim2orig - -paddle.enable_static() - - -class TestAutoGradTransformForAdd(unittest.TestCase): - # This UT is deprecated for 'prim2org' mechanism has been already deprecated - # so this UT will be skipped as method 'test_run' was renamed to '_test_run' - def setUp(self): - self.main_program = paddle.static.Program() - self.startup_program = paddle.static.Program() - - with paddle.static.program_guard( - self.main_program, self.startup_program - ): - self.init_data() - - def init_data(self): - # { input_index: input_shape } - self.xs_shape_map = {0: (20, 40), 1: (20, 40)} - # { output_index: output_shape } - self.ys_shape_map = {0: (20, 40)} - X0 = paddle.static.data( - name='X0', shape=self.xs_shape_map[0], dtype='float32' - ) - X0.stop_gradient = False - X1 = paddle.static.data( - name='X1', shape=self.xs_shape_map[1], dtype='float32' - ) - X1.stop_gradient = False - - A = paddle.tanh(X0) - B = paddle.tanh(X1) - C = paddle.rsqrt(B) - Y = paddle.add(A, C) - - self.orig_xs = [X0, X1] - self.orig_ys = [ - Y, - ] - - self.orig_ops = ['tanh', 'tanh', 'elementwise_add', 'rsqrt'] - self.orig2prim_ops = ['tanh_p', 'tanh_p', 'add_p', 'rsqrt_p'] - self.linearize_ops = self.orig2prim_ops + [ - # call fill_const() in linearize() function - 'fill_constant_p', - 'fill_constant_p', - # linearized op - 'mul_p', - 'sub_p', - 'fill_constant_p', - 'mul_p', - 'mul_p', - 'sub_p', - 'fill_constant_p', - 'mul_p', - 'add_p', - 'fill_constant_p', - 'div_p', - 'div_p', - 'mul_p', - ] - self.transpose_ops = self.orig2prim_ops + [ - # call fill_const() in transpose() function - 'fill_constant_p', - # linearized op after remove path - 'fill_constant_p', - 'fill_constant_p', - 'mul_p', - 'sub_p', - 'fill_constant_p', - 'mul_p', - 'sub_p', - 'fill_constant_p', - 'mul_p', - 'div_p', - 'div_p', - 'fill_constant_p', - # transposed op - 'mul_p', - 'mul_p', - ] - self.prim2orig_ops_with_blacklist = [ - 'tanh', - 'tanh', - 'add_p', - 'fill_constant', - 'fill_constant', - 'fill_constant', - 'elementwise_mul', - 'sub_p', - 'fill_constant', - 'elementwise_mul', - 'sub_p', - 'fill_constant', - 'elementwise_mul', - 'elementwise_mul', - 'rsqrt', - 'fill_constant', - 'elementwise_div', - 'elementwise_div', - 'elementwise_mul', - ] - self.prim2orig_ops = [ - 'tanh', - 'tanh', - 'elementwise_add', - 'fill_constant', - 'fill_constant', - 'fill_constant', - 'elementwise_mul', - 'elementwise_sub', - 'fill_constant', - 'elementwise_mul', - 'elementwise_sub', - 'fill_constant', - 'elementwise_mul', - 'elementwise_mul', - 'rsqrt', - 'fill_constant', - 'elementwise_div', - 'elementwise_div', - 'elementwise_mul', - ] - - def _test_run(self): - # Must using with program_guard(), otherwise prim ops will append other block - with paddle.static.program_guard( - self.main_program, self.startup_program - ): - ad = Transform(self.main_program.block(0)) - orig_ops = [op.type for op in self.main_program.block(0).ops] - self.assertEqual(sorted(orig_ops), sorted(self.orig_ops)) - - # Test orig2prim - orig2prim(block=self.main_program.block(0)) - orig2prim_ops = [op.type for op in self.main_program.block(0).ops] - self.assertEqual(sorted(orig2prim_ops), sorted(self.orig2prim_ops)) - - # Test linearize - xs_dot, ys_dot = ad.linearize(self.orig_xs, self.orig_ys) - linearize_ops = [op.type for op in self.main_program.block(0).ops] - self.assertEqual(sorted(linearize_ops), sorted(self.linearize_ops)) - flatten_xs_dot = paddle.utils.flatten(xs_dot) - for k, v in self.xs_shape_map.items(): - self.assertEqual(flatten_xs_dot[k].shape, v) - flatten_ys_dot = paddle.utils.flatten(ys_dot) - for k, v in self.ys_shape_map.items(): - self.assertEqual(flatten_ys_dot[k].shape, v) - - # Test transpose - ys_bar, xs_bar = ad.transpose(ys_dot, xs_dot, retain_fwd=False) - transpose_ops = [op.type for op in self.main_program.block(0).ops] - self.assertEqual(sorted(transpose_ops), sorted(self.transpose_ops)) - flatten_xs_bar = paddle.utils.flatten(xs_bar) - for k, v in self.xs_shape_map.items(): - # There may be None in the result of transpose like gather op - if flatten_xs_bar[k] is not None: - self.assertEqual(flatten_xs_bar[k].shape, v) - flatten_ys_bar = paddle.utils.flatten(ys_bar) - for k, v in self.ys_shape_map.items(): - self.assertEqual(flatten_ys_bar[k].shape, v) - - # Test prim2orig with blacklist - prim2orig( - block=self.main_program.block(0), blacklist=['add_p', 'sub_p'] - ) - prim2orig_ops = [op.type for op in self.main_program.block(0).ops] - self.assertEqual( - sorted(prim2orig_ops), sorted(self.prim2orig_ops_with_blacklist) - ) - - # Test prim2orig - prim2orig(block=self.main_program.block(0)) - prim2orig_ops = [op.type for op in self.main_program.block(0).ops] - self.assertEqual(sorted(prim2orig_ops), sorted(self.prim2orig_ops)) - - -class TestAutoGradTransformForMatmul(TestAutoGradTransformForAdd): - def init_data(self): - # { input_index: input_shape } - self.xs_shape_map = {0: (100, 2), 1: (5, 2)} - # { output_index: output_shape } - self.ys_shape_map = {0: (100, 5)} - X0 = paddle.static.data( - 'X0', shape=self.xs_shape_map[0], dtype='float32' - ) - X0.stop_gradient = False - X1 = paddle.static.data( - 'X1', shape=self.xs_shape_map[1], dtype='float32' - ) - X1.stop_gradient = False - - A = paddle.reshape(X1, [2, 5]) - B = paddle.scale(A, scale=2.0, bias=2.0) - Y = paddle.matmul(X0, B) - - self.orig_xs = [X0, X1] - self.orig_ys = [ - Y, - ] - - self.orig_ops = ['reshape2', 'scale', 'matmul_v2'] - self.orig2prim_ops = [ - 'reshape_p', - 'fill_constant_p', - 'fill_constant_p', - 'fill_constant_p', - 'mul_p', - 'add_p', - 'matmul_p', - ] - self.linearize_ops = self.orig2prim_ops + [ - # call fill_const() in linearize() function - 'fill_constant_p', - 'fill_constant_p', - # linearized op - 'reshape_p', - 'mul_p', - # 'mul_p', # JVP rules handle `None` input, some op will not be appended - # 'add_p', - # 'add_p', - 'matmul_p', - 'matmul_p', - 'add_p', - ] - self.transpose_ops = self.orig2prim_ops + [ - # call fill_const() in transpose() function - 'fill_constant_p', - # linearized op after remove path - 'fill_constant_p', - 'fill_constant_p', - 'mul_p', - # transposed op - 'transpose_p', - 'matmul_p', - 'transpose_p', - 'matmul_p', - # 'mul_p', - 'reshape_p', - ] - - self.prim2orig_ops_with_blacklist = [ - 'reshape2', - 'fill_constant', - 'fill_constant', - 'fill_constant', - 'elementwise_mul', - 'add_p', - 'matmul_v2', - 'fill_constant', - 'fill_constant', - 'fill_constant', - 'elementwise_mul', - 'transpose2', - 'matmul_v2', - 'transpose2', - 'matmul_v2', - # 'elementwise_mul', - 'reshape2', - ] - - self.prim2orig_ops = [ - 'reshape2', - 'fill_constant', - 'fill_constant', - 'fill_constant', - 'elementwise_mul', - 'elementwise_add', - 'matmul_v2', - 'fill_constant', - 'fill_constant', - 'fill_constant', - 'elementwise_mul', - 'transpose2', - 'matmul_v2', - 'transpose2', - 'matmul_v2', - # 'elementwise_mul', - 'reshape2', - ] - - -class TestAutoGradTransformForIndexSelect(TestAutoGradTransformForAdd): - def init_data(self): - # { input_index: input_shape } - self.xs_shape_map = {0: (7, 8, 9), 1: (8, 1), 2: (7, 8, 9), 3: (3,)} - # { output_index: output_shape } - self.ys_shape_map = {0: (3, 16, 9)} - - X0 = paddle.static.data( - 'X0', shape=self.xs_shape_map[0], dtype='float32' - ) - X0.stop_gradient = False - X1 = paddle.static.data( - 'X1', shape=self.xs_shape_map[1], dtype='float32' - ) - X1.stop_gradient = False - X2 = paddle.static.data( - 'X2', shape=self.xs_shape_map[2], dtype='float32' - ) - X2.stop_gradient = False - X3 = paddle.static.data('X3', shape=self.xs_shape_map[3], dtype='int32') - X3.stop_gradient = False - - A = paddle.add(X0, X1) # (7, 8, 9) - B = paddle.norm(x=A, p=2) # (1, ) - C = paddle.subtract(X2, B) # (7, 8, 9) - D = paddle.concat(x=(A, C), axis=1) # (7, 16, 9) - Y = paddle.index_select(D, X3, axis=0) # (3, 16, 9) - - self.orig_xs = [X0, X1, X2, X3] - self.orig_ys = [ - Y, - ] - self.orig_ops = [ - 'elementwise_add', - 'p_norm', - 'elementwise_sub', - 'concat', - 'index_select', - ] - self.orig2prim_ops = [ - 'broadcast_p', - 'add_p', - 'reshape_p', - 'mul_p', - 'reduce_sum_p', - 'sqrt_p', - 'broadcast_p', - 'sub_p', - 'concat_p', - 'gather_p', - ] - self.linearize_ops = self.orig2prim_ops + [ - # call fill_const() in linearize() function - 'fill_constant_p', - 'fill_constant_p', - 'fill_constant_p', - 'fill_constant_p', - # linearized op - 'broadcast_p', - 'add_p', - 'reshape_p', - 'mul_p', - 'mul_p', - 'add_p', - 'reduce_sum_p', - 'fill_constant_p', # 'sqrt_p', Will not append sqrt_p op when apply JVP for sqrt_p - 'mul_p', - 'div_p', - 'broadcast_p', - 'sub_p', - 'concat_p', - 'gather_p', - ] - self.transpose_ops = self.orig2prim_ops + [ - # call fill_const() in transpose() function - 'fill_constant_p', - # linearized op after remove path - 'fill_constant_p', - 'fill_constant_p', - 'fill_constant_p', - 'fill_constant_p', - 'fill_constant_p', - 'mul_p', - # transposed op - 'reduce_sum_p', - 'reshape_p', - 'reshape_p', - 'mul_p', - 'mul_p', - 'reshape_p', - 'broadcast_p', - 'div_p', - 'reduce_sum_p', - 'reshape_p', - 'fill_constant_p', - 'sub_p', - 'split_p', - 'fill_constant_p', - 'scatter_add_p', - 'add_p', # The output of the op is used by multiple subsequent ops - 'add_p', - ] - - self.prim2orig_ops_with_blacklist = [ - 'expand_v2', - 'add_p', - 'reshape2', - 'elementwise_mul', - 'reduce_sum', - 'sqrt', - 'expand_v2', - 'sub_p', - 'concat', - 'gather', - 'fill_constant', - 'fill_constant', - 'fill_constant', - 'fill_constant', - 'fill_constant', - 'fill_constant', - 'elementwise_mul', - 'reduce_sum', - 'reshape2', - 'reshape2', - 'elementwise_mul', - 'elementwise_mul', - 'reshape2', - 'expand_v2', - 'elementwise_div', - 'reduce_sum', - 'reshape2', - 'fill_constant', - 'sub_p', - 'split', - 'fill_constant', - 'fill_any_like', - 'add_p', - 'scatter', - 'elementwise_add', - 'add_p', - ] - - self.prim2orig_ops = [ - 'expand_v2', - 'elementwise_add', - 'reshape2', - 'elementwise_mul', - 'reduce_sum', - 'sqrt', - 'expand_v2', - 'elementwise_sub', - 'concat', - 'gather', - 'fill_constant', - 'fill_constant', - 'fill_constant', - 'fill_constant', - 'fill_constant', - 'fill_constant', - 'elementwise_mul', - 'reduce_sum', - 'reshape2', - 'reshape2', - 'elementwise_mul', - 'elementwise_mul', - 'reshape2', - 'expand_v2', - 'elementwise_div', - 'reduce_sum', - 'reshape2', - 'fill_constant', - 'elementwise_sub', - 'split', - 'fill_constant', - 'fill_any_like', - 'elementwise_add', - 'scatter', - 'elementwise_add', - 'elementwise_add', - ] - - -if __name__ == '__main__': - unittest.main()