Skip to content

Commit

Permalink
batch add inpalce api (#55078)
Browse files Browse the repository at this point in the history
* batch add inpalce api

* fix inplace fn generate

* add test for  new inpalce api

* fix typro

* fix typro

* fix typro

* fix test error

* fix atan2

* remove atan2

* auto genereate inpalce api

* fix inplace generate fn error

* fix windows error

* fix test error

* fix test error

* fix windows ci error

* fix test error

* fix test_error

* fix test error

* fix eigen aliasing error in inplace

* remove elementwise_pow inplace

* fix doc error

* fix test error
  • Loading branch information
GGBond8488 authored Jul 18, 2023
1 parent 5e6645d commit 1930293
Show file tree
Hide file tree
Showing 10 changed files with 377 additions and 79 deletions.
1 change: 0 additions & 1 deletion paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@
func : ElementwiseInferMeta
kernel :
func : elementwise_pow
inplace: (x -> out)
backward : elementwise_pow_grad

- op : embedding
Expand Down
45 changes: 30 additions & 15 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
kernel :
func : abs
data_type : x
inplace: (x -> out)
backward : abs_grad

- op : accuracy
Expand All @@ -26,20 +27,22 @@

- op : acos
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : acos
inplace: (x -> out)
backward : acos_grad

- op : acosh
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : acosh
inplace: (x -> out)
backward : acosh_grad

- op : adagrad_
Expand Down Expand Up @@ -90,12 +93,13 @@

- op : addmm
args : (Tensor input, Tensor x, Tensor y, float beta=1.0, float alpha=1.0)
output : Tensor
output : Tensor(out)
infer_meta :
func : AddmmInferMeta
kernel :
func : addmm
data_type : x
inplace: (input -> out)
backward : addmm_grad

- op : affine_grid
Expand Down Expand Up @@ -176,34 +180,37 @@

- op : asin
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : asin
inplace: (x -> out)
backward : asin_grad

- op : asinh
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : asinh
inplace: (x -> out)
backward : asinh_grad

- op : atan
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : atan
inplace: (x -> out)
backward : atan_grad

- op : atan2
args : (Tensor x, Tensor y)
output : Tensor
output : Tensor(out)
infer_meta :
func : Atan2InferMeta
kernel :
Expand All @@ -212,11 +219,12 @@

- op : atanh
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : atanh
inplace: (x -> out)
backward : atanh_grad

- op : auc
Expand Down Expand Up @@ -524,20 +532,22 @@

- op : cos
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : cos
inplace: (x -> out)
backward : cos_grad

- op : cosh
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : cosh
inplace: (x -> out)
backward : cosh_grad

- op : crop
Expand Down Expand Up @@ -756,11 +766,12 @@

- op : erf
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : erf
inplace : (x -> out)
backward : erf_grad

- op : erfinv
Expand Down Expand Up @@ -806,12 +817,13 @@

- op : expm1
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : expm1
inplace: (x -> out)
backward : expm1_grad

- op : fft_c2c
Expand Down Expand Up @@ -2250,20 +2262,22 @@

- op : sin
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : sin
inplace: (x -> out)
backward : sin_grad

- op : sinh
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : sinh
inplace: (x -> out)
backward : sinh_grad

- op : slogdet
Expand Down Expand Up @@ -2409,11 +2423,12 @@

- op : tan
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : tan
inplace: (x -> out)
backward : tan_grad

- op : tanh
Expand Down
28 changes: 17 additions & 11 deletions paddle/phi/kernels/funcs/activation_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ template <typename T>
struct SinFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Sine<T>());
// Note(GGBond8488): Since Eigen3.3, Behavior like {A = (B * A).cwiseAbs()}
// will give wrong result, details see
// http://eigen.tuxfamily.org/dox/group__TopicAliasing.html
out.device(d) = x.unaryExpr(Sine<T>()).eval();
}
};

Expand Down Expand Up @@ -448,7 +451,7 @@ template <typename T>
struct CosFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Cosine<T>());
out.device(d) = x.unaryExpr(Cosine<T>()).eval();
}
};

Expand Down Expand Up @@ -762,7 +765,10 @@ template <typename T>
struct TanFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Tangent<T>());
// Note(GGBond8488): Since Eigen3.3, Behavior like {A = (B * A).cwiseAbs()}
// will give wrong result, details see
// http://eigen.tuxfamily.org/dox/group__TopicAliasing.html
out.device(d) = x.unaryExpr(Tangent<T>()).eval();
}
};

Expand Down Expand Up @@ -795,7 +801,7 @@ template <typename T>
struct SinhFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Sinh<T>());
out.device(d) = x.unaryExpr(Sinh<T>()).eval();
}
};

Expand All @@ -804,7 +810,7 @@ template <typename T>
struct CoshFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Cosh<T>());
out.device(d) = x.unaryExpr(Cosh<T>()).eval();
}
};

Expand Down Expand Up @@ -855,7 +861,7 @@ template <typename T>
struct AcosFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Acos<T>());
out.device(d) = x.unaryExpr(Acos<T>()).eval();
}
};

Expand Down Expand Up @@ -892,7 +898,7 @@ template <typename T>
struct AsinFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Asin<T>());
out.device(d) = x.unaryExpr(Asin<T>()).eval();
}
};

Expand Down Expand Up @@ -929,7 +935,7 @@ template <typename T>
struct AtanFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Atan<T>());
out.device(d) = x.unaryExpr(Atan<T>()).eval();
}
};

Expand Down Expand Up @@ -977,7 +983,7 @@ template <typename T>
struct AcoshFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Acosh<T>());
out.device(d) = x.unaryExpr(Acosh<T>()).eval();
}
};

Expand Down Expand Up @@ -1014,7 +1020,7 @@ template <typename T>
struct AsinhFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Asinh<T>());
out.device(d) = x.unaryExpr(Asinh<T>()).eval();
}
};

Expand Down Expand Up @@ -1051,7 +1057,7 @@ template <typename T>
struct AtanhFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Atanh<T>());
out.device(d) = x.unaryExpr(Atanh<T>()).eval();
}
};

Expand Down
Loading

0 comments on commit 1930293

Please sign in to comment.