Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2St] fix test_grad in PIR mode #60621

Merged
merged 13 commits into from
Jan 9, 2024
138 changes: 56 additions & 82 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,22 @@ std::string GetValueInfo(Value v) {
} else if (auto arg = v.dyn_cast<BlockArgument>()) {
ss << "block_arg, index = " << arg.index();
}
ss << ", dtype=" << v.type();
if (v.type().isa<paddle::dialect::AllocatedDenseTensorType>()) {
ss << ", place="
<< v.type()
.dyn_cast<paddle::dialect::AllocatedDenseTensorType>()
.place();
if (!v.type()) {
ss << ", dtype=<<NULL TYPE>>";
} else {
ss << ", dtype=" << v.type();
if (v.type().isa<paddle::dialect::AllocatedDenseTensorType>()) {
ss << ", place="
<< v.type()
.dyn_cast<paddle::dialect::AllocatedDenseTensorType>()
.place();
}
}
auto stop_gradient = v.attribute<BoolAttribute>(kAttrStopGradients);
if (stop_gradient && !stop_gradient.data()) {
ss << ", stop_gradient=False";
} else {
ss << ", stop_gradient=True";
}
return ss.str();
}
Expand Down Expand Up @@ -527,13 +537,20 @@ void BindOperation(py::module *m) {
return op_list;
})
.def("replace_all_uses_with",
[](Operation &self, const std::vector<OpResult> &op_results) {
self.ReplaceAllUsesWith(op_results);
[](Operation &self, const std::vector<Value> &values) {
self.ReplaceAllUsesWith(values);
})
.def("as_if_op",
[](Operation &self) { return PyIfOp(self.dyn_cast<IfOp>()); })
.def("as_while_op",
[](Operation &self) { return PyWhileOp(self.dyn_cast<WhileOp>()); });
[](Operation &self) { return PyWhileOp(self.dyn_cast<WhileOp>()); })
.def("__repr__", [](Operation &self) {
std::ostringstream print_stream;
print_stream << "Operation(";
self.Print(print_stream);
print_stream << ")";
return print_stream.str();
});
py::class_<Operation::BlockContainer> block_container(
*m, "Operation_BlockContainer", R"DOC(
The Operation_BlockContainer only use to walk all blocks in the operation.
Expand All @@ -555,6 +572,9 @@ py::str Value2String(Value self) {
}

phi::DataType GetValueDtype(Value value) {
if (!value.type()) {
PADDLE_THROW(phi::errors::InvalidArgument("The type of value is nullptr."));
}
if (value.type().isa<DenseTensorType>()) {
return paddle::dialect::TransToPhiDataType(
value.type().dyn_cast<DenseTensorType>().dtype());
Expand All @@ -572,6 +592,9 @@ phi::DataType GetValueDtype(Value value) {
}

const phi::DDim &GetValueDims(Value value) {
if (!value.type()) {
PADDLE_THROW(phi::errors::InvalidArgument("The type of value is nullptr."));
}
if (value.type().isa<DenseTensorType>()) {
return value.type().dyn_cast<DenseTensorType>().dims();
} else if (value.type().isa<SelectedRowsType>()) {
Expand Down Expand Up @@ -768,21 +791,6 @@ void BindValue(py::module *m) {
.def("first_use", &Value::first_use, return_value_policy::reference)
.def("has_one_use", &Value::HasOneUse)
.def("use_empty", &Value::use_empty)
.def("__str__",
[](Value self) -> py::str {
std::ostringstream print_stream;
print_stream << "Value(";
print_stream << GetValueInfo(self);
auto stop_gradient =
self.attribute<BoolAttribute>(kAttrStopGradients);
if (stop_gradient && !stop_gradient.data()) {
print_stream << ", stop_gradient=False";
} else {
print_stream << ", stop_gradient=True";
}
print_stream << ")";
return print_stream.str();
})
.def("apply", &apply)
.def("is_same", &Value::operator==)
.def("hash", [](Value self) { return std::hash<pir::Value>{}(self); })
Expand Down Expand Up @@ -1006,14 +1014,14 @@ using SplitedProgram = std::vector<std::shared_ptr<Program>>;
using SplitedAttribute = std::map<std::string, std::vector<pir::Value>>;
using SplitedResult = std::pair<SplitedProgram, SplitedAttribute>;

pir::OpResult FakeOpResult() {
// create a fake opresults to simplify `ForwardBackwardSplit`.
return pir::OpResult(nullptr);
pir::Value FakeValue() {
// create a fake value to simplify `ForwardBackwardSplit`.
return pir::Value(nullptr);
}

bool IsFakeOpResult(const pir::OpResult &result) {
// create a fake opresults to simplify `ForwardBackwardSplit`.
return result.Value::impl() == nullptr || !result.Value::type();
bool IsFakeValue(const pir::Value &value) {
// create a fake value to simplify `ForwardBackwardSplit`.
return value.impl() == nullptr || !value.type();
}

static auto GetNoNeedBufferValue(const ::pir::Block *whole_block,
Expand Down Expand Up @@ -1079,7 +1087,7 @@ std::pair<std::shared_ptr<Program>, OpResultMap> CloneProgram(
}

void AppendShadowOutput(Program *forward_program,
const pir::OpResult &result,
const pir::Value &value,
const std::string &name,
size_t start_point) {
pir::IrContext *ctx = pir::IrContext::Instance();
Expand All @@ -1088,7 +1096,7 @@ void AppendShadowOutput(Program *forward_program,
{"output_name", pir::StrAttribute::get(ctx, name)},
};
pir::Operation *operation =
pir::Operation::Create({result}, attribute_map, {}, op_info);
pir::Operation::Create({value}, attribute_map, {}, op_info);
auto position = forward_program->block()->begin();
std::advance(position, start_point);
if (position == forward_program->block()->end()) {
Expand All @@ -1099,19 +1107,19 @@ void AppendShadowOutput(Program *forward_program,
}

int AppendShadowOutputs(Program *forward_program,
const std::vector<pir::OpResult> &outputs_op_result,
const std::vector<pir::Value> &outputs,
int start_point,
std::string name_prefix) {
int counter = 0;
std::unordered_set<pir::OpResult> added_op_result;
std::unordered_set<pir::Value> added_value;

for (const auto &result : outputs_op_result) {
if (!added_op_result.count(result) || IsFakeOpResult(result)) {
for (const auto &value : outputs) {
if (!added_value.count(value) || IsFakeValue(value)) {
std::string shadow_output_name = name_prefix + std::to_string(counter);
AppendShadowOutput(
forward_program, result, shadow_output_name, start_point + counter);
forward_program, value, shadow_output_name, start_point + counter);
counter += 1;
added_op_result.insert(result);
added_value.insert(value);
}
}
// return the inserted op.
Expand All @@ -1120,51 +1128,17 @@ int AppendShadowOutputs(Program *forward_program,

SplitedResult SplitForwardBackward(
const Program &program,
const std::vector<pir::OpResult> &op_result_forward_inputs,
const std::vector<pir::OpResult> &op_result_forward_params,
const std::vector<pir::OpResult> &op_result_forward_outputs,
const std::vector<pir::OpResult> &op_result_forward_inputs_grads,
const std::vector<pir::OpResult> &op_result_forward_params_grads,
const std::vector<pir::OpResult> &op_result_forward_outputs_grads,
const std::vector<pir::Value> &forward_inputs,
const std::vector<pir::Value> &forward_params,
const std::vector<pir::Value> &forward_outputs,
const std::vector<pir::Value> &forward_inputs_grads,
const std::vector<pir::Value> &forward_params_grads,
const std::vector<pir::Value> &forward_outputs_grads,
const std::vector<int> &forward_range,
const std::vector<int> &backward_range) {
// transform opresult -> value
std::vector<pir::Value> forward_inputs, forward_outputs, forward_inputs_grads,
forward_outputs_grads, forward_params, forward_params_grads;

auto op_result_to_value = [](const pir::OpResult &r) {
if (r.impl() == nullptr) return Value(nullptr);
return Value(r.Value::impl());
};

std::transform(op_result_forward_inputs.begin(),
op_result_forward_inputs.end(),
std::back_inserter(forward_inputs),
op_result_to_value);
std::transform(op_result_forward_outputs.begin(),
op_result_forward_outputs.end(),
std::back_inserter(forward_outputs),
op_result_to_value);
std::transform(op_result_forward_inputs_grads.begin(),
op_result_forward_inputs_grads.end(),
std::back_inserter(forward_inputs_grads),
op_result_to_value);
std::transform(op_result_forward_outputs_grads.begin(),
op_result_forward_outputs_grads.end(),
std::back_inserter(forward_outputs_grads),
op_result_to_value);
std::transform(op_result_forward_params.begin(),
op_result_forward_params.end(),
std::back_inserter(forward_params),
op_result_to_value);
std::transform(op_result_forward_params_grads.begin(),
op_result_forward_params_grads.end(),
std::back_inserter(forward_params_grads),
op_result_to_value);

std::vector<pir::Value> forward_in_out_values;
for (auto &v : std::vector<std::vector<pir::Value> *>(
{&forward_inputs, &forward_outputs, &forward_params})) {
for (auto &v :
std::vector({&forward_inputs, &forward_outputs, &forward_params})) {
forward_in_out_values.insert(
forward_in_out_values.end(), v->begin(), v->end());
}
Expand Down Expand Up @@ -1434,8 +1408,8 @@ void BindUtils(pybind11::module *m) {
m->def("reset_shadow_output_name", ResetShadowOutputName);
m->def("split_program", SplitForwardBackward);
m->def("append_shadow_outputs", AppendShadowOutputs);
m->def("fake_op_result", FakeOpResult);
m->def("is_fake_op_result", IsFakeOpResult);
m->def("fake_value", FakeValue);
m->def("is_fake_value", IsFakeValue);
m->def("get_current_insertion_point", []() -> PyInsertionPoint {
return {ApiBuilder::Instance().GetCurrentInsertionPoint()};
});
Expand Down
8 changes: 8 additions & 0 deletions python/paddle/autograd/backward_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ def __iter__(self):
def __contains__(self, key):
return ValueWrapper(key) in self._items

def __repr__(self) -> str:
items_str = ", ".join(f"{key}: {val}" for key, val in self.items())
return f'ValueDict({items_str})'


class ValueSet:
def __init__(
Expand Down Expand Up @@ -153,6 +157,10 @@ def __iter__(self):
def __contains__(self, val):
return ValueWrapper(val) in self._set

def __repr__(self) -> str:
items_str = ", ".join(repr(item) for item in self)
return f'ValueSet({items_str})'


class State:
"""
Expand Down
11 changes: 6 additions & 5 deletions python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def check_all_puts(block, inputs, outputs):


def append_full_like(float_value, copy_value, value, state, backward_ops):
if paddle.pir.is_fake_value(value):
state.value_to_valuegrad[value] = [[paddle.pir.fake_value()]]
return
if copy_value.is_tensorarray():
value_grad = paddle._pir_ops.create_array_like(
copy_value,
Expand Down Expand Up @@ -174,9 +177,9 @@ def prepare_grad_outputs(grad_outputs, outputs, state):
visited_output.add(opresult)
continue
else:
if paddle.pir.is_fake_op_result(opresult):
if paddle.pir.is_fake_value(opresult):
state.value_to_valuegrad[opresult] = [
[paddle.pir.fake_op_result()]
[paddle.pir.fake_value()]
]
else:
grad_value = append_full_like(
Expand Down Expand Up @@ -702,9 +705,7 @@ def append_yield(
new_value = return_map_value(
value, control_flow_value_to_copyvalue_map
)
value_grad = append_full_like(
0.0, new_value, value, state, backward_ops
)
append_full_like(0.0, new_value, value, state, backward_ops)
input_grad = state.value_to_valuegrad[value][0][0]

inputs_grad.append(input_grad)
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/decomposition/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def _prepare_grad_outputs(fwd_op, bwd_op):
new_grad_outputs.append(grad_outputs[index])
index += 1
else:
new_grad_outputs.append([pir.fake_op_result()])
new_grad_outputs.append([pir.fake_value()])
return new_grad_outputs


Expand Down Expand Up @@ -679,7 +679,7 @@ def _decomp_bwd_with_vjp(
if grad_input[0] is not None and grad_input[0].initialized():
res.append(grad_input[0])
else:
res.append(pir.fake_op_result())
res.append(pir.fake_value())
assert len(res) == len(
bwd_op.results()
), "results of original backward op do not match results of decomposed backward op"
Expand Down Expand Up @@ -752,7 +752,7 @@ def _decomp_bwd_without_vjp(
res.append(new_grad_inputs[input_grads_idx])
input_grads_idx += 1
else:
res.append(pir.fake_op_result())
res.append(pir.fake_value())

# step4: upgrade grad_var_to_var
_upgrade_grad_var_to_var(
Expand Down
Loading