Skip to content

Commit

Permalink
[Phi] Migrate infermeta and add yaml for solve op (#44379)
Browse files Browse the repository at this point in the history
* migrate solve kernel to phi

* re useless header file, fix a bug in grad_kernel_impl

* add header file in need

* add yaml for solve op

* fix solve_sig.cc ArgumentMapping and update tests case

* disable legacy dygraph check in op_test

* rm solve_op.cc / solve_sig.cc and migrate yaml config

* Update op_test.py

disable legacy dygraph check when check_eager is True
  • Loading branch information
veyron95 authored Jul 19, 2022
1 parent 6fb2958 commit 5dfb87d
Show file tree
Hide file tree
Showing 13 changed files with 238 additions and 282 deletions.
222 changes: 0 additions & 222 deletions paddle/fluid/operators/solve_op.cc

This file was deleted.

10 changes: 10 additions & 0 deletions paddle/phi/api/yaml/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,16 @@
func : poisson
backward : poisson_grad

- api : solve
args : (Tensor x, Tensor y)
output : Tensor
infer_meta :
func : SolveInferMeta
kernel :
func : solve
data_type : x
backward : solve_grad

- api : trace
args : (Tensor x, int offset = 0, int axis1 = 0, int axis2 = 1)
output : Tensor
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/api/yaml/api_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@
outputs :
out : Out

- api : solve
inputs :
{x : X, y : Y}
outputs :
out : Out

- api : trace
inputs :
x : Input
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,16 @@
kernel :
func : poisson_grad

- backward_api : solve_grad
forward : solve (Tensor x, Tensor y) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out, Tensor out_grad)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, y]
kernel :
func : solve_grad

- backward_api : trace_grad
forward : trace (Tensor x, int offset, int axis1, int axis2) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int offset, int axis1, int axis2)
Expand Down
87 changes: 87 additions & 0 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2082,6 +2082,93 @@ void ValueCompareInferMeta(const MetaTensor& x,
out->set_dtype(DataType::BOOL);
}

void SolveInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) {
auto x_dims = x.dims();
auto y_dims = y.dims();

std::vector<int64_t> x_dims_vec = phi::vectorize(x.dims());
std::vector<int64_t> y_dims_vec = phi::vectorize(y.dims());

auto x_dims_n = x_dims_vec.size();
auto y_dims_n = y_dims_vec.size();

PADDLE_ENFORCE_GT(
x_dims_n,
1,
phi::errors::InvalidArgument("The input tensor X's dimensions of SolveOp "
"should be larger than 1. But received X's "
"dimensions = %d, X's shape = [%s]",
x_dims_n,
x_dims));

PADDLE_ENFORCE_GE(y_dims_n,
1,
phi::errors::InvalidArgument(
"The input tensor Y's dimensions of SolveOp "
"should be larger than or equal 1. But received Y's "
"dimensions = %d, Y's shape = [%s]",
y_dims_n,
y_dims));

PADDLE_ENFORCE_EQ(x_dims[x_dims_n - 2],
x_dims[x_dims_n - 1],
phi::errors::InvalidArgument(
"The inner-most 2 dimensions of Input(X) all should "
"be square matrices "
"But received X's shape[-2] = %d and shape[-1] = %d.",
x_dims[x_dims_n - 2],
x_dims[x_dims_n - 1]));

bool x_broadcasted = false, y_broadcasted = false;
bool trans_x = false, trans_y = false;
if (x_dims_n == 1) {
x_dims_vec.insert(x_dims_vec.begin(), 1);
x_dims_n = 2;
x_broadcasted = true;
}

if (y_dims_n == 1) {
y_dims_vec.push_back(1);
y_dims_n = 2;
y_broadcasted = true;
}

size_t M, N;
if (trans_x) {
M = x_dims_vec[x_dims_n - 1];
} else {
M = x_dims_vec[x_dims_n - 2];
}
if (trans_y) {
N = y_dims_vec[y_dims_n - 2];
} else {
N = y_dims_vec[y_dims_n - 1];
}

std::vector<int64_t> new_dims;
if (x_dims_n >= y_dims_n) {
new_dims.assign(x_dims_vec.begin(), x_dims_vec.end() - 2);
} else {
new_dims.assign(y_dims_vec.begin(), y_dims_vec.end() - 2);
}
if (!x_broadcasted) {
new_dims.push_back(M);
}
if (!y_broadcasted) {
new_dims.push_back(N);
}
if (x_broadcasted && y_broadcasted) {
new_dims.push_back(1);
}

auto out_dims = phi::make_ddim(new_dims);

out->set_dims(out_dims);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
out->share_lod(x);
}

} // namespace phi

PD_REGISTER_INFER_META_FN(add_raw, phi::ElementwiseRawInferMeta);
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,4 +305,6 @@ void ValueCompareInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());

void SolveInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out);

} // namespace phi
2 changes: 1 addition & 1 deletion paddle/phi/kernels/impl/solve_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ template <typename T, typename Context>
void SolveGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
const DenseTensor& out,
const DenseTensor& dout,
DenseTensor* dx,
DenseTensor* dy) {
bool is_vector = false;
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/solve_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ template <typename T, typename Context>
void SolveGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
const DenseTensor& out,
const DenseTensor& dout,
DenseTensor* dx,
DenseTensor* dy);

Expand Down
Loading

0 comments on commit 5dfb87d

Please sign in to comment.