diff --git a/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py b/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py index b454d3b961035..02938ef122d19 100644 --- a/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py +++ b/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py @@ -57,6 +57,8 @@ 'slice_double', 'poisson', 'gumbel_softmax', + 'tril', + 'triu', ] vjp_interface_implementation_gen_op_list = [ "tanh", @@ -92,4 +94,6 @@ 'slice_double', 'poisson', 'gumbel_softmax', + 'tril', + 'triu', ] diff --git a/paddle/fluid/primitive/codegen/gen.py b/paddle/fluid/primitive/codegen/gen.py index 200b6a05b493f..0d6db93d39f36 100644 --- a/paddle/fluid/primitive/codegen/gen.py +++ b/paddle/fluid/primitive/codegen/gen.py @@ -38,6 +38,8 @@ VJPS = [ + 'tril_grad', + 'triu_grad', 'tanh_grad', 'mean_grad', 'add_grad', @@ -91,6 +93,8 @@ VJP_COMPS = PRIM_VJP + CUSTOM_VJP BACKENDS = [ + 'tril_grad', + 'triu_grad', 'add_n', 'mean', 'sum', diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index e71b7ff65a63a..bb7d0f6142f4e 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -1504,7 +1504,7 @@ def tril(x, diagonal=0, name=None): [5 , 0 , 0 , 0 ], [9 , 10, 0 , 0 ]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.tril(x, diagonal) else: return _tril_triu_op(LayerHelper('tril', **locals())) @@ -1581,7 +1581,7 @@ def triu(x, diagonal=0, name=None): [0 , 10, 11, 12]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.triu(x, diagonal) else: return _tril_triu_op(LayerHelper('triu', **locals())) diff --git a/test/legacy_test/test_tril_triu_op.py b/test/legacy_test/test_tril_triu_op.py index a3add39f00f3f..1c64288dabbe5 100644 --- a/test/legacy_test/test_tril_triu_op.py +++ b/test/legacy_test/test_tril_triu_op.py @@ -45,10 +45,10 @@ def setUp(self): } def test_check_output(self): - self.check_output() + self.check_output(check_new_ir=True) def test_check_grad_normal(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_new_ir=True) def init_dtype(self): self.dtype = np.float64 @@ -86,11 +86,15 @@ def initTestCase(self): self.X = np.arange(1, 101, dtype="float32").reshape([10, -1]) def test_check_output(self): - self.check_output_with_place(core.CUDAPlace(0)) + self.check_output_with_place(core.CUDAPlace(0), check_new_ir=True) def test_check_grad_normal(self): self.check_grad_with_place( - core.CUDAPlace(0), ['X'], 'Out', numeric_grad_delta=0.05 + core.CUDAPlace(0), + ['X'], + 'Out', + numeric_grad_delta=0.05, + check_new_ir=True, )