Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SOT] sot export test files #60547

Merged
merged 5 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions python/paddle/jit/sot/opcode_translator/executor/function_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from functools import cached_property
from typing import Any, Callable

from paddle.utils import flatten

from ...infer_meta import (
InferMetaCache,
LayerInferMetaCache,
Expand Down Expand Up @@ -67,6 +69,7 @@
ListVariable,
NullVariable,
PaddleLayerVariable,
ParameterVariable,
TensorVariable,
VariableBase,
VariableFactory,
Expand Down Expand Up @@ -117,6 +120,19 @@ def func(x):
return output


def get_params_and_non_param_symbol(*args, **kwargs):
params = set()
non_params = set()

for value in flatten([args, kwargs]):
if isinstance(value, ParameterVariable):
params.add(value.get_symbol())
elif isinstance(value, TensorVariable):
non_params.add(value.get_symbol())

return params, non_params


class FunctionGraph:
"""
A Graph representation corresponding to each FunctionFrame
Expand Down Expand Up @@ -559,6 +575,8 @@ def symbolic_call(self, infer_meta_fn, compute_fn, func, *args, **kwargs):

self.sir_ctx.TOS.set_symbol_meta_map(get_symbol_meta_map(args))
self.sir_ctx.TOS.set_symbol_meta_map(get_symbol_meta_map(kwargs))
params, non_params = get_params_and_non_param_symbol(*args, **kwargs)
self.sir_ctx.TOS.set_parameter_info(params, non_params)

log(3, f" inputs : {inputs_symbols}", "\n")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

from .base import ( # noqa: F401
ConstTypes,
VariableBase,
VariableFactory,
find_traceable_vars,
Expand All @@ -30,6 +29,7 @@
NullVariable,
NumpyVariable,
ObjectVariable,
ParameterVariable,
SliceVariable,
TensorVariable,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@
]


ConstTypes = (int, float, str, bool, type(None))


@event_register("find_traceable_vars")
def find_traceable_vars(
root_vars: list[VariableBase],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ....symbolic.statement_ir import Symbol
from ....utils import (
BreakGraphError,
ConstTypes,
FallbackError,
NameGenerator,
paddle_tensor_methods,
Expand All @@ -50,7 +51,7 @@
GlobalTracker,
Tracker,
)
from .base import ConstTypes, VariableBase, VariableFactory
from .base import VariableBase, VariableFactory

if TYPE_CHECKING:
from ..function_graph import FunctionGraph
Expand Down Expand Up @@ -549,6 +550,22 @@ def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
return None


class ParameterVariable(TensorVariable):
def __init__(
self,
param: paddle.Tensor | MetaInfo,
graph: FunctionGraph,
tracker: Tracker,
):
super().__init__(param, graph, tracker)

@VariableFactory.register_from_value(successor="TensorVariable")
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
if isinstance(value, (paddle.base.framework.EagerParamBase)):
return ParameterVariable(value, graph, tracker)
return None


class ObjectVariable(VariableBase):
"""
ObjectVariable is a subclass of VariableBase used to wrap a Variable of the object type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .... import psdb
from ....profiler import EventGuard
from ....utils import (
ENV_SOT_EXPORT,
get_static_function,
is_break_graph_api,
is_break_graph_tensor_methods,
Expand Down Expand Up @@ -553,6 +554,9 @@ def main_info(self) -> dict[str, Any]:

@VariableFactory.register_from_value(successor="UserDefinedLayerVariable")
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
# TODO: @wuzhanfei, if we support create sub layer when export, remove this branch
if ENV_SOT_EXPORT.get() != "":
return None
# TODO(SigureMo): Add a more common way to check if a value is a paddle builtin layer.
if isinstance(value, paddle.nn.Layer):
# If there is a user-defined behavior, such as a container class layer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from functools import reduce
from typing import TYPE_CHECKING, Any

from ....utils import ConstTypes
from ....utils.exceptions import FallbackError, InnerError
from ..dispatcher import Dispatcher
from ..guard import StringifyExpression, check_guard
Expand All @@ -32,7 +33,7 @@
GetIterTracker,
Tracker,
)
from .base import ConstTypes, VariableBase, VariableFactory
from .base import VariableBase, VariableFactory
from .basic import ConstantVariable
from .callable import BuiltinVariable, UserDefinedFunctionVariable

Expand Down
5 changes: 5 additions & 0 deletions python/paddle/jit/sot/symbolic/compile_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ..infer_meta import convert_meta_to_input_spec
from ..profiler import EventGuard
from ..utils import (
ENV_SOT_EXPORT,
Cache,
GraphLogger,
Singleton,
Expand All @@ -33,6 +34,7 @@
log_do,
map_if,
)
from .export import export
from .interpreter import compile_sir

if TYPE_CHECKING:
Expand Down Expand Up @@ -143,6 +145,9 @@ def __call__(self, *args, **kwargs):
4,
lambda: print("[CompileCache] run sir forward success."),
)
if ENV_SOT_EXPORT.get() != "":
export(self.SIR, ENV_SOT_EXPORT.get())

return outputs


Expand Down
Loading