Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2Stat]Add build_strategy in @to_static to support open pass #34347

Merged
merged 4 commits into from
Jul 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,16 @@ class PartialProgramLayer:
Layer: A Layer object that run all ops internally in static mode.
"""

def __init__(self, main_program, inputs, outputs, parameters=None):
def __init__(self, main_program, inputs, outputs, parameters=None,
**kwargs):
super(PartialProgramLayer, self).__init__()
self._inputs = NestSequence(inputs)
self._outputs = NestSequence(outputs, need_check=True)
self._params = parameters if parameters is not None else []

self._build_strategy = kwargs.get('build_strategy', BuildStrategy())
assert isinstance(self._build_strategy, BuildStrategy)

self._origin_main_program = self._verify_program(main_program)
self._tmp_scope_vec = self._create_scope_vec()
# A fake_var to handle empty input or output
Expand Down Expand Up @@ -170,7 +174,11 @@ def _infer_program_id(self):

@LazyInitialized
def _train_program_id(self):
return _hash_with_id(self._train_program, self)
program_id = _hash_with_id(self._train_program, self)
core._set_cached_executor_build_strategy(program_id,
self._build_strategy)

return program_id

def _verify_program(self, main_program):
"""
Expand Down Expand Up @@ -451,6 +459,6 @@ def partial_program_from(concrete_program):
if inputs and isinstance(inputs[0], layers.Layer):
inputs = inputs[1:]

return PartialProgramLayer(concrete_program.main_program, inputs,
concrete_program.outputs,
concrete_program.parameters)
return PartialProgramLayer(
concrete_program.main_program, inputs, concrete_program.outputs,
concrete_program.parameters, **concrete_program.kwargs)
39 changes: 24 additions & 15 deletions python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,13 @@ class CacheKey(object):
"""
Cached key for ProgramCache.
"""

__slots__ = [
'function_spec', 'input_args_with_spec', 'input_kwargs_with_spec',
'class_instance'
'class_instance', 'kwargs'
]

def __init__(self, function_spec, input_args_with_spec,
input_kwargs_with_spec, class_instance):
input_kwargs_with_spec, class_instance, **kwargs):
"""
Initializes a cache key.

Expand All @@ -161,11 +160,14 @@ def __init__(self, function_spec, input_args_with_spec,
input_args_with_spec(list[InputSpec]): actual input args with some arguments replaced by InputSpec.
input_kwargs_with_spec(list[{string:InputSpec}]): actual input kwargs with some arguments replaced by InputSpec.
class_instance(object): a instance of class `Layer`.
**kwargs(dict): manage other arguments used for better scalability
Aurelius84 marked this conversation as resolved.
Show resolved Hide resolved
"""
self.function_spec = function_spec
self.input_args_with_spec = input_args_with_spec
self.input_kwargs_with_spec = input_kwargs_with_spec
self.class_instance = class_instance
# NOTE: `kwargs` is usually not considered as basic member for `__hash__`
self.kwargs = kwargs

@classmethod
def from_func_and_args(cls, function_spec, args, kwargs, class_instance):
Expand Down Expand Up @@ -235,13 +237,14 @@ class StaticFunction(object):

"""

def __init__(self, function, input_spec=None):
def __init__(self, function, input_spec=None, **kwargs):
"""
Initializes a `StaticFunction`.

Args:
function(callable): A function or method that will be converted into static program.
input_spec(list[InputSpec]): list of InputSpec to specify the `shape/dtype/name` information for each input argument, default None.
**kwargs(dict): other arguments like `build_strategy` et.al.
"""
# save the instance `self` while decorating a method of class.
if inspect.ismethod(function):
Expand All @@ -257,6 +260,7 @@ def __init__(self, function, input_spec=None):
self._descriptor_cache = weakref.WeakKeyDictionary()
# Note: Hold a reference to ProgramTranslator for switching `enable_to_static`.
self._program_trans = ProgramTranslator()
self._kwargs = kwargs

def __get__(self, instance, owner):
"""
Expand Down Expand Up @@ -395,7 +399,8 @@ def get_concrete_program(self, *args, **kwargs):

# 2. generate cache key
cache_key = CacheKey(self._function_spec, input_args_with_spec,
input_kwargs_with_spec, self._class_instance)
input_kwargs_with_spec, self._class_instance,
**self._kwargs)

# 3. check whether hit the cache or build a new program for the input arguments
concrete_program, partial_program_layer = self._program_cache[cache_key]
Expand Down Expand Up @@ -586,7 +591,7 @@ class ConcreteProgram(object):

__slots__ = [
'inputs', 'outputs', 'main_program', "startup_program", "parameters",
"function"
"function", 'kwargs'
]

def __init__(self,
Expand All @@ -595,18 +600,20 @@ def __init__(self,
parameters,
function,
main_program,
startup_program=None):
startup_program=None,
**kwargs):
self.inputs = inputs
self.outputs = outputs
self.main_program = main_program
self.startup_program = startup_program
self.parameters = parameters
self.function = function
self.kwargs = kwargs

@staticmethod
@switch_to_static_graph
def from_func_spec(func_spec, input_spec, input_kwargs_spec,
class_instance):
def from_func_spec(func_spec, input_spec, input_kwargs_spec, class_instance,
**kwargs):
"""
Builds the main_program with specialized inputs and returns outputs
of program as fetch_list.
Expand Down Expand Up @@ -635,8 +642,8 @@ def from_func_spec(func_spec, input_spec, input_kwargs_spec,
# 1. Adds `fluid.data` layers for input if needed
inputs = func_spec.to_static_inputs_with_spec(input_spec,
main_program)
kwargs = func_spec.to_static_inputs_with_spec(input_kwargs_spec,
main_program)
_kwargs = func_spec.to_static_inputs_with_spec(
input_kwargs_spec, main_program)
if class_instance:
inputs = tuple([class_instance] + list(inputs))

Expand All @@ -649,8 +656,8 @@ def from_func_spec(func_spec, input_spec, input_kwargs_spec,
class_instance, False)), param_guard(
get_buffers(class_instance, False)):
try:
if kwargs:
outputs = static_func(*inputs, **kwargs)
if _kwargs:
outputs = static_func(*inputs, **_kwargs)
else:
outputs = static_func(*inputs)
except BaseException as e:
Expand All @@ -675,7 +682,8 @@ def from_func_spec(func_spec, input_spec, input_kwargs_spec,
parameters=all_parameters_and_buffers,
function=dygraph_function,
main_program=main_program,
startup_program=startup_program)
startup_program=startup_program,
**kwargs)


def _extract_indeed_params_buffers(class_instance):
Expand All @@ -702,7 +710,8 @@ def _build_once(self, cache_key):
func_spec=cache_key.function_spec,
input_spec=cache_key.input_args_with_spec,
input_kwargs_spec=cache_key.input_kwargs_with_spec,
class_instance=cache_key.class_instance)
class_instance=cache_key.class_instance,
**cache_key.kwargs)
return concrete_program, partial_program_from(concrete_program)

def __getitem__(self, item):
Expand Down
18 changes: 16 additions & 2 deletions python/paddle/fluid/dygraph/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def copy_decorator_attrs(original_func, decorated_obj):
return decorated_obj


def declarative(function=None, input_spec=None):
def declarative(function=None, input_spec=None, build_strategy=None):
"""
Converts imperative dygraph APIs into declarative function APIs. Decorator
@declarative handles the Program and Executor of static mode and returns
Expand All @@ -171,6 +171,12 @@ def declarative(function=None, input_spec=None):
function (callable): callable imperative function.
input_spec(list[InputSpec]|tuple[InputSpec]): list/tuple of InputSpec to specific the shape/dtype/name
information of each input Tensor.
build_strategy(BuildStrategy|None): This argument is used to compile the
converted program with the specified options, such as operators' fusion
in the computational graph and memory optimization during the execution
of the computational graph. For more information about build_strategy,
please refer to :code:`paddle.static.BuildStrategy`. The default is None.


Returns:
Tensor(s): containing the numerical result.
Expand Down Expand Up @@ -206,10 +212,18 @@ def decorated(python_func):
static_layer = copy_decorator_attrs(
original_func=python_func,
decorated_obj=StaticFunction(
function=python_func, input_spec=input_spec))
function=python_func,
input_spec=input_spec,
build_strategy=build_strategy))

return static_layer

build_strategy = build_strategy or BuildStrategy()
if not isinstance(build_strategy, BuildStrategy):
raise TypeError(
"Required type(build_strategy) shall be `paddle.static.BuildStrategy`, but received {}".
format(type(build_strategy).__name__))

# for usage: `declarative(foo, ...)`
if function is not None:
if isinstance(function, Layer):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ set_tests_properties(test_reinforcement_learning PROPERTIES TIMEOUT 120)
set_tests_properties(test_transformer PROPERTIES TIMEOUT 200)
set_tests_properties(test_bmn PROPERTIES TIMEOUT 120)
#set_tests_properties(test_mnist PROPERTIES TIMEOUT 120)
set_tests_properties(test_build_strategy PROPERTIES TIMEOUT 120)

if(NOT WIN32)
set_tests_properties(test_resnet_v2 PROPERTIES TIMEOUT 120)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import paddle
import unittest
import numpy as np
from paddle.jit import ProgramTranslator

from test_resnet import ResNet, train, predict_dygraph_jit
from test_resnet import predict_dygraph, predict_static, predict_analysis_inference

program_translator = ProgramTranslator()


class TestResnetWithPass(unittest.TestCase):
def setUp(self):
self.build_strategy = paddle.static.BuildStrategy()
self.build_strategy.fuse_elewise_add_act_ops = True
self.build_strategy.fuse_bn_act_ops = True
self.build_strategy.fuse_bn_add_act_ops = True
self.build_strategy.enable_addto = True
# NOTE: for enable_addto
paddle.fluid.set_flags({"FLAGS_max_inplace_grad_add": 8})

def train(self, to_static):
program_translator.enable(to_static)

return train(to_static, self.build_strategy)

def verify_predict(self):
image = np.random.random([1, 3, 224, 224]).astype('float32')
dy_pre = predict_dygraph(image)
st_pre = predict_static(image)
dy_jit_pre = predict_dygraph_jit(image)
predictor_pre = predict_analysis_inference(image)
self.assertTrue(
np.allclose(dy_pre, st_pre),
msg="dy_pre:\n {}\n, st_pre: \n{}.".format(dy_pre, st_pre))
self.assertTrue(
np.allclose(dy_jit_pre, st_pre),
msg="dy_jit_pre:\n {}\n, st_pre: \n{}.".format(dy_jit_pre, st_pre))
self.assertTrue(
np.allclose(predictor_pre, st_pre),
msg="predictor_pre:\n {}\n, st_pre: \n{}.".format(predictor_pre,
st_pre))

def test_resnet(self):
static_loss = self.train(to_static=True)
dygraph_loss = self.train(to_static=False)
self.assertTrue(
np.allclose(static_loss, dygraph_loss),
msg="static_loss: {} \n dygraph_loss: {}".format(static_loss,
dygraph_loss))
self.verify_predict()

def test_in_static_mode_mkldnn(self):
paddle.fluid.set_flags({'FLAGS_use_mkldnn': True})
try:
if paddle.fluid.core.is_compiled_with_mkldnn():
train(True, self.build_strategy)
finally:
paddle.fluid.set_flags({'FLAGS_use_mkldnn': False})


class TestError(unittest.TestCase):
def test_type_error(self):
def foo(x):
out = x + 1
return out

with self.assertRaises(TypeError):
static_foo = paddle.jit.to_static(foo, build_strategy="x")


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ def __init__(self, layers=50, class_dim=102):
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)))

@declarative
def forward(self, inputs):
y = self.conv(inputs)
y = self.pool2d_max(y)
Expand All @@ -213,7 +212,7 @@ def __reader__():
return __reader__


def train(to_static):
def train(to_static, build_strategy=None):
"""
Tests model decorated by `dygraph_to_static_output` in static mode. For users, the model is defined in dygraph mode and trained in static mode.
"""
Expand All @@ -231,6 +230,8 @@ def train(to_static):
data_loader.set_sample_list_generator(train_reader)

resnet = ResNet()
if to_static:
resnet = paddle.jit.to_static(resnet, build_strategy=build_strategy)
optimizer = optimizer_setting(parameter_list=resnet.parameters())

for epoch in range(epoch_num):
Expand Down
1 change: 1 addition & 0 deletions tools/windows/run_unittests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ disable_wincpu_test="^jit_kernel_test$|\
^test_bmn$|\
^test_mobile_net$|\
^test_resnet_v2$|\
^test_build_strategy$|\
^test_se_resnet$|\
^disable_wincpu_test$"

Expand Down