-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add CPU and GPU eigh op implementation #34990
Changes from 48 commits
7a62db2
512613a
1461d1a
2dda0d1
cca0bdb
d1bb551
fd50e3f
e386f6c
c8218bd
b29f124
f9bdc21
c96121e
1c9ecc2
dbbebd2
ea7cc0f
d945247
ad9a412
2b01c35
27230b6
b0ad2b4
a2b8897
af8a892
a08dd88
1e4c267
fa0ed0a
e95d4c0
042d203
6da3219
5f365f0
062229e
b368e44
8876ccc
3c4384d
44b301b
8da21ec
c3d9e51
607c0d0
e312530
4a1cbff
4b7f80b
5ea7373
cdf7260
94b9f2e
2323730
b8c1f5e
d9a4739
f7854e1
b77a4b8
6be5f8f
61b71a5
98f46cb
823a50f
83939f2
b629849
3944f53
d478e09
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/operators/eigh_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
using framework::Tensor; | ||
|
||
class EighOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
void InferShape(framework::InferShapeContext* ctx) const override { | ||
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Eigh"); | ||
OP_INOUT_CHECK(ctx->HasOutput("Eigenvalues"), "Output", "Eigenvalues", | ||
"Eigh"); | ||
OP_INOUT_CHECK(ctx->HasOutput("Eigenvectors"), "Output", "Eigenvectors", | ||
"Eigh"); | ||
|
||
auto input_dim = ctx->GetInputDim("X"); | ||
auto rank = input_dim.size(); | ||
|
||
PADDLE_ENFORCE_GE(rank, 2, | ||
platform::errors::InvalidArgument( | ||
"The Input(X) should have at least 2 dimensions." | ||
"But received a %d dimension tensor.", | ||
rank)); | ||
PADDLE_ENFORCE_EQ( | ||
input_dim[rank - 2], input_dim[rank - 1], | ||
platform::errors::InvalidArgument( | ||
"Eigh op is designed for square matrix, consequently" | ||
"inner-most 2 dimensions of Input(X) should be symmetric." | ||
"But received X's shape[-2] = %d and shape[-1] = %d.", | ||
input_dim[rank - 2], input_dim[rank - 1])); | ||
|
||
std::vector<int64_t> values_dim; | ||
if (rank > 2) { | ||
for (auto i = 0; i < rank - 1; i++) { | ||
values_dim.emplace_back(input_dim[i]); | ||
} | ||
} else { | ||
values_dim = {input_dim[1]}; | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 输入的shape如果是[, N, N],那特征值的shape应该就是[, N]吧,也就是特征值应该是rank - 1维的,而不应该固化为2维? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 根据建议修改,应该是rank-1维度的 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里也没必要用 |
||
|
||
ctx->SetOutputDim("Eigenvalues", framework::make_ddim(values_dim)); | ||
ctx->SetOutputDim("Eigenvectors", input_dim); | ||
} | ||
}; | ||
|
||
class EignOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
AddInput("X", | ||
"(Tensor), Hermitian or real symmetric matrices." | ||
"Its shape should be [*, N, N] where * is zero or" | ||
"more batch dimensions. The data type is float32 ," | ||
"float64, complex64, complex128."); | ||
AddOutput("Eigenvalues", | ||
"(Tensor), The eigenvalues in ascending order." | ||
"The data type is float32 or float64."); | ||
AddOutput( | ||
"Eigenvectors", | ||
"(Tensor), The column is the normalized eigenvector " | ||
"corresponding to the eigenvalue. The data type is the same as ``X``."); | ||
AddAttr<std::string>( | ||
"UPLO", | ||
"(string, default 'L'), 'L' represents the lower triangular matrix," | ||
"'U' represents the upper triangular matrix.") | ||
.SetDefault("L"); | ||
AddComment(R"DOC( | ||
Eigh Operator. | ||
|
||
Computes the eigenvalues and eigenvectors of a complex Hermitian | ||
(conjugate symmetric) or a real symmetric matrix. | ||
|
||
)DOC"); | ||
} | ||
}; | ||
|
||
class EighGradOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
void InferShape(framework::InferShapeContext* ctx) const override { | ||
OP_INOUT_CHECK(ctx->HasInput("Eigenvalues"), "Input", "Eigenvalues", | ||
"EighGrad"); | ||
OP_INOUT_CHECK(ctx->HasInput("Eigenvectors"), "Input", "Eigenvectors", | ||
"EighGrad"); | ||
OP_INOUT_CHECK(ctx->HasInputs(framework::GradVarName("Eigenvalues")), | ||
"Input", "Eigenvalues@GRAD", "EighGrad"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Eigenvectors是一个Tensor There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 那这里应该使用 |
||
OP_INOUT_CHECK(ctx->HasInputs(framework::GradVarName("Eigenvectors")), | ||
"Input", "Eigenvectors@GRAD", "EighGrad"); | ||
auto dims = ctx->GetInputDim("Eigenvectors"); | ||
auto x_grad_name = framework::GradVarName("X"); | ||
if (ctx->HasOutput(x_grad_name)) { | ||
ctx->SetOutputDim(x_grad_name, dims); | ||
} | ||
} | ||
|
||
protected: | ||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext& ctx) const override { | ||
return framework::OpKernelType( | ||
OperatorWithKernel::IndicateVarDataType( | ||
ctx, framework::GradVarName("Eigenvectors")), | ||
ctx.device_context()); | ||
} | ||
}; | ||
|
||
template <typename T> | ||
class EighGradOpMaker : public framework::SingleGradOpMaker<T> { | ||
public: | ||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; | ||
|
||
protected: | ||
void Apply(GradOpPtr<T> op) const override { | ||
op->SetType(this->ForwardOpType() + "_grad"); | ||
op->SetInput("Eigenvalues", this->Output("Eigenvalues")); | ||
op->SetInput("Eigenvectors", this->Output("Eigenvectors")); | ||
op->SetInput(framework::GradVarName("Eigenvalues"), | ||
this->OutputGrad("Eigenvalues")); | ||
op->SetInput(framework::GradVarName("Eigenvectors"), | ||
this->OutputGrad("Eigenvectors")); | ||
op->SetAttrMap(this->Attrs()); | ||
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
|
||
REGISTER_OPERATOR(eigh, ops::EighOp, ops::EignOpMaker, | ||
ops::EighGradOpMaker<paddle::framework::OpDesc>, | ||
ops::EighGradOpMaker<paddle::imperative::OpBase>); | ||
REGISTER_OPERATOR(eigh_grad, ops::EighGradOp); | ||
|
||
REGISTER_OP_CPU_KERNEL( | ||
eigh, ops::EighKernel<paddle::platform::CPUDeviceContext, float, float>, | ||
ops::EighKernel<paddle::platform::CPUDeviceContext, double, double>, | ||
ops::EighKernel<paddle::platform::CPUDeviceContext, float, | ||
paddle::platform::complex<float>>, | ||
ops::EighKernel<paddle::platform::CPUDeviceContext, double, | ||
paddle::platform::complex<double>>); | ||
|
||
REGISTER_OP_CPU_KERNEL( | ||
eigh_grad, | ||
ops::EighGradKernel<paddle::platform::CPUDeviceContext, float, float>, | ||
ops::EighGradKernel<paddle::platform::CPUDeviceContext, double, double>, | ||
ops::EighGradKernel<paddle::platform::CPUDeviceContext, float, | ||
paddle::platform::complex<float>>, | ||
ops::EighGradKernel<paddle::platform::CPUDeviceContext, double, | ||
paddle::platform::complex<double>>); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/operators/eigh_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
using Tensor = framework::Tensor; | ||
|
||
template <typename ValueType, typename T> | ||
class EighGPUKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext &ctx) const override { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 实现和 |
||
auto input_var = ctx.Input<Tensor>("X"); | ||
auto output_w_var = ctx.Output<Tensor>("Eigenvalues"); | ||
auto output_v_var = ctx.Output<Tensor>("Eigenvectors"); | ||
std::string lower = ctx.Attr<std::string>("UPLO"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 因为numpy接口参数"UPLO"是string定义,所以直接采用string接收参数,如果使用bool类型,需要多加一次转换操作。 |
||
bool is_lower = (lower == "L"); | ||
math::MatrixEighFunctor<ValueType, T> functor; | ||
functor(ctx, *input_var, output_w_var, output_v_var, is_lower, true); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
|
||
REGISTER_OP_CUDA_KERNEL( | ||
eigh, ops::EighGPUKernel<float, float>, ops::EighGPUKernel<double, double>, | ||
ops::EighGPUKernel<float, paddle::platform::complex<float>>, | ||
ops::EighGPUKernel<double, paddle::platform::complex<double>>); | ||
|
||
REGISTER_OP_CUDA_KERNEL( | ||
eigh_grad, | ||
ops::EighGradKernel<paddle::platform::CUDADeviceContext, float, float>, | ||
ops::EighGradKernel<paddle::platform::CUDADeviceContext, double, double>, | ||
ops::EighGradKernel<paddle::platform::CUDADeviceContext, float, | ||
paddle::platform::complex<float>>, | ||
ops::EighGradKernel<paddle::platform::CUDADeviceContext, double, | ||
paddle::platform::complex<double>>); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/operators/math/eigen_values_vectors.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
using Tensor = framework::Tensor; | ||
|
||
template <typename T, size_t D, int MajorType = Eigen::RowMajor, | ||
typename IndexType = Eigen::DenseIndex> | ||
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>; | ||
template <typename T, int MajorType = Eigen::RowMajor, | ||
typename IndexType = Eigen::DenseIndex> | ||
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这部分Eigen的声明可以删除了? |
||
|
||
template <typename DeviceContext, typename ValueType, typename T> | ||
class EighKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
auto input_var = ctx.Input<Tensor>("X"); | ||
auto output_w_var = ctx.Output<Tensor>("Eigenvalues"); | ||
auto output_v_var = ctx.Output<Tensor>("Eigenvectors"); | ||
std::string lower = ctx.Attr<std::string>("UPLO"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 |
||
bool is_lower = (lower == "L"); | ||
math::MatrixEighFunctorCPU<DeviceContext, ValueType, T> functor; | ||
functor(ctx, *input_var, output_w_var, output_v_var, is_lower, true); | ||
} | ||
}; | ||
|
||
template <typename DeviceContext, typename ValueType, typename T> | ||
class EighGradKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
auto& x_grad = *ctx.Output<framework::Tensor>(framework::GradVarName("X")); | ||
x_grad.mutable_data<T>(ctx.GetPlace()); | ||
auto& output_w_var = *ctx.Input<Tensor>("Eigenvalues"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 变量名为啥加 |
||
auto& output_v_var = *ctx.Input<Tensor>("Eigenvectors"); | ||
auto& output_w_grad = | ||
*ctx.Input<Tensor>(framework::GradVarName("Eigenvalues")); | ||
auto& output_v_grad = | ||
*ctx.Input<Tensor>(framework::GradVarName("Eigenvectors")); | ||
|
||
auto& dims = output_v_var.dims(); | ||
const int m = dims[dims.size() - 1]; | ||
auto dito = | ||
math::DeviceIndependenceTensorOperations<DeviceContext, T, ValueType>( | ||
ctx); | ||
auto tV = dito.Transpose(dito.Conj(output_v_var)); | ||
auto W = dito.Sub_(dito.Unsqueeze(output_w_var, -2), | ||
dito.Unsqueeze(output_w_var, -1)); | ||
Tensor result = dito.Matmul(tV, output_v_grad); | ||
result.mutable_data<T>(dims, ctx.GetPlace()); | ||
std::vector<int> out_shape = framework::vectorize<int>(dims); | ||
auto constant = dito.Fill(out_shape, 0.5); | ||
result = dito.Sub(result, dito.Conj(dito.Transpose(result))); | ||
result = dito.Mul(result, constant); | ||
result = dito.Div_(result, W); | ||
result = dito.DiagFill(m, m, m, 0, output_w_grad, result); | ||
x_grad = dito.Matmul(output_v_var, dito.Matmul(result, tV)); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rank的检查挪到使用之前。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
根据建议修改