diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index 0cf2a7a1700c0..c89f034c7ab82 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -183,7 +183,12 @@ class {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{ 'bool': 'pir::BoolAttribute', } -PD_MANUAL_OP_LIST = {'add_n', 'add_n_', 'add_n_with_kernel', 'split_grad'} +PD_MANUAL_OP_LIST = { + 'add_n', + 'add_n_', + 'add_n_with_kernel', + 'split_grad', +} def to_phi_and_fluid_op_name(op_item): diff --git a/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py b/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py index 9dbbcd089c694..9998a40ec2c87 100644 --- a/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py +++ b/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py @@ -31,6 +31,7 @@ "add", "concat", "split", + "split_with_num", "gelu", "matmul", "erf", @@ -45,6 +46,7 @@ 'layer_norm', 'reshape', 'cast', + "scale", 'softmax', 'silu', 'elementwise_pow', @@ -53,7 +55,6 @@ 'slice', 'transpose', 'slice_double', - 'scale', ] vjp_interface_implementation_gen_op_list = [ "tanh", @@ -63,6 +64,7 @@ "add", "concat", "split", + "split_with_num", "gelu", "matmul", "erf", @@ -77,6 +79,7 @@ 'layer_norm', 'reshape', 'cast', + "scale", 'softmax', 'silu', 'elementwise_pow', @@ -85,5 +88,4 @@ 'slice', 'transpose', 'slice_double', - 'scale', ] diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_api.cc b/paddle/fluid/pir/dialect/operator/ir/manual_api.cc index 83f37f348eb43..ba8fc47744ed3 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_api.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_api.cc @@ -28,6 +28,18 @@ pir::OpResult builtin_combine(const std::vector& x) { return combine_op.out(); } +std::vector add_n_grad(std::vector inputs, + pir::Value out_grad) { + std::vector inputs_grad; + for (size_t i = 0; i < inputs.size(); i++) { + paddle::dialect::ScaleOp scale_op = + APIBuilder::Instance().GetBuilder()->Build( + out_grad, 1.0, 0.0, true); + inputs_grad.push_back(scale_op.result(0)); + } + return inputs_grad; +} + pir::OpResult zeros_like(pir::Value x, phi::DataType dtype, const Place& place) { @@ -76,5 +88,23 @@ pir::OpResult embedding_grad(pir::Value x, } } +pir::OpResult split_with_num_grad(std::vector out_grad, int axis) { + auto out_grad_combine_op = + APIBuilder::Instance().GetBuilder()->Build(out_grad); + paddle::dialect::SplitGradOp split_grad_op = + APIBuilder::Instance().GetBuilder()->Build( + out_grad_combine_op.out(), axis); + return split_grad_op.result(0); +} + +pir::OpResult split_with_num_grad(std::vector out_grad, + pir::Value axis) { + auto out_grad_combine_op = + APIBuilder::Instance().GetBuilder()->Build(out_grad); + paddle::dialect::SplitGradOp split_grad_op = + APIBuilder::Instance().GetBuilder()->Build( + out_grad_combine_op.out(), axis); + return split_grad_op.result(0); +} } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_api.h b/paddle/fluid/pir/dialect/operator/ir/manual_api.h index 6c5018adc64d1..7e5aba6fcbaa8 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_api.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_api.h @@ -25,6 +25,9 @@ namespace dialect { pir::OpResult builtin_combine(const std::vector& x); +std::vector add_n_grad(std::vector inputs, + pir::Value out_grad); + pir::OpResult zeros_like(pir::Value x, phi::DataType dtype = phi::DataType::UNDEFINED, const Place& place = {}); @@ -41,5 +44,9 @@ pir::OpResult embedding_grad(pir::Value x, int64_t padding_idx = -1, bool sparse = false); +pir::OpResult split_with_num_grad(std::vector out_grad, int axis); + +pir::OpResult split_with_num_grad(std::vector out_grad, + pir::Value axis); } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index 3dd489fd9e97d..76a94c9950560 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -16,6 +16,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/primitive/rule/vjp/vjp.h" #include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" @@ -148,6 +149,7 @@ void AddNOp::Build(pir::Builder &builder, // NOLINT dense_out.offset()); argument_outputs.push_back(out_dense_tensor_type); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); + ::pir::PassStopGradientsDefaultly(argument); } void AddNOp::InferMeta(phi::InferMetaContext *infer_meta) { diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_op.h index a94803089e1a0..b9f8474755ef7 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.h @@ -12,18 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifdef GET_MANUAL_OP_LIST -#undef GET_MANUAL_OP_LIST -paddle::dialect::AddNOp, paddle::dialect::SplitGradOp, paddle::dialect::IfOp - -#else - #pragma once #include #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" #include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/interface/vjp.h" #include "paddle/fluid/pir/dialect/operator/trait/inplace.h" #include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h" #include "paddle/fluid/pir/dialect/operator/utils/utils.h" @@ -36,7 +31,10 @@ paddle::dialect::AddNOp, paddle::dialect::SplitGradOp, paddle::dialect::IfOp namespace paddle { namespace dialect { -class AddNOp : public pir::Op { +class AddNOp : public pir::Op { public: using Op::Op; static const char *name() { return "pd_op.add_n"; } @@ -51,6 +49,10 @@ class AddNOp : public pir::Op { pir::Value inputs() { return operand_source(0); } pir::OpResult out() { return result(0); } static void InferMeta(phi::InferMetaContext *infer_meta); + static std::vector> Vjp( + pir::Operation *op, + const std::vector> &out_grads, + const std::vector> &stop_gradients); }; class AddN_Op : public pir::Op> AddNOp::Vjp( + pir::Operation* op, + const std::vector>& out_grads, + const std::vector>& stop_gradients) { + AddNOp op_obj = op->dyn_cast(); + + VLOG(6) << "Prepare inputs of add_n_grad"; + + pir::CombineOp combine_op_obj = op_obj.inputs() + .dyn_cast() + .owner() + ->dyn_cast(); + std::vector inputs; + for (size_t idx = 0; idx < combine_op_obj.inputs().size(); idx++) { + inputs.emplace_back( + std::make_shared(combine_op_obj.inputs()[idx])); + } + + Tensor out_grad(std::make_shared(out_grads[0][0])); + + VLOG(6) << "Vjp prepare Prepare attributes of add_n_grad"; + + VLOG(6) << "Vjp prepare call add_n's vjp inteface"; + + std::vector> tensor_res = + primitive::add_n_vjp(inputs, out_grad, stop_gradients); + + VLOG(6) << "Vjp prepare stop gradient of add_n_grad"; + + std::vector> res(tensor_res.size()); + for (size_t i = 0; i < tensor_res.size(); ++i) { + res[i].resize(tensor_res[i].size()); + for (size_t j = 0; j < tensor_res[i].size(); ++j) { + if (tensor_res[i][j].defined()) { + res[i][j] = std::static_pointer_cast( + tensor_res[i][j].impl()) + ->value() + .dyn_cast(); + } + } + } + return res; +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/primitive/backend/manual/manual_backend.h b/paddle/fluid/primitive/backend/manual/manual_backend.h index 438f06034b010..16c1facbd5354 100644 --- a/paddle/fluid/primitive/backend/manual/manual_backend.h +++ b/paddle/fluid/primitive/backend/manual/manual_backend.h @@ -18,6 +18,7 @@ #include #include "paddle/phi/api/include/tensor.h" +#include "paddle/utils/optional.h" namespace paddle { namespace primitive { @@ -28,6 +29,10 @@ using Scalar = paddle::experimental::Scalar; using IntArray = paddle::experimental::IntArray; using DataType = phi::DataType; +template +std::vector add_n_grad(const std::vector& x, + const Tensor& out_grad); + } // namespace backend } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/backend/manual/manual_static_backend.cc b/paddle/fluid/primitive/backend/manual/manual_static_backend.cc index 7d96b4ddfecc2..7b33200336d00 100644 --- a/paddle/fluid/primitive/backend/manual/manual_static_backend.cc +++ b/paddle/fluid/primitive/backend/manual/manual_static_backend.cc @@ -12,7 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/pir/dialect/operator/ir/manual_api.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_api.h" +#include "paddle/fluid/primitive/backend/generated/generated_backend.h" #include "paddle/fluid/primitive/backend/manual/manual_backend.h" #include "paddle/fluid/primitive/primitive/primitive.h" #include "paddle/fluid/primitive/type/lazy_tensor.h" @@ -22,6 +24,26 @@ namespace primitive { namespace backend { using LazyTensor = paddle::primitive::LazyTensor; +template <> +std::vector add_n_grad(const std::vector& x, + const Tensor& out_grad) { + std::vector x_res(x.size()); + std::transform(x.begin(), x.end(), x_res.begin(), [](const Tensor& t) { + return std::static_pointer_cast(t.impl())->value(); + }); + pir::Value out_grad_res = + std::static_pointer_cast(out_grad.impl())->value(); + auto op_res = paddle::dialect::add_n_grad(x_res, out_grad_res); + + std::vector x_grad(op_res.size()); + std::transform(op_res.begin(), + op_res.end(), + x_grad.begin(), + [](const pir::OpResult& res) { + return Tensor(std::make_shared(res)); + }); + return x_grad; +} } // namespace backend } // namespace primitive diff --git a/paddle/fluid/primitive/codegen/gen.py b/paddle/fluid/primitive/codegen/gen.py index 0dced3ec2a3b9..0239f3d702e96 100644 --- a/paddle/fluid/primitive/codegen/gen.py +++ b/paddle/fluid/primitive/codegen/gen.py @@ -45,6 +45,7 @@ 'sum_grad', 'concat_grad', 'split_grad', + 'split_with_num_grad', 'gelu_grad', 'softmax_grad', 'silu_grad', @@ -97,6 +98,7 @@ 'sum_grad', 'concat_grad', 'split_grad', + 'split_with_num_grad', 'gelu_grad', 'softmax_grad', 'silu_grad', @@ -145,6 +147,7 @@ 'slice', 'layer_norm_grad', 'embedding_grad', + 'sqrt', 'uniform', ] diff --git a/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.cc b/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.cc index a882f78c52018..d7c94f9cd274d 100644 --- a/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.cc +++ b/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.cc @@ -24,5 +24,21 @@ #include "paddle/pir/core/operation.h" namespace paddle { -namespace primitive {} // namespace primitive +namespace primitive { + +std::vector> add_n_vjp( + const std::vector& x, + const Tensor& out_grad, + const std::vector>& stop_gradients) { + std::vector> vjp_res; + for (auto arg : stop_gradients) { + vjp_res.push_back(std::vector(arg.size())); + } + auto op_res = backend::add_n_grad(x, out_grad); + vjp_res[0] = op_res; + vjp_res = ConstructVjpResultByStopGradients(vjp_res, stop_gradients); + return vjp_res; +} + +} // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.h b/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.h index 35810f6d652ca..06b702ef9b50e 100644 --- a/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.h +++ b/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.h @@ -23,5 +23,9 @@ namespace paddle { namespace primitive { using IntArray = paddle::experimental::IntArray; +std::vector> add_n_vjp( + const std::vector& x, + const Tensor& out_grad, + const std::vector>& stop_gradients); } // namespace primitive } // namespace paddle diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 8989808846c76..8de465867273c 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -390,6 +390,7 @@ void AddNInferMeta(const std::vector& x, out->set_dims(in_dim); } out->share_lod(*x[0]); + out->set_dtype(x[0]->dtype()); } // TODO(YuanRisheng) This InferMeta is used in Fluid diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index 876072e48b8ba..e33c3a38bff74 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -207,6 +207,23 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set): union_op_flags[i] = False intersection_op_flags[i] = False + # some inputs in no_grad_set but its next op is effective, + # add their defining op here. + total_ops_list = list(total_ops) + for i, op in enumerate(total_ops_list): + if union_op_flags[i] is False: + for result in op.results(): + if result.has_one_use(): + next_op = result.first_use().owner() + if ( + next_op in total_ops + and union_op_flags[total_ops_list.index(next_op)] + is True + ): + union_op_flags[i] = True + else: + continue + effective_ops = [ total_ops[i] for i in range(len(total_ops)) if intersection_op_flags[i] ] diff --git a/test/legacy_test/op_test.py b/test/legacy_test/op_test.py index 80396471e762a..22984f463ccfb 100644 --- a/test/legacy_test/op_test.py +++ b/test/legacy_test/op_test.py @@ -1370,7 +1370,7 @@ def construct_output_dict_by_kernel_sig(ret_tuple, output_sig): result[key][0][ i ].name = self.python_out_sig_sub_name[key][i] - return result + return result def _check_ir_output(self, place, program, feed_map, fetch_list, outs): if os.getenv("FLAGS_NEW_IR_OPTEST") is None: @@ -2087,14 +2087,16 @@ def compare_single_output_with_expect(self, name, expect): expect, expect_np = self.find_expect_value(name) else: expect_np = ( - expect[0] if isinstance(expect, tuple) else expect + expect[0] + if isinstance(expect, (tuple, list)) + else expect ) actual_np, expect_np = self.convert_uint16_to_float_ifneed( actual_np, expect_np ) # modify there for fp32 check self._compare_numpy(name, actual_np, expect_np) - if isinstance(expect, tuple): + if isinstance(expect, (tuple, list)): self._compare_list(name, actual, expect) def compare_outputs_with_expects(self): @@ -3492,6 +3494,11 @@ def _get_gradient( return res + def _find_var_in_pir(self, output_vars, name): + if name in output_vars: + return output_vars[name] + raise AssertionError(name, " not in outputs:", output_vars.keys()) + def _get_ir_gradient( self, inputs_to_check, @@ -3572,13 +3579,36 @@ def construct_output_dict_by_kernel_sig(ret_tuple, output_sig): # cast outputs if self.dtype == np.uint16: - for output in outputs: - outputs[output][0] = paddle.cast( - outputs[output][0], - paddle.base.core.DataType.FLOAT32, - ) + cast_inputs = [] + for output_name in output_names: + cast_input = self._find_var_in_pir(outputs, output_name) + cast_inputs = cast_inputs + cast_input + cast_outputs = [] + for cast_input in cast_inputs: + if isinstance( + cast_input, paddle.base.libpaddle.ir.OpResult + ): + cast_outputs.append( + paddle.cast( + cast_input, + paddle.base.core.DataType.FLOAT32, + ) + ) + else: + raise TypeError( + "Unsupported test data type %s." + % type(cast_input) + ) + + outputs = {} + for i in range(len(output_names)): + outputs.update({output_names[i]: [cast_outputs[i]]}) - outputs_valid = outputs + outputs_valid = {} + for output_name in output_names: + outputs_valid[output_name] = self._find_var_in_pir( + outputs, output_name + ) loss_inputs = [] for input_name in inputs_to_check: loss_inputs.append(inputs_dict[input_name]) @@ -3613,7 +3643,6 @@ def construct_output_dict_by_kernel_sig(ret_tuple, output_sig): grad_outputs=grad_outputs, ) fetch_list = list(grad_inputs) - # executor run executor = paddle.static.Executor() outs = executor.run( diff --git a/test/legacy_test/test_split_op.py b/test/legacy_test/test_split_op.py index 9520facc44394..964e127aafb81 100644 --- a/test/legacy_test/test_split_op.py +++ b/test/legacy_test/test_split_op.py @@ -60,7 +60,9 @@ def test_check_output(self): self.check_output(check_new_ir=True) def test_check_grad(self): - self.check_grad(['X'], ['out0', 'out1', 'out2'], check_prim=True) + self.check_grad( + ['X'], ['out0', 'out1', 'out2'], check_prim=True, check_new_ir=True + ) # test with attr(num) @@ -114,7 +116,9 @@ def test_check_output(self): self.check_output(check_new_ir=True) def test_check_grad(self): - self.check_grad(['X'], ['out0', 'out1', 'out2'], check_prim=True) + self.check_grad( + ['X'], ['out0', 'out1', 'out2'], check_prim=True, check_new_ir=True + ) # attr(axis) is Tensor @@ -151,7 +155,7 @@ def test_check_output(self): self.check_output(check_new_ir=True) def test_check_grad(self): - self.check_grad(['X'], ['out0', 'out1', 'out2']) + self.check_grad(['X'], ['out0', 'out1', 'out2'], check_new_ir=True) # attr(sections) is list containing Tensor @@ -199,7 +203,7 @@ def test_check_output(self): self.check_output(check_new_ir=True) def test_check_grad(self): - self.check_grad(['X'], ['out0', 'out1', 'out2']) + self.check_grad(['X'], ['out0', 'out1', 'out2'], check_new_ir=True) class TestSplitOp_unk_section(OpTest): @@ -238,7 +242,9 @@ def test_check_output(self): self.check_output(check_new_ir=True) def test_check_grad(self): - self.check_grad(['X'], ['out0', 'out1', 'out2'], check_prim=True) + self.check_grad( + ['X'], ['out0', 'out1', 'out2'], check_prim=True, check_new_ir=True + ) class TestSplitByrefOp(OpTest): @@ -284,7 +290,9 @@ def test_check_output(self): def test_check_grad(self): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X'], 'out2', check_prim=True) + self.check_grad_with_place( + place, ['X'], 'out2', check_prim=True, check_new_ir=True + ) cls_name = "{}_{}".format(parent.__name__, "BF16Op") TestSplitBF16Op.__name__ = cls_name diff --git a/test/legacy_test/test_sum_op.py b/test/legacy_test/test_sum_op.py index 0738f52f2ba7d..63a68442936ab 100644 --- a/test/legacy_test/test_sum_op.py +++ b/test/legacy_test/test_sum_op.py @@ -31,16 +31,16 @@ def sum_wrapper(X, use_mkldnn=False): - res = 0 + res = paddle.full(shape=X[0].shape, fill_value=0.0, dtype=X[0].dtype) for x in X: - res += x + res = paddle.add(res, x) return res class TestSumOp(OpTest): def setUp(self): self.op_type = "sum" - self.python_api = sum_wrapper + self.python_api = paddle.add_n self.public_python_api = paddle.add_n self.prim_op_type = "comp" self.init_kernel_type() @@ -58,10 +58,12 @@ def init_kernel_type(self): self.dtype = np.float64 def test_check_output(self): - self.check_output(check_prim=True, check_cinn=True) + self.check_output(check_prim=True, check_cinn=True, check_new_ir=True) def test_check_grad(self): - self.check_grad(['x0'], 'Out', check_prim=True, check_cinn=True) + self.check_grad( + ['x0'], 'Out', check_prim=True, check_cinn=True, check_new_ir=True + ) class TestSelectedRowsSumOp(unittest.TestCase): @@ -286,21 +288,23 @@ def create_lod_tensor(self, scope, place, var_name): @unittest.skipIf( not core.is_compiled_with_cuda(), "core is not compiled with CUDA" ) -class TestFP16SumOp(TestSumOp): +class TestAFP16SumOp(TestSumOp): def init_kernel_type(self): self.dtype = np.float16 def test_check_output(self): place = core.CUDAPlace(0) if core.is_float16_supported(place): - self.check_output_with_place(place, check_cinn=True) + self.check_output_with_place( + place, check_cinn=True, check_new_ir=True + ) # FIXME: Because of the precision fp16, max_relative_error # should be 0.15 here. def test_check_grad(self): place = core.CUDAPlace(0) if core.is_float16_supported(place): - self.check_grad(['x0'], 'Out', check_cinn=True) + self.check_grad(['x0'], 'Out', check_cinn=True, check_new_ir=True) def create_test_sum_fp16_class(parent): @@ -326,6 +330,7 @@ def test_w_is_selected_rows(self): class TestSumBF16Op(OpTest): def setUp(self): self.op_type = "sum" + self.python_api = paddle.add_n self.init_kernel_type() x0 = np.random.random((3, 40)).astype(np.float32) x1 = np.random.random((3, 40)).astype(np.float32) @@ -345,11 +350,11 @@ def init_kernel_type(self): def test_check_output(self): # new dynamic graph mode does not support unit16 type - self.check_output(check_dygraph=False) + self.check_output(check_dygraph=False, check_new_ir=True) def test_check_grad(self): # new dynamic graph mode does not support unit16 type - self.check_grad(['x0'], 'Out', check_dygraph=False) + self.check_grad(['x0'], 'Out', check_dygraph=False, check_new_ir=True) class API_Test_Add_n(unittest.TestCase):