Skip to content

Commit

Permalink
[NewIR] support symbol overload (PaddlePaddle#57164)
Browse files Browse the repository at this point in the history
* add symbol overload

* add test case

* fix code
  • Loading branch information
cyber-pioneer authored Sep 12, 2023
1 parent e8bdafa commit 5de380c
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 14 deletions.
21 changes: 21 additions & 0 deletions paddle/fluid/pybind/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "paddle/fluid/pir/dialect/operator/ir/api_builder.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_api.h"
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"
#include "paddle/fluid/pir/transforms/inplace_pass.h"
#include "paddle/phi/core/enforce.h"
Expand Down Expand Up @@ -407,6 +408,26 @@ void BindOpResult(py::module *m) {
[](OpResult &self, Value &other) {
return self.value_impl() == other.impl();
})
.def("__neg__",
[](OpResult &self) {
return paddle::dialect::scale(self, -1.0, 0.0, true);
})
.def("__add__",
[](OpResult &self, OpResult &other) {
return paddle::dialect::add(self, other);
})
.def("__sub__",
[](OpResult &self, OpResult &other) {
return paddle::dialect::subtract(self, other);
})
.def("__mul__",
[](OpResult &self, OpResult &other) {
return paddle::dialect::multiply(self, other);
})
.def("__truediv__",
[](OpResult &self, OpResult &other) {
return paddle::dialect::divide(self, other);
})
.def("__hash__",
[](OpResult &self) {
return std::hash<pir::Value>{}(self.dyn_cast<pir::Value>());
Expand Down
16 changes: 3 additions & 13 deletions python/paddle/decomposition/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def mean(x, axis, keepdim):
value=value_to_fill,
dtype=sum_x.dtype,
)
res = divide(sum_x, norm)
res = sum_x / norm
return res


Expand All @@ -60,16 +60,6 @@ def gelu_composite(x, approximate):
else:
# gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2)))

cdf = _ir_ops.multiply(
half,
(
_ir_ops.add(
one,
_ir_ops.erf(
_ir_ops.multiply(x, full(x.shape, M_SQRT1_2, x.dtype))
),
)
),
)
out = _ir_ops.multiply(x, cdf)
cdf = half * (one + _ir_ops.erf(x * full(x.shape, M_SQRT1_2, x.dtype)))
out = x * cdf
return out
3 changes: 2 additions & 1 deletion test/ir/new_ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ file(
"test_*.py")
string(REPLACE ".py" "" TEST_INTERP_CASES "${TEST_INTERP_CASES}")

set(TEST_IR_SYSTEM_CASES test_build_model test_pd_inplace_pass)
set(TEST_IR_SYSTEM_CASES test_build_model test_pd_inplace_pass
test_symbol_overload)
list(REMOVE_ITEM TEST_INTERP_CASES ${TEST_IR_SYSTEM_CASES})

foreach(target ${TEST_INTERP_CASES})
Expand Down
111 changes: 111 additions & 0 deletions test/ir/new_ir/test_symbol_overload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

import paddle
from paddle import _ir_ops, nn
from paddle.autograd.ir_backward import grad

paddle.enable_static()


class Net(nn.Layer):
def __init__(self):
super().__init__()

def forward(self, x, y):
z1 = _ir_ops.add(x, y)
z2 = _ir_ops.multiply(x, y)
z3 = _ir_ops.subtract(z1, z2)
z4 = _ir_ops.scale(z3, -1, 0, True)
res = _ir_ops.divide(z3, z4)
return res


class SimbolNet(nn.Layer):
def __init__(self):
super().__init__()

def forward(self, x, y):
z1 = x + y
z2 = x * y
z3 = z1 - z2
z4 = -z3
res = z3 / z4
return res


class TestOpresultSymbol(unittest.TestCase):
def setUp(self):
np.random.seed(2023)
self.shape_x = [2, 1024, 1024]
self.shape_y = [2, 1024, 1024]
self.x = np.random.random(self.shape_x).astype("float32")
self.y = np.random.random(self.shape_y).astype("float32")

def base_net(self):
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program):
net = Net()
x = paddle.static.data('x', self.shape_x, dtype='float32')
y = paddle.static.data('y', self.shape_y, dtype='float32')

res = net(x, y)
gradients = grad(res, (x, y))

exe = paddle.static.Executor()
outs = exe.run(
feed={
'x': self.x,
'y': self.y,
},
fetch_list=[res, gradients[0], gradients[1]],
)
ops = [op.name() for op in main_program.global_block().ops]
return outs, ops

def symbol_net(self):
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program):
net = SimbolNet()
x = paddle.static.data('x', self.shape_x, dtype='float32')
y = paddle.static.data('y', self.shape_y, dtype='float32')

res = net(x, y)
gradients = grad(res, (x, y))

exe = paddle.static.Executor()
outs = exe.run(
feed={
'x': self.x,
'y': self.y,
},
fetch_list=[res, gradients[0], gradients[1]],
)
ops = [op.name() for op in main_program.global_block().ops]
return outs, ops

def test_symbol_overload(self):
res_ref, ops_ref = self.base_net()
res, ops = self.symbol_net()
for ref, actual in zip(res_ref, res):
np.testing.assert_equal(ref, actual)
self.assertEqual(ops_ref, ops)


if __name__ == "__main__":
unittest.main()

0 comments on commit 5de380c

Please sign in to comment.