Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Prim][PIR] add leaky_relu, sigmoid, instance_norm op forward prim #60564

Merged
merged 8 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@
"dropout",
"full_like",
"gelu",
"instance_norm",
"layer_norm",
"leaky_relu",
"mean",
"pow",
"relu",
"rsqrt",
"sigmoid",
"silu",
"softmax",
"sqrt",
Expand All @@ -44,11 +47,14 @@
"dropout",
"full_like",
"gelu",
"instance_norm",
"layer_norm",
"leaky_relu",
"mean",
"pow",
"relu",
"rsqrt",
"sigmoid",
"silu",
"softmax",
"sqrt",
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@
kernel :
func : hardswish_grad
inplace : (out_grad -> x_grad)
composite : hardswish_grad(x, out_grad, x_grad)

- backward_op : hsigmoid_loss_grad
forward : hsigmoid_loss (Tensor x, Tensor label, Tensor w, Tensor bias, Tensor path, Tensor code, int num_classes, bool is_sparse) -> Tensor(out), Tensor(pre_out), Tensor(w_out)
Expand Down
109 changes: 109 additions & 0 deletions paddle/fluid/primitive/composite/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,115 @@ Tensor gelu_decomp(const Tensor& x, bool approximate) {
}
}

template <typename T>
Tensor sigmoid_decomp(const Tensor& x) {
auto org_dtype = x.dtype();
Tensor x_cast = x;

bool need_cast = is_half_dtype(org_dtype);
if (need_cast) {
x_cast = cast<T>(x, phi::DataType::FLOAT32);
}

// res = 1 / (1 + exp(-x))
auto one = full<T>(common::vectorize(x_cast.dims()), 1, x_cast.dtype());
auto exp_tmp = exp<T>(
full<T>(common::vectorize(x_cast.dims()), -1, x_cast.dtype()) * x_cast);
auto res = one / (one + exp_tmp);
if (need_cast) {
return cast<T>(res, org_dtype);
} else {
return res;
}
}

template <typename T>
Tensor leaky_relu_decomp(const Tensor& x, float negative_slope) {
auto multiply_tmp =
full<T>(phi::vectorize(x.dims()), negative_slope, x.dtype()) * x;
if (negative_slope < 1.0) {
return maximum<T>(x, multiply_tmp);
} else {
return minimum<T>(x, multiply_tmp);
}
}

template <typename T>
std::tuple<Tensor, Tensor, Tensor> instance_norm_decomp(
const Tensor& x,
const paddle::optional<Tensor>& scale,
const paddle::optional<Tensor>& bias,
float epsilon) {
auto org_dtype = x.dtype();
Tensor x_cast = x;

bool need_cast = is_half_dtype(org_dtype);
if (need_cast) {
x_cast = cast<T>(x, phi::DataType::FLOAT32);
}

std::vector<int64_t> axis;
auto x_dim = common::vectorize<int64_t>(x.dims());
for (size_t i = 2; i < x_dim.size(); i++) {
axis.push_back(static_cast<int64_t>(i));
}

// out = (x - mean(x)) / sqrt(var + epsilon))
// var = mean((x-mean(x))^2)
auto mean_ = mean_decomp<T>(x_cast, IntArray(axis), true);
auto difference = x_cast - mean_;
auto var_tmp1 = difference * difference;
auto variance = mean_decomp<T>(var_tmp1, IntArray(axis), true);
auto var_tmp3 = variance + epsilon;
auto rsqrt_var = elementwise_pow<T>(
var_tmp3,
full<T>(common::vectorize(var_tmp3.dims()), 0.5, var_tmp3.dtype()));
auto out = difference / rsqrt_var;

auto scale_ptr = scale.get_ptr();
auto bias_ptr = bias.get_ptr();
std::vector<int64_t> slice_shape(x_dim.size(), 1);
slice_shape[1] = x_dim[1];

Tensor scale_cast;
if (scale_ptr) {
if (slice_shape != scale_ptr->shape()) {
scale_cast = reshape<T>(*scale_ptr, slice_shape);
} else {
scale_cast = *scale_ptr;
}
if (need_cast) {
scale_cast = cast<T>(scale_cast, phi::DataType::FLOAT32);
}
out = out * scale_cast;
}
Tensor bias_cast;
if (bias_ptr) {
if (slice_shape != bias_ptr->shape()) {
bias_cast = reshape<T>(*bias_ptr, slice_shape);
} else {
bias_cast = *bias_ptr;
}
if (need_cast) {
bias_cast = cast<T>(bias_cast, phi::DataType::FLOAT32);
}
out = out + bias_cast;
}

std::vector<int64_t> res_shape(1, -1);
auto mean_out = reshape<T>(mean_, res_shape);
auto variance_out = reshape<T>(1 / rsqrt_var, res_shape);

Tensor res;
if (need_cast) {
res = cast<T>(out, org_dtype);
} else {
res = out;
}

return std::make_tuple(res, mean_out, variance_out);
}

} // namespace details

} // namespace primitive
Expand Down
65 changes: 47 additions & 18 deletions test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def if_enable_cinn(self):
pass

def test_check_output(self):
self.check_output(check_pir=True)
self.check_output(check_pir=True, check_prim_pir=True)

def test_check_grad(self):
if self.dtype == np.float16:
Expand All @@ -411,7 +411,7 @@ def init_dtype(self):

def test_check_output(self):
with paddle.static.scope_guard(paddle.static.Scope()):
self.check_output(check_prim=False)
self.check_output(check_prim=False, check_prim_pir=False)

def test_check_grad(self):
self.check_grad(
Expand All @@ -420,6 +420,7 @@ def test_check_grad(self):
max_relative_error=0.006,
check_prim=False,
check_pir=True,
check_prim_pir=False,
)


Expand All @@ -428,7 +429,9 @@ def init_dtype(self):
self.dtype = np.complex128

def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=False, check_pir=True)
self.check_grad(
['X'], 'Out', check_prim=False, check_pir=True, check_prim_pir=False
)


class TestSigmoid_ZeroDim(TestSigmoid):
Expand Down Expand Up @@ -469,7 +472,9 @@ def if_enable_cinn(self):

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, check_prim=True, check_pir=True)
self.check_output_with_place(
place, check_prim=True, check_pir=True, check_prim_pir=True
)

def test_check_grad(self):
place = core.CUDAPlace(0)
Expand Down Expand Up @@ -2555,7 +2560,7 @@ def if_enable_cinn(self):
pass

def test_check_output(self):
self.check_output(check_prim=True, check_pir=True)
self.check_output(check_prim=True, check_pir=True, check_prim_pir=True)

def test_check_grad(self):
if self.dtype == np.float16:
Expand Down Expand Up @@ -3038,7 +3043,9 @@ def test_check_grad(self):
else False,
only_check_prim=self.if_only_check_prim(),
check_pir=True,
check_prim_pir=True,
check_prim_pir=True
if self.dtype not in [np.complex64, np.complex128]
else False,
)

def test_check_output(self):
Expand Down Expand Up @@ -4832,7 +4839,11 @@ def test_check_grad(self):
)
create_test_act_fp16_class(TestExpm1)
create_test_act_fp16_class(
TestSigmoid, check_prim=True, enable_cinn=True, check_pir=True
TestSigmoid,
check_prim=True,
enable_cinn=True,
check_pir=True,
check_prim_pir=True,
)
create_test_act_fp16_class(
TestSilu, check_prim=True, enable_cinn=True, check_prim_pir=True
Expand Down Expand Up @@ -4929,18 +4940,24 @@ def test_check_grad(self):
create_test_act_fp16_class(TestHardSwish, check_prim=True, check_pir=True)
create_test_act_fp16_class(TestMish, check_pir=True)
create_test_act_fp16_class(
TestLeakyRelu, check_prim=True, enable_cinn=True, check_pir=True
TestLeakyRelu,
check_prim=True,
enable_cinn=True,
check_pir=True,
check_prim_pir=True,
)
create_test_act_fp16_class(
TestLeakyReluAlpha1, check_prim=True, enable_cinn=True, check_prim_pir=True
)
create_test_act_fp16_class(
TestLeakyReluAlpha1, check_prim=True, enable_cinn=True
TestLeakyReluAlpha2, check_prim=True, enable_cinn=True, check_prim_pir=True
)
create_test_act_fp16_class(
TestLeakyReluAlpha2, check_prim=True, enable_cinn=True
TestLeakyReluAlpha3, check_prim=True, enable_cinn=True, check_prim_pir=True
)
create_test_act_fp16_class(
TestLeakyReluAlpha3, check_prim=True, enable_cinn=True
TestLeakyRelu_ZeroDim, check_prim=True, check_prim_pir=True
)
create_test_act_fp16_class(TestLeakyRelu_ZeroDim, check_prim=True)
create_test_act_fp16_class(
TestRsqrt,
check_prim=True,
Expand Down Expand Up @@ -5017,7 +5034,9 @@ def test_check_grad(self):
TestExpFp32_Prim, check_prim=True, check_prim_pir=True
)
create_test_act_bf16_class(TestExpm1)
create_test_act_bf16_class(TestSigmoid, check_prim=True, check_pir=True)
create_test_act_bf16_class(
TestSigmoid, check_prim=True, check_pir=True, check_prim_pir=True
)
create_test_act_bf16_class(TestSilu, check_prim=True, check_prim_pir=True)
create_test_act_bf16_class(TestLogSigmoid)
create_test_act_bf16_class(TestTanh, check_prim=True, check_prim_pir=True)
Expand Down Expand Up @@ -5089,11 +5108,21 @@ def test_check_grad(self):
create_test_act_bf16_class(TestSwish)
create_test_act_bf16_class(TestHardSwish, check_prim=True, check_pir=True)
create_test_act_bf16_class(TestMish, check_pir=True)
create_test_act_bf16_class(TestLeakyRelu, check_prim=True, check_pir=True)
create_test_act_bf16_class(TestLeakyReluAlpha1, check_prim=True)
create_test_act_bf16_class(TestLeakyReluAlpha2, check_prim=True)
create_test_act_bf16_class(TestLeakyReluAlpha3, check_prim=True)
create_test_act_bf16_class(TestLeakyRelu_ZeroDim, check_prim=True)
create_test_act_bf16_class(
TestLeakyRelu, check_prim=True, check_pir=True, check_prim_pir=True
)
create_test_act_bf16_class(
TestLeakyReluAlpha1, check_prim=True, check_prim_pir=True
)
create_test_act_bf16_class(
TestLeakyReluAlpha2, check_prim=True, check_prim_pir=True
)
create_test_act_bf16_class(
TestLeakyReluAlpha3, check_prim=True, check_prim_pir=True
)
create_test_act_bf16_class(
TestLeakyRelu_ZeroDim, check_prim=True, check_prim_pir=True
)
create_test_act_bf16_class(
TestRsqrt, check_prim=True, check_pir=True, check_prim_pir=True
)
Expand Down
2 changes: 1 addition & 1 deletion test/legacy_test/test_instance_norm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def setUp(self):
}

def test_check_output(self):
self.check_output(check_prim=True, check_pir=True)
self.check_output(check_prim=True, check_pir=True, check_prim_pir=True)

def test_check_grad(self):
self.check_grad(
Expand Down
22 changes: 19 additions & 3 deletions test/legacy_test/test_instance_norm_op_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,12 @@ def setUp(self):

def test_check_output(self):
self.check_output(
atol=self.atol, check_prim=self.check_prim, check_pir=True
atol=self.atol,
check_prim=self.check_prim,
check_pir=True,
check_prim_pir=False
if os.getenv("FLAGS_enable_pir_in_executor")
else True,
)

def test_check_grad(self):
Expand Down Expand Up @@ -275,7 +280,13 @@ def set_err_thre(self):
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(
place, atol=self.atol, check_prim=self.check_prim, check_pir=True
place,
atol=self.atol,
check_prim=self.check_prim,
check_pir=True,
check_prim_pir=False
if os.getenv("FLAGS_enable_pir_in_executor")
else True,
)

def test_check_grad(self):
Expand Down Expand Up @@ -350,7 +361,12 @@ def init_shape(self):
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(
place, check_prim=self.check_prim, check_pir=True
place,
check_prim=self.check_prim,
check_pir=True,
check_prim_pir=False
if os.getenv("FLAGS_enable_pir_in_executor")
else True,
)

def test_check_grad(self):
Expand Down