Skip to content

Commit

Permalink
[PIR / dy2static] Fix tostatic unittest bugs. (#58959)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: SigureMo <sigure.qaq@gmail.com>
  • Loading branch information
2742195759 and SigureMo authored Nov 14, 2023
1 parent e960192 commit e80c65e
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 115 deletions.
10 changes: 6 additions & 4 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1034,10 +1034,10 @@ void AppendSetParameter(Program *forward_program,
}
}

void AppendSetParameters(Program *forward_program,
const std::vector<pir::OpResult> &outputs_op_result,
int start_point,
std::string name_prefix) {
int AppendSetParameters(Program *forward_program,
const std::vector<pir::OpResult> &outputs_op_result,
int start_point,
std::string name_prefix) {
int counter = 0;
std::unordered_set<pir::OpResult> added_op_result;

Expand All @@ -1051,6 +1051,8 @@ void AppendSetParameters(Program *forward_program,
added_op_result.insert(result);
}
}
// return the inserted op.
return counter;
}

SplitedResult SplitForwardBackward(
Expand Down
210 changes: 101 additions & 109 deletions python/paddle/jit/dy2static/pir_partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,9 @@
from paddle.base.compiler import BuildStrategy
from paddle.base.data_feeder import check_type, convert_dtype
from paddle.base.dygraph.base import switch_to_static_graph
from paddle.framework import use_pir_api
from paddle.optimizer.lr import LRScheduler
from paddle.pir import OpResult, fake_op_result, is_fake_op_result

from . import logging_utils
from .utils import RETURN_NO_VALUE_MAGIC_NUM, backend_guard

__all__ = []
Expand All @@ -52,60 +50,57 @@ def __get__(self, instance, cls):
class NestSequence:
"""
A wrapper class that easily to flatten and restore the nest structure of
given sequence.
given sequence. It also remove the duplicate variables in the sequence.
For example:
>>> t = [v1, v2, v1]
>>> m = tolist(t)
[v1, v2]
>>> m.restore([t1, t2])
[t1, t2, t1]
"""

def __init__(self, raw_input, need_check=False):
self.__raw_input = raw_input
self.__input_list = self.tolist()
self.__var_ids = self._get_var_ids()
self._check_non_variable(need_check)
def __init__(self, raw_input):
self._raw_input = raw_input
self._var_map, self._var_list = self._tolist()

def tolist(self):
@property
def var_list(self):
return self._var_list

def _tolist(self):
"""
Flattens the nested sequences into single list.
Flattens the nested sequences into single list and remove duplicate variables + non-variable elements.
"""
return paddle.utils.flatten(self.__raw_input)
variable_map = {} # opresult -> list idx
variable_list = []
for value in paddle.utils.flatten(self._raw_input):
if not isinstance(value, OpResult):
continue
if value in variable_map:
# remove duplicate opresults.
continue
variable_map[value] = len(variable_list)
variable_list.append(value)
return variable_map, variable_list

def restore(self, value_list):
"""
Restores the nested sequence from value list.
"""
assert len(self.__input_list) == len(value_list)
return paddle.utils.pack_sequence_as(self.__raw_input, value_list)

def _get_var_ids(self):
var_ids = []
for idx, var in enumerate(self.__input_list):
if isinstance(var, (OpResult, core.eager.Tensor)):
var_ids.append(idx)
assert len(self._var_list) == len(value_list)

return var_ids

def _check_non_variable(self, need_check):
"""
Raises warning if output of traced function contains non-tensor type values.
"""
if need_check:
warning_types = set()
for var in self.__input_list:
if not isinstance(var, (framework.Variable, core.eager.Tensor)):
warning_types.add(type(var))
if warning_types:
logging_utils.warn(
"Output of traced function contains non-tensor type values: {}. "
"Currently, We don't support to update them while training and will return "
"what we first saw. Please try to return them as tensor.".format(
list(warning_types)
)
)
def to_value(x):
if isinstance(x, OpResult):
return value_list[self._var_map[x]]
return x

@property
def var_ids(self):
return self.__var_ids
return paddle.utils.pack_sequence_as(
self._raw_input,
list(map(to_value, paddle.utils.flatten(self._raw_input))),
)

def __getitem__(self, item):
return self.__input_list[item]
return self._var_list[item]


class RunableProgram:
Expand Down Expand Up @@ -146,10 +141,7 @@ def convert_name(self, values):
return []
if isinstance(values[0], str):
return values
try:
return [self.get_value_name_map[v] for v in values]
except:
breakpoint()
return [self.get_value_name_map[v] for v in values]

@cached_property
def x_values(self):
Expand Down Expand Up @@ -415,7 +407,7 @@ def __init__(
):
super().__init__()
self._inputs = NestSequence(inputs)
self._outputs = NestSequence(outputs, need_check=True)
self._outputs = NestSequence(outputs)
self._params, self._param_values = (
parameters if parameters is not None else ([], [])
)
Expand Down Expand Up @@ -458,15 +450,7 @@ def __call__(self, inputs):
"""
in_vars, out_vars = self._prepare(inputs)
attrs = self._prepare_attributes()

# self._sync_lr_value_with_scheduler()

c_run_program_fn = None
if use_pir_api():
c_run_program_fn = _legacy_C_ops.pir_run_program
else:
c_run_program_fn = _legacy_C_ops.run_program
c_run_program_fn(
_legacy_C_ops.pir_run_program(
self._valid_vars(in_vars),
self._valid_vars(self._params),
self._valid_vars(out_vars),
Expand All @@ -482,12 +466,8 @@ def __call__(self, inputs):

@cached_property
def origin_runable_program(self):
inputs = list(
filter(lambda x: isinstance(x, OpResult), self._inputs.tolist())
)
outputs = list(
filter(lambda x: isinstance(x, OpResult), self._outputs.tolist())
)
inputs = list(self._inputs.var_list)
outputs = list(self._outputs.var_list)
params = self._param_values
paddle.base.libpaddle.pir.append_set_parameters(
self._origin_main_program,
Expand Down Expand Up @@ -638,7 +618,7 @@ def _need_aggregation(var):
"""
if exist a op whose inputs is var, then return True
"""
if not isinstance(var, framework.Variable) or var.type not in [
if var.type not in [
core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.SELECTED_ROWS,
]:
Expand Down Expand Up @@ -692,7 +672,7 @@ def _insert_aggregation_ops_for_var(target_program, var):
return None

to_processed_vars = list(
filter(_need_aggregation, self._outputs.tolist())
filter(_need_aggregation, self._outputs.var_list)
)
for _var in to_processed_vars:
_insert_aggregation_ops_for_var(target_program, _var)
Expand All @@ -710,31 +690,58 @@ def _append_backward_desc(self, train_runnable_program: RunableProgram):
params = train_runnable_program.param_values
combined_inputs = list(itertools.chain(inputs, params))
forward_end_idx = len(program.global_block().ops)
if targets:
with backend_guard(self._backend):
check_type(
targets,
'targets',
(OpResult, list, tuple),
'paddle.static.gradients',
)
with ir_static.program_guard(program, None):
grad_info_map = grad(
inputs=combined_inputs, outputs=targets
)
grad_info_map = [None] * len(combined_inputs)
with backend_guard(self._backend):
check_type(
targets,
'targets',
(OpResult, list, tuple),
'paddle.static.gradients',
)
with ir_static.program_guard(program, None):
# create outputs_grad for backward to avoid full and full_like op.
forward_outputs_grads = []
not_stop_gradient_num = 0
for out_op_result in self._outputs.tolist():
for out_op_result in targets:
if out_op_result.stop_gradient is True:
forward_outputs_grads.append(None)
continue
opres = (
program.global_block()
.ops[forward_end_idx + 2 * not_stop_gradient_num + 1]
.results()[0]
forward_outputs_grads.append(fake_op_result())
else:
value = paddle.full_like(
out_op_result,
fill_value=1.0,
dtype=out_op_result.dtype,
)
forward_outputs_grads.append(value)
paddle.base.libpaddle.pir.append_set_parameters(
program,
forward_outputs_grads,
len(program.global_block().ops),
"grad_input_",
)
op_between_forward_and_backward = (
len(program.global_block().ops) - forward_end_idx
)

# call grad to get backward ops.
if (
len(
list(
filter(lambda x: x.stop_gradient is False, targets)
)
)
> 0
):
grad_info_map = grad(
inputs=combined_inputs,
outputs=list(
filter(lambda x: x.stop_gradient is False, targets)
),
grad_outputs=list(
filter(
lambda x: not is_fake_op_result(x),
forward_outputs_grads,
)
),
)
forward_outputs_grads.append(opres)
not_stop_gradient_num += 1

if self._hooker:
(
Expand All @@ -744,7 +751,6 @@ def _append_backward_desc(self, train_runnable_program: RunableProgram):
) = self._hooker.after_append_backward(
program, targets, forward_end_idx
)

# TODO: add later
# self.prepare_gradient_aggregation(
# start_idx + 1, main_program, program
Expand All @@ -759,9 +765,6 @@ def _append_backward_desc(self, train_runnable_program: RunableProgram):
)
p_grad_value = list(map(mapping_op_result, grad_info_map[inputs_size:]))
o_grad_value = list(map(mapping_op_result, forward_outputs_grads))
backward_start_op_index = forward_end_idx + 2 * len(
list(filter(lambda r: r.stop_gradient is False, self._outputs))
)

# insert grads name for RunableProgram (we need name for grad_inputs and grad_outputs)
input_grads_to_append = list(
Expand All @@ -772,13 +775,6 @@ def _append_backward_desc(self, train_runnable_program: RunableProgram):
lambda x: not is_fake_op_result(x), x_grad_value + p_grad_value
)
)
paddle.base.libpaddle.pir.append_set_parameters(
program,
input_grads_to_append,
backward_start_op_index,
"grad_input_",
)
backward_start_op_index += len(input_grads_to_append)
backward_end_op_index = len(program.global_block().ops)
paddle.base.libpaddle.pir.append_set_parameters(
program,
Expand All @@ -787,6 +783,9 @@ def _append_backward_desc(self, train_runnable_program: RunableProgram):
"grad_output_",
)

backward_start_op_index = (
forward_end_idx + op_between_forward_and_backward
)
# construct a runnable program.
return RunableProgram(
program,
Expand Down Expand Up @@ -878,8 +877,7 @@ def _prepare(self, inputs):
# mapping from name(string) -> Tensor
out_tensor_map = {}

def create_out(var_id):
var = self._outputs[var_id]
def create_out(var):
assert isinstance(var, OpResult)

if id(var) in out_tensor_map:
Expand All @@ -901,7 +899,7 @@ def create_out(var_id):
return out

# Create Tensor to receive output data.
out_vars = list(map(create_out, self._outputs.var_ids))
out_vars = list(map(create_out, self._outputs.var_list))
return input_vars, out_vars

def _create_scope_vec(self, program_id=None, use_scope_cache=False):
Expand All @@ -923,26 +921,20 @@ def _create_cuda_graph_vec(self):

def _update_stop_gradient(self, out_vars):
# Update stop_gradient for all outputs
def set_stop_gradient(var_id, eager_tensor):
var = self._outputs[var_id]
def set_stop_gradient(var, eager_tensor):
assert isinstance(var, OpResult)
eager_tensor.stop_gradient = var.stop_gradient

for idx, var in zip(self._outputs.var_ids, out_vars):
for idx, var in zip(self._outputs.var_list, out_vars):
set_stop_gradient(idx, var)

def _restore_out(self, out_vars):
"""
Restores same nested outputs by only replacing the Variable with Tensor.
"""

flatten_outputs = self._outputs.tolist()
for i, idx in enumerate(self._outputs.var_ids):
flatten_outputs[idx] = out_vars[i]
outs = self._outputs.restore(flatten_outputs)
outs = self._outputs.restore(out_vars)
if outs is not None and len(outs) == 1:
outs = outs[0]

return outs

@switch_to_static_graph
Expand Down
7 changes: 5 additions & 2 deletions test/dygraph_to_static/test_duplicate_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
import unittest

import numpy as np
from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir
from dygraph_to_static_utils_new import (
Dy2StTestBase,
test_legacy_and_pir_exe_and_pir_api,
)

import paddle

Expand Down Expand Up @@ -55,7 +58,7 @@ def _run_static(self):

self.assertEqual(param[0].grad.numpy(), 1.0)

@test_legacy_and_pir
@test_legacy_and_pir_exe_and_pir_api
def test_ast_to_func(self):
self._run_static()

Expand Down
Loading

0 comments on commit e80c65e

Please sign in to comment.