Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[API/OP] Migrate Lstsq op into phi #44318

Merged
merged 22 commits into from
Jul 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 23 additions & 82 deletions paddle/fluid/operators/lstsq_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,92 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/lstsq_op.h"

#include <string>
#include <vector>

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/infermeta/binary.h"

namespace paddle {
namespace operators {

class LstsqOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "LstsqOp");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "LstsqOp");

OP_INOUT_CHECK(ctx->HasOutput("Solution"), "Output", "Solution", "LstsqOp");
OP_INOUT_CHECK(ctx->HasOutput("Rank"), "Output", "Rank", "LstsqOp");
OP_INOUT_CHECK(ctx->HasOutput("SingularValues"),
"Output",
"SingularValues",
"LstsqOp");

auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
int x_rank = x_dims.size();
int y_rank = y_dims.size();

PADDLE_ENFORCE_GE(x_rank,
2,
platform::errors::InvalidArgument(
"Expects input tensor x to be not less than "
"2 dimentions, but got dimention %d",
x_rank));
PADDLE_ENFORCE_GE(y_rank,
2,
platform::errors::InvalidArgument(
"Expects input tensor y to be not less than "
"2 dimentions, but got dimention %d",
y_rank));

PADDLE_ENFORCE_EQ(
x_rank,
y_rank,
platform::errors::InvalidArgument(
"Expects input tensor x and y to have the same dimension "
"but got x's dimention [%d] and y's dimention [%d]",
x_rank,
y_rank));

std::vector<int> batch_dims_vec{};
for (int i = 0; i < x_rank - 2; ++i) {
PADDLE_ENFORCE_EQ(
x_dims[i],
y_dims[i],
platform::errors::InvalidArgument(
"Expects input tensor x and y to have the same batch "
"dimension, but got x's batch dimention [%d] and "
"y's batch dimention [%d] in %d-th dim",
x_dims[i],
y_dims[i],
i));
batch_dims_vec.emplace_back(x_dims[i]);
}

PADDLE_ENFORCE_EQ(
x_dims[x_rank - 2],
y_dims[y_rank - 2],
platform::errors::InvalidArgument(
"Expects input tensor x and y to have the same row dimension "
"of the inner-most 2-dims matrix, "
"but got x's row dimention [%d] and y's row dimention [%d]",
x_dims[x_rank - 2],
y_dims[y_rank - 2]));

ctx->SetOutputDim("Rank", phi::make_ddim(batch_dims_vec));

batch_dims_vec.emplace_back(
std::min(x_dims[x_rank - 2], x_dims[x_rank - 1]));
ctx->SetOutputDim("SingularValues", phi::make_ddim(batch_dims_vec));

batch_dims_vec[x_rank - 2] = x_dims[x_rank - 1];
batch_dims_vec.emplace_back(y_dims[x_rank - 1]);
ctx->SetOutputDim("Solution", phi::make_ddim(batch_dims_vec));
}

protected:
// The output of lstsq is always complex-valued even for real-valued inputs
Expand Down Expand Up @@ -133,6 +58,9 @@ class LstsqOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault("gels");
AddOutput("Solution",
"(Tensor), The output Solution tensor with shape (*, n, k).");
AddOutput("Residuals",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么需要修改op定义

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lstsq API的输出有4个tensor,其中三个是调用第三方库后计算而得;之前第4个输出tensor是在python端通过组合API/OP来计算,迁移时进行了逻辑改进,将其计算过程放在c++端完成,不仅保证c++ Op和python API的表现一致,而且减少了op kernel调度时间。

"(Tensor), The output Residuals tensor with shape (*, k).")
.AsDispensable();
AddOutput("Rank", "(Tensor), The output Rank tensor with shape (*).");
AddOutput(
"SingularValues",
Expand All @@ -148,8 +76,21 @@ This API processes Lstsq functor for general matrices.
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(lstsq, ops::LstsqOp, ops::LstsqOpMaker)

REGISTER_OP_CPU_KERNEL(lstsq,
ops::LstsqCPUKernel<phi::CPUContext, float>,
ops::LstsqCPUKernel<phi::CPUContext, double>);
DECLARE_INFER_SHAPE_FUNCTOR(lstsq,
LstsqInferShapeFunctor,
PD_INFER_META(phi::LstsqInferMeta));

REGISTER_OPERATOR(lstsq,
ops::LstsqOp,
ops::LstsqOpMaker,
LstsqInferShapeFunctor);

REGISTER_OP_VERSION(lstsq).AddCheckpoint(
R"ROC(
Upgrade lstsq, add 1 outputs [Residuals].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewOutput(
"Residuals",
"Output tensor of lstsq operator, "
"meaning the squared residuals of the calculated solutions."));
1 change: 1 addition & 0 deletions paddle/fluid/pybind/op_function_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
"SavedMean",
"SavedVariance",
"ReserveSpace"}},
{"lstsq", {"Solution", "Residuals", "Rank", "SingularValues"}},
{"inplace_abn",
{"Y",
"MeanOut",
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/api/yaml/legacy_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1425,6 +1425,15 @@
func : logsumexp
backward : logsumexp_grad

- api : lstsq
args : (Tensor x, Tensor y, Scalar rcond, str driver)
output : Tensor(solution), Tensor(residuals), Tensor(rank), Tensor(singular_values)
infer_meta :
func : LstsqInferMeta
dtype : x
kernel :
func : lstsq

- api : lu
args : (Tensor x, bool pivot)
output : Tensor(out), Tensor(pivots), Tensor(infos)
Expand Down
84 changes: 84 additions & 0 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2007,6 +2007,90 @@ void TriangularSolveInferMeta(const MetaTensor& x,
out->share_lod(y);
}

void LstsqInferMeta(const MetaTensor& x,
const MetaTensor& y,
const Scalar& rcond,
const std::string& driver,
MetaTensor* solution,
MetaTensor* residuals,
MetaTensor* rank,
MetaTensor* singular_values) {
auto x_dims = x.dims();
auto y_dims = y.dims();
int x_rank = x_dims.size();
int y_rank = y_dims.size();

int m = x_dims[x_rank - 2];
int n = x_dims[x_rank - 1];
int nrhs = y_dims[x_rank - 1];

PADDLE_ENFORCE_GE(
x_rank,
2,
phi::errors::InvalidArgument("Expects input tensor x to be not less than "
"2 dimentions, but got dimention %d",
x_rank));
PADDLE_ENFORCE_GE(
y_rank,
2,
phi::errors::InvalidArgument("Expects input tensor y to be not less than "
"2 dimentions, but got dimention %d",
y_rank));

PADDLE_ENFORCE_EQ(
x_rank,
y_rank,
phi::errors::InvalidArgument(
"Expects input tensor x and y to have the same dimension "
"but got x's dimention [%d] and y's dimention [%d]",
x_rank,
y_rank));

std::vector<int> batch_dims_vec{};
for (int i = 0; i < x_rank - 2; ++i) {
PADDLE_ENFORCE_EQ(x_dims[i],
y_dims[i],
phi::errors::InvalidArgument(
"Expects input tensor x and y to have the same batch "
"dimension, but got x's batch dimention [%d] and "
"y's batch dimention [%d] in %d-th dim",
x_dims[i],
y_dims[i],
i));
batch_dims_vec.emplace_back(x_dims[i]);
}

PADDLE_ENFORCE_EQ(
m,
y_dims[y_rank - 2],
phi::errors::InvalidArgument(
"Expects input tensor x and y to have the same row dimension "
"of the inner-most 2-dims matrix, "
"but got x's row dimention [%d] and y's row dimention [%d]",
m,
y_dims[y_rank - 2]));

rank->set_dims(phi::make_ddim(batch_dims_vec));

if (m > n) {
batch_dims_vec.emplace_back(nrhs);
residuals->set_dims(phi::make_ddim(batch_dims_vec));
batch_dims_vec.pop_back();
} else {
residuals->set_dims(phi::make_ddim({0}));
}
residuals->set_dtype(y.dtype());

batch_dims_vec.emplace_back(std::min(m, n));
singular_values->set_dims(phi::make_ddim(batch_dims_vec));
singular_values->set_dtype(y.dtype());

batch_dims_vec[x_rank - 2] = n;
batch_dims_vec.emplace_back(nrhs);
solution->set_dims(phi::make_ddim(batch_dims_vec));
solution->set_dtype(y.dtype());
}

void YoloBoxInferMeta(const MetaTensor& x,
const MetaTensor& img_size,
const std::vector<int>& anchors,
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,15 @@ void TriangularSolveInferMeta(const MetaTensor& x,
bool unitriangular,
MetaTensor* out);

void LstsqInferMeta(const MetaTensor& x,
const MetaTensor& y,
const Scalar& rcond,
const std::string& driver,
MetaTensor* solution,
MetaTensor* residuals,
MetaTensor* rank,
MetaTensor* singular_values);

void YoloBoxInferMeta(const MetaTensor& x,
const MetaTensor& img_size,
const std::vector<int>& anchors,
Expand Down
Loading