From b32314b67fa95a7b7e50d48a5a282c52eb2ca239 Mon Sep 17 00:00:00 2001 From: xiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com> Date: Fri, 25 Aug 2023 14:09:30 +0800 Subject: [PATCH] [NewIR]pulish name in backward.py (#56650) * modify sum with divide net bug mutablesttribute * delete prin * pulish backward * pulish backward --- python/paddle/autograd/backward.py | 110 ++++++++++++++--------------- 1 file changed, 52 insertions(+), 58 deletions(-) diff --git a/python/paddle/autograd/backward.py b/python/paddle/autograd/backward.py index e631c02b0bd3e0..13a5acd46a9175 100644 --- a/python/paddle/autograd/backward.py +++ b/python/paddle/autograd/backward.py @@ -52,8 +52,7 @@ def check_all_puts(block, inputs, outputs): def update_no_grad_set_by_stopgradient(block, no_grad_set): for op in block.ops: - for opresult_idx in range(op.num_results()): - value = op.result(opresult_idx) + for value in op.results(): if value.stop_gradient and value not in no_grad_set: no_grad_set.add(value) @@ -63,9 +62,7 @@ def update_bwdop_structure(backward_ops, op_to_opgrad_list, grad_op): op_to_opgrad_list.append(grad_op) -def prepare_grad_outputs( - block, grad_outputs, outputs, value_to_valuegrad, op_to_opgrad -): +def prepare_grad_outputs(grad_outputs, outputs, state): """ if grad_outputs is none, add fill_1 op to create grad_outputs, else check whether outputs shape and dtype is same to grad_outputs, otherwise raise error. @@ -100,10 +97,10 @@ def prepare_grad_outputs( update_bwdop_structure( backward_ops, - op_to_opgrad[output.get_defining_op()], + state.op_to_opgrad[output.get_defining_op()], fillop, ) - value_to_valuegrad[output] = [[output_grad]] + state.value_to_valuegrad[output] = [[output_grad]] else: if output.shape != grad.shape: raise ValueError( @@ -117,9 +114,11 @@ def prepare_grad_outputs( ) feedop = grad.get_defining_op() update_bwdop_structure( - backward_ops, op_to_opgrad[output.get_defining_op()], feedop + backward_ops, + state.op_to_opgrad[output.get_defining_op()], + feedop, ) - value_to_valuegrad[output] = [[grad]] + state.value_to_valuegrad[output] = [[grad]] # add input for bwd first op complete_outputs = outputs @@ -130,7 +129,7 @@ def prepare_grad_outputs( if output in visited_output: continue for opresult in output.get_defining_op().results(): - if opresult in value_to_valuegrad: + if opresult in state.value_to_valuegrad: visited_output.add(opresult) continue else: @@ -143,10 +142,10 @@ def prepare_grad_outputs( update_bwdop_structure( backward_ops, - op_to_opgrad[opresult.get_defining_op()], + state.op_to_opgrad[opresult.get_defining_op()], fillop, ) - value_to_valuegrad[opresult] = [grad_value] + state.value_to_valuegrad[opresult] = [grad_value] visited_output.add(opresult) @@ -196,7 +195,6 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set): # from output to input for i, op in reversed(list(enumerate(total_ops))): - # while op support if some_in_set(op.results(), outputs_set): for operand in op.operands_source(): if operand not in no_grad_set: @@ -233,7 +231,7 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set): def update_no_grad_set_after_prune( - block, effective_forward_op, no_grad_set, inputs, outputs + block, effective_forward_ops, no_grad_set, inputs, outputs ): ''' update no_grad_set after forward prune @@ -249,14 +247,14 @@ def update_no_grad_set_after_prune( if value not in no_grad_set: inputs_set.add(value) - for op in effective_forward_op: + for op in effective_forward_ops: for value in op.operands_source(): - if value not in inputs_set: # and value.get_stopgradient(): + if value not in inputs_set: no_grad_set.add(value) outputs_set = set(outputs) no_grad_set_tmp = set() - for op in reversed(effective_forward_op): + for op in reversed(effective_forward_ops): for output in op.results(): if output not in outputs_set and not some_in_set( [output], set(op.operands_source()) @@ -313,7 +311,7 @@ def inverse_sort_op(ops): def append_backward_ops( - block, effective_forward_op, no_grad_set, backward_ops, state + block, effective_forward_ops, no_grad_set, backward_ops, state ): ''' add grad_op in order of topological inverse sort @@ -417,28 +415,26 @@ def make_output_grad(op): return zero_flag, output_grad def make_input_stopgradient(op): - input_grad_stopgradient_list = [] + input_grad_stopgradients = [] for input in op.operands_source(): if input.get_defining_op().name() == "builtin.combine": stop_gradient = make_input_stopgradient(input.get_defining_op()) - input_grad_stopgradient_list.append( + input_grad_stopgradients.append( [info[0] for info in stop_gradient] ) else: if input in no_grad_set: - input_grad_stopgradient_list.append([True]) + input_grad_stopgradients.append([True]) else: - input_grad_stopgradient_list.append([False]) - return input_grad_stopgradient_list + input_grad_stopgradients.append([False]) + return input_grad_stopgradients - def update_input_grad_map(op, input_grad_list): + def update_input_grad_map(op, input_grads): for i, input in enumerate(op.operands_source()): if input.get_defining_op().name() == "builtin.combine": - update_input_grad_map( - input.get_defining_op(), input_grad_list[i] - ) + update_input_grad_map(input.get_defining_op(), input_grads[i]) else: - input_grad = input_grad_list[i] + input_grad = input_grads[i] if isinstance(input_grad, list): state.value_to_valuegrad[input].append(input_grad) else: @@ -451,31 +447,31 @@ def update_input_grad_map(op, input_grad_list): # [op4] (op4's inputs and outputs are not vectorType) # einsum has twp vectorType outputs, special pattern - clear_effective_forward_op = [] + clear_effective_forward_ops = [] - for op in effective_forward_op: + for op in effective_forward_ops: if op.name() != "builtin.combine" and op.name() != "builtin.split": - clear_effective_forward_op.append(op) + clear_effective_forward_ops.append(op) - for op in clear_effective_forward_op: + for op in clear_effective_forward_ops: if paddle.framework.core.has_vjp(op): # prepare output_grad - output_grad_list = [] # (opresult) + output_grads = [] # (opresult) zero_flag, output_grad = make_output_grad(op) - output_grad_list.append(output_grad) + output_grads.append(output_grad) # all(zero_flag) support this op has no contribution for grad # should be delete (prune sub_graph) - if len(output_grad_list) == 0 or all(zero_flag): + if len(output_grads) == 0 or all(zero_flag): continue # prepare input_grad stop_gradient info. - input_grad_stopgradient_list = make_input_stopgradient(op) + input_grad_stopgradients = make_input_stopgradient(op) # create grad_op before_ops_num = len(block.ops) - input_grad_list = paddle.framework.core.call_vjp( - op, output_grad_list, input_grad_stopgradient_list + input_grads = paddle.framework.core.call_vjp( + op, output_grads, input_grad_stopgradients ) after_ops_num = len(block.ops) @@ -486,7 +482,7 @@ def update_input_grad_map(op, input_grad_list): ) # update input_grad map - update_input_grad_map(op, input_grad_list) + update_input_grad_map(op, input_grads) else: if op.num_operands() == 0 and op.num_results() != 0: @@ -519,15 +515,18 @@ def update_input_grad_map(op, input_grad_list): def create_backward_prune_set(inputs, outputs, no_grad_set, state): outputs_set = set() - for input in inputs: - for item in input.first_use().owner().operands_source(): - if state.value_to_valuegrad[item] != []: - outputs_set.add(state.value_to_valuegrad[item][0][0]) + for input_ in inputs: + if not input_.use_empty(): + for item in input_.first_use().owner().operands_source(): + if state.value_to_valuegrad[item] != []: + outputs_set.add(state.value_to_valuegrad[item][0][0]) + else: + raise ValueError("input privided by inputs has no use") + inputs_set = set() for output in outputs: if state.value_to_valuegrad[output] != []: inputs_set.add(state.value_to_valuegrad[output][0][0]) - inputs_set_tmp = set() for out_grad in inputs_set: for item in out_grad.first_use().owner().operands_source(): @@ -538,7 +537,6 @@ def create_backward_prune_set(inputs, outputs, no_grad_set, state): for key in state.value_to_valuegrad: if key in no_grad_set: no_gradvar_set.add(state.value_to_valuegrad[key][0][0]) - for key in state.value_to_sumvaluegrad: if key in no_grad_set: for item in state.value_to_sumvaluegrad[key][0]: @@ -575,26 +573,22 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set): # update no_grad_set if some value stop_gradient=True update_no_grad_set_by_stopgradient(block, no_grad_set) complete_outputs, _, backward_ops = prepare_grad_outputs( - block, - grad_outputs, - outputs, - state.value_to_valuegrad, - state.op_to_opgrad, + grad_outputs, outputs, state ) inputs_set = set(inputs) outputs_set = set(complete_outputs) - effective_forward_op, _ = prune_ops( + effective_forward_ops, _ = prune_ops( block.ops, inputs_set, outputs_set, no_grad_set ) update_no_grad_set_after_prune( - block, effective_forward_op, no_grad_set, inputs, complete_outputs + block, effective_forward_ops, no_grad_set, inputs, complete_outputs ) - inverse_effective_forward_op = inverse_sort_op(effective_forward_op) + inverse_effective_forward_ops = inverse_sort_op(effective_forward_ops) append_backward_ops( - block, inverse_effective_forward_op, no_grad_set, backward_ops, state + block, inverse_effective_forward_ops, no_grad_set, backward_ops, state ) # now value_to_valuegrad should be value <-> value (add sum op for the same values's gradvalue) @@ -719,26 +713,26 @@ def grad( outputs, 'outputs', ((paddle.ir.Value, paddle.ir.OpResult), list, tuple), - 'paddle.ir.grad', + 'paddle.autograd.backward.grad', ) check_type( inputs, 'inputs', ((paddle.ir.Value, paddle.ir.OpResult), list, tuple), - 'paddle.ir.grad', + 'paddle.autograd.backward.grad', ) check_type( grad_outputs, 'grad_outputs', ((paddle.ir.Value, paddle.ir.OpResult), list, tuple, type(None)), - 'paddle.ir.grad', + 'paddle.autograd.backward.grad', ) check_type( no_grad_vars, 'no_grad_vars', ((paddle.ir.Value, paddle.ir.OpResult), list, tuple, set, type(None)), - 'paddle.ir.grad', + 'paddle.autograd.backward.grad', ) outputs = _as_list(outputs) inputs = _as_list(inputs)