Skip to content

Commit

Permalink
export paddle.static.normalize_program method. test=develop (#31080)
Browse files Browse the repository at this point in the history
  • Loading branch information
Shibo Tao authored Feb 23, 2021
1 parent 1d2bd35 commit 29543da
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 5 deletions.
42 changes: 42 additions & 0 deletions python/paddle/fluid/tests/unittests/test_inference_model_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,48 @@ def test_serialize_program_and_persistables(self):
self.assertRaises(TypeError, paddle.static.io.deserialize_persistables,
None, None, None)

def test_normalize_program(self):
init_program = fluid.default_startup_program()
program = fluid.default_main_program()

# fake program without feed/fetch
with program_guard(program, init_program):
x = layers.data(name='x', shape=[2], dtype='float32')
y = layers.data(name='y', shape=[1], dtype='float32')

y_predict = layers.fc(input=x, size=1, act=None)

cost = layers.square_error_cost(input=y_predict, label=y)
avg_cost = layers.mean(cost)

sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.001)
sgd_optimizer.minimize(avg_cost, init_program)

place = core.CPUPlace()
exe = executor.Executor(place)
exe.run(init_program, feed={}, fetch_list=[])

tensor_x = np.array([[1, 1], [1, 2], [5, 2]]).astype("float32")
tensor_y = np.array([[-2], [-3], [-7]]).astype("float32")
for i in six.moves.xrange(3):
exe.run(program,
feed={'x': tensor_x,
'y': tensor_y},
fetch_list=[avg_cost])

# test if return type of serialize_program is bytes
res = paddle.static.normalize_program(program, [x, y], [avg_cost])
self.assertTrue(isinstance(res, Program))
# test program type
self.assertRaises(TypeError, paddle.static.normalize_program, None,
[x, y], [avg_cost])
# test feed_vars type
self.assertRaises(TypeError, paddle.static.normalize_program, program,
['x', 'y'], [avg_cost])
# test fetch_vars type
self.assertRaises(TypeError, paddle.static.normalize_program, program,
[x, y], ['avg_cost'])


class TestLoadInferenceModelError(unittest.TestCase):
def test_load_model_not_exist(self):
Expand Down
1 change: 1 addition & 0 deletions python/paddle/static/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from .io import serialize_program #DEFINE_ALIAS
from .io import load_from_file #DEFINE_ALIAS
from .io import save_to_file #DEFINE_ALIAS
from .io import normalize_program #DEFINE_ALIAS
from ..fluid import Scope #DEFINE_ALIAS
from .input import data #DEFINE_ALIAS
from .input import InputSpec #DEFINE_ALIAS
Expand Down
65 changes: 60 additions & 5 deletions python/paddle/static/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
'deserialize_program',
'deserialize_persistables',
'load_from_file',
'normalize_program',
]

_logger = get_logger(
Expand Down Expand Up @@ -127,10 +128,64 @@ def _clone_var_in_block(block, var):
persistable=True)


def _normalize_program(program, feed_vars, fetch_vars):
def normalize_program(program, feed_vars, fetch_vars):
"""
optimize program according feed_vars and fetch_vars.
:api_attr: Static Graph
Normalize/Optimize a program according to feed_vars and fetch_vars.
Args:
program(Program): Specify a program you want to optimize.
feed_vars(Variable | list[Variable]): Variables needed by inference.
fetch_vars(Variable | list[Variable]): Variables returned by inference.
Returns:
Program: Normalized/Optimized program.
Raises:
TypeError: If `program` is not a Program, an exception is thrown.
TypeError: If `feed_vars` is not a Variable or a list of Variable, an exception is thrown.
TypeError: If `fetch_vars` is not a Variable or a list of Variable, an exception is thrown.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
path_prefix = "./infer_model"
# User defined network, here a softmax regession example
image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
predict = paddle.static.nn.fc(image, 10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(predict, label)
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(paddle.static.default_startup_program())
# normalize main program.
program = default_main_program()
normalized_program = paddle.static.normalize_program(program, [image], [predict])
"""
if not isinstance(program, Program):
raise TypeError(
"program type must be `fluid.Program`, but received `%s`" %
type(program))
if not isinstance(feed_vars, list):
feed_vars = [feed_vars]
if not all(isinstance(v, Variable) for v in feed_vars):
raise TypeError(
"feed_vars type must be a Variable or a list of Variable.")
if not isinstance(fetch_vars, list):
fetch_vars = [fetch_vars]
if not all(isinstance(v, Variable) for v in fetch_vars):
raise TypeError(
"fetch_vars type must be a Variable or a list of Variable.")

# remind users to set auc_states to 0 if auc op were found.
for op in program.global_block().ops:
# clear device of Op
Expand Down Expand Up @@ -255,7 +310,7 @@ def serialize_program(feed_vars, fetch_vars, **kwargs):
_check_vars('fetch_vars', fetch_vars)

program = _get_valid_program(kwargs.get('program', None))
program = _normalize_program(program, feed_vars, fetch_vars)
program = normalize_program(program, feed_vars, fetch_vars)
return _serialize_program(program)


Expand Down Expand Up @@ -319,7 +374,7 @@ def serialize_persistables(feed_vars, fetch_vars, executor, **kwargs):
_check_vars('fetch_vars', fetch_vars)

program = _get_valid_program(kwargs.get('program', None))
program = _normalize_program(program, feed_vars, fetch_vars)
program = normalize_program(program, feed_vars, fetch_vars)
return _serialize_persistables(program, executor)


Expand Down Expand Up @@ -463,7 +518,7 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor,
_check_vars('fetch_vars', fetch_vars)

program = _get_valid_program(kwargs.get('program', None))
program = _normalize_program(program, feed_vars, fetch_vars)
program = normalize_program(program, feed_vars, fetch_vars)
# serialize and save program
program_bytes = _serialize_program(program)
save_to_file(model_path, program_bytes)
Expand Down

0 comments on commit 29543da

Please sign in to comment.