Skip to content

Commit

Permalink
【PIR】modify Subtract optest (#57608)
Browse files Browse the repository at this point in the history
* modify ci bug

* add sub test

* modify pd name
  • Loading branch information
xiaoguoguo626807 authored Sep 22, 2023
1 parent 71704ae commit c2ea73e
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 15 deletions.
3 changes: 2 additions & 1 deletion paddle/fluid/pir/transforms/inplace_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ static std::unordered_map<pir::Operation*, std::string> GetInplaceOps(
// NOTE(zhangbo): add_grad cpu kernel can't do inplace, for the reason shown
// in the function: CommonElementwiseBroadcastBackward
// (paddle/phi/kernels/funcs/elementwise_grad_base.h)
if ((upper_op_name == "pd_op.add_grad") &&
if ((upper_op_name == "pd_op.add_grad" ||
upper_op_name == "pd_op.subtract_grad") &&
(upper_op_attrs.at("kernel_key")
.dyn_cast<paddle::dialect::KernelAttribute>()
.data()
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def inverse_sort_op(ops):
sorted_list = []
for op in ops:
for x in op.operands():
if x.source().get_defining_op() in ops_set:
if x.source() and x.source().get_defining_op() in ops_set:
pending_count[x.source().get_defining_op()] += 1

queue = collections.deque()
Expand Down
52 changes: 39 additions & 13 deletions test/legacy_test/test_elementwise_sub_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,12 @@ def init_dtype(self):
self.dtype = np.float64

def test_check_output(self):
self.check_output()
self.check_output(check_new_ir=True)

def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', check_prim=self.check_prim)
self.check_grad(
['X', 'Y'], 'Out', check_prim=self.check_prim, check_new_ir=True
)

def test_check_grad_ingore_x(self):
self.check_grad(
Expand All @@ -56,6 +58,7 @@ def test_check_grad_ingore_x(self):
max_relative_error=0.005,
no_grad_set=set("X"),
check_prim=self.check_prim,
check_new_ir=True,
)

def test_check_grad_ingore_y(self):
Expand All @@ -65,6 +68,7 @@ def test_check_grad_ingore_y(self):
max_relative_error=0.005,
no_grad_set=set('Y'),
check_prim=self.check_prim,
check_new_ir=True,
)

def if_check_prim(self):
Expand Down Expand Up @@ -116,7 +120,12 @@ def test_check_grad_normal(self):
def test_check_grad_ingore_x(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['Y'], 'Out', no_grad_set=set("X"), max_relative_error=0.1
place,
['Y'],
'Out',
no_grad_set=set("X"),
max_relative_error=0.1,
check_new_ir=True,
)

def test_check_grad_ingore_y(self):
Expand All @@ -128,6 +137,7 @@ def test_check_grad_ingore_y(self):
no_grad_set=set('Y'),
max_relative_error=0.1,
check_prim=True,
check_new_ir=True,
)


Expand Down Expand Up @@ -372,10 +382,12 @@ def setUp(self):
}

def test_check_output(self):
self.check_output(check_dygraph=False)
self.check_output(check_dygraph=False, check_new_ir=False)

def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', check_dygraph=False)
self.check_grad(
['X', 'Y'], 'Out', check_dygraph=False, check_new_ir=False
)

def test_check_grad_ingore_x(self):
self.check_grad(
Expand All @@ -384,6 +396,7 @@ def test_check_grad_ingore_x(self):
max_relative_error=0.005,
no_grad_set=set("X"),
check_dygraph=False,
check_new_ir=False,
)

def test_check_grad_ingore_y(self):
Expand All @@ -393,6 +406,7 @@ def test_check_grad_ingore_y(self):
max_relative_error=0.005,
no_grad_set=set('Y'),
check_dygraph=False,
check_new_ir=False,
)


Expand Down Expand Up @@ -427,24 +441,36 @@ def setUp(self):

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, check_dygraph=False)
self.check_output_with_place(
place, check_dygraph=False, check_new_ir=False
)

def test_check_grad_normal(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X', 'Y'], 'Out', check_dygraph=False
place, ['X', 'Y'], 'Out', check_dygraph=False, check_new_ir=False
)

def test_check_grad_ingore_x(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['Y'], 'Out', no_grad_set=set("X"), check_dygraph=False
place,
['Y'],
'Out',
no_grad_set=set("X"),
check_dygraph=False,
check_new_ir=False,
)

def test_check_grad_ingore_y(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', no_grad_set=set('Y'), check_dygraph=False
place,
['X'],
'Out',
no_grad_set=set('Y'),
check_dygraph=False,
check_new_ir=False,
)


Expand Down Expand Up @@ -810,13 +836,11 @@ def init_input_output(self):
self.out = self.x - self.y

def test_check_output(self):
self.check_output()
self.check_output(check_new_ir=False)

def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
check_prim=self.check_prim,
['X', 'Y'], 'Out', check_prim=self.check_prim, check_new_ir=False
)

def test_check_grad_ingore_x(self):
Expand All @@ -825,6 +849,7 @@ def test_check_grad_ingore_x(self):
'Out',
no_grad_set=set("X"),
check_prim=self.check_prim,
check_new_ir=False,
)

def test_check_grad_ingore_y(self):
Expand All @@ -833,6 +858,7 @@ def test_check_grad_ingore_y(self):
'Out',
no_grad_set=set('Y'),
check_prim=self.check_prim,
check_new_ir=False,
)

def if_enable_cinn(self):
Expand Down

0 comments on commit c2ea73e

Please sign in to comment.