Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CPU and GPU eigh op implementation #34990

Merged
merged 56 commits into from
Sep 16, 2021
Merged
Show file tree
Hide file tree
Changes from 52 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
7a62db2
add CPU Eigh op
Zjq9409 Aug 18, 2021
512613a
add file
Zjq9409 Aug 18, 2021
1461d1a
add file
Zjq9409 Aug 18, 2021
2dda0d1
Merge branch 'EighOP' of https://github.com/Zjq9409/Paddle into EighOP
Zjq9409 Aug 19, 2021
cca0bdb
modify head file path
Zjq9409 Aug 19, 2021
d1bb551
modify cmake file
Zjq9409 Aug 19, 2021
fd50e3f
add test
Zjq9409 Aug 21, 2021
e386f6c
Merge branch 'develop' into EighOP
Zjq9409 Aug 21, 2021
c8218bd
merge conflict
Zjq9409 Aug 22, 2021
b29f124
add test
Zjq9409 Aug 22, 2021
f9bdc21
modify head file
Zjq9409 Aug 22, 2021
c96121e
test
Zjq9409 Aug 22, 2021
1c9ecc2
test
Zjq9409 Aug 22, 2021
dbbebd2
add backward
Zjq9409 Aug 25, 2021
ea7cc0f
add backward
Zjq9409 Aug 27, 2021
d945247
add tool
Zjq9409 Aug 27, 2021
ad9a412
add backward test
Zjq9409 Aug 28, 2021
2b01c35
Merge branch 'develop' into EighOP
Zjq9409 Aug 28, 2021
27230b6
Merge branch 'develop' into EighOP
Zjq9409 Aug 28, 2021
b0ad2b4
Merge branch 'EighOP' of https://github.com/Zjq9409/Paddle into EighOP
Zjq9409 Aug 30, 2021
a2b8897
Merge branch 'EighOP' of https://github.com/Zjq9409/Paddle into EighOP
Zjq9409 Aug 30, 2021
af8a892
Merge branch 'EighOP' of https://github.com/Zjq9409/Paddle into EighOP
Zjq9409 Aug 30, 2021
a08dd88
Modify the configuration file
Zjq9409 Aug 30, 2021
1e4c267
Modify the configuration file
Zjq9409 Aug 30, 2021
fa0ed0a
Merge branch 'EighOP' of https://github.com/Zjq9409/Paddle into EighOP
Zjq9409 Aug 30, 2021
e95d4c0
remove the reverse calculation create op
Zjq9409 Sep 2, 2021
042d203
merge conflict
Zjq9409 Sep 3, 2021
6da3219
Merge branch 'develop' into EighOP
Zjq9409 Sep 3, 2021
5f365f0
remove create op
Zjq9409 Sep 3, 2021
062229e
Merge branch 'develop' into EighOP
Zjq9409 Sep 3, 2021
b368e44
improve the performance of more than 32 dimensions on the gpu, and im…
Zjq9409 Sep 6, 2021
8876ccc
Merge branch 'develop' into EighOP
Zjq9409 Sep 6, 2021
3c4384d
perfect unit test
Zjq9409 Sep 6, 2021
44b301b
perfect unit test
Zjq9409 Sep 6, 2021
8da21ec
Merge branch 'EighOP' of https://github.com/Zjq9409/Paddle into EighOP
Zjq9409 Sep 6, 2021
c3d9e51
remove eigh_helper file and unit test replace fluid
Zjq9409 Sep 9, 2021
607c0d0
CPU forward calculation uses lapack to replace eigen library
Zjq9409 Sep 10, 2021
e312530
extract eigh to calculate eigenvalues ​​and eigenvectors
Zjq9409 Sep 10, 2021
4a1cbff
extract common header file
Zjq9409 Sep 13, 2021
4b7f80b
Merge branch 'develop' into EighOP
Zjq9409 Sep 13, 2021
5ea7373
extract the common header files of eigenvalues ​​and eigenvectors, an…
Zjq9409 Sep 13, 2021
cdf7260
extract the common header files of eigenvalues ​​and eigenvectors, an…
Zjq9409 Sep 13, 2021
94b9f2e
Merge branch 'EighOP' of https://github.com/Zjq9409/Paddle into EighOP
Zjq9409 Sep 13, 2021
2323730
add PADDLE_WITH_HIP
Zjq9409 Sep 13, 2021
b8c1f5e
Solve the problem of not being able to find cuda file
Zjq9409 Sep 14, 2021
d9a4739
Add Eigenvector to whitelist
Zjq9409 Sep 14, 2021
f7854e1
Add Eigenvector to whitelist
Zjq9409 Sep 14, 2021
b77a4b8
Merge branch 'EighOP' of https://github.com/Zjq9409/Paddle into EighOP
Zjq9409 Sep 14, 2021
6be5f8f
Modify variable name
Zjq9409 Sep 15, 2021
61b71a5
Merge branch 'develop' into EighOP
Zjq9409 Sep 15, 2021
98f46cb
Merge branch 'develop' into EighOP
Zjq9409 Sep 15, 2021
823a50f
Merge branch 'EighOP' of https://github.com/Zjq9409/Paddle into EighOP
Zjq9409 Sep 15, 2021
83939f2
Merge branch 'develop' into EighOP
Zjq9409 Sep 15, 2021
b629849
Merge branch 'develop' into EighOP
Zjq9409 Sep 15, 2021
3944f53
Merge branch 'develop' into EighOP
Zjq9409 Sep 15, 2021
d478e09
Merge branch 'develop' into EighOP
Zjq9409 Sep 16, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmake/operators.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ function(op_library TARGET)
list(REMOVE_ITEM hip_srcs "cholesky_op.cu")
list(REMOVE_ITEM hip_srcs "matrix_rank_op.cu")
list(REMOVE_ITEM hip_srcs "svd_op.cu")
list(REMOVE_ITEM hip_srcs "eigh_op.cu")
list(REMOVE_ITEM hip_srcs "multinomial_op.cu")
list(REMOVE_ITEM hip_srcs "decode_jpeg_op.cu")
hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cc_srcs} ${miopen_cu_cc_srcs} ${miopen_cu_srcs} ${mkldnn_cc_srcs} ${hip_srcs} DEPS ${op_library_DEPS}
Expand Down
167 changes: 167 additions & 0 deletions paddle/fluid/operators/eigh_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/eigh_op.h"

namespace paddle {
namespace operators {

using framework::Tensor;

class EighOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Eigh");
OP_INOUT_CHECK(ctx->HasOutput("Eigenvalues"), "Output", "Eigenvalues",
"Eigh");
OP_INOUT_CHECK(ctx->HasOutput("Eigenvectors"), "Output", "Eigenvectors",
"Eigh");

auto input_dim = ctx->GetInputDim("X");
auto rank = input_dim.size();

PADDLE_ENFORCE_GE(rank, 2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rank的检查挪到使用之前。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据建议修改

platform::errors::InvalidArgument(
"The Input(X) should have at least 2 dimensions."
"But received a %d dimension tensor.",
rank));
PADDLE_ENFORCE_EQ(
input_dim[rank - 2], input_dim[rank - 1],
platform::errors::InvalidArgument(
"Eigh op is designed for square matrix, consequently"
"inner-most 2 dimensions of Input(X) should be symmetric."
"But received X's shape[-2] = %d and shape[-1] = %d.",
input_dim[rank - 2], input_dim[rank - 1]));

std::vector<int64_t> values_dim;
if (rank > 2) {
for (auto i = 0; i < rank - 1; i++) {
values_dim.emplace_back(input_dim[i]);
}
} else {
values_dim = {input_dim[1]};
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

输入的shape如果是[, N, N],那特征值的shape应该就是[, N]吧,也就是特征值应该是rank - 1维的,而不应该固化为2维?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据建议修改,应该是rank-1维度的

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里也没必要用if else两个分支写。


ctx->SetOutputDim("Eigenvalues", framework::make_ddim(values_dim));
ctx->SetOutputDim("Eigenvectors", input_dim);
}
};

class EignOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor), Hermitian or real symmetric matrices."
"Its shape should be [*, N, N] where * is zero or"
"more batch dimensions. The data type is float32 ,"
"float64, complex64, complex128.");
AddOutput("Eigenvalues",
"(Tensor), The eigenvalues in ascending order."
"The data type is float32 or float64.");
AddOutput(
"Eigenvectors",
"(Tensor), The column is the normalized eigenvector "
"corresponding to the eigenvalue. The data type is the same as ``X``.");
AddAttr<std::string>(
"UPLO",
"(string, default 'L'), 'L' represents the lower triangular matrix,"
"'U' represents the upper triangular matrix.")
.SetDefault("L");
AddComment(R"DOC(
Eigh Operator.

Computes the eigenvalues and eigenvectors of a complex Hermitian
(conjugate symmetric) or a real symmetric matrix.

)DOC");
}
};

class EighGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Eigenvalues"), "Input", "Eigenvalues",
"EighGrad");
OP_INOUT_CHECK(ctx->HasInput("Eigenvectors"), "Input", "Eigenvectors",
"EighGrad");
OP_INOUT_CHECK(ctx->HasInputs(framework::GradVarName("Eigenvalues")),
"Input", "Eigenvalues@GRAD", "EighGrad");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inputs表示Eigenvectors是一个std::vector<Tensor>,这里Eigenvectors也只是一个Tensor吧?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eigenvectors是一个Tensor

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那这里应该使用ctx->HasInput而不是ctx->HasInputs,没有s

OP_INOUT_CHECK(ctx->HasInputs(framework::GradVarName("Eigenvectors")),
"Input", "Eigenvectors@GRAD", "EighGrad");
auto dims = ctx->GetInputDim("Eigenvectors");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, dims);
}
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Eigenvectors")),
ctx.device_context());
}
};

template <typename T>
class EighGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("Eigenvalues", this->Output("Eigenvalues"));
op->SetInput("Eigenvectors", this->Output("Eigenvectors"));
op->SetInput(framework::GradVarName("Eigenvalues"),
this->OutputGrad("Eigenvalues"));
op->SetInput(framework::GradVarName("Eigenvectors"),
this->OutputGrad("Eigenvectors"));
op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OPERATOR(eigh, ops::EighOp, ops::EignOpMaker,
ops::EighGradOpMaker<paddle::framework::OpDesc>,
ops::EighGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(eigh_grad, ops::EighGradOp);

REGISTER_OP_CPU_KERNEL(
eigh, ops::EighKernel<paddle::platform::CPUDeviceContext, float, float>,
ops::EighKernel<paddle::platform::CPUDeviceContext, double, double>,
ops::EighKernel<paddle::platform::CPUDeviceContext, float,
paddle::platform::complex<float>>,
ops::EighKernel<paddle::platform::CPUDeviceContext, double,
paddle::platform::complex<double>>);

REGISTER_OP_CPU_KERNEL(
eigh_grad,
ops::EighGradKernel<paddle::platform::CPUDeviceContext, float, float>,
ops::EighGradKernel<paddle::platform::CPUDeviceContext, double, double>,
ops::EighGradKernel<paddle::platform::CPUDeviceContext, float,
paddle::platform::complex<float>>,
ops::EighGradKernel<paddle::platform::CPUDeviceContext, double,
paddle::platform::complex<double>>);
53 changes: 53 additions & 0 deletions paddle/fluid/operators/eigh_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/eigh_op.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename ValueType, typename T>
class EighGPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

实现和EighKernel一样,GPU也可以直接使用EighKernel来注册,没有必要实现EighGPUKernel

auto input_var = ctx.Input<Tensor>("X");
auto output_w_var = ctx.Output<Tensor>("Eigenvalues");
auto output_v_var = ctx.Output<Tensor>("Eigenvectors");
std::string lower = ctx.Attr<std::string>("UPLO");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lower 的作用仅作为判断使用, ctx.Attrstd::string("UPLO") 可以替换成 ctx.Attr("Is_Uplo");

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为numpy接口参数"UPLO"是string定义,所以直接采用string接收参数,如果使用bool类型,需要多加一次转换操作。

bool is_lower = (lower == "L");
math::MatrixEighFunctor<ValueType, T> functor;
functor(ctx, *input_var, output_w_var, output_v_var, is_lower, true);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OP_CUDA_KERNEL(
eigh, ops::EighGPUKernel<float, float>, ops::EighGPUKernel<double, double>,
ops::EighGPUKernel<float, paddle::platform::complex<float>>,
ops::EighGPUKernel<double, paddle::platform::complex<double>>);

REGISTER_OP_CUDA_KERNEL(
eigh_grad,
ops::EighGradKernel<paddle::platform::CUDADeviceContext, float, float>,
ops::EighGradKernel<paddle::platform::CUDADeviceContext, double, double>,
ops::EighGradKernel<paddle::platform::CUDADeviceContext, float,
paddle::platform::complex<float>>,
ops::EighGradKernel<paddle::platform::CUDADeviceContext, double,
paddle::platform::complex<double>>);
80 changes: 80 additions & 0 deletions paddle/fluid/operators/eigh_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/eigen_values_vectors.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分Eigen的声明可以删除了?


template <typename DeviceContext, typename ValueType, typename T>
class EighKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto input_var = ctx.Input<Tensor>("X");
auto output_w_var = ctx.Output<Tensor>("Eigenvalues");
auto output_v_var = ctx.Output<Tensor>("Eigenvectors");
std::string lower = ctx.Attr<std::string>("UPLO");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

bool is_lower = (lower == "L");
math::MatrixEighFunctorCPU<DeviceContext, ValueType, T> functor;
functor(ctx, *input_var, output_w_var, output_v_var, is_lower, true);
}
};

template <typename DeviceContext, typename ValueType, typename T>
class EighGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& x_grad = *ctx.Output<framework::Tensor>(framework::GradVarName("X"));
x_grad.mutable_data<T>(ctx.GetPlace());
auto& output_w_var = *ctx.Input<Tensor>("Eigenvalues");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

变量名为啥加_var后缀,直接用output_w

auto& output_v_var = *ctx.Input<Tensor>("Eigenvectors");
auto& output_w_grad =
*ctx.Input<Tensor>(framework::GradVarName("Eigenvalues"));
auto& output_v_grad =
*ctx.Input<Tensor>(framework::GradVarName("Eigenvectors"));

auto& dims = output_v_var.dims();
const int m = dims[dims.size() - 1];
auto dito =
math::DeviceIndependenceTensorOperations<DeviceContext, T, ValueType>(
ctx);
auto tV = dito.Transpose(dito.Conj(output_v_var));
auto W = dito.Sub_(dito.Unsqueeze(output_w_var, -2),
dito.Unsqueeze(output_w_var, -1));
Tensor result = dito.Matmul(tV, output_v_grad);
result.mutable_data<T>(dims, ctx.GetPlace());
std::vector<int> out_shape = framework::vectorize<int>(dims);
auto constant = dito.Fill(out_shape, 0.5);
result = dito.Sub(result, dito.Conj(dito.Transpose(result)));
result = dito.Mul(result, constant);
result = dito.Div_(result, W);
result = dito.DiagFill(m, m, m, 0, output_w_grad, result);
x_grad = dito.Matmul(output_v_var, dito.Matmul(result, tV));
}
};

} // namespace operators
} // namespace paddle
Loading