Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NewIR]pulish name in backward.py #56650

Merged
merged 6 commits into from
Aug 25, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 52 additions & 58 deletions python/paddle/autograd/backward.py
Original file line number Diff line number Diff line change
@@ -52,8 +52,7 @@ def check_all_puts(block, inputs, outputs):

def update_no_grad_set_by_stopgradient(block, no_grad_set):
for op in block.ops:
for opresult_idx in range(op.num_results()):
value = op.result(opresult_idx)
for value in op.results():
if value.stop_gradient and value not in no_grad_set:
no_grad_set.add(value)

@@ -63,9 +62,7 @@ def update_bwdop_structure(backward_ops, op_to_opgrad_list, grad_op):
op_to_opgrad_list.append(grad_op)


def prepare_grad_outputs(
block, grad_outputs, outputs, value_to_valuegrad, op_to_opgrad
):
def prepare_grad_outputs(grad_outputs, outputs, state):
"""
if grad_outputs is none, add fill_1 op to create grad_outputs,
else check whether outputs shape and dtype is same to grad_outputs, otherwise raise error.
@@ -100,10 +97,10 @@ def prepare_grad_outputs(

update_bwdop_structure(
backward_ops,
op_to_opgrad[output.get_defining_op()],
state.op_to_opgrad[output.get_defining_op()],
fillop,
)
value_to_valuegrad[output] = [[output_grad]]
state.value_to_valuegrad[output] = [[output_grad]]
else:
if output.shape != grad.shape:
raise ValueError(
@@ -117,9 +114,11 @@ def prepare_grad_outputs(
)
feedop = grad.get_defining_op()
update_bwdop_structure(
backward_ops, op_to_opgrad[output.get_defining_op()], feedop
backward_ops,
state.op_to_opgrad[output.get_defining_op()],
feedop,
)
value_to_valuegrad[output] = [[grad]]
state.value_to_valuegrad[output] = [[grad]]

# add input for bwd first op
complete_outputs = outputs
@@ -130,7 +129,7 @@ def prepare_grad_outputs(
if output in visited_output:
continue
for opresult in output.get_defining_op().results():
if opresult in value_to_valuegrad:
if opresult in state.value_to_valuegrad:
visited_output.add(opresult)
continue
else:
@@ -143,10 +142,10 @@ def prepare_grad_outputs(

update_bwdop_structure(
backward_ops,
op_to_opgrad[opresult.get_defining_op()],
state.op_to_opgrad[opresult.get_defining_op()],
fillop,
)
value_to_valuegrad[opresult] = [grad_value]
state.value_to_valuegrad[opresult] = [grad_value]

visited_output.add(opresult)

@@ -196,7 +195,6 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set):

# from output to input
for i, op in reversed(list(enumerate(total_ops))):
# while op support
if some_in_set(op.results(), outputs_set):
for operand in op.operands_source():
if operand not in no_grad_set:
@@ -233,7 +231,7 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set):


def update_no_grad_set_after_prune(
block, effective_forward_op, no_grad_set, inputs, outputs
block, effective_forward_ops, no_grad_set, inputs, outputs
):
'''
update no_grad_set after forward prune
@@ -249,14 +247,14 @@ def update_no_grad_set_after_prune(
if value not in no_grad_set:
inputs_set.add(value)

for op in effective_forward_op:
for op in effective_forward_ops:
for value in op.operands_source():
if value not in inputs_set: # and value.get_stopgradient():
if value not in inputs_set:
no_grad_set.add(value)

outputs_set = set(outputs)
no_grad_set_tmp = set()
for op in reversed(effective_forward_op):
for op in reversed(effective_forward_ops):
for output in op.results():
if output not in outputs_set and not some_in_set(
[output], set(op.operands_source())
@@ -313,7 +311,7 @@ def inverse_sort_op(ops):


def append_backward_ops(
block, effective_forward_op, no_grad_set, backward_ops, state
block, effective_forward_ops, no_grad_set, backward_ops, state
):
'''
add grad_op in order of topological inverse sort
@@ -417,28 +415,26 @@ def make_output_grad(op):
return zero_flag, output_grad

def make_input_stopgradient(op):
input_grad_stopgradient_list = []
input_grad_stopgradients = []
for input in op.operands_source():
if input.get_defining_op().name() == "builtin.combine":
stop_gradient = make_input_stopgradient(input.get_defining_op())
input_grad_stopgradient_list.append(
input_grad_stopgradients.append(
[info[0] for info in stop_gradient]
)
else:
if input in no_grad_set:
input_grad_stopgradient_list.append([True])
input_grad_stopgradients.append([True])
else:
input_grad_stopgradient_list.append([False])
return input_grad_stopgradient_list
input_grad_stopgradients.append([False])
return input_grad_stopgradients

def update_input_grad_map(op, input_grad_list):
def update_input_grad_map(op, input_grads):
for i, input in enumerate(op.operands_source()):
if input.get_defining_op().name() == "builtin.combine":
update_input_grad_map(
input.get_defining_op(), input_grad_list[i]
)
update_input_grad_map(input.get_defining_op(), input_grads[i])
else:
input_grad = input_grad_list[i]
input_grad = input_grads[i]
if isinstance(input_grad, list):
state.value_to_valuegrad[input].append(input_grad)
else:
@@ -451,31 +447,31 @@ def update_input_grad_map(op, input_grad_list):
# [op4] (op4's inputs and outputs are not vectorType)
# einsum has twp vectorType outputs, special pattern

clear_effective_forward_op = []
clear_effective_forward_ops = []

for op in effective_forward_op:
for op in effective_forward_ops:
if op.name() != "builtin.combine" and op.name() != "builtin.split":
clear_effective_forward_op.append(op)
clear_effective_forward_ops.append(op)

for op in clear_effective_forward_op:
for op in clear_effective_forward_ops:
if paddle.framework.core.has_vjp(op):
# prepare output_grad
output_grad_list = [] # (opresult)
output_grads = [] # (opresult)
zero_flag, output_grad = make_output_grad(op)
output_grad_list.append(output_grad)
output_grads.append(output_grad)

# all(zero_flag) support this op has no contribution for grad
# should be delete (prune sub_graph)
if len(output_grad_list) == 0 or all(zero_flag):
if len(output_grads) == 0 or all(zero_flag):
continue

# prepare input_grad stop_gradient info.
input_grad_stopgradient_list = make_input_stopgradient(op)
input_grad_stopgradients = make_input_stopgradient(op)

# create grad_op
before_ops_num = len(block.ops)
input_grad_list = paddle.framework.core.call_vjp(
op, output_grad_list, input_grad_stopgradient_list
input_grads = paddle.framework.core.call_vjp(
op, output_grads, input_grad_stopgradients
)
after_ops_num = len(block.ops)

@@ -486,7 +482,7 @@ def update_input_grad_map(op, input_grad_list):
)

# update input_grad map
update_input_grad_map(op, input_grad_list)
update_input_grad_map(op, input_grads)

else:
if op.num_operands() == 0 and op.num_results() != 0:
@@ -519,15 +515,18 @@ def update_input_grad_map(op, input_grad_list):

def create_backward_prune_set(inputs, outputs, no_grad_set, state):
outputs_set = set()
for input in inputs:
for item in input.first_use().owner().operands_source():
if state.value_to_valuegrad[item] != []:
outputs_set.add(state.value_to_valuegrad[item][0][0])
for input_ in inputs:
if not input_.use_empty():
for item in input_.first_use().owner().operands_source():
if state.value_to_valuegrad[item] != []:
outputs_set.add(state.value_to_valuegrad[item][0][0])
else:
raise ValueError("input privided by inputs has no use")

inputs_set = set()
for output in outputs:
if state.value_to_valuegrad[output] != []:
inputs_set.add(state.value_to_valuegrad[output][0][0])

inputs_set_tmp = set()
for out_grad in inputs_set:
for item in out_grad.first_use().owner().operands_source():
@@ -538,7 +537,6 @@ def create_backward_prune_set(inputs, outputs, no_grad_set, state):
for key in state.value_to_valuegrad:
if key in no_grad_set:
no_gradvar_set.add(state.value_to_valuegrad[key][0][0])

for key in state.value_to_sumvaluegrad:
if key in no_grad_set:
for item in state.value_to_sumvaluegrad[key][0]:
@@ -575,26 +573,22 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set):
# update no_grad_set if some value stop_gradient=True
update_no_grad_set_by_stopgradient(block, no_grad_set)
complete_outputs, _, backward_ops = prepare_grad_outputs(
block,
grad_outputs,
outputs,
state.value_to_valuegrad,
state.op_to_opgrad,
grad_outputs, outputs, state
)

inputs_set = set(inputs)
outputs_set = set(complete_outputs)
effective_forward_op, _ = prune_ops(
effective_forward_ops, _ = prune_ops(
block.ops, inputs_set, outputs_set, no_grad_set
)
update_no_grad_set_after_prune(
block, effective_forward_op, no_grad_set, inputs, complete_outputs
block, effective_forward_ops, no_grad_set, inputs, complete_outputs
)

inverse_effective_forward_op = inverse_sort_op(effective_forward_op)
inverse_effective_forward_ops = inverse_sort_op(effective_forward_ops)

append_backward_ops(
block, inverse_effective_forward_op, no_grad_set, backward_ops, state
block, inverse_effective_forward_ops, no_grad_set, backward_ops, state
)
# now value_to_valuegrad should be value <-> value (add sum op for the same values's gradvalue)

@@ -719,26 +713,26 @@ def grad(
outputs,
'outputs',
((paddle.ir.Value, paddle.ir.OpResult), list, tuple),
'paddle.ir.grad',
'paddle.autograd.backward.grad',
)
check_type(
inputs,
'inputs',
((paddle.ir.Value, paddle.ir.OpResult), list, tuple),
'paddle.ir.grad',
'paddle.autograd.backward.grad',
)
check_type(
grad_outputs,
'grad_outputs',
((paddle.ir.Value, paddle.ir.OpResult), list, tuple, type(None)),
'paddle.ir.grad',
'paddle.autograd.backward.grad',
)

check_type(
no_grad_vars,
'no_grad_vars',
((paddle.ir.Value, paddle.ir.OpResult), list, tuple, set, type(None)),
'paddle.ir.grad',
'paddle.autograd.backward.grad',
)
outputs = _as_list(outputs)
inputs = _as_list(inputs)