Skip to content

Commit

Permalink
[JitLayer]Fix jit.save error when save params combined (#44504)
Browse files Browse the repository at this point in the history
* Fix jit.save error when save params combined

* Change dict_value to list
  • Loading branch information
0x45f authored Jul 25, 2022
1 parent e32e4a1 commit c0a29d2
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions python/paddle/fluid/dygraph/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,9 +483,9 @@ def _get_output_vars(outputs, output_spec, with_hook=False):
if isinstance(var, Variable):
output_vars_dict[var.name] = var
if output_spec is None:
result_list = output_vars_dict.values()
result_list = list(output_vars_dict.values())
elif output_spec is not None and len(output_spec) == len(output_vars_dict):
result_list = output_vars_dict.values()
result_list = list(output_vars_dict.values())
for var in output_spec:
if var.name not in output_vars_dict:
warnings.warn(name_no_exists_error % var.name)
Expand Down Expand Up @@ -868,7 +868,7 @@ def fun(inputs):
layer,
]

all_vars = set()
combine_vars = {}
property_vals = [] # (value, key)
for attr_func in functions:
if isinstance(layer, Layer):
Expand Down Expand Up @@ -1020,19 +1020,28 @@ def fun(inputs):
program_only=configs._program_only,
clip_extra=configs.clip_extra)

# collect all vars
for var in concrete_program.main_program.list_vars():
all_vars.add(var)
if combine_params:
clone_main_program = concrete_program.main_program.clone()
clone_main_program = clone_main_program._prune_with_input(
input_var_names, output_vars)
for block in clone_main_program.blocks:
combine_vars.update(block.vars)

# save shared params
if combine_params:
# sort vars by name
combine_vars = sorted(combine_vars.items(), key=lambda item: item[0])
ordered_vars = []
for name, var in combine_vars:
ordered_vars.append(var)

params_filename = file_prefix + INFER_PARAMS_SUFFIX
with scope_guard(scope):
paddle.static.save_vars(Executor(_current_expected_place()),
dirname=model_path,
vars=list(
filter(paddle.fluid.io.is_persistable,
all_vars)),
ordered_vars)),
filename=params_filename)
# TODO: save property

Expand Down

0 comments on commit c0a29d2

Please sign in to comment.