Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【pir】 modify array_write and array_read vjp , add a simple while with array_write #60575

Merged
merged 59 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from 58 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
001d799
optimize backward
xiaoguoguo626807 Dec 8, 2023
05ca298
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Dec 11, 2023
4fd113e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Dec 12, 2023
8f60538
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Dec 13, 2023
8854896
[PIR] add vjp interface for while op
winter-wang Dec 12, 2023
7e177f6
[PIR] fix ci error.
winter-wang Dec 13, 2023
11c8656
modify while stopgradient
xiaoguoguo626807 Dec 14, 2023
d8c3936
merge
xiaoguoguo626807 Dec 14, 2023
da62e16
merge
xiaoguoguo626807 Dec 15, 2023
67ed811
merge
xiaoguoguo626807 Dec 15, 2023
30bba32
modify while grad bug
xiaoguoguo626807 Dec 18, 2023
53f2920
merge
xiaoguoguo626807 Dec 18, 2023
fde161c
modify while grad op
xiaoguoguo626807 Dec 18, 2023
fdc12c7
modify
xiaoguoguo626807 Dec 18, 2023
e3d19b9
increment vp
xiaoguoguo626807 Dec 19, 2023
600d99c
merge
xiaoguoguo626807 Dec 20, 2023
0913436
[PIR] add get_used_external_value interface for block.
winter-wang Dec 19, 2023
63344b7
while case
xiaoguoguo626807 Dec 20, 2023
59ad2fc
delete print
xiaoguoguo626807 Dec 20, 2023
f4eceb6
delete print
xiaoguoguo626807 Dec 20, 2023
1c9eb96
Update python/paddle/autograd/ir_backward.py
xiaoguoguo626807 Dec 20, 2023
4beaa79
Merge branch 'develop' into while_2
xiaoguoguo626807 Dec 20, 2023
df0b46a
[PIR] add unit_test for get_used_external_value
winter-wang Dec 20, 2023
65083df
modify while_loop
xiaoguoguo626807 Dec 21, 2023
f2f4fa0
Merge branch 'while_2' of https://github.com/xiaoguoguo626807/Paddle …
xiaoguoguo626807 Dec 21, 2023
f8e3ac4
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Dec 21, 2023
95bc3d7
code_style
xiaoguoguo626807 Dec 21, 2023
37e807c
modofy ci bug
xiaoguoguo626807 Dec 21, 2023
52afa31
Merge branch 'develop', commit 'refs/pull/60159/head' of https://gith…
xiaoguoguo626807 Dec 21, 2023
48de124
modify while api
xiaoguoguo626807 Dec 22, 2023
a7f13c9
merge
xiaoguoguo626807 Dec 25, 2023
adb627a
modify ci
xiaoguoguo626807 Dec 25, 2023
e90cd79
modify array
xiaoguoguo626807 Dec 26, 2023
17e17d4
merge
xiaoguoguo626807 Dec 26, 2023
1aa50c0
Update python/paddle/autograd/ir_backward.py
xiaoguoguo626807 Dec 26, 2023
eef3e24
Update test/legacy_test/test_cond.py
xiaoguoguo626807 Dec 26, 2023
d78b574
update
xiaoguoguo626807 Dec 26, 2023
d404059
modify array_write grad info
xiaoguoguo626807 Dec 26, 2023
fb8c52d
merge
xiaoguoguo626807 Dec 26, 2023
f3e09e5
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Dec 26, 2023
44d856f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Dec 27, 2023
39fcb4b
merge
xiaoguoguo626807 Dec 27, 2023
655482a
add_n and createarraylike
xiaoguoguo626807 Dec 29, 2023
ec43be4
merge
xiaoguoguo626807 Dec 29, 2023
785d367
conflict
xiaoguoguo626807 Dec 29, 2023
800d8e4
modify
xiaoguoguo626807 Jan 4, 2024
467c0c3
modify array_write vjp
xiaoguoguo626807 Jan 4, 2024
0ae98cf
modify array_write vjp
xiaoguoguo626807 Jan 4, 2024
4aeca79
Update paddle/fluid/pybind/manual_static_op_function.h
xiaoguoguo626807 Jan 4, 2024
032ea27
modify array_write vjp
xiaoguoguo626807 Jan 4, 2024
577b58e
Merge branch 'while_9' of https://github.com/xiaoguoguo626807/Paddle …
xiaoguoguo626807 Jan 5, 2024
5034426
modify ci bug
xiaoguoguo626807 Jan 5, 2024
04ab90a
modify
xiaoguoguo626807 Jan 5, 2024
064e55d
modify
xiaoguoguo626807 Jan 5, 2024
32ca4ac
Update test/legacy_test/test_while_loop_op.py
xiaoguoguo626807 Jan 5, 2024
493643f
modify inplace array_read
xiaoguoguo626807 Jan 5, 2024
9e0ea64
Merge branch 'while_9' of https://github.com/xiaoguoguo626807/Paddle …
xiaoguoguo626807 Jan 5, 2024
188e779
Update test/legacy_test/test_while_op.py
xiaoguoguo626807 Jan 5, 2024
8b47a66
Update test/ir/pir/test_while_api.py
xiaoguoguo626807 Jan 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1501,7 +1501,7 @@ OpInfoTuple ArrayReadOp::GetOpInfo() {
false,
false,
false,
true),
false),
OpInputInfo(
"i", "paddle::dialect::ScalarAttribute", false, false, true, false)};

Expand Down
40 changes: 20 additions & 20 deletions paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ std::vector<std::vector<pir::OpResult>> ArrayWrite_Op::Vjp(
pir::Operation* op,
const std::vector<std::vector<pir::Value>>& inputs_,
const std::vector<std::vector<pir::Value>>& outputs,
const std::vector<std::vector<pir::Value>>& out_grads,
const std::vector<std::vector<pir::Value>>& in_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
PADDLE_ENFORCE_EQ(
inputs_.size(),
Expand All @@ -212,19 +212,21 @@ std::vector<std::vector<pir::OpResult>> ArrayWrite_Op::Vjp(
outputs.size()));

PADDLE_ENFORCE_EQ(
out_grads.size(),
in_grads.size(),
1,
platform::errors::InvalidArgument(
"ArrayWrite_ op's outputs size should be 1, but now is %d.",
outputs.size()));

VLOG(6) << "Vjp prepare call ArrayWrite_'s vjp inteface";
pir::OpResult tensor_res =
paddle::dialect::array_read(out_grads[0][0], inputs_[2][0]);

std::vector<std::vector<pir::OpResult>> res{{tensor_res}};
if (stop_gradients[0][0]) {
res = {{}};
pir::OpResult x_grad =
paddle::dialect::array_read(in_grads[0][0], inputs_[2][0]);
pir::OpResult zero = paddle::dialect::zeros_like(inputs_[1][0]);
paddle::dialect::array_write_(in_grads[0][0], zero, inputs_[2][0]);
std::vector<std::vector<pir::OpResult>> res(1);
res[0].resize(1);
if (!stop_gradients[0][0]) {
res[0][0] = x_grad;
}
return res;
}
Expand All @@ -247,27 +249,25 @@ std::vector<std::vector<pir::OpResult>> ArrayReadOp::Vjp(
platform::errors::InvalidArgument(
"Array_read op's outputs size should be 1, but now is %d.",
outputs.size()));

// x = array_read(input, i)
// out_grads[0][0] is x_grad
// out_grads[1][0] is input_array_grad
PADDLE_ENFORCE_EQ(
out_grads.size(),
1,
2,
platform::errors::InvalidArgument(
"Array_read op's outputs size should be 1, but now is %d.",
outputs.size()));

VLOG(6) << "Vjp prepare call Array_read's vjp inteface";

paddle::dialect::DenseTensorType outgrad_type =
out_grads[0][0].type().dyn_cast<paddle::dialect::DenseTensorType>();
pir::Value new_array = paddle::dialect::create_array(
paddle::dialect::TransToPhiDataType(outgrad_type.dtype()));
pir::OpResult tensor_res =
paddle::dialect::array_write_(new_array, out_grads[0][0], inputs_[1][0]);
pir::Value array_grad_i_origin =
paddle::dialect::array_read(out_grads[1][0], inputs_[1][0]);
pir::Value array_grad_i =
paddle::dialect::add(array_grad_i_origin, out_grads[0][0]);
paddle::dialect::array_write_(out_grads[1][0], array_grad_i, inputs_[1][0]);

std::vector<std::vector<pir::OpResult>> res{{tensor_res}};
if (stop_gradients[0][0]) {
res = {{}};
}
std::vector<std::vector<pir::OpResult>> res;
return res;
}
} // namespace dialect
Expand Down
177 changes: 134 additions & 43 deletions python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,26 @@ def check_all_puts(block, inputs, outputs):


def append_full_like(float_value, copy_value, value, state, backward_ops):
value_grad = paddle.full_like(
copy_value,
float_value,
dtype=copy_value.dtype,
)
full_like_op = value_grad.get_defining_op()
full_op = full_like_op.operand_source(1).get_defining_op()
if copy_value.is_tensorarray():
value_grad = paddle._pir_ops.create_array_like(
copy_value,
float_value,
)
full_like_op = value_grad.get_defining_op()
backward_ops_ = [full_like_op]
else:
value_grad = paddle.full_like(
copy_value,
float_value,
dtype=copy_value.dtype,
)
full_like_op = value_grad.get_defining_op()
full_op = full_like_op.operand_source(1).get_defining_op()
backward_ops_ = [full_like_op, full_op]
update_bwdop_structure(
backward_ops,
state.op_to_opgrad[value.get_defining_op()],
[full_like_op, full_op],
backward_ops_,
)
state.value_to_valuegrad[value] = [[value_grad]]
return value_grad
Expand Down Expand Up @@ -367,6 +376,16 @@ def inverse_sort_op(ops):
return sorted_list


def inplace_net(op_list):
op_name_list = [op.name() for op in op_list]
if (
"pd_op.array_write_" in op_name_list
or "pd_op.array_read" in op_name_list
):
return True
return False


def append_backward_ops(
base_op,
base_inputs,
Expand Down Expand Up @@ -425,13 +444,28 @@ def return_map_value(value, map):
output = map[output]
return output

def return_map_value_list(grad_value, map):
output = []
for i in range(len(grad_value)):
if grad_value[i] in map:
output.append(map[grad_value[i]])
else:
output.append(grad_value[i])
return output

def append_add_n(value):
# value is input of more than one fwd_op,
# so more than one bwd_op create input_grad,
# need add sum op to accumulate gradient
add_n_value = paddle.add_n(
[item[0] for item in state.value_to_valuegrad[value]]
)
if value.is_tensorarray():
add_n_value = paddle._pir_ops.add_n_array(
[item[0] for item in state.value_to_valuegrad[value]]
)
else:
add_n_value = paddle.add_n(
[item[0] for item in state.value_to_valuegrad[value]]
)

add_n_op = add_n_value.get_defining_op()
combine_op = add_n_op.operand_source(0).get_defining_op()
update_bwdop_structure(
Expand All @@ -446,7 +480,12 @@ def make_output_with_output_grad(op):
zero_flag = [False] * op.num_results()
outputs = []
output_grads = []
for i, value in enumerate(op.results()):
if op.name() == "pd_op.array_write_":
output_list = [op.operand_source(0)]
else:
output_list = op.results()

for i, value in enumerate(output_list):
new_value = [
return_map_value(value, control_flow_value_to_copyvalue_map)
]
Expand Down Expand Up @@ -496,9 +535,39 @@ def make_output_with_output_grad(op):
outputs.append(new_value)
grad_value = state.value_to_valuegrad[value][0]
output_grads.append(
[bwd_value_to_block_argument_map[grad_value[0]]]
if grad_value[0] in bwd_value_to_block_argument_map
else grad_value
return_map_value_list(
grad_value, bwd_value_to_block_argument_map
)
)

if op.name() == "pd_op.array_read":
value = op.operand_source(0)
while value in state.inside_value_to_outside_value_map:
value = state.inside_value_to_outside_value_map[value]

if value in state.value_to_valuegrad:
if len(state.value_to_valuegrad[value]) > 1:
append_add_n(value)

if (
value not in state.value_to_valuegrad
or state.value_to_valuegrad[value] == []
):
append_full_like(
0.0,
return_map_value(
value, control_flow_value_to_copyvalue_map
),
value,
state,
backward_ops,
)

grad_value = state.value_to_valuegrad[value][0]
output_grads.append(
return_map_value_list(
grad_value, bwd_value_to_block_argument_map
)
)

return zero_flag, outputs, output_grads
Expand Down Expand Up @@ -692,7 +761,11 @@ def argument_to_value(while_op):
else:
forward_ops = effective_forward_ops

inverse_effective_forward_ops = inverse_sort_op(forward_ops)
if inplace_net(forward_ops):
inverse_effective_forward_ops = reversed(forward_ops)
else:
inverse_effective_forward_ops = inverse_sort_op(forward_ops)

clear_effective_forward_ops = []
for op in inverse_effective_forward_ops:
if op.name() != "builtin.combine" and op.name() != "builtin.split":
Expand Down Expand Up @@ -743,7 +816,9 @@ def argument_to_value(while_op):
else:
# all(zero_flag) support this op has no contribution for grad
# should be delete (prune sub_graph)
if len(output_grads) == 0 or all(zero_flag):
if (
len(output_grads) == 0 or all(zero_flag)
) and op.name() != "pd_op.while":
continue

if op.name() == "pd_op.if":
Expand Down Expand Up @@ -787,17 +862,28 @@ def argument_to_value(while_op):
for i, input in enumerate(
get_used_external_value(while_block)
):
append_full_like(
0.0, input, input, sub_state, backward_ops
)
if input in sub_state.value_to_valuegrad:
if len(sub_state.value_to_valuegrad[input]) > 1:
append_add_n(input)

if (
input not in sub_state.value_to_valuegrad
or sub_state.value_to_valuegrad[input] == []
):
append_full_like(
0.0, input, input, sub_state, backward_ops
)

grad_value = sub_state.value_to_valuegrad[input][0]
for tmp in state.value_to_valuegrad[input]:
state.value_to_sumvaluegrad[input].append(tmp)
state.value_to_valuegrad[input] = []
output_grads.append(
[bwd_value_to_block_argument_map[grad_value[0]]]
if grad_value[0]
in bwd_value_to_block_argument_map
else grad_value
return_map_value_list(
grad_value,
bwd_value_to_block_argument_map,
)
)

build_pipe_for_block(while_block)
with dynamic_shape_prim_vjp_guard(op, inputs):
input_grads = paddle.framework.core.call_vjp(
Expand Down Expand Up @@ -953,6 +1039,11 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set):
block = outputs[0].get_defining_op().get_parent_block()
state = State(block)

total_ops = []
if block.parent_block is not None:
total_ops += block.parent_block.ops
total_ops += block.ops

# update no_grad_set if some value stop_gradient=True
update_no_grad_set_by_stopgradient(block, no_grad_set)
complete_outputs, _, backward_ops = prepare_grad_outputs(
Expand All @@ -961,14 +1052,14 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set):

inputs_set = ValueSet(inputs)
outputs_set = ValueSet(complete_outputs)
total_ops = []
if block.parent_block is not None:
total_ops += block.parent_block.ops
total_ops += block.ops

effective_forward_ops, _ = prune_ops(
total_ops, inputs_set, outputs_set, no_grad_set
)
if inplace_net(total_ops):
effective_forward_ops = total_ops
else:
effective_forward_ops, _ = prune_ops(
total_ops, inputs_set, outputs_set, no_grad_set
)

update_no_grad_set_after_prune(
total_ops, effective_forward_ops, no_grad_set, inputs, complete_outputs
)
Expand All @@ -993,18 +1084,18 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set):
outputs_set, inputs_set, no_gradvar_set = create_backward_prune_set(
outputs_fwd_set, inputs_fwd_set, no_grad_set, state
)
if not inplace_net(backward_ops):
_, remove_ops = prune_ops(
backward_ops, inputs_set, outputs_set, no_gradvar_set
)

_, remove_ops = prune_ops(
backward_ops, inputs_set, outputs_set, no_gradvar_set
)

state.turn_map()
for bwd_op in inverse_sort_op(remove_ops):
if bwd_op.result(0) in ValueSet(grad_outputs):
continue
if bwd_op.result(0).use_empty():
remove_op(block, bwd_op, state)
state.turn_map()
state.turn_map()
for bwd_op in inverse_sort_op(remove_ops):
if bwd_op.result(0) in ValueSet(grad_outputs):
continue
if bwd_op.result(0).use_empty():
remove_op(block, bwd_op, state)
state.turn_map()

input_grad_map = state.value_to_valuegrad

Expand Down
2 changes: 1 addition & 1 deletion python/paddle/tensor/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def array_write(x, i, array=None):
if array is None:
array = paddle._pir_ops.create_array(x.dtype)

array = paddle._pir_ops.array_write_(array, x, i)
paddle._pir_ops.array_write_(array, x, i)
return array
else:
check_variable_and_dtype(i, 'i', ['int64'], 'array_write')
Expand Down
5 changes: 3 additions & 2 deletions test/ir/pir/test_while_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,13 @@ def test_backward(self):
out,
[i, j],
)
print(main_program)
xiaoguoguo626807 marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(
grad_outs[0].get_defining_op().name(), "pd_op.while"
)
self.assertEqual(
main_program.global_block()
.ops[-3]
.ops[-1]
.as_while_op()
.body()
.ops[-4]
Expand All @@ -187,7 +188,7 @@ def test_backward(self):

self.assertEqual(
main_program.global_block()
.ops[-3]
.ops[-1]
.as_while_op()
.body()
.ops[-5]
Expand Down
Loading