Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new API "eigvals" in linalg #35720

Merged
merged 15 commits into from
Sep 17, 2021
Merged
24 changes: 24 additions & 0 deletions paddle/fluid/framework/ddim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,30 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
return os;
}

DDim flatten_to_3d(const DDim& src, int num_row_dims, int num_col_dims){
PADDLE_ENFORCE_GE(src.size(), 3, platform::errors::InvalidArgument(
"The rank of src dim should be at least 3 in flatten_to_3d, but received %d.",
src.size()));
PADDLE_ENFORCE_EQ((num_row_dims >= 1 && num_row_dims < src.size()), true,
platform::errors::InvalidArgument(
"The num_row_dims should be inside [1, %d] in flatten_to_3d, but received %d.",
src.size() - 1, num_row_dims));
PADDLE_ENFORCE_EQ((num_col_dims >= 2 && num_col_dims <= src.size()), true,
platform::errors::InvalidArgument(
"The num_col_dims should be inside [2, %d] in flatten_to_3d, but received %d.",
src.size(), num_col_dims));
PADDLE_ENFORCE_GE(
num_col_dims, num_row_dims,
platform::errors::InvalidArgument(
"The num_row_dims should be less than num_col_dims in flatten_to_3d,"
"but received num_row_dims = %d, num_col_dims = %d.",
num_row_dims, num_col_dims));

return DDim({product(slice_ddim(src, 0, num_row_dims)),
product(slice_ddim(src, num_row_dims, num_col_dims)),
product(slice_ddim(src, num_col_dims, src.size()))});
}

DDim flatten_to_2d(const DDim& src, int num_col_dims) {
return DDim({product(slice_ddim(src, 0, num_col_dims)),
product(slice_ddim(src, num_col_dims, src.size()))});
Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/framework/ddim.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,13 @@ int arity(const DDim& ddim);

std::ostream& operator<<(std::ostream&, const DDim&);

/**
* \brief Flatten dim to 3d
* e.g., DDim d = mak_ddim({1, 2, 3, 4, 5, 6})
* flatten_to_3d(d, 2, 4); ===> {1*2, 3*4, 5*6} ===> {2, 12, 30}
*/
DDim flatten_to_3d(const DDim& src, int num_row_dims, int num_col_dims);

// Reshape a tensor to a matrix. The matrix's first dimension(column length)
// will be the product of tensor's first `num_col_dims` dimensions.
DDim flatten_to_2d(const DDim& src, int num_col_dims);
Expand Down
135 changes: 135 additions & 0 deletions paddle/fluid/operators/eigvals_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/eigvals_op.h"
#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {
class EigvalsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor), A complex- or real-valued tensor with shape (*, n, n)"
"where * is zero or more batch dimensions");
AddOutput("Out",
"(Tensor) The output tensor with shape (*,n) cointaining the "
"eigenvalues of X.");
AddComment(R"DOC(eigvals operator
Return the eigenvalues of one or more square matrices. The eigenvalues are complex even when the input matrices are real.
)DOC");
}
};

class EigvalsOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Eigvals");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Eigvals");

DDim x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_GE(x_dims.size(), 2,
platform::errors::InvalidArgument(
"The dimensions of Input(X) for Eigvals operator "
"should be at least 2, "
"but received X's dimension = %d, X's shape = [%s].",
x_dims.size(), x_dims));

if (ctx->IsRuntime() || !framework::contain_unknown_dim(x_dims)) {
int last_dim = x_dims.size() - 1;
PADDLE_ENFORCE_EQ(x_dims[last_dim], x_dims[last_dim - 1],
platform::errors::InvalidArgument(
"The last two dimensions of Input(X) for Eigvals "
"operator should be equal, "
"but received X's shape = [%s].",
x_dims));
}

auto output_dims = vectorize(x_dims);
output_dims.resize(x_dims.size() - 1);
ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
}
};

class EigvalsOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext* ctx) const {
auto input_dtype = ctx->GetInputDataType("X");
auto output_dtype = framework::IsComplexType(input_dtype)
? input_dtype
: framework::ToComplexType(input_dtype);
ctx->SetOutputDataType("Out", output_dtype);
}
};

class EigvalsGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "EigvalsGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@Grad", "EigvalsGrad");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
"X@Grad", "EigvalsGrad");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};

template <typename T>
class EigvalsGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("eigvals_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OPERATOR(eigvals, ops::EigvalsOp, ops::EigvalsOpMaker,
ops::EigvalsOpVarTypeInference,
ops::EigvalsGradOpMaker<paddle::framework::OpDesc>,
ops::EigvalsGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(eigvals_grad, ops::EigvalsGradOp);
REGISTER_OP_CPU_KERNEL(eigvals,
ops::EigvalsKernel<plat::CPUDeviceContext, float>,
ops::EigvalsKernel<plat::CPUDeviceContext, double>,
ops::EigvalsKernel<plat::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::EigvalsKernel<plat::CPUDeviceContext,
paddle::platform::complex<double>>);

// TODO(Ruibiao): Support gradient kernel for Eigvals OP
// REGISTER_OP_CPU_KERNEL(eigvals_grad,
// ops::EigvalsGradKernel<plat::CPUDeviceContext, float>,
// ops::EigvalsGradKernel<plat::CPUDeviceContext, double>,
// ops::EigvalsGradKernel<plat::CPUDeviceContext,
// paddle::platform::complex<float>>,
// ops::EigvalsGradKernel<plat::CPUDeviceContext,
// paddle::platform::complex<double>>);
129 changes: 129 additions & 0 deletions paddle/fluid/operators/eigvals_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <complex>
#include <vector>
#include "Eigen/Dense"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;

template <typename T>
struct PaddleComplex {
using Type = paddle::platform::complex<T>;
};
template <>
struct PaddleComplex<paddle::platform::complex<float>> {
using Type = paddle::platform::complex<float>;
};
template <>
struct PaddleComplex<paddle::platform::complex<double>> {
using Type = paddle::platform::complex<double>;
};

template <typename T>
struct StdComplex {
using Type = std::complex<T>;
};
template <>
struct StdComplex<paddle::platform::complex<float>> {
using Type = std::complex<float>;
};
template <>
struct StdComplex<paddle::platform::complex<double>> {
using Type = std::complex<double>;
};

template <typename T>
using PaddleCType = typename PaddleComplex<T>::Type;
template <typename T>
using StdCType = typename StdComplex<T>::Type;
template <typename T>
using EigenMatrixPaddle = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>;
template <typename T>
using EigenVectorPaddle = Eigen::Matrix<PaddleCType<T>, Eigen::Dynamic, 1>;
template <typename T>
using EigenMatrixStd =
Eigen::Matrix<StdCType<T>, Eigen::Dynamic, Eigen::Dynamic>;
template <typename T>
using EigenVectorStd = Eigen::Matrix<StdCType<T>, Eigen::Dynamic, 1>;

static void SpiltBatchSquareMatrix(const Tensor *input,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better use const& for input

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

std::vector<Tensor> *output) {
DDim input_dims = input->dims();
int last_dim = input_dims.size() - 1;
int n_dim = input_dims[last_dim];

DDim flattened_input_dims, flattened_output_dims;
if (input_dims.size() > 2) {
flattened_input_dims = flatten_to_3d(input_dims, last_dim - 1, last_dim);
} else {
flattened_input_dims = framework::make_ddim({1, n_dim, n_dim});
}

Tensor flattened_input;
flattened_input.ShareDataWith(*input);
flattened_input.Resize(flattened_input_dims);
(*output) = flattened_input.Split(1, 0);
}

template <typename DeviceContext, typename T>
class EigvalsKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const Tensor *input = ctx.Input<Tensor>("X");
Tensor *output = ctx.Output<Tensor>("Out");

auto input_type = input->type();
auto output_type = framework::IsComplexType(input_type)
? input_type
: framework::ToComplexType(input_type);
output->mutable_data(ctx.GetPlace(), output_type);

std::vector<Tensor> input_matrices;
SpiltBatchSquareMatrix(input, /*->*/ &input_matrices);

int n_dim = input_matrices[0].dims()[1];
int n_batch = input_matrices.size();

DDim output_dims = output->dims();
output->Resize(framework::make_ddim({n_batch, n_dim}));
std::vector<Tensor> output_vectors = output->Split(1, 0);

Eigen::Map<EigenMatrixPaddle<T>> input_emp(NULL, n_dim, n_dim);
Eigen::Map<EigenVectorPaddle<T>> output_evp(NULL, n_dim);
EigenMatrixStd<T> input_ems;
EigenVectorStd<T> output_evs;

for (int i = 0; i < n_batch; ++i) {
new (&input_emp) Eigen::Map<EigenMatrixPaddle<T>>(
input_matrices[i].data<T>(), n_dim, n_dim);
new (&output_evp) Eigen::Map<EigenVectorPaddle<T>>(
output_vectors[i].data<PaddleCType<T>>(), n_dim);
input_ems = input_emp.template cast<StdCType<T>>();
output_evs = input_ems.eigenvalues();
output_evp = output_evs.template cast<PaddleCType<T>>();
}
output->Resize(output_dims);
}
};
} // namespace operators
} // namespace paddle
Loading