From 628e2dfbb717de95728b5ea21a99beb56de46e28 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Mon, 13 Sep 2021 09:19:28 +0000 Subject: [PATCH 01/11] Add linalg.eigvals API --- paddle/fluid/framework/ddim.cc | 24 ++ paddle/fluid/framework/ddim.h | 7 + paddle/fluid/operators/eigvals_op.cc | 124 +++++++ paddle/fluid/operators/eigvals_op.h | 221 ++++++++++++ .../fluid/tests/unittests/test_eigvals_op.py | 322 ++++++++++++++++++ python/paddle/linalg.py | 4 +- python/paddle/tensor/__init__.py | 2 + python/paddle/tensor/linalg.py | 59 ++++ 8 files changed, 762 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/eigvals_op.cc create mode 100644 paddle/fluid/operators/eigvals_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_eigvals_op.py diff --git a/paddle/fluid/framework/ddim.cc b/paddle/fluid/framework/ddim.cc index fe7d243066237..975711c2d73e8 100644 --- a/paddle/fluid/framework/ddim.cc +++ b/paddle/fluid/framework/ddim.cc @@ -107,6 +107,30 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) { return os; } +DDim flatten_to_3d(const DDim& src, int num_row_dims, int num_col_dims){ + PADDLE_ENFORCE_GE(src.size(), 3, platform::errors::InvalidArgument( + "The rank of src dim should be at least 3 in flatten_to_3d, but received %d.", + src.size())); + PADDLE_ENFORCE_EQ((num_row_dims >= 1 && num_row_dims < src.size()), true, + platform::errors::InvalidArgument( + "The num_row_dims should be inside [1, %d] in flatten_to_3d, but received %d.", + src.size() - 1, num_row_dims)); + PADDLE_ENFORCE_EQ((num_col_dims >= 2 && num_col_dims <= src.size()), true, + platform::errors::InvalidArgument( + "The num_col_dims should be inside [2, %d] in flatten_to_3d, but received %d.", + src.size(), num_col_dims)); + PADDLE_ENFORCE_GE( + num_col_dims, num_row_dims, + platform::errors::InvalidArgument( + "The num_row_dims should be less than num_col_dims in flatten_to_3d," + "but received num_row_dims = %d, num_col_dims = %d.", + num_row_dims, num_col_dims)); + + return DDim({product(slice_ddim(src, 0, num_row_dims)), + product(slice_ddim(src, num_row_dims, num_col_dims)), + product(slice_ddim(src, num_col_dims, src.size()))}); +} + DDim flatten_to_2d(const DDim& src, int num_col_dims) { return DDim({product(slice_ddim(src, 0, num_col_dims)), product(slice_ddim(src, num_col_dims, src.size()))}); diff --git a/paddle/fluid/framework/ddim.h b/paddle/fluid/framework/ddim.h index e69fb4e761939..565e0b430dfdc 100644 --- a/paddle/fluid/framework/ddim.h +++ b/paddle/fluid/framework/ddim.h @@ -230,6 +230,13 @@ int arity(const DDim& ddim); std::ostream& operator<<(std::ostream&, const DDim&); +/** +* \brief Flatten dim to 3d +* e.g., DDim d = mak_ddim({1, 2, 3, 4, 5, 6}) +* flatten_to_3d(d, 2, 4); ===> {1*2, 3*4, 5*6} ===> {2, 12, 30} +*/ +DDim flatten_to_3d(const DDim& src, int num_row_dims, int num_col_dims); + // Reshape a tensor to a matrix. The matrix's first dimension(column length) // will be the product of tensor's first `num_col_dims` dimensions. DDim flatten_to_2d(const DDim& src, int num_col_dims); diff --git a/paddle/fluid/operators/eigvals_op.cc b/paddle/fluid/operators/eigvals_op.cc new file mode 100644 index 0000000000000..37420d4232fd0 --- /dev/null +++ b/paddle/fluid/operators/eigvals_op.cc @@ -0,0 +1,124 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/eigvals_op.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class EigvalsOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor), A complex- or real-valued tensor with shape (*, n, n)" + "where * is zero or more batch dimensions"); + AddOutput("Out", + "(Tensor) The output tensor with shape (*,n) cointaining the eigenvalues of X."); + AddComment(R"DOC(eigvals operator + Return the eigenvalues of one or more square matrices. The eigenvalues are complex even when the input matrices are real. + )DOC"); + } +}; + +class EigvalsOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Eigvals"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Eigvals"); + + DDim x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_GE(x_dims.size(), 2, + platform::errors::InvalidArgument( + "The dimensions of Input(X) for Eigvals operator should be at least 2, " + "but received X's dimension = %d, X's shape = [%s].", + x_dims.size(), x_dims)); + + if(ctx->IsRuntime() || !framework::contain_unknown_dim(x_dims)){ + int last_dim = x_dims.size() - 1; + PADDLE_ENFORCE_EQ(x_dims[last_dim], x_dims[last_dim-1], + platform::errors::InvalidArgument( + "The last two dimensions of Input(X) for Eigvals operator should be equal, " + "but received X's shape = [%s].", + x_dims)); + } + + auto output_dims = vectorize(x_dims); + output_dims.resize(x_dims.size() - 1); + ctx->SetOutputDim("Out", framework::make_ddim(output_dims)); + } +}; + +class EigvalsOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const { + auto input_dtype = ctx->GetInputDataType("X"); + auto output_dtype = framework::IsComplexType(input_dtype) ? + input_dtype : framework::ToComplexType(input_dtype); + ctx->SetOutputDataType("Out", output_dtype); + } +}; + +class EigvalsGradOp : public framework::OperatorWithKernel { +public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "EigvalsGrad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@Grad", "EigvalsGrad"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", + "X@Grad", "EigvalsGrad"); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } +}; + +template +class EigvalsGradOpMaker : public framework::SingleGradOpMaker{ + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + protected: + void Apply(GradOpPtr retv) const override{ + retv->SetType("eigvals_grad"); + retv->SetInput("X", this->Input("X")); + retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } +}; + + + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OPERATOR(eigvals, + ops::EigvalsOp, ops::EigvalsOpMaker, ops::EigvalsOpVarTypeInference, + ops::EigvalsGradOpMaker, + ops::EigvalsGradOpMaker); +REGISTER_OPERATOR(eigvals_grad, ops::EigvalsGradOp); +REGISTER_OP_CPU_KERNEL(eigvals, + ops::EigvalsKernel, + ops::EigvalsKernel, + ops::EigvalsKernel>, + ops::EigvalsKernel>); + +// TODO(Ruibiao): Support gradient kernel for Eigvals OP +REGISTER_OP_CPU_KERNEL(eigvals_grad, + ops::EigvalsGradKernel, + ops::EigvalsGradKernel, + ops::EigvalsGradKernel>, + ops::EigvalsGradKernel>); \ No newline at end of file diff --git a/paddle/fluid/operators/eigvals_op.h b/paddle/fluid/operators/eigvals_op.h new file mode 100644 index 0000000000000..2d601f2ce3935 --- /dev/null +++ b/paddle/fluid/operators/eigvals_op.h @@ -0,0 +1,221 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "Eigen/Dense" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/data_type.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using DDim = framework::DDim; + +template +struct PaddleComplex{ + using Type = paddle::platform::complex; +}; +template <> +struct PaddleComplex>{ + using Type = paddle::platform::complex; +}; +template <> +struct PaddleComplex>{ + using Type = paddle::platform::complex; +}; + +template +struct StdComplex{ + using Type = std::complex; +}; +template <> +struct StdComplex>{ + using Type = std::complex; +}; +template <> +struct StdComplex>{ + using Type = std::complex; +}; + +template +using PaddleCType = typename PaddleComplex::Type; +template +using StdCType = typename StdComplex::Type; +template +using EigenMatrixPaddle = Eigen::Matrix; +template +using EigenVectorPaddle = Eigen::Matrix, Eigen::Dynamic, 1>; +template +using EigenMatrixStd = Eigen::Matrix, Eigen::Dynamic, Eigen::Dynamic>; +template +using EigenVectorStd = Eigen::Matrix, Eigen::Dynamic, 1>; + +static void SpiltBatchSquareMatrix(const Tensor* input, std::vector& output){ + DDim input_dims = input -> dims(); + int last_dim = input_dims.size() - 1; + int n_dim = input_dims[last_dim]; + + DDim flattened_input_dims, flattened_output_dims; + if(input_dims.size() > 2){ + flattened_input_dims = flatten_to_3d(input_dims, last_dim - 1, last_dim); + } + else{ + flattened_input_dims = framework::make_ddim({1, n_dim, n_dim}); + } + + Tensor flattened_input; + flattened_input.ShareDataWith(*input); + flattened_input.Resize(flattened_input_dims); + output = flattened_input.Split(1, 0); +} + +template +class EigvalsKernel : public framework::OpKernel { +public: + void Compute(const framework::ExecutionContext &ctx) const override { + const Tensor *input = ctx.Input("X"); + Tensor *output = ctx.Output("Out"); + + auto input_type = input -> type(); + auto output_type = framework::IsComplexType(input_type) ? + input_type : framework::ToComplexType(input_type); + output -> mutable_data(ctx.GetPlace(), output_type); + + std::vector input_matrices; + SpiltBatchSquareMatrix(input, /*->*/ input_matrices); + + int n_dim = input_matrices[0].dims()[1]; + int n_batch = input_matrices.size(); + + DDim output_dims = output->dims(); + output -> Resize(framework::make_ddim({n_batch, n_dim})); + std::vector output_vectors = output->Split(1, 0); + + Eigen::Map> input_emp(NULL, n_dim, n_dim); + Eigen::Map> output_evp(NULL, n_dim); + EigenMatrixStd input_ems; + EigenVectorStd output_evs; + + for(int i = 0; i < n_batch; ++i){ + new (&input_emp) Eigen::Map>( + input_matrices[i].data(), n_dim, n_dim); + new (&output_evp) Eigen::Map>( + output_vectors[i].data>(), n_dim); + input_ems = input_emp.template cast>(); + output_evs = input_ems.eigenvalues(); + output_evp = output_evs.template cast>(); + } + output -> Resize(output_dims); + } +}; + + +template +inline void CastToPaddleType( + EigenMatrixStd& input, + Eigen::Map>& output){ + output = input.template cast(); +} +template<> +inline void CastToPaddleType( + EigenMatrixStd& input, + Eigen::Map>& output){ + output = input.real(); +} +template<> +inline void CastToPaddleType( + EigenMatrixStd& input, + Eigen::Map>& output){ + output = input.real(); +} + +template +class EigvalsGradKernel : public framework::OpKernel { +public: + void Compute(const framework::ExecutionContext &ctx) const override { + const Tensor* input = ctx.Input("X"); + const Tensor* output_grad = ctx.Input(framework::GradVarName("Out")); + Tensor* input_grad = ctx.Output(framework::GradVarName("X")); + input_grad -> mutable_data(ctx.GetPlace(), input -> type()); + + Tensor output; + output.Resize(output_grad -> dims()); + output.mutable_data(ctx.GetPlace(), output_grad -> type()); + + std::vector input_matrices, input_grad_matrices; + SpiltBatchSquareMatrix(input, /*->*/ input_matrices); + SpiltBatchSquareMatrix(input_grad, /*->*/ input_grad_matrices); + + int n_dim = input_matrices[0].dims()[1]; + int n_batch = input_matrices.size(); + + Tensor flattened_output_grad; + flattened_output_grad.ShareDataWith(*output_grad); + flattened_output_grad.Resize(framework::make_ddim({n_batch, n_dim})); + std::vector output_grad_vectors = flattened_output_grad.Split(1, 0); + + /* + input_emp -> input_ems -> v_ems -> vh_ems + output_grad_evp -> output_grad_evs + output_grad_evs + vh_ems -> input_grad_ems -> input_grad_emp + */ + Eigen::Map> input_emp(NULL, n_dim, n_dim); + Eigen::Map> input_grad_emp(NULL, n_dim, n_dim); + Eigen::Map> output_grad_evp(NULL, n_dim); + EigenMatrixStd input_ems; + EigenVectorStd output_grad_evs; + EigenMatrixStd v_ems; + EigenMatrixStd input_grad_ems; + + for(std::vector::size_type i = 0; i < input_matrices.size(); ++i){ + new (&input_emp) Eigen::Map>( + input_matrices[i].data(), n_dim, n_dim); + new (&input_grad_emp) Eigen::Map>( + input_grad_matrices[i].data(), n_dim, n_dim); + new (&output_grad_evp) Eigen::Map>( + output_grad_vectors[i].data>(), n_dim); + + /** + * Let the input square matrix + * A = VLV^{-1}, + * the gradient of A for eigenvalues L is + * A_grad = V^{-H} L_grad V^H. + * See Eq. 4.77 in https://arxiv.org/pdf/1701.00392.pdf + */ + input_ems = input_emp.template cast>(); + output_grad_evs = output_grad_evp.template cast>(); +VLOG(4) << "input_ems:\n" << input_ems; +VLOG(4) << "output_grad_evs:\n" << output_grad_evs; + Eigen::ComplexEigenSolver> es(input_ems); + v_ems = es.eigenvectors(); +VLOG(4) << "eigenvalues:\n" << es.eigenvalues(); +VLOG(4) << "v_ems:\n" << v_ems; + v_ems.adjointInPlace(); +VLOG(4) << "vh_ems:\n" << v_ems; + input_grad_ems = v_ems.colPivHouseholderQr().solve(output_grad_evs.asDiagonal() * v_ems); +VLOG(4) << "input_grad_ems:\n" << input_grad_ems; + CastToPaddleType(input_grad_ems, /*->*/ input_grad_emp); +VLOG(4) << "input_grad_emp:\n" << input_grad_emp; + } + + + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_eigvals_op.py b/python/paddle/fluid/tests/unittests/test_eigvals_op.py new file mode 100644 index 0000000000000..6f540c42a73b1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_eigvals_op.py @@ -0,0 +1,322 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import unittest +import paddle.fluid as fluid +import paddle.fluid.core as core +import numpy as np +from op_test import OpTest + +np.set_printoptions(threshold=np.inf) + +def np_eigvals(a): + res = np.linalg.eigvals(a) + if(a.dtype == np.float32 or a.dtype == np.complex64): + res = res.astype(np.complex64) + else: + res = res.astype(np.complex128) + + return res + +def np_eigvals_grad(a, out_grad): + l, v = np.linalg.eig(a) + print("l:") + print(l) + print("v:") + print(v) + vh = v.conj().T + print("vh:") + print(vh) + print("out_grad:") + print(out_grad) + a_grad = np.linalg.solve(vh, np.diagflat(out_grad, 0) * vh) + print("a_grad") + print(a_grad) + + + return a_grad.astype(a.dtype) + + + +class TestEigvalsOp(OpTest): + def setUp(self): + paddle.enable_static() + self.op_type = "eigvals" + self.set_dtype() + self.set_input_dims() + self.set_input_data() + + np_output = np_eigvals(self.input_data) + + self.inputs = {'X': self.input_data} + self.outputs = {'Out': np_output} + + def set_dtype(self): + #self.dtype = np.complex128 + self.dtype = np.float64 + + def set_input_dims(self): + self.input_dims = (3, 3) + + def set_input_data(self): + if(self.dtype == np.float32 or self.dtype == np.float64): + self.input_data = np.random.random(self.input_dims).astype(self.dtype) + else: + self.input_data = (np.random.random(self.input_dims) + np.random.random(self.input_dims) * 1j).astype(self.dtype) + + def test_check_output(self): + #self.__class__.no_need_check_grad = True + self.check_output_customized(checker = self.verify_output) + + def test_check_grad_normal(self): + self.grad_dtype = self.dtype + if self.dtype == np.float32: + self.grad_dtype = np.complex64 + elif self.dtype == np.float64: + self.grad_dtype = np.complex128 + + self.out_grad = (np.random.random(self.input_dims[-1:]) + np.random.random(self.input_dims[-1:]) * 1j).astype(self.grad_dtype) + self.x_grad = np_eigvals_grad(self.input_data, self.out_grad) + + print("np_eigvals_grad:\n") + print(self.x_grad) + + self.check_grad(['X'], 'Out', + user_defined_grads=[self.x_grad], + user_defined_grad_outputs=[self.out_grad]) + + def verify_output(self, outs): + actual_outs = np.array(outs[0]) + expect_outs = np.array(self.outputs['Out']) + self.assertTrue(actual_outs.shape == expect_outs.shape, + "Output shape has diff." + "\nExpect shape " + str(expect_outs.shape) + + "\n" + "But Got" + str(actual_outs.shape) + + " in class " + self.__class__.__name__) + + n_dim = actual_outs.shape[-1] + for actual_row, expect_row in zip(actual_outs.reshape((-1, n_dim)), expect_outs.reshape((-1, n_dim))): + is_mapped_index = np.zeros((n_dim,)) + for i in range(n_dim): + is_mapped = False + for j in range(n_dim): + if is_mapped_index[j] == 0 and np.isclose(np.array(actual_row[i]), np.array(expect_row[j]), atol=1e-5): + is_mapped_index[j] = True + is_mapped = True + break + self.assertTrue(is_mapped, + "Output has diff in class " + self.__class__.__name__ + + "\nExpect " + str(expect_outs) + + "\n" + "But Got" + str(actual_outs) + + "\nThe data " + str(actual_row[i]) + " in " + + str(actual_row) + " mismatch." + ) + +''' +class TestEigvalsOpFloat64(TestEigvalsOp): + def set_dtype(self): + self.dtype = np.float64 + +class TestEigvalsOpComplex64(TestEigvalsOp): + def set_dtype(self): + self.dtype = np.complex64 + +class TestEigvalsOpComplex128(TestEigvalsOp): + def set_dtype(self): + self.dtype = np.complex128 + + +class TestEigvalsOpLargeScare(TestEigvalsOp): + def set_input_dims(self): + self.input_dims = (128, 128) + +class TestEigvalsOpLargeScareFloat64(TestEigvalsOpLargeScare): + def set_dtype(self): + self.dtype = np.float64 + +class TestEigvalsOpLargeScareComplex64(TestEigvalsOpLargeScare): + def set_dtype(self): + self.dtype = np.complex64 + +class TestEigvalsOpLargeScareComplex128(TestEigvalsOpLargeScare): + def set_dtype(self): + self.dtype = np.complex128 + + +class TestEigvalsOpBatch1(TestEigvalsOp): + def set_input_dims(self): + self.input_dims = (1, 2, 3, 4, 4) + +class TestEigvalsOpBatch2(TestEigvalsOp): + def set_input_dims(self): + self.input_dims = (3, 1, 4, 5, 5) + +class TestEigvalsOpBatch3(TestEigvalsOp): + def set_input_dims(self): + self.input_dims = (6, 2, 9, 6, 6) + + + +class TestEigvalsAPI(unittest.TestCase): + def setUp(self): + self.small_dims = [6, 6] + self.large_dims = [128, 128] + self.batch_dims = [6, 9, 2, 2] + + self.set_dtype() + + self.input_dims = self.small_dims + self.set_input_data() + self.small_input = np.copy(self.input_data) + + self.input_dims = self.large_dims + self.set_input_data() + self.large_input = np.copy(self.input_data) + + self.input_dims = self.batch_dims + self.set_input_data() + self.batch_input = np.copy(self.input_data) + + + def set_dtype(self): + self.dtype = np.float32 + + + def set_input_data(self): + if(self.dtype == np.float32 or self.dtype == np.float64): + self.input_data = np.random.random(self.input_dims).astype(self.dtype) + else: + self.input_data = (np.random.random(self.input_dims) + np.random.random(self.input_dims) * 1j).astype(self.dtype) + + + def verify_output(self, actural_outs, expect_outs): + actual_outs = np.array(actural_outs) + expect_outs = np.array(expect_outs) + self.assertTrue(actual_outs.shape == expect_outs.shape, + "Output shape has diff." + "\nExpect shape " + str(expect_outs.shape) + + "\n" + "But Got" + str(actual_outs.shape) + + " in class " + self.__class__.__name__) + + n_dim = actual_outs.shape[-1] + for actual_row, expect_row in zip(actual_outs.reshape((-1, n_dim)), expect_outs.reshape((-1, n_dim))): + is_mapped_index = np.zeros((n_dim,)) + for i in range(n_dim): + is_mapped = False + for j in range(n_dim): + if is_mapped_index[j] == 0 and np.isclose(np.array(actual_row[i]), np.array(expect_row[j]), atol=1e-5): + is_mapped_index[j] = True + is_mapped = True + break + self.assertTrue(is_mapped, + "Output has diff in class " + self.__class__.__name__ + + "\nExpect " + str(expect_outs) + + "\n" + "But Got" + str(actual_outs) + + "\nThe data " + str(actual_row[i]) + " in " + + str(actual_row) + " mismatch." + ) + + + def run_dygraph(self, place): + paddle.disable_static() + + small_input_tensor = paddle.to_tensor(self.small_input) + large_input_tensor = paddle.to_tensor(self.large_input) + batch_input_tensor = paddle.to_tensor(self.batch_input) + + paddle_outs = paddle.linalg.eigvals(small_input_tensor, name = 'small_x') + np_outs = np_eigvals(self.small_input) + self.verify_output(paddle_outs, np_outs) + + paddle_outs = paddle.linalg.eigvals(large_input_tensor, name = 'large_x') + np_outs = np_eigvals(self.large_input) + self.verify_output(paddle_outs, np_outs) + + paddle_outs = paddle.linalg.eigvals(batch_input_tensor, name = 'small_x') + np_outs = np_eigvals(self.batch_input) + self.verify_output(paddle_outs, np_outs) + + + def run_static(self, place): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + small_input_tensor = paddle.static.data( + name = 'small_x', shape = self.small_dims, dtype = self.dtype) + large_input_tensor = paddle.static.data( + name = 'large_x', shape = self.large_dims, dtype = self.dtype) + batch_input_tensor = paddle.static.data( + name = 'batch_x', shape = self.batch_dims, dtype = self.dtype) + + small_outs = paddle.linalg.eigvals(small_input_tensor, name = 'small_x') + large_outs = paddle.linalg.eigvals(large_input_tensor, name = 'large_x') + batch_outs = paddle.linalg.eigvals(batch_input_tensor, name = 'batch_x') + + exe = paddle.static.Executor(place) + + paddle_outs = exe.run( + feed={ + "small_x": self.small_input, + "large_x": self.large_input, + "batch_x": self.batch_input + }, + fetch_list=[small_outs, large_outs, batch_outs]) + + np_outs = np_eigvals(self.small_input) + self.verify_output(paddle_outs[0], np_outs) + + np_outs = np_eigvals(self.large_input) + self.verify_output(paddle_outs[1], np_outs) + + np_outs = np_eigvals(self.batch_input) + self.verify_output(paddle_outs[2], np_outs) + + + def test_cases(self): + places = [core.CPUPlace()] + #if core.is_compiled_with_cuda(): + # places.append(core.CUDAPlace(0)) + for place in places: + self.run_dygraph(place) + self.run_static(place) + + + def test_error(self): + paddle.disable_static() + x = paddle.to_tensor([1]) + with self.assertRaises(BaseException): + paddle.linalg.eigvals(x) + + self.input_dims = [1, 2, 3, 4] + self.set_input_data() + x = paddle.to_tensor(self.input_data) + with self.assertRaises(BaseException): + paddle.linalg.eigvals(x) + + +class TestEigvalsAPIFloat64(TestEigvalsAPI): + def set_dtype(self): + self.dtype = np.float64 + +class TestEigvalsAPIComplex64(TestEigvalsAPI): + def set_dtype(self): + self.dtype = np.complex64 + +class TestEigvalsAPIComplex128(TestEigvalsAPI): + def set_dtype(self): + self.dtype = np.complex128 +''' +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/linalg.py b/python/paddle/linalg.py index ec6b7aa9e3d82..96dbb4fd4fec4 100644 --- a/python/paddle/linalg.py +++ b/python/paddle/linalg.py @@ -16,10 +16,12 @@ from .tensor.linalg import norm # noqa: F401 from .tensor.linalg import matrix_power # noqa: F401 from .tensor import inverse as inv # noqa: F401 +from .tensor import eigvals # noqa: F401 __all__ = [ 'cholesky', #noqa 'norm', 'inv', - 'matrix_power' + 'matrix_power', + 'eigvals' ] diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 040bec2f67b9e..171f087a7b031 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -45,6 +45,7 @@ from .linalg import histogram # noqa: F401 from .linalg import mv # noqa: F401 from .linalg import matrix_power # noqa: F401 +from .linalg import eigvals # noqa: F401 from .logic import equal # noqa: F401 from .logic import greater_equal # noqa: F401 from .logic import greater_than # noqa: F401 @@ -223,6 +224,7 @@ 'histogram', 'mv', 'matrix_power', + 'eigvals', 'abs', 'acos', 'all', diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 74d9876cddd5c..9a087aa65bc9d 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1011,3 +1011,62 @@ def matrix_power(x, n, name=None): outputs={'Out': out}, attrs={'n': n}) return out + + +def eigvals(x, name=None): + """ + Compute the eigenvalues of one or more general matrices. + + Warning: + The gradient kernel of this operator does not yet developed. If you want to backpropagate through this operator, please replace it with paddle.linalg.eig. + + Args: + x (Tensor): A square matrix or a batch of square matrices whose eigenvalues will be computed. + Its shape should be `[*, M, M]`, where `*` is zero or more batch dimensions. + Its data type should be float32, float64, complex64, or complex128. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: A tensor cointaining the unsorted eigenvalues. The eigenvalues are complex-valued even when `x` is real. + + Examples: + .. code-block:: python + + import paddle + + x = paddle.rand(shape=[3, 3], dtype='float64') + # [[0.12163323, 0.35750244, 0.00040121], #random + # [0.36698967, 0.95818203, 0.40474149], #random + # [0.27632808, 0.63281696, 0.70740548]] #random + + print(paddle.linalg.eigvals(x)) + # [(-0.003106318667270132+0j), (0.3449088087647463+0j), (1.4454182494638632+0j)] #complex128 + """ + + check_variable_and_dtype(x, 'dtype', ['float32', 'float64', 'complex64', 'complex128'], 'eigvals') + + x_shape = list(x.shape) + if len(x_shape) < 2: + raise ValueError( + "The dimension of Input(x) should be at least 2, but received x's dimention = {}, x's shape = {}". + format(len(x_shape), x_shape)) + + if x_shape[-1] != x_shape[-2]: + raise ValueError( + "The last two dimensions of Input(x) should be equal, but received x's shape = {}". + format(x_shape)) + + if in_dygraph_mode(): + return _C_ops.eigvals(x) + + helper = LayerHelper('eigvals', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='eigvals', + inputs={'X': x}, + outputs={'Out': out}) + return out + + + From 87b11cc621be90c2cccb330a9c29238046b69e07 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Tue, 14 Sep 2021 03:58:52 +0000 Subject: [PATCH 02/11] pre-commit check --- paddle/fluid/operators/eigvals_op.cc | 97 +++++----- paddle/fluid/operators/eigvals_op.h | 182 +++++------------- .../fluid/tests/unittests/test_eigvals_op.py | 169 ++++++++-------- python/paddle/tensor/linalg.py | 26 ++- 4 files changed, 203 insertions(+), 271 deletions(-) diff --git a/paddle/fluid/operators/eigvals_op.cc b/paddle/fluid/operators/eigvals_op.cc index 37420d4232fd0..d8ec98247cd2e 100644 --- a/paddle/fluid/operators/eigvals_op.cc +++ b/paddle/fluid/operators/eigvals_op.cc @@ -17,15 +17,15 @@ limitations under the License. */ namespace paddle { namespace operators { - class EigvalsOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", - "(Tensor), A complex- or real-valued tensor with shape (*, n, n)" - "where * is zero or more batch dimensions"); + "(Tensor), A complex- or real-valued tensor with shape (*, n, n)" + "where * is zero or more batch dimensions"); AddOutput("Out", - "(Tensor) The output tensor with shape (*,n) cointaining the eigenvalues of X."); + "(Tensor) The output tensor with shape (*,n) cointaining the " + "eigenvalues of X."); AddComment(R"DOC(eigvals operator Return the eigenvalues of one or more square matrices. The eigenvalues are complex even when the input matrices are real. )DOC"); @@ -41,84 +41,95 @@ class EigvalsOp : public framework::OperatorWithKernel { DDim x_dims = ctx->GetInputDim("X"); PADDLE_ENFORCE_GE(x_dims.size(), 2, - platform::errors::InvalidArgument( - "The dimensions of Input(X) for Eigvals operator should be at least 2, " - "but received X's dimension = %d, X's shape = [%s].", - x_dims.size(), x_dims)); - - if(ctx->IsRuntime() || !framework::contain_unknown_dim(x_dims)){ + platform::errors::InvalidArgument( + "The dimensions of Input(X) for Eigvals operator " + "should be at least 2, " + "but received X's dimension = %d, X's shape = [%s].", + x_dims.size(), x_dims)); + + if (ctx->IsRuntime() || !framework::contain_unknown_dim(x_dims)) { int last_dim = x_dims.size() - 1; - PADDLE_ENFORCE_EQ(x_dims[last_dim], x_dims[last_dim-1], - platform::errors::InvalidArgument( - "The last two dimensions of Input(X) for Eigvals operator should be equal, " - "but received X's shape = [%s].", - x_dims)); + PADDLE_ENFORCE_EQ(x_dims[last_dim], x_dims[last_dim - 1], + platform::errors::InvalidArgument( + "The last two dimensions of Input(X) for Eigvals " + "operator should be equal, " + "but received X's shape = [%s].", + x_dims)); } auto output_dims = vectorize(x_dims); output_dims.resize(x_dims.size() - 1); - ctx->SetOutputDim("Out", framework::make_ddim(output_dims)); + ctx->SetOutputDim("Out", framework::make_ddim(output_dims)); } }; class EigvalsOpVarTypeInference : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext *ctx) const { + void operator()(framework::InferVarTypeContext* ctx) const { auto input_dtype = ctx->GetInputDataType("X"); - auto output_dtype = framework::IsComplexType(input_dtype) ? - input_dtype : framework::ToComplexType(input_dtype); + auto output_dtype = framework::IsComplexType(input_dtype) + ? input_dtype + : framework::ToComplexType(input_dtype); ctx->SetOutputDataType("Out", output_dtype); } }; class EigvalsGradOp : public framework::OperatorWithKernel { -public: + public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "EigvalsGrad"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", "Out@Grad", "EigvalsGrad"); OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", "X@Grad", "EigvalsGrad"); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); - } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } }; template -class EigvalsGradOpMaker : public framework::SingleGradOpMaker{ +class EigvalsGradOpMaker : public framework::SingleGradOpMaker { public: using framework::SingleGradOpMaker::SingleGradOpMaker; + protected: - void Apply(GradOpPtr retv) const override{ + void Apply(GradOpPtr retv) const override { retv->SetType("eigvals_grad"); retv->SetInput("X", this->Input("X")); retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); } }; - - - } // namespace operators } // namespace paddle - namespace ops = paddle::operators; namespace plat = paddle::platform; -REGISTER_OPERATOR(eigvals, - ops::EigvalsOp, ops::EigvalsOpMaker, ops::EigvalsOpVarTypeInference, - ops::EigvalsGradOpMaker, - ops::EigvalsGradOpMaker); +REGISTER_OPERATOR(eigvals, ops::EigvalsOp, ops::EigvalsOpMaker, + ops::EigvalsOpVarTypeInference, + ops::EigvalsGradOpMaker, + ops::EigvalsGradOpMaker); REGISTER_OPERATOR(eigvals_grad, ops::EigvalsGradOp); -REGISTER_OP_CPU_KERNEL(eigvals, - ops::EigvalsKernel, - ops::EigvalsKernel, - ops::EigvalsKernel>, - ops::EigvalsKernel>); - +REGISTER_OP_CPU_KERNEL(eigvals, + ops::EigvalsKernel, + ops::EigvalsKernel, + ops::EigvalsKernel>, + ops::EigvalsKernel>); + // TODO(Ruibiao): Support gradient kernel for Eigvals OP -REGISTER_OP_CPU_KERNEL(eigvals_grad, - ops::EigvalsGradKernel, - ops::EigvalsGradKernel, - ops::EigvalsGradKernel>, - ops::EigvalsGradKernel>); \ No newline at end of file +// REGISTER_OP_CPU_KERNEL(eigvals_grad, +// ops::EigvalsGradKernel, +// ops::EigvalsGradKernel, +// ops::EigvalsGradKernel>, +// ops::EigvalsGradKernel>); diff --git a/paddle/fluid/operators/eigvals_op.h b/paddle/fluid/operators/eigvals_op.h index 2d601f2ce3935..bea1a52cbf7b6 100644 --- a/paddle/fluid/operators/eigvals_op.h +++ b/paddle/fluid/operators/eigvals_op.h @@ -14,207 +14,115 @@ #pragma once -#include #include +#include #include "Eigen/Dense" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { - using Tensor = framework::Tensor; using DDim = framework::DDim; -template -struct PaddleComplex{ +template +struct PaddleComplex { using Type = paddle::platform::complex; }; -template <> -struct PaddleComplex>{ +template <> +struct PaddleComplex> { using Type = paddle::platform::complex; }; -template <> -struct PaddleComplex>{ +template <> +struct PaddleComplex> { using Type = paddle::platform::complex; }; -template -struct StdComplex{ +template +struct StdComplex { using Type = std::complex; }; -template <> -struct StdComplex>{ +template <> +struct StdComplex> { using Type = std::complex; }; -template <> -struct StdComplex>{ +template <> +struct StdComplex> { using Type = std::complex; }; -template +template using PaddleCType = typename PaddleComplex::Type; -template +template using StdCType = typename StdComplex::Type; -template +template using EigenMatrixPaddle = Eigen::Matrix; -template +template using EigenVectorPaddle = Eigen::Matrix, Eigen::Dynamic, 1>; -template -using EigenMatrixStd = Eigen::Matrix, Eigen::Dynamic, Eigen::Dynamic>; -template +template +using EigenMatrixStd = + Eigen::Matrix, Eigen::Dynamic, Eigen::Dynamic>; +template using EigenVectorStd = Eigen::Matrix, Eigen::Dynamic, 1>; -static void SpiltBatchSquareMatrix(const Tensor* input, std::vector& output){ - DDim input_dims = input -> dims(); +static void SpiltBatchSquareMatrix(const Tensor *input, + std::vector *output) { + DDim input_dims = input->dims(); int last_dim = input_dims.size() - 1; int n_dim = input_dims[last_dim]; - DDim flattened_input_dims, flattened_output_dims; - if(input_dims.size() > 2){ + DDim flattened_input_dims, flattened_output_dims; + if (input_dims.size() > 2) { flattened_input_dims = flatten_to_3d(input_dims, last_dim - 1, last_dim); - } - else{ + } else { flattened_input_dims = framework::make_ddim({1, n_dim, n_dim}); } Tensor flattened_input; flattened_input.ShareDataWith(*input); - flattened_input.Resize(flattened_input_dims); - output = flattened_input.Split(1, 0); + flattened_input.Resize(flattened_input_dims); + (*output) = flattened_input.Split(1, 0); } template class EigvalsKernel : public framework::OpKernel { -public: + public: void Compute(const framework::ExecutionContext &ctx) const override { const Tensor *input = ctx.Input("X"); Tensor *output = ctx.Output("Out"); - - auto input_type = input -> type(); - auto output_type = framework::IsComplexType(input_type) ? - input_type : framework::ToComplexType(input_type); - output -> mutable_data(ctx.GetPlace(), output_type); + + auto input_type = input->type(); + auto output_type = framework::IsComplexType(input_type) + ? input_type + : framework::ToComplexType(input_type); + output->mutable_data(ctx.GetPlace(), output_type); std::vector input_matrices; SpiltBatchSquareMatrix(input, /*->*/ input_matrices); - + int n_dim = input_matrices[0].dims()[1]; int n_batch = input_matrices.size(); DDim output_dims = output->dims(); - output -> Resize(framework::make_ddim({n_batch, n_dim})); + output->Resize(framework::make_ddim({n_batch, n_dim})); std::vector output_vectors = output->Split(1, 0); Eigen::Map> input_emp(NULL, n_dim, n_dim); Eigen::Map> output_evp(NULL, n_dim); EigenMatrixStd input_ems; EigenVectorStd output_evs; - - for(int i = 0; i < n_batch; ++i){ + + for (int i = 0; i < n_batch; ++i) { new (&input_emp) Eigen::Map>( - input_matrices[i].data(), n_dim, n_dim); + input_matrices[i].data(), n_dim, n_dim); new (&output_evp) Eigen::Map>( - output_vectors[i].data>(), n_dim); + output_vectors[i].data>(), n_dim); input_ems = input_emp.template cast>(); output_evs = input_ems.eigenvalues(); output_evp = output_evs.template cast>(); } - output -> Resize(output_dims); - } -}; - - -template -inline void CastToPaddleType( - EigenMatrixStd& input, - Eigen::Map>& output){ - output = input.template cast(); -} -template<> -inline void CastToPaddleType( - EigenMatrixStd& input, - Eigen::Map>& output){ - output = input.real(); -} -template<> -inline void CastToPaddleType( - EigenMatrixStd& input, - Eigen::Map>& output){ - output = input.real(); -} - -template -class EigvalsGradKernel : public framework::OpKernel { -public: - void Compute(const framework::ExecutionContext &ctx) const override { - const Tensor* input = ctx.Input("X"); - const Tensor* output_grad = ctx.Input(framework::GradVarName("Out")); - Tensor* input_grad = ctx.Output(framework::GradVarName("X")); - input_grad -> mutable_data(ctx.GetPlace(), input -> type()); - - Tensor output; - output.Resize(output_grad -> dims()); - output.mutable_data(ctx.GetPlace(), output_grad -> type()); - - std::vector input_matrices, input_grad_matrices; - SpiltBatchSquareMatrix(input, /*->*/ input_matrices); - SpiltBatchSquareMatrix(input_grad, /*->*/ input_grad_matrices); - - int n_dim = input_matrices[0].dims()[1]; - int n_batch = input_matrices.size(); - - Tensor flattened_output_grad; - flattened_output_grad.ShareDataWith(*output_grad); - flattened_output_grad.Resize(framework::make_ddim({n_batch, n_dim})); - std::vector output_grad_vectors = flattened_output_grad.Split(1, 0); - - /* - input_emp -> input_ems -> v_ems -> vh_ems - output_grad_evp -> output_grad_evs - output_grad_evs + vh_ems -> input_grad_ems -> input_grad_emp - */ - Eigen::Map> input_emp(NULL, n_dim, n_dim); - Eigen::Map> input_grad_emp(NULL, n_dim, n_dim); - Eigen::Map> output_grad_evp(NULL, n_dim); - EigenMatrixStd input_ems; - EigenVectorStd output_grad_evs; - EigenMatrixStd v_ems; - EigenMatrixStd input_grad_ems; - - for(std::vector::size_type i = 0; i < input_matrices.size(); ++i){ - new (&input_emp) Eigen::Map>( - input_matrices[i].data(), n_dim, n_dim); - new (&input_grad_emp) Eigen::Map>( - input_grad_matrices[i].data(), n_dim, n_dim); - new (&output_grad_evp) Eigen::Map>( - output_grad_vectors[i].data>(), n_dim); - - /** - * Let the input square matrix - * A = VLV^{-1}, - * the gradient of A for eigenvalues L is - * A_grad = V^{-H} L_grad V^H. - * See Eq. 4.77 in https://arxiv.org/pdf/1701.00392.pdf - */ - input_ems = input_emp.template cast>(); - output_grad_evs = output_grad_evp.template cast>(); -VLOG(4) << "input_ems:\n" << input_ems; -VLOG(4) << "output_grad_evs:\n" << output_grad_evs; - Eigen::ComplexEigenSolver> es(input_ems); - v_ems = es.eigenvectors(); -VLOG(4) << "eigenvalues:\n" << es.eigenvalues(); -VLOG(4) << "v_ems:\n" << v_ems; - v_ems.adjointInPlace(); -VLOG(4) << "vh_ems:\n" << v_ems; - input_grad_ems = v_ems.colPivHouseholderQr().solve(output_grad_evs.asDiagonal() * v_ems); -VLOG(4) << "input_grad_ems:\n" << input_grad_ems; - CastToPaddleType(input_grad_ems, /*->*/ input_grad_emp); -VLOG(4) << "input_grad_emp:\n" << input_grad_emp; - } - - + output->Resize(output_dims); } }; } // namespace operators diff --git a/python/paddle/fluid/tests/unittests/test_eigvals_op.py b/python/paddle/fluid/tests/unittests/test_eigvals_op.py index 81f4f9a7826e1..b71178c86d031 100644 --- a/python/paddle/fluid/tests/unittests/test_eigvals_op.py +++ b/python/paddle/fluid/tests/unittests/test_eigvals_op.py @@ -21,15 +21,17 @@ np.set_printoptions(threshold=np.inf) + def np_eigvals(a): res = np.linalg.eigvals(a) - if(a.dtype == np.float32 or a.dtype == np.complex64): + if (a.dtype == np.float32 or a.dtype == np.complex64): res = res.astype(np.complex64) else: res = res.astype(np.complex128) return res + def np_eigvals_grad(a, out_grad): l, v = np.linalg.eig(a) print("l:") @@ -45,9 +47,7 @@ def np_eigvals_grad(a, out_grad): print("a_grad") print(a_grad) - return a_grad.astype(a.dtype) - class TestEigvalsOp(OpTest): @@ -71,15 +71,18 @@ def set_input_dims(self): self.input_dims = (5, 5) def set_input_data(self): - if(self.dtype == np.float32 or self.dtype == np.float64): - self.input_data = np.random.random(self.input_dims).astype(self.dtype) + if (self.dtype == np.float32 or self.dtype == np.float64): + self.input_data = np.random.random(self.input_dims).astype( + self.dtype) else: - self.input_data = (np.random.random(self.input_dims) + np.random.random(self.input_dims) * 1j).astype(self.dtype) + self.input_data = ( + np.random.random(self.input_dims) + + np.random.random(self.input_dims) * 1j).astype(self.dtype) def test_check_output(self): self.__class__.no_need_check_grad = True - self.check_output_customized(checker = self.verify_output) - + self.check_output_customized(checker=self.verify_output) + ''' The gradient kernel of this operator does not yet developed. def test_check_grad_normal(self): self.grad_dtype = self.dtype @@ -90,7 +93,7 @@ def test_check_grad_normal(self): self.out_grad = (np.random.random(self.input_dims[-1:]) + np.random.random(self.input_dims[-1:]) * 1j).astype(self.grad_dtype) self.x_grad = np_eigvals_grad(self.input_data, self.out_grad) - + print("np_eigvals_grad:\n") print(self.x_grad) @@ -102,39 +105,44 @@ def test_check_grad_normal(self): def verify_output(self, outs): actual_outs = np.array(outs[0]) expect_outs = np.array(self.outputs['Out']) - self.assertTrue(actual_outs.shape == expect_outs.shape, - "Output shape has diff." - "\nExpect shape " + str(expect_outs.shape) + - "\n" + "But Got" + str(actual_outs.shape) + - " in class " + self.__class__.__name__) + self.assertTrue( + actual_outs.shape == expect_outs.shape, "Output shape has diff." + "\nExpect shape " + str(expect_outs.shape) + "\n" + "But Got" + + str(actual_outs.shape) + " in class " + self.__class__.__name__) n_dim = actual_outs.shape[-1] - for actual_row, expect_row in zip(actual_outs.reshape((-1, n_dim)), expect_outs.reshape((-1, n_dim))): - is_mapped_index = np.zeros((n_dim,)) + for actual_row, expect_row in zip( + actual_outs.reshape((-1, n_dim)), + expect_outs.reshape((-1, n_dim))): + is_mapped_index = np.zeros((n_dim, )) for i in range(n_dim): is_mapped = False for j in range(n_dim): - if is_mapped_index[j] == 0 and np.isclose(np.array(actual_row[i]), np.array(expect_row[j]), atol=1e-5): + if is_mapped_index[j] == 0 and np.isclose( + np.array(actual_row[i]), + np.array(expect_row[j]), + atol=1e-5): is_mapped_index[j] = True is_mapped = True break - self.assertTrue(is_mapped, - "Output has diff in class " + self.__class__.__name__ + - "\nExpect " + str(expect_outs) + - "\n" + "But Got" + str(actual_outs) + - "\nThe data " + str(actual_row[i]) + " in " + - str(actual_row) + " mismatch." - ) - + self.assertTrue( + is_mapped, + "Output has diff in class " + self.__class__.__name__ + + "\nExpect " + str(expect_outs) + "\n" + "But Got" + + str(actual_outs) + "\nThe data " + str(actual_row[i]) + + " in " + str(actual_row) + " mismatch.") + class TestEigvalsOpFloat64(TestEigvalsOp): def set_dtype(self): self.dtype = np.float64 + class TestEigvalsOpComplex64(TestEigvalsOp): def set_dtype(self): self.dtype = np.complex64 + class TestEigvalsOpComplex128(TestEigvalsOp): def set_dtype(self): self.dtype = np.complex128 @@ -143,32 +151,36 @@ def set_dtype(self): class TestEigvalsOpLargeScare(TestEigvalsOp): def set_input_dims(self): self.input_dims = (128, 128) - + + class TestEigvalsOpLargeScareFloat64(TestEigvalsOpLargeScare): def set_dtype(self): self.dtype = np.float64 + class TestEigvalsOpLargeScareComplex64(TestEigvalsOpLargeScare): def set_dtype(self): self.dtype = np.complex64 + class TestEigvalsOpLargeScareComplex128(TestEigvalsOpLargeScare): def set_dtype(self): self.dtype = np.complex128 class TestEigvalsOpBatch1(TestEigvalsOp): - def set_input_dims(self): - self.input_dims = (1, 2, 3, 4, 4) + def set_input_dims(self): + self.input_dims = (1, 2, 3, 4, 4) + class TestEigvalsOpBatch2(TestEigvalsOp): - def set_input_dims(self): - self.input_dims = (3, 1, 4, 5, 5) + def set_input_dims(self): + self.input_dims = (3, 1, 4, 5, 5) -class TestEigvalsOpBatch3(TestEigvalsOp): - def set_input_dims(self): - self.input_dims = (6, 2, 9, 6, 6) +class TestEigvalsOpBatch3(TestEigvalsOp): + def set_input_dims(self): + self.input_dims = (6, 2, 9, 6, 6) class TestEigvalsAPI(unittest.TestCase): @@ -191,46 +203,48 @@ def setUp(self): self.set_input_data() self.batch_input = np.copy(self.input_data) - def set_dtype(self): self.dtype = np.float32 - def set_input_data(self): - if(self.dtype == np.float32 or self.dtype == np.float64): - self.input_data = np.random.random(self.input_dims).astype(self.dtype) + if (self.dtype == np.float32 or self.dtype == np.float64): + self.input_data = np.random.random(self.input_dims).astype( + self.dtype) else: - self.input_data = (np.random.random(self.input_dims) + np.random.random(self.input_dims) * 1j).astype(self.dtype) - + self.input_data = ( + np.random.random(self.input_dims) + + np.random.random(self.input_dims) * 1j).astype(self.dtype) def verify_output(self, actural_outs, expect_outs): actual_outs = np.array(actural_outs) expect_outs = np.array(expect_outs) - self.assertTrue(actual_outs.shape == expect_outs.shape, - "Output shape has diff." - "\nExpect shape " + str(expect_outs.shape) + - "\n" + "But Got" + str(actual_outs.shape) + - " in class " + self.__class__.__name__) + self.assertTrue( + actual_outs.shape == expect_outs.shape, "Output shape has diff." + "\nExpect shape " + str(expect_outs.shape) + "\n" + "But Got" + + str(actual_outs.shape) + " in class " + self.__class__.__name__) n_dim = actual_outs.shape[-1] - for actual_row, expect_row in zip(actual_outs.reshape((-1, n_dim)), expect_outs.reshape((-1, n_dim))): - is_mapped_index = np.zeros((n_dim,)) + for actual_row, expect_row in zip( + actual_outs.reshape((-1, n_dim)), + expect_outs.reshape((-1, n_dim))): + is_mapped_index = np.zeros((n_dim, )) for i in range(n_dim): is_mapped = False for j in range(n_dim): - if is_mapped_index[j] == 0 and np.isclose(np.array(actual_row[i]), np.array(expect_row[j]), atol=1e-5): + if is_mapped_index[j] == 0 and np.isclose( + np.array(actual_row[i]), + np.array(expect_row[j]), + atol=1e-5): is_mapped_index[j] = True is_mapped = True break - self.assertTrue(is_mapped, - "Output has diff in class " + self.__class__.__name__ + - "\nExpect " + str(expect_outs) + - "\n" + "But Got" + str(actual_outs) + - "\nThe data " + str(actual_row[i]) + " in " + - str(actual_row) + " mismatch." - ) - - + self.assertTrue( + is_mapped, + "Output has diff in class " + self.__class__.__name__ + + "\nExpect " + str(expect_outs) + "\n" + "But Got" + + str(actual_outs) + "\nThe data " + str(actual_row[i]) + + " in " + str(actual_row) + " mismatch.") + def run_dygraph(self, place): paddle.disable_static() @@ -238,33 +252,35 @@ def run_dygraph(self, place): large_input_tensor = paddle.to_tensor(self.large_input) batch_input_tensor = paddle.to_tensor(self.batch_input) - paddle_outs = paddle.linalg.eigvals(small_input_tensor, name = 'small_x') + paddle_outs = paddle.linalg.eigvals(small_input_tensor, name='small_x') np_outs = np_eigvals(self.small_input) self.verify_output(paddle_outs, np_outs) - paddle_outs = paddle.linalg.eigvals(large_input_tensor, name = 'large_x') + paddle_outs = paddle.linalg.eigvals(large_input_tensor, name='large_x') np_outs = np_eigvals(self.large_input) self.verify_output(paddle_outs, np_outs) - paddle_outs = paddle.linalg.eigvals(batch_input_tensor, name = 'small_x') + paddle_outs = paddle.linalg.eigvals(batch_input_tensor, name='small_x') np_outs = np_eigvals(self.batch_input) self.verify_output(paddle_outs, np_outs) - def run_static(self, place): paddle.enable_static() - with paddle.static.program_guard(paddle.static.Program(), + with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()): small_input_tensor = paddle.static.data( - name = 'small_x', shape = self.small_dims, dtype = self.dtype) + name='small_x', shape=self.small_dims, dtype=self.dtype) large_input_tensor = paddle.static.data( - name = 'large_x', shape = self.large_dims, dtype = self.dtype) + name='large_x', shape=self.large_dims, dtype=self.dtype) batch_input_tensor = paddle.static.data( - name = 'batch_x', shape = self.batch_dims, dtype = self.dtype) + name='batch_x', shape=self.batch_dims, dtype=self.dtype) - small_outs = paddle.linalg.eigvals(small_input_tensor, name = 'small_x') - large_outs = paddle.linalg.eigvals(large_input_tensor, name = 'large_x') - batch_outs = paddle.linalg.eigvals(batch_input_tensor, name = 'batch_x') + small_outs = paddle.linalg.eigvals( + small_input_tensor, name='small_x') + large_outs = paddle.linalg.eigvals( + large_input_tensor, name='large_x') + batch_outs = paddle.linalg.eigvals( + batch_input_tensor, name='batch_x') exe = paddle.static.Executor(place) @@ -285,7 +301,6 @@ def run_static(self, place): np_outs = np_eigvals(self.batch_input) self.verify_output(paddle_outs[2], np_outs) - def test_cases(self): places = [core.CPUPlace()] #if core.is_compiled_with_cuda(): @@ -294,7 +309,6 @@ def test_cases(self): self.run_dygraph(place) self.run_static(place) - def test_error(self): paddle.disable_static() x = paddle.to_tensor([1]) @@ -306,19 +320,22 @@ def test_error(self): x = paddle.to_tensor(self.input_data) with self.assertRaises(BaseException): paddle.linalg.eigvals(x) - -class TestEigvalsAPIFloat64(TestEigvalsAPI): + +class TestEigvalsAPIFloat64(TestEigvalsAPI): def set_dtype(self): - self.dtype = np.float64 + self.dtype = np.float64 -class TestEigvalsAPIComplex64(TestEigvalsAPI): + +class TestEigvalsAPIComplex64(TestEigvalsAPI): def set_dtype(self): - self.dtype = np.complex64 + self.dtype = np.complex64 -class TestEigvalsAPIComplex128(TestEigvalsAPI): + +class TestEigvalsAPIComplex128(TestEigvalsAPI): def set_dtype(self): - self.dtype = np.complex128 + self.dtype = np.complex128 + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index df74f9f286ef8..64cbb68db49c9 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1178,7 +1178,7 @@ def eigvals(x, name=None): Compute the eigenvalues of one or more general matrices. Warning: - The gradient kernel of this operator does not yet developed. If you want to backpropagate through this operator, please replace it with paddle.linalg.eig. + The gradient kernel of this operator does not yet developed. If you need back propagation through this operator, please replace it with paddle.linalg.eig. Args: x (Tensor): A square matrix or a batch of square matrices whose eigenvalues will be computed. @@ -1188,7 +1188,7 @@ def eigvals(x, name=None): For more information, please refer to :ref:`api_guide_Name`. Returns: - Tensor: A tensor cointaining the unsorted eigenvalues. The eigenvalues are complex-valued even when `x` is real. + Tensor: A tensor containing the unsorted eigenvalues which has the same batch dimensions with `x`. The eigenvalues are complex-valued even when `x` is real. Examples: .. code-block:: python @@ -1203,30 +1203,26 @@ def eigvals(x, name=None): print(paddle.linalg.eigvals(x)) # [(-0.003106318667270132+0j), (0.3449088087647463+0j), (1.4454182494638632+0j)] #complex128 """ - - check_variable_and_dtype(x, 'dtype', ['float32', 'float64', 'complex64', 'complex128'], 'eigvals') - + + check_variable_and_dtype(x, 'dtype', + ['float32', 'float64', 'complex64', 'complex128'], + 'eigvals') + x_shape = list(x.shape) if len(x_shape) < 2: raise ValueError( "The dimension of Input(x) should be at least 2, but received x's dimention = {}, x's shape = {}". format(len(x_shape), x_shape)) - + if x_shape[-1] != x_shape[-2]: - raise ValueError( + raise ValueError( "The last two dimensions of Input(x) should be equal, but received x's shape = {}". format(x_shape)) - + if in_dygraph_mode(): return _C_ops.eigvals(x) helper = LayerHelper('eigvals', **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) - helper.append_op( - type='eigvals', - inputs={'X': x}, - outputs={'Out': out}) + helper.append_op(type='eigvals', inputs={'X': x}, outputs={'Out': out}) return out - - - From 4878c3d867c37ae42e3f661d5e044e24118552e3 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Tue, 14 Sep 2021 06:30:47 +0000 Subject: [PATCH 03/11] Adjust code style --- paddle/fluid/operators/eigvals_op.h | 2 +- python/paddle/fluid/tests/unittests/test_eigvals_op.py | 7 ++++--- python/paddle/tensor/linalg.py | 6 ++++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/eigvals_op.h b/paddle/fluid/operators/eigvals_op.h index bea1a52cbf7b6..cfe8d3cd363fa 100644 --- a/paddle/fluid/operators/eigvals_op.h +++ b/paddle/fluid/operators/eigvals_op.h @@ -99,7 +99,7 @@ class EigvalsKernel : public framework::OpKernel { output->mutable_data(ctx.GetPlace(), output_type); std::vector input_matrices; - SpiltBatchSquareMatrix(input, /*->*/ input_matrices); + SpiltBatchSquareMatrix(input, /*->*/ &input_matrices); int n_dim = input_matrices[0].dims()[1]; int n_batch = input_matrices.size(); diff --git a/python/paddle/fluid/tests/unittests/test_eigvals_op.py b/python/paddle/fluid/tests/unittests/test_eigvals_op.py index b71178c86d031..5c0d90a2effd0 100644 --- a/python/paddle/fluid/tests/unittests/test_eigvals_op.py +++ b/python/paddle/fluid/tests/unittests/test_eigvals_op.py @@ -91,7 +91,8 @@ def test_check_grad_normal(self): elif self.dtype == np.float64: self.grad_dtype = np.complex128 - self.out_grad = (np.random.random(self.input_dims[-1:]) + np.random.random(self.input_dims[-1:]) * 1j).astype(self.grad_dtype) + self.out_grad = (np.random.random(self.input_dims[-1:]) + + np.random.random(self.input_dims[-1:]) * 1j).astype(self.grad_dtype) self.x_grad = np_eigvals_grad(self.input_data, self.out_grad) print("np_eigvals_grad:\n") @@ -106,8 +107,8 @@ def verify_output(self, outs): actual_outs = np.array(outs[0]) expect_outs = np.array(self.outputs['Out']) self.assertTrue( - actual_outs.shape == expect_outs.shape, "Output shape has diff." - "\nExpect shape " + str(expect_outs.shape) + "\n" + "But Got" + + actual_outs.shape == expect_outs.shape, "Output shape has diff.\n" + "Expect shape " + str(expect_outs.shape) + "\n" + "But Got" + str(actual_outs.shape) + " in class " + self.__class__.__name__) n_dim = actual_outs.shape[-1] diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 64cbb68db49c9..4b4358de0dfcf 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1178,7 +1178,8 @@ def eigvals(x, name=None): Compute the eigenvalues of one or more general matrices. Warning: - The gradient kernel of this operator does not yet developed. If you need back propagation through this operator, please replace it with paddle.linalg.eig. + The gradient kernel of this operator does not yet developed. + If you need back propagation through this operator, please replace it with paddle.linalg.eig. Args: x (Tensor): A square matrix or a batch of square matrices whose eigenvalues will be computed. @@ -1188,7 +1189,8 @@ def eigvals(x, name=None): For more information, please refer to :ref:`api_guide_Name`. Returns: - Tensor: A tensor containing the unsorted eigenvalues which has the same batch dimensions with `x`. The eigenvalues are complex-valued even when `x` is real. + Tensor: A tensor containing the unsorted eigenvalues which has the same batch dimensions with `x`. + The eigenvalues are complex-valued even when `x` is real. Examples: .. code-block:: python From 16ed5ea3edc6b28fa855eafc92189e2da338b458 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Tue, 14 Sep 2021 07:23:34 +0000 Subject: [PATCH 04/11] Fix conflict --- python/paddle/linalg.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/paddle/linalg.py b/python/paddle/linalg.py index 7622e947f2aae..d647767c59c08 100644 --- a/python/paddle/linalg.py +++ b/python/paddle/linalg.py @@ -24,7 +24,6 @@ 'cholesky', #noqa 'norm', 'inv', - 'matrix_power', 'eigvals', 'matrix_rank', 'svd', From 14c0ef45f923d1bf298c7a6e674d24e3f050ec9e Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Tue, 14 Sep 2021 09:20:00 +0000 Subject: [PATCH 05/11] Improve code style --- paddle/fluid/framework/ddim.cc | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/framework/ddim.cc b/paddle/fluid/framework/ddim.cc index 975711c2d73e8..8bac8b7df6d2d 100644 --- a/paddle/fluid/framework/ddim.cc +++ b/paddle/fluid/framework/ddim.cc @@ -107,24 +107,28 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) { return os; } -DDim flatten_to_3d(const DDim& src, int num_row_dims, int num_col_dims){ - PADDLE_ENFORCE_GE(src.size(), 3, platform::errors::InvalidArgument( - "The rank of src dim should be at least 3 in flatten_to_3d, but received %d.", - src.size())); +DDim flatten_to_3d(const DDim& src, int num_row_dims, int num_col_dims) { + PADDLE_ENFORCE_GE(src.size(), 3, + platform::errors::InvalidArgument( + "The rank of src dim should be at least 3 " + "in flatten_to_3d, but received %d.", + src.size())); PADDLE_ENFORCE_EQ((num_row_dims >= 1 && num_row_dims < src.size()), true, - platform::errors::InvalidArgument( - "The num_row_dims should be inside [1, %d] in flatten_to_3d, but received %d.", - src.size() - 1, num_row_dims)); + platform::errors::InvalidArgument( + "The num_row_dims should be inside [1, %d] " + "in flatten_to_3d, but received %d.", + src.size() - 1, num_row_dims)); PADDLE_ENFORCE_EQ((num_col_dims >= 2 && num_col_dims <= src.size()), true, - platform::errors::InvalidArgument( - "The num_col_dims should be inside [2, %d] in flatten_to_3d, but received %d.", - src.size(), num_col_dims)); + platform::errors::InvalidArgument( + "The num_col_dims should be inside [2, %d] " + "in flatten_to_3d, but received %d.", + src.size(), num_col_dims)); PADDLE_ENFORCE_GE( - num_col_dims, num_row_dims, - platform::errors::InvalidArgument( - "The num_row_dims should be less than num_col_dims in flatten_to_3d," - "but received num_row_dims = %d, num_col_dims = %d.", - num_row_dims, num_col_dims)); + num_col_dims, num_row_dims, + platform::errors::InvalidArgument( + "The num_row_dims should be less than num_col_dims in flatten_to_3d," + "but received num_row_dims = %d, num_col_dims = %d.", + num_row_dims, num_col_dims)); return DDim({product(slice_ddim(src, 0, num_row_dims)), product(slice_ddim(src, num_row_dims, num_col_dims)), From 562e65f04ebc8e48025a5b8d5236feab4e988db8 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Tue, 14 Sep 2021 14:07:14 +0000 Subject: [PATCH 06/11] Modify the test code to ignore testing CUDA kernel --- paddle/fluid/operators/eigvals_op.h | 8 ++++---- python/paddle/fluid/tests/unittests/op_test.py | 6 ++++++ .../fluid/tests/unittests/test_eigvals_op.py | 15 +++++++++------ python/paddle/tensor/linalg.py | 10 ++++++---- 4 files changed, 25 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/operators/eigvals_op.h b/paddle/fluid/operators/eigvals_op.h index cfe8d3cd363fa..998dcd9f1efda 100644 --- a/paddle/fluid/operators/eigvals_op.h +++ b/paddle/fluid/operators/eigvals_op.h @@ -66,9 +66,9 @@ using EigenMatrixStd = template using EigenVectorStd = Eigen::Matrix, Eigen::Dynamic, 1>; -static void SpiltBatchSquareMatrix(const Tensor *input, +static void SpiltBatchSquareMatrix(const Tensor &input, std::vector *output) { - DDim input_dims = input->dims(); + DDim input_dims = input.dims(); int last_dim = input_dims.size() - 1; int n_dim = input_dims[last_dim]; @@ -80,7 +80,7 @@ static void SpiltBatchSquareMatrix(const Tensor *input, } Tensor flattened_input; - flattened_input.ShareDataWith(*input); + flattened_input.ShareDataWith(input); flattened_input.Resize(flattened_input_dims); (*output) = flattened_input.Split(1, 0); } @@ -99,7 +99,7 @@ class EigvalsKernel : public framework::OpKernel { output->mutable_data(ctx.GetPlace(), output_type); std::vector input_matrices; - SpiltBatchSquareMatrix(input, /*->*/ &input_matrices); + SpiltBatchSquareMatrix(*input, /*->*/ &input_matrices); int n_dim = input_matrices[0].dims()[1]; int n_batch = input_matrices.size(); diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 018a979bc5eaa..6a996730f09a9 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -1368,6 +1368,12 @@ def check_output_customized(self, checker, custom_place=None): outs.sort(key=len) checker(outs) + def check_output_with_place_customized(self, checker, place): + outs = self.calc_output(place) + outs = [np.array(out) for out in outs] + outs.sort(key=len) + checker(outs) + def _assert_is_close(self, numeric_grads, analytic_grads, names, max_relative_error, msg_prefix): for a, b, name in six.moves.zip(numeric_grads, analytic_grads, names): diff --git a/python/paddle/fluid/tests/unittests/test_eigvals_op.py b/python/paddle/fluid/tests/unittests/test_eigvals_op.py index 5c0d90a2effd0..64d451b245fcb 100644 --- a/python/paddle/fluid/tests/unittests/test_eigvals_op.py +++ b/python/paddle/fluid/tests/unittests/test_eigvals_op.py @@ -32,6 +32,7 @@ def np_eigvals(a): return res +''' Keep it for testing gradient kernel in the future. def np_eigvals_grad(a, out_grad): l, v = np.linalg.eig(a) print("l:") @@ -48,6 +49,7 @@ def np_eigvals_grad(a, out_grad): print(a_grad) return a_grad.astype(a.dtype) +''' class TestEigvalsOp(OpTest): @@ -81,9 +83,10 @@ def set_input_data(self): def test_check_output(self): self.__class__.no_need_check_grad = True - self.check_output_customized(checker=self.verify_output) + self.check_output_with_place_customized( + checker=self.verify_output, place=core.CPUPlace()) - ''' The gradient kernel of this operator does not yet developed. + ''' The gradient kernel of this operator does not yet develop. def test_check_grad_normal(self): self.grad_dtype = self.dtype if self.dtype == np.float32: @@ -248,10 +251,10 @@ def verify_output(self, actural_outs, expect_outs): def run_dygraph(self, place): paddle.disable_static() - - small_input_tensor = paddle.to_tensor(self.small_input) - large_input_tensor = paddle.to_tensor(self.large_input) - batch_input_tensor = paddle.to_tensor(self.batch_input) + paddle.set_device("cpu") + small_input_tensor = paddle.to_tensor(self.small_input, place=place) + large_input_tensor = paddle.to_tensor(self.large_input, place=place) + batch_input_tensor = paddle.to_tensor(self.batch_input, place=place) paddle_outs = paddle.linalg.eigvals(small_input_tensor, name='small_x') np_outs = np_eigvals(self.small_input) diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 4b4358de0dfcf..8c7685c72edbf 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1197,13 +1197,15 @@ def eigvals(x, name=None): import paddle + paddle.seed(1234) + x = paddle.rand(shape=[3, 3], dtype='float64') - # [[0.12163323, 0.35750244, 0.00040121], #random - # [0.36698967, 0.95818203, 0.40474149], #random - # [0.27632808, 0.63281696, 0.70740548]] #random + # [[0.02773777, 0.93004224, 0.06911496], + # [0.24831591, 0.45733623, 0.07717843], + # [0.48016702, 0.14235102, 0.42620817]]) print(paddle.linalg.eigvals(x)) - # [(-0.003106318667270132+0j), (0.3449088087647463+0j), (1.4454182494638632+0j)] #complex128 + # [(-0.27078833542132674+0j), (0.29962280156230725+0j), (0.8824477020120244+0j)] #complex128 """ check_variable_and_dtype(x, 'dtype', From 2649d99a0a9ee3a611472c23da1bd870e9c0a9e5 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Wed, 15 Sep 2021 02:59:56 +0000 Subject: [PATCH 07/11] Sort ouput data before checking in test code --- python/paddle/fluid/tests/unittests/test_eigvals_op.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_eigvals_op.py b/python/paddle/fluid/tests/unittests/test_eigvals_op.py index 64d451b245fcb..a7bf3efb49e4f 100644 --- a/python/paddle/fluid/tests/unittests/test_eigvals_op.py +++ b/python/paddle/fluid/tests/unittests/test_eigvals_op.py @@ -107,8 +107,8 @@ def test_check_grad_normal(self): ''' def verify_output(self, outs): - actual_outs = np.array(outs[0]) - expect_outs = np.array(self.outputs['Out']) + actual_outs = np.sort(np.array(outs[0])) + expect_outs = np.sort(np.array(self.outputs['Out'])) self.assertTrue( actual_outs.shape == expect_outs.shape, "Output shape has diff.\n" "Expect shape " + str(expect_outs.shape) + "\n" + "But Got" + From 4b57a9140e625ec45682aa3774472b8a3807d901 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Wed, 15 Sep 2021 06:39:20 +0000 Subject: [PATCH 08/11] Set timeout value for UT --- python/paddle/fluid/tests/unittests/CMakeLists.txt | 1 + python/paddle/fluid/tests/unittests/test_eigvals_op.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index f4dca4f1bf49c..b55cffb1f6e6f 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1026,3 +1026,4 @@ if(WITH_GPU OR WITH_ROCM) set_tests_properties(test_rank_attention_op PROPERTIES TIMEOUT 120) endif() set_tests_properties(test_inplace_addto_strategy PROPERTIES TIMEOUT 120) +set_tests_properties(test_eigvals_op PROPERTIES TIMEOUT 400) diff --git a/python/paddle/fluid/tests/unittests/test_eigvals_op.py b/python/paddle/fluid/tests/unittests/test_eigvals_op.py index a7bf3efb49e4f..e3e0448687cbd 100644 --- a/python/paddle/fluid/tests/unittests/test_eigvals_op.py +++ b/python/paddle/fluid/tests/unittests/test_eigvals_op.py @@ -66,7 +66,6 @@ def setUp(self): self.outputs = {'Out': np_output} def set_dtype(self): - #self.dtype = np.complex128 self.dtype = np.float32 def set_input_dims(self): From 8d26031573a04d377644548e444a19bc7f0b844e Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Wed, 15 Sep 2021 11:24:29 +0000 Subject: [PATCH 09/11] Improve API example code to pass CI --- python/paddle/fluid/tests/unittests/test_eigvals_op.py | 3 +++ python/paddle/tensor/linalg.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_eigvals_op.py b/python/paddle/fluid/tests/unittests/test_eigvals_op.py index e3e0448687cbd..9f93023de0768 100644 --- a/python/paddle/fluid/tests/unittests/test_eigvals_op.py +++ b/python/paddle/fluid/tests/unittests/test_eigvals_op.py @@ -54,6 +54,7 @@ def np_eigvals_grad(a, out_grad): class TestEigvalsOp(OpTest): def setUp(self): + np.random.seed(0) paddle.enable_static() self.op_type = "eigvals" self.set_dtype() @@ -188,6 +189,8 @@ def set_input_dims(self): class TestEigvalsAPI(unittest.TestCase): def setUp(self): + np.random.seed(0) + self.small_dims = [6, 6] self.large_dims = [128, 128] self.batch_dims = [6, 9, 2, 2] diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 8c7685c72edbf..cd7285db5fbcd 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1196,7 +1196,8 @@ def eigvals(x, name=None): .. code-block:: python import paddle - + + paddle.set_device("cpu") paddle.seed(1234) x = paddle.rand(shape=[3, 3], dtype='float64') From 7b1aa3cd5c3d50d8cc90dbecc5a7d97cc7886761 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Thu, 16 Sep 2021 12:02:58 +0000 Subject: [PATCH 10/11] Fix bug for None fetch_list in Windows --- python/paddle/tensor/linalg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 9a558150099bf..805ec014ed3e2 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1173,7 +1173,6 @@ def matrix_power(x, n, name=None): return out - def eigvals(x, name=None): """ Compute the eigenvalues of one or more general matrices. @@ -1231,8 +1230,9 @@ def eigvals(x, name=None): helper = LayerHelper('eigvals', **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) helper.append_op(type='eigvals', inputs={'X': x}, outputs={'Out': out}) + return out + - def multi_dot(x, name=None): """ Multi_dot is an operator that calculates multiple matrix multiplications. From 971bab0ee7e7d530efcfff8f961e9a0196ba8bc9 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Fri, 17 Sep 2021 07:28:30 +0000 Subject: [PATCH 11/11] Delete grad Op --- paddle/fluid/operators/eigvals_op.cc | 48 +------------------ .../fluid/tests/unittests/test_eigvals_op.py | 40 ---------------- 2 files changed, 1 insertion(+), 87 deletions(-) diff --git a/paddle/fluid/operators/eigvals_op.cc b/paddle/fluid/operators/eigvals_op.cc index d8ec98247cd2e..dcf350190951e 100644 --- a/paddle/fluid/operators/eigvals_op.cc +++ b/paddle/fluid/operators/eigvals_op.cc @@ -73,50 +73,13 @@ class EigvalsOpVarTypeInference : public framework::VarTypeInference { ctx->SetOutputDataType("Out", output_dtype); } }; - -class EigvalsGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "EigvalsGrad"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", - "Out@Grad", "EigvalsGrad"); - OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", - "X@Grad", "EigvalsGrad"); - ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); - } - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); - } -}; - -template -class EigvalsGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr retv) const override { - retv->SetType("eigvals_grad"); - retv->SetInput("X", this->Input("X")); - retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - } -}; } // namespace operators } // namespace paddle namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OPERATOR(eigvals, ops::EigvalsOp, ops::EigvalsOpMaker, - ops::EigvalsOpVarTypeInference, - ops::EigvalsGradOpMaker, - ops::EigvalsGradOpMaker); -REGISTER_OPERATOR(eigvals_grad, ops::EigvalsGradOp); + ops::EigvalsOpVarTypeInference); REGISTER_OP_CPU_KERNEL(eigvals, ops::EigvalsKernel, ops::EigvalsKernel, @@ -124,12 +87,3 @@ REGISTER_OP_CPU_KERNEL(eigvals, paddle::platform::complex>, ops::EigvalsKernel>); - -// TODO(Ruibiao): Support gradient kernel for Eigvals OP -// REGISTER_OP_CPU_KERNEL(eigvals_grad, -// ops::EigvalsGradKernel, -// ops::EigvalsGradKernel, -// ops::EigvalsGradKernel>, -// ops::EigvalsGradKernel>); diff --git a/python/paddle/fluid/tests/unittests/test_eigvals_op.py b/python/paddle/fluid/tests/unittests/test_eigvals_op.py index 9f93023de0768..eff9d4ea6e801 100644 --- a/python/paddle/fluid/tests/unittests/test_eigvals_op.py +++ b/python/paddle/fluid/tests/unittests/test_eigvals_op.py @@ -32,26 +32,6 @@ def np_eigvals(a): return res -''' Keep it for testing gradient kernel in the future. -def np_eigvals_grad(a, out_grad): - l, v = np.linalg.eig(a) - print("l:") - print(l) - print("v:") - print(v) - vh = v.conj().T - print("vh:") - print(vh) - print("out_grad:") - print(out_grad) - a_grad = np.linalg.solve(vh, np.diagflat(out_grad, 0) * vh) - print("a_grad") - print(a_grad) - - return a_grad.astype(a.dtype) -''' - - class TestEigvalsOp(OpTest): def setUp(self): np.random.seed(0) @@ -86,26 +66,6 @@ def test_check_output(self): self.check_output_with_place_customized( checker=self.verify_output, place=core.CPUPlace()) - ''' The gradient kernel of this operator does not yet develop. - def test_check_grad_normal(self): - self.grad_dtype = self.dtype - if self.dtype == np.float32: - self.grad_dtype = np.complex64 - elif self.dtype == np.float64: - self.grad_dtype = np.complex128 - - self.out_grad = (np.random.random(self.input_dims[-1:]) + - np.random.random(self.input_dims[-1:]) * 1j).astype(self.grad_dtype) - self.x_grad = np_eigvals_grad(self.input_data, self.out_grad) - - print("np_eigvals_grad:\n") - print(self.x_grad) - - self.check_grad(['X'], 'Out', - user_defined_grads=[self.x_grad], - user_defined_grad_outputs=[self.out_grad]) - ''' - def verify_output(self, outs): actual_outs = np.sort(np.array(outs[0])) expect_outs = np.sort(np.array(self.outputs['Out']))