Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance Check mechanism and Support single tuple/list of fetch_list in Executor #35726

Merged
merged 4 commits into from
Sep 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 35 additions & 12 deletions python/paddle/fluid/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,14 +703,18 @@ def _get_targets(_optimize_ops, _fetch_list, item):
"The item in fetch_list should be str, variable or optimize_op, but recieved %s.",
type(item))

for item in fetch_list:
for index, item in enumerate(fetch_list):
# NOTE(zhiqiu): to support (optimizer_ops, param_and_grads) and optimizer_ops in fetch_list
# we should handle tuple and list in fetch_list.
# TODO(zhiqiu): find a better way to handle that.
if isinstance(item, list):
for i in item:
_get_targets(_optimize_ops, _fetch_list, i)
elif isinstance(item, tuple):
if not isinstance(item[0], (list, tuple)):
raise TypeError(
"Requires fetch_list[{}][0] shall be one of (list, tuple) when type(fetch_list[{}]) is `tuple`, but received fetch_list[{}][0]'s type is `{}`.".
format(index, index, index, type(item[0]).__name__))
for i in item[0]:
_get_targets(_optimize_ops, _fetch_list, i)
else:
Expand Down Expand Up @@ -1119,17 +1123,7 @@ def _run_impl(self, program, feed, fetch_list, feed_var_name,
if program is None:
program = default_main_program()

if fetch_list is not None:
if isinstance(fetch_list, Variable) or isinstance(
fetch_list, str) or isinstance(fetch_list,
six.string_types):
fetch_list = [fetch_list]
assert isinstance(fetch_list, tuple) or isinstance(fetch_list, list), \
"Currently , The fetch_list type only should be list or tuple, \n"\
"but the input type is {}. For more information please refer to \n"\
"the executor.run(...).".format(type(fetch_list))
else:
fetch_list = []
fetch_list = self._check_fetch_list(fetch_list)

if isinstance(program, Program) and program._pipeline_opt:
if "startup_program" in program._pipeline_opt:
Expand Down Expand Up @@ -1343,6 +1337,35 @@ def _run_program(self, program, feed, fetch_list, feed_var_name,
def _run_inference(self, exe, feed):
return exe.run(feed)

def _check_fetch_list(self, fetch_list):
is_fetch_var = lambda var: isinstance(var, (Variable, str, six.string_types))
is_tuple_list = lambda var: isinstance(var, (tuple, list))

if fetch_list is None: return []
if is_fetch_var(fetch_list): return [fetch_list]

assert is_tuple_list(fetch_list), \
"Currently , The fetch_list type only should be list or tuple, \n"\
"but the input type is {}. For more information please refer to \n"\
"the executor.run(...).".format(type(fetch_list))

res = []
for i, var in enumerate(fetch_list):
if is_fetch_var(var):
res.append(var)
# such as [x, 'mean_out', loss]
elif is_tuple_list(var):
if all(is_fetch_var(v) for v in var):
res.extend(list(var))
else:
res.append(var)
Aurelius84 marked this conversation as resolved.
Show resolved Hide resolved
else:
raise TypeError(
"Require fetch_list[{}] 's type shall be one of (Variable, str), but received {}.".
format(i, type(var).__name__))

return res

def _dump_debug_info(self, program=None, trainer=None):
with open(str(id(program)) + "_train_desc.prototxt", "w") as fout:
fout.write(str(trainer))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest

import numpy as np
import paddle
import unittest


class TestCheckFetchList(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.feed = {"x": np.array([[0], [0], [1], [0]], dtype='float32')}
self.expected = np.array([[0], [1], [0]], dtype='float32')
self.build_program()
self.exe = paddle.static.Executor(paddle.CPUPlace())

def build_program(self):
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program):
x = paddle.static.data(name='x', shape=[4, 1], dtype='float32')
output = paddle.unique_consecutive(
x, return_inverse=True, return_counts=True, axis=0)

self.main_program = main_program
self.fetch_list = output

def test_with_tuple(self):

res = self.exe.run(
self.main_program,
feed=self.feed,
fetch_list=[self.fetch_list], # support single list/tuple
return_numpy=True)

self.assertTrue(np.array_equal(res[0], self.expected))

def test_with_error(self):
with self.assertRaises(TypeError):
fetch_list = [23]
res = self.exe.run(self.main_program,
feed=self.feed,
fetch_list=fetch_list)

with self.assertRaises(TypeError):
fetch_list = [(self.fetch_list[0], 32)]
res = self.exe.run(self.main_program,
feed=self.feed,
fetch_list=fetch_list)


if __name__ == '__main__':
unittest.main()