Skip to content

Commit

Permalink
增加pir模式下op名称检测 (#66382)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lans1ot authored Jul 24, 2024
1 parent 6885613 commit e356e2e
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion test/deprecated/legacy_test/test_nn_sigmoid_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@ def check_static_api(self, place):
exe = paddle.static.Executor(place)
out = exe.run(main_program, feed={'x': self.x}, fetch_list=[y])
np.testing.assert_allclose(out[0], self.y, rtol=1e-05)
self.assertTrue(y.name.startswith("api_sigmoid"))

if paddle.framework.in_pir_mode():
y_name = y.get_defining_op().name()
self.assertTrue(y_name.startswith("pd_op.sigmoid"))
else:
self.assertTrue(y.name.startswith("api_sigmoid"))

def check_dynamic_api(self, place):
paddle.disable_static(place)
Expand Down

0 comments on commit e356e2e

Please sign in to comment.