Skip to content

Commit

Permalink
[Prim][PIR] softmax forward sink (PaddlePaddle#58591)
Browse files Browse the repository at this point in the history
* relu forward sink

* softmax sink pir

* remove relu

* pir relu sink

* remove softmax register

* remove python softmax

* update code

* merge code

---------

Co-authored-by: kevincheng2 <cheng112@gmail.com>
  • Loading branch information
kevincheng2 and kevincheng2 authored Nov 13, 2023
1 parent fb476c8 commit 5eb7f6d
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@

# come into effect in generated file pd_op.h
# manual decomp interface declare are located in manual_op.h
decomp_interface_declare_gen_op_list = ["mean", "squeeze", "add_n", "relu"]
decomp_interface_declare_gen_op_list = [
"mean",
"squeeze",
"add_n",
"relu",
"softmax",
]

# come into effect in generated file op_decomp.cc
# manual decomp interface implementation are located in manual_op_decomp.cc
Expand All @@ -28,4 +34,5 @@
"squeeze",
"add_n",
"relu",
"softmax",
]
25 changes: 25 additions & 0 deletions paddle/fluid/primitive/composite/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,31 @@ Tensor mean_decomp(const Tensor& x, const IntArray& axis, bool keepdim) {
}
}

template <typename T>
Tensor softmax_decomp(const Tensor& x, const int& axis) {
auto org_dtype = x.dtype();
auto x_tmp = x;
auto axis_tmp = IntArray({axis});

bool need_cast =
org_dtype == phi::DataType::FLOAT16 || org_dtype == phi::DataType::UINT16;
if (need_cast) {
x_tmp = cast<T>(x, phi::DataType::FLOAT32);
}

auto max_tmp = max<T>(x_tmp, axis_tmp, true);
auto molecular = exp<T>(subtract<T>(x_tmp, max_tmp));

auto denominator = sum<T>(molecular, axis_tmp, molecular.dtype(), true);
auto res = divide<T>(molecular, denominator);

if (need_cast) {
return cast<T>(res, org_dtype);
} else {
return res;
}
}

template <typename T>
Tensor relu_decomp(const Tensor& x) {
return maximum<T>(x, full<T>(phi::vectorize(x.dims()), 0.0, x.dtype()));
Expand Down
27 changes: 0 additions & 27 deletions python/paddle/decomposition/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,33 +248,6 @@ def silu(x):
return res if not is_amp else cast(res, dtype)


@register_decomp('pd_op.softmax')
def softmax(x, axis):
"""define composite rule of op softmax"""
is_amp = False
from paddle.base.data_feeder import convert_dtype

# Softmax need fp32 compute since it has sum op in
dtype = convert_dtype(x.dtype)
if dtype in ["float16", "uint16"]:
is_amp = True
x = cast(x, "float32")
if not x.shape:
# do not return 1, to ensure gradients
res = exp(x - x)
if is_amp:
res = cast(res, "float16")
return res
max_temp = max(x, axis, keepdim=True)
max_temp.stop_gradient = True
molecular = exp(x - max_temp)
denominator = sum(molecular, axis=axis, keepdim=True)
res = divide(molecular, denominator)
if is_amp:
res = cast(res, dtype)
return res


@register_decomp('pd_op.full_like')
def full_like(x, fill_value, dtype, place=None):
"""define composite rule of op full_like."""
Expand Down

0 comments on commit 5eb7f6d

Please sign in to comment.