Skip to content

Commit

Permalink
move eig operator from fluid to phi (#44398)
Browse files Browse the repository at this point in the history
* move eig operator from fluid to phi

* add eig_grad unitest, upgrade IsComplexType() from fluid to phi
  • Loading branch information
freeliuzc authored Jul 19, 2022
1 parent 9e30722 commit 3788f5e
Show file tree
Hide file tree
Showing 16 changed files with 704 additions and 408 deletions.
86 changes: 16 additions & 70 deletions paddle/fluid/operators/eig_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,45 +17,18 @@
#include <string>
#include <vector>

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"

namespace paddle {
namespace operators {

class EigOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Eig");
OP_INOUT_CHECK(
ctx->HasOutput("Eigenvalues"), "Output", "Eigenvalues", "Eig");
OP_INOUT_CHECK(
ctx->HasOutput("Eigenvectors"), "Output", "Eigenvectors", "Eig");

auto x_dims = ctx->GetInputDim("X");
int rank = x_dims.size();
PADDLE_ENFORCE_GE(rank,
2,
platform::errors::InvalidArgument(
"Expects input tensor x to be not less than "
"2 dimentions, but got dimention %d",
rank));
PADDLE_ENFORCE_EQ(x_dims[rank - 2],
x_dims[rank - 1],
platform::errors::InvalidArgument(
"The input matrix must be a square matrix, "
"but receive a matrix with %d rows and %d colums",
x_dims[rank - 2],
x_dims[rank - 1]));

std::vector<int> batch_dims_vec{};
for (int i = 0; i < rank - 1; ++i) {
batch_dims_vec.emplace_back(x_dims[i]);
}

ctx->SetOutputDim("Eigenvectors", x_dims);
ctx->SetOutputDim("Eigenvalues", phi::make_ddim(batch_dims_vec));
}

protected:
// The output of eig is always complex-valued even for real-valued inputs
Expand Down Expand Up @@ -100,26 +73,6 @@ This API processes eigen decomposition for general square matrices.
class EigGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(
ctx->HasInput("Eigenvalues"), "Input", "Eigenvalues", "EigGrad");
OP_INOUT_CHECK(
ctx->HasInput("Eigenvectors"), "Input", "Eigenvectors", "EigGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Eigenvalues")),
"Input",
"Eigenvalues@GRAD",
"EigGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Eigenvectors")),
"Input",
"Eigenvectors@GRAD",
"EigGrad");

auto dims = ctx->GetInputDim("Eigenvectors");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, dims);
}
}

protected:
framework::OpKernelType GetExpectedKernelType(
Expand Down Expand Up @@ -152,27 +105,20 @@ class EigGradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace operators
} // namespace paddle

using complex64 = paddle::platform::complex<float>;
using complex128 = paddle::platform::complex<double>;

namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(eig,
EigInferShapeFunctor,
PD_INFER_META(phi::EigInferMeta));

DECLARE_INFER_SHAPE_FUNCTOR(eig_grad,
EigGradInferShapeFunctor,
PD_INFER_META(phi::EigGradInferMeta));

REGISTER_OPERATOR(eig,
ops::EigOp,
ops::EigOpMaker,
ops::EigGradOpMaker<paddle::framework::OpDesc>,
ops::EigGradOpMaker<paddle::imperative::OpBase>);

REGISTER_OPERATOR(eig_grad, ops::EigGradOp);

REGISTER_OP_CPU_KERNEL(eig,
ops::EigKernel<phi::CPUContext, float, complex64>,
ops::EigKernel<phi::CPUContext, double, complex128>,
ops::EigKernel<phi::CPUContext, complex64, complex64>,
ops::EigKernel<phi::CPUContext, complex128, complex128>);

REGISTER_OP_CPU_KERNEL(
eig_grad,
ops::EigGradKernel<phi::CPUContext, float, complex64>,
ops::EigGradKernel<phi::CPUContext, double, complex128>,
ops::EigGradKernel<phi::CPUContext, complex64, complex64>,
ops::EigGradKernel<phi::CPUContext, complex128, complex128>);
ops::EigGradOpMaker<paddle::imperative::OpBase>,
EigInferShapeFunctor);

REGISTER_OPERATOR(eig_grad, ops::EigGradOp, EigGradInferShapeFunctor);
Loading

0 comments on commit 3788f5e

Please sign in to comment.