Skip to content

Commit

Permalink
[Dy2St] fix test_grad in PIR mode (#60621)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: xiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com>
  • Loading branch information
SigureMo and xiaoguoguo626807 authored Jan 9, 2024
1 parent 0e13ae0 commit 7eb6b0d
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 135 deletions.
138 changes: 56 additions & 82 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,22 @@ std::string GetValueInfo(Value v) {
} else if (auto arg = v.dyn_cast<BlockArgument>()) {
ss << "block_arg, index = " << arg.index();
}
ss << ", dtype=" << v.type();
if (v.type().isa<paddle::dialect::AllocatedDenseTensorType>()) {
ss << ", place="
<< v.type()
.dyn_cast<paddle::dialect::AllocatedDenseTensorType>()
.place();
if (!v.type()) {
ss << ", dtype=<<NULL TYPE>>";
} else {
ss << ", dtype=" << v.type();
if (v.type().isa<paddle::dialect::AllocatedDenseTensorType>()) {
ss << ", place="
<< v.type()
.dyn_cast<paddle::dialect::AllocatedDenseTensorType>()
.place();
}
}
auto stop_gradient = v.attribute<BoolAttribute>(kAttrStopGradients);
if (stop_gradient && !stop_gradient.data()) {
ss << ", stop_gradient=False";
} else {
ss << ", stop_gradient=True";
}
return ss.str();
}
Expand Down Expand Up @@ -527,13 +537,20 @@ void BindOperation(py::module *m) {
return op_list;
})
.def("replace_all_uses_with",
[](Operation &self, const std::vector<OpResult> &op_results) {
self.ReplaceAllUsesWith(op_results);
[](Operation &self, const std::vector<Value> &values) {
self.ReplaceAllUsesWith(values);
})
.def("as_if_op",
[](Operation &self) { return PyIfOp(self.dyn_cast<IfOp>()); })
.def("as_while_op",
[](Operation &self) { return PyWhileOp(self.dyn_cast<WhileOp>()); });
[](Operation &self) { return PyWhileOp(self.dyn_cast<WhileOp>()); })
.def("__repr__", [](Operation &self) {
std::ostringstream print_stream;
print_stream << "Operation(";
self.Print(print_stream);
print_stream << ")";
return print_stream.str();
});
py::class_<Operation::BlockContainer> block_container(
*m, "Operation_BlockContainer", R"DOC(
The Operation_BlockContainer only use to walk all blocks in the operation.
Expand All @@ -555,6 +572,9 @@ py::str Value2String(Value self) {
}

phi::DataType GetValueDtype(Value value) {
if (!value.type()) {
PADDLE_THROW(phi::errors::InvalidArgument("The type of value is nullptr."));
}
if (value.type().isa<DenseTensorType>()) {
return paddle::dialect::TransToPhiDataType(
value.type().dyn_cast<DenseTensorType>().dtype());
Expand All @@ -572,6 +592,9 @@ phi::DataType GetValueDtype(Value value) {
}

const phi::DDim &GetValueDims(Value value) {
if (!value.type()) {
PADDLE_THROW(phi::errors::InvalidArgument("The type of value is nullptr."));
}
if (value.type().isa<DenseTensorType>()) {
return value.type().dyn_cast<DenseTensorType>().dims();
} else if (value.type().isa<SelectedRowsType>()) {
Expand Down Expand Up @@ -768,21 +791,6 @@ void BindValue(py::module *m) {
.def("first_use", &Value::first_use, return_value_policy::reference)
.def("has_one_use", &Value::HasOneUse)
.def("use_empty", &Value::use_empty)
.def("__str__",
[](Value self) -> py::str {
std::ostringstream print_stream;
print_stream << "Value(";
print_stream << GetValueInfo(self);
auto stop_gradient =
self.attribute<BoolAttribute>(kAttrStopGradients);
if (stop_gradient && !stop_gradient.data()) {
print_stream << ", stop_gradient=False";
} else {
print_stream << ", stop_gradient=True";
}
print_stream << ")";
return print_stream.str();
})
.def("apply", &apply)
.def("is_same", &Value::operator==)
.def("hash", [](Value self) { return std::hash<pir::Value>{}(self); })
Expand Down Expand Up @@ -965,14 +973,14 @@ using SplitedProgram = std::vector<std::shared_ptr<Program>>;
using SplitedAttribute = std::map<std::string, std::vector<pir::Value>>;
using SplitedResult = std::pair<SplitedProgram, SplitedAttribute>;

pir::OpResult FakeOpResult() {
// create a fake opresults to simplify `ForwardBackwardSplit`.
return pir::OpResult(nullptr);
pir::Value FakeValue() {
// create a fake value to simplify `ForwardBackwardSplit`.
return pir::Value(nullptr);
}

bool IsFakeOpResult(const pir::OpResult &result) {
// create a fake opresults to simplify `ForwardBackwardSplit`.
return result.Value::impl() == nullptr || !result.Value::type();
bool IsFakeValue(const pir::Value &value) {
// create a fake value to simplify `ForwardBackwardSplit`.
return value.impl() == nullptr || !value.type();
}

static auto GetNoNeedBufferValue(const ::pir::Block *whole_block,
Expand Down Expand Up @@ -1032,7 +1040,7 @@ std::pair<std::shared_ptr<Program>, OpResultMap> CloneProgram(
}

void AppendShadowOutput(Program *forward_program,
const pir::OpResult &result,
const pir::Value &value,
const std::string &name,
size_t start_point) {
pir::IrContext *ctx = pir::IrContext::Instance();
Expand All @@ -1041,7 +1049,7 @@ void AppendShadowOutput(Program *forward_program,
{"output_name", pir::StrAttribute::get(ctx, name)},
};
pir::Operation *operation =
pir::Operation::Create({result}, attribute_map, {}, op_info);
pir::Operation::Create({value}, attribute_map, {}, op_info);
auto position = forward_program->block()->begin();
std::advance(position, start_point);
if (position == forward_program->block()->end()) {
Expand All @@ -1052,19 +1060,19 @@ void AppendShadowOutput(Program *forward_program,
}

int AppendShadowOutputs(Program *forward_program,
const std::vector<pir::OpResult> &outputs_op_result,
const std::vector<pir::Value> &outputs,
int start_point,
std::string name_prefix) {
int counter = 0;
std::unordered_set<pir::OpResult> added_op_result;
std::unordered_set<pir::Value> added_value;

for (const auto &result : outputs_op_result) {
if (!added_op_result.count(result) || IsFakeOpResult(result)) {
for (const auto &value : outputs) {
if (!added_value.count(value) || IsFakeValue(value)) {
std::string shadow_output_name = name_prefix + std::to_string(counter);
AppendShadowOutput(
forward_program, result, shadow_output_name, start_point + counter);
forward_program, value, shadow_output_name, start_point + counter);
counter += 1;
added_op_result.insert(result);
added_value.insert(value);
}
}
// return the inserted op.
Expand All @@ -1073,51 +1081,17 @@ int AppendShadowOutputs(Program *forward_program,

SplitedResult SplitForwardBackward(
const Program &program,
const std::vector<pir::OpResult> &op_result_forward_inputs,
const std::vector<pir::OpResult> &op_result_forward_params,
const std::vector<pir::OpResult> &op_result_forward_outputs,
const std::vector<pir::OpResult> &op_result_forward_inputs_grads,
const std::vector<pir::OpResult> &op_result_forward_params_grads,
const std::vector<pir::OpResult> &op_result_forward_outputs_grads,
const std::vector<pir::Value> &forward_inputs,
const std::vector<pir::Value> &forward_params,
const std::vector<pir::Value> &forward_outputs,
const std::vector<pir::Value> &forward_inputs_grads,
const std::vector<pir::Value> &forward_params_grads,
const std::vector<pir::Value> &forward_outputs_grads,
const std::vector<int> &forward_range,
const std::vector<int> &backward_range) {
// transform opresult -> value
std::vector<pir::Value> forward_inputs, forward_outputs, forward_inputs_grads,
forward_outputs_grads, forward_params, forward_params_grads;

auto op_result_to_value = [](const pir::OpResult &r) {
if (r.impl() == nullptr) return Value(nullptr);
return Value(r.Value::impl());
};

std::transform(op_result_forward_inputs.begin(),
op_result_forward_inputs.end(),
std::back_inserter(forward_inputs),
op_result_to_value);
std::transform(op_result_forward_outputs.begin(),
op_result_forward_outputs.end(),
std::back_inserter(forward_outputs),
op_result_to_value);
std::transform(op_result_forward_inputs_grads.begin(),
op_result_forward_inputs_grads.end(),
std::back_inserter(forward_inputs_grads),
op_result_to_value);
std::transform(op_result_forward_outputs_grads.begin(),
op_result_forward_outputs_grads.end(),
std::back_inserter(forward_outputs_grads),
op_result_to_value);
std::transform(op_result_forward_params.begin(),
op_result_forward_params.end(),
std::back_inserter(forward_params),
op_result_to_value);
std::transform(op_result_forward_params_grads.begin(),
op_result_forward_params_grads.end(),
std::back_inserter(forward_params_grads),
op_result_to_value);

std::vector<pir::Value> forward_in_out_values;
for (auto &v : std::vector<std::vector<pir::Value> *>(
{&forward_inputs, &forward_outputs, &forward_params})) {
for (auto &v :
std::vector({&forward_inputs, &forward_outputs, &forward_params})) {
forward_in_out_values.insert(
forward_in_out_values.end(), v->begin(), v->end());
}
Expand Down Expand Up @@ -1393,8 +1367,8 @@ void BindUtils(pybind11::module *m) {
m->def("reset_shadow_output_name", ResetShadowOutputName);
m->def("split_program", SplitForwardBackward);
m->def("append_shadow_outputs", AppendShadowOutputs);
m->def("fake_op_result", FakeOpResult);
m->def("is_fake_op_result", IsFakeOpResult);
m->def("fake_value", FakeValue);
m->def("is_fake_value", IsFakeValue);
m->def("get_current_insertion_point", []() -> PyInsertionPoint {
return {ApiBuilder::Instance().GetCurrentInsertionPoint()};
});
Expand Down
8 changes: 8 additions & 0 deletions python/paddle/autograd/backward_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ def __iter__(self):
def __contains__(self, key):
return ValueWrapper(key) in self._items

def __repr__(self) -> str:
items_str = ", ".join(f"{key}: {val}" for key, val in self.items())
return f'ValueDict({items_str})'


class ValueSet:
def __init__(
Expand Down Expand Up @@ -153,6 +157,10 @@ def __iter__(self):
def __contains__(self, val):
return ValueWrapper(val) in self._set

def __repr__(self) -> str:
items_str = ", ".join(repr(item) for item in self)
return f'ValueSet({items_str})'


class State:
"""
Expand Down
11 changes: 6 additions & 5 deletions python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def check_all_puts(block, inputs, outputs):


def append_full_like(float_value, copy_value, value, state, backward_ops):
if paddle.pir.is_fake_value(value):
state.value_to_valuegrad[value] = [[paddle.pir.fake_value()]]
return
if copy_value.is_tensorarray():
value_grad = paddle._pir_ops.create_array_like(
copy_value,
Expand Down Expand Up @@ -174,9 +177,9 @@ def prepare_grad_outputs(grad_outputs, outputs, state):
visited_output.add(opresult)
continue
else:
if paddle.pir.is_fake_op_result(opresult):
if paddle.pir.is_fake_value(opresult):
state.value_to_valuegrad[opresult] = [
[paddle.pir.fake_op_result()]
[paddle.pir.fake_value()]
]
else:
grad_value = append_full_like(
Expand Down Expand Up @@ -702,9 +705,7 @@ def append_yield(
new_value = return_map_value(
value, control_flow_value_to_copyvalue_map
)
value_grad = append_full_like(
0.0, new_value, value, state, backward_ops
)
append_full_like(0.0, new_value, value, state, backward_ops)
input_grad = state.value_to_valuegrad[value][0][0]

inputs_grad.append(input_grad)
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/decomposition/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def _prepare_grad_outputs(fwd_op, bwd_op):
new_grad_outputs.append(grad_outputs[index])
index += 1
else:
new_grad_outputs.append([pir.fake_op_result()])
new_grad_outputs.append([pir.fake_value()])
return new_grad_outputs


Expand Down Expand Up @@ -679,7 +679,7 @@ def _decomp_bwd_with_vjp(
if grad_input[0] is not None and grad_input[0].initialized():
res.append(grad_input[0])
else:
res.append(pir.fake_op_result())
res.append(pir.fake_value())
assert len(res) == len(
bwd_op.results()
), "results of original backward op do not match results of decomposed backward op"
Expand Down Expand Up @@ -752,7 +752,7 @@ def _decomp_bwd_without_vjp(
res.append(new_grad_inputs[input_grads_idx])
input_grads_idx += 1
else:
res.append(pir.fake_op_result())
res.append(pir.fake_value())

# step4: upgrade grad_var_to_var
_upgrade_grad_var_to_var(
Expand Down
Loading

0 comments on commit 7eb6b0d

Please sign in to comment.