Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【pir】deal with if build stop gradient #59585

Merged
merged 4 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@ paddle::dialect::IfOp, paddle::dialect::WhileOp

#include "paddle/phi/core/enforce.h"
#include "paddle/pir/core/builder.h"
#include "paddle/pir/core/builtin_attribute.h"
#include "paddle/pir/core/builtin_type.h"
#include "paddle/pir/core/ir_printer.h"
#include "paddle/pir/core/op_trait.h"
#include "paddle/pir/core/operation_utils.h"
#include "paddle/pir/core/utils.h"
#include "paddle/pir/dialect/control_flow/ir/cf_op.h"

using pir::TuplePopOp;
using pir::TuplePushOp;
constexpr char kStopGradientAttrName[] = "stop_gradient";
namespace paddle {
namespace dialect {

Expand All @@ -52,9 +55,26 @@ void IfOp::Build(pir::Builder &builder, // NOLINT
true_block->back().isa<pir::YieldOp>()) {
auto &op = true_block->back();

std::vector<pir::Attribute> outs_stop_gradient;
for (size_t i = 0; i < op.num_operands(); ++i) {
argument.AddOutput(op.operand(i).type());
bool input_stop_gradient = true;
auto input = op.operand_source(i).dyn_cast<pir::OpResult>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果input不是OpResult,而是BlockArgument的话,这句返回的是空的OpResult。 下一句的defining_op返回的也是空指针。再后面会直接奔溃。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

auto *defining_op = input.owner();
if (defining_op->HasAttribute(kStopGradientAttrName)) {
auto attrs = defining_op->attribute(kStopGradientAttrName)
.dyn_cast<pir::ArrayAttribute>()
.AsVector();
input_stop_gradient =
attrs[input.index()].dyn_cast<pir::BoolAttribute>().data();
}
outs_stop_gradient.push_back(pir::BoolAttribute::get(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
outs_stop_gradient.push_back(pir::BoolAttribute::get(
outs_stop_gradient.push_back(builder.bool_attr(input_stop_gradient));

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

pir::IrContext::Instance(), input_stop_gradient));
}

argument.AddAttribute(kStopGradientAttrName,
pir::ArrayAttribute::get(pir::IrContext::Instance(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pir::ArrayAttribute::get(pir::IrContext::Instance(),
pir::ArrayAttribute::get(builder.ir_context(),

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

outs_stop_gradient));
}
if (false_block && !false_block->empty() &&
false_block->back().isa<pir::YieldOp>()) {
Expand Down Expand Up @@ -85,6 +105,20 @@ void IfOp::Build(pir::Builder &builder, // NOLINT
argument.AddRegion().push_back(true_block.release());
argument.AddRegion().push_back(false_block.release());
argument.AddInput(cond);

auto cond_ = cond.dyn_cast<pir::OpResult>();
auto cond_op = cond_.owner();
if (cond_op->HasAttribute(kStopGradientAttrName)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同理,如果这儿的cond是BlockArgument, 而不是OpResult, 这儿也会崩溃,且没有提示。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

auto attrs = cond_op->attribute(kStopGradientAttrName)
.dyn_cast<pir::ArrayAttribute>()
.AsVector();
attrs[cond_.index()] =
pir::BoolAttribute::get(pir::IrContext::Instance(), true);

cond_op->set_attribute(
kStopGradientAttrName,
pir::ArrayAttribute::get(pir::IrContext::Instance(), attrs));
}
}

pir::Block *IfOp::true_block() {
Expand Down
1 change: 0 additions & 1 deletion paddle/pir/core/builtin_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include "paddle/phi/core/enforce.h"
#include "paddle/pir/core/builtin_attribute.h"
#include "paddle/pir/core/builtin_type.h"
#include "paddle/pir/core/enforce.h"

namespace pir {

Expand Down
60 changes: 27 additions & 33 deletions python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ def check_all_puts(block, inputs, outputs):
)


def get_real_op_inputs(op):
if op.name() in ["pd_op.if", "pd_op.while"]:
return get_used_external_value(op)
else:
return op.operands_source()


def update_no_grad_set_by_stopgradient(block, no_grad_set):
for op in block.ops:
for value in op.results():
Expand Down Expand Up @@ -199,12 +206,7 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set):
union_op_flags[i] = True
continue

op_inputs = (
get_used_external_value(op)
if op.name() in ["pd_op.if", "pd_op.while"]
else op.operands_source()
)
if some_in_set(op_inputs, inputs_set):
if some_in_set(get_real_op_inputs(op), inputs_set):
union_op_flags[i] = True
for value in op.results():
if value not in no_grad_set:
Expand All @@ -216,12 +218,7 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set):
for i, op in reversed(list(enumerate(total_ops))):
if some_in_set(op.results(), outputs_set):
union_op_flags[i] = True
op_inputs = (
get_used_external_value(op)
if op.name() in ["pd_op.if", "pd_op.while"]
else op.operands_source()
)
for operand in op_inputs:
for operand in get_real_op_inputs(op):
if operand not in no_grad_set:
outputs_set.add(operand)
else:
Expand Down Expand Up @@ -269,13 +266,13 @@ def update_no_grad_set_after_prune(
inputs_set = set(inputs)
if inputs_set:
for op in block.ops:
if some_in_set(op.operands_source(), inputs_set):
if some_in_set(get_real_op_inputs(op), inputs_set):
for value in op.results():
if value not in no_grad_set:
inputs_set.add(value)

for op in effective_forward_ops:
for value in op.operands_source():
for value in get_real_op_inputs(op):
if value not in inputs_set:
no_grad_set.add(value)

Expand All @@ -284,11 +281,11 @@ def update_no_grad_set_after_prune(
for op in reversed(effective_forward_ops):
for output in op.results():
if output not in outputs_set and not some_in_set(
[output], set(op.operands_source())
[output], set(get_real_op_inputs(op))
):
no_grad_set_tmp.add(output)

for input in op.operands_source():
for input in get_real_op_inputs(op):
if input not in no_grad_set:
outputs_set.add(input)

Expand All @@ -309,9 +306,9 @@ def inverse_sort_op(ops):
ops_set = set(ops)
sorted_list = []
for op in ops:
for x in op.operands():
if x.source() and x.source().get_defining_op() in ops_set:
pending_count[x.source().get_defining_op()] += 1
for x in get_real_op_inputs(op):
if x and x.get_defining_op() in ops_set:
pending_count[x.get_defining_op()] += 1

queue = collections.deque()

Expand All @@ -323,8 +320,8 @@ def inverse_sort_op(ops):
op = queue.popleft()
sorted_list.append(op)

for x in op.operands():
x_op = x.source().get_defining_op()
for x in get_real_op_inputs(op):
x_op = x.get_defining_op()
pending_count[x_op] -= 1
if pending_count[x_op] == 0:
queue.append(x_op)
Expand Down Expand Up @@ -471,11 +468,6 @@ def make_output_with_output_grad(op):
return zero_flag, outputs, output_grads

def make_input_with_input_stopgradient(op):
origin_inputs = (
get_used_external_value(op)
if op.name() in ["pd_op.if", "pd_op.while"]
else op.operands_source()
)
inputs = []
input_grad_stopgradients = []
if op.name() in [
Expand All @@ -484,11 +476,15 @@ def make_input_with_input_stopgradient(op):
"pd_op.while",
"cf.tuple_push",
]:
grad_semantic_info = [True for _ in range(len(origin_inputs))]
grad_semantic_info = [
True for _ in range(len(get_real_op_inputs(op)))
]
else:
grad_semantic_info = op.get_input_grad_semantics()

for input, grad_semantic in zip(origin_inputs, grad_semantic_info):
for input, grad_semantic in zip(
get_real_op_inputs(op), grad_semantic_info
):
if not grad_semantic:
if (
input.get_defining_op() is not None
Expand Down Expand Up @@ -570,7 +566,7 @@ def update_input_grad_map(op, input_grads, origin_inputs):

def append_yield(block, inputs):
with block:
inputs_grad = [paddle.pir.fake_op_result()]
inputs_grad = []
for value in inputs:
if value in state.value_to_valuegrad:
if len(state.value_to_valuegrad[value]) > 1:
Expand Down Expand Up @@ -677,7 +673,6 @@ def append_yield(block, inputs):
sub_backward_ops,
sub_state,
)

# update input_grad map
update_input_grad_map(op, input_grads, origin_inputs)
else:
Expand Down Expand Up @@ -726,7 +721,7 @@ def prepare_backward_prune_set(inputs, outputs):
outputs_fwd_set = set()
for input_ in inputs:
if not input_.use_empty():
for item in input_.first_use().owner().operands_source():
for item in get_real_op_inputs(input_.first_use().owner()):
outputs_fwd_set.add(item)
else:
logging.warning("input privided by inputs has no use")
Expand Down Expand Up @@ -754,7 +749,7 @@ def create_backward_prune_set(
inputs_set_tmp = set()
for out_grad in inputs_set:
if not out_grad.use_empty():
for item in out_grad.first_use().owner().operands_source():
for item in get_real_op_inputs(out_grad.first_use().owner()):
inputs_set_tmp.add(item)
inputs_set.update(inputs_set_tmp)

Expand Down Expand Up @@ -792,7 +787,6 @@ def remove_op(block, op, state):

def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set):
block = outputs[0].get_defining_op().get_parent_block()
block.refresh_stopgradient()
state = State(block)

# check all inputs and outputs in the same block
Expand Down
8 changes: 5 additions & 3 deletions test/ir/pir/test_if_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,17 @@ def test_if_op_backward(self):
if_op = main_program.global_block().ops[-1]
self.assertEqual(if_op.name(), "pd_op.if")
with paddle.pir.core.program_guard(main_program):
if_op.result(0).stop_gradient = False
self.assertEqual(
main_program.global_block().ops[-2].result(0).stop_gradient,
True,
)
self.assertEqual(if_op.result(0).stop_gradient, False)
# check vjp interface for if_op
print("main_program ", main_program)
grad_outs = grad(
if_op.results(),
[dataop0.result(0), dataop1.result(0)],
)

print("main_program ", main_program)
self.assertEqual(grad_outs[0].get_defining_op().name(), "pd_op.if")
self.assertEqual(grad_outs[1].get_defining_op().name(), "pd_op.if")

Expand Down