Skip to content

Commit

Permalink
[SOT] sot export test files (#60547)
Browse files Browse the repository at this point in the history
feifei-111 authored Jan 9, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 6935d75 commit b578350
Showing 14 changed files with 420 additions and 6 deletions.
18 changes: 18 additions & 0 deletions python/paddle/jit/sot/opcode_translator/executor/function_graph.py
Original file line number Diff line number Diff line change
@@ -24,6 +24,8 @@
from functools import cached_property
from typing import Any, Callable

from paddle.utils import flatten

from ...infer_meta import (
InferMetaCache,
LayerInferMetaCache,
@@ -67,6 +69,7 @@
ListVariable,
NullVariable,
PaddleLayerVariable,
ParameterVariable,
TensorVariable,
VariableBase,
VariableFactory,
@@ -117,6 +120,19 @@ def func(x):
return output


def get_params_and_non_param_symbol(*args, **kwargs):
params = set()
non_params = set()

for value in flatten([args, kwargs]):
if isinstance(value, ParameterVariable):
params.add(value.get_symbol())
elif isinstance(value, TensorVariable):
non_params.add(value.get_symbol())

return params, non_params


class FunctionGraph:
"""
A Graph representation corresponding to each FunctionFrame
@@ -559,6 +575,8 @@ def symbolic_call(self, infer_meta_fn, compute_fn, func, *args, **kwargs):

self.sir_ctx.TOS.set_symbol_meta_map(get_symbol_meta_map(args))
self.sir_ctx.TOS.set_symbol_meta_map(get_symbol_meta_map(kwargs))
params, non_params = get_params_and_non_param_symbol(*args, **kwargs)
self.sir_ctx.TOS.set_parameter_info(params, non_params)

log(3, f" inputs : {inputs_symbols}", "\n")

Original file line number Diff line number Diff line change
@@ -13,7 +13,6 @@
# limitations under the License.

from .base import ( # noqa: F401
ConstTypes,
VariableBase,
VariableFactory,
find_traceable_vars,
@@ -30,6 +29,7 @@
NullVariable,
NumpyVariable,
ObjectVariable,
ParameterVariable,
SliceVariable,
TensorVariable,
)
Original file line number Diff line number Diff line change
@@ -47,9 +47,6 @@
]


ConstTypes = (int, float, str, bool, type(None))


@event_register("find_traceable_vars")
def find_traceable_vars(
root_vars: list[VariableBase],
Original file line number Diff line number Diff line change
@@ -27,6 +27,7 @@
from ....symbolic.statement_ir import Symbol
from ....utils import (
BreakGraphError,
ConstTypes,
FallbackError,
NameGenerator,
paddle_tensor_methods,
@@ -50,7 +51,7 @@
GlobalTracker,
Tracker,
)
from .base import ConstTypes, VariableBase, VariableFactory
from .base import VariableBase, VariableFactory

if TYPE_CHECKING:
from ..function_graph import FunctionGraph
@@ -549,6 +550,22 @@ def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
return None


class ParameterVariable(TensorVariable):
def __init__(
self,
param: paddle.Tensor | MetaInfo,
graph: FunctionGraph,
tracker: Tracker,
):
super().__init__(param, graph, tracker)

@VariableFactory.register_from_value(successor="TensorVariable")
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
if isinstance(value, (paddle.base.framework.EagerParamBase)):
return ParameterVariable(value, graph, tracker)
return None


class ObjectVariable(VariableBase):
"""
ObjectVariable is a subclass of VariableBase used to wrap a Variable of the object type.
Original file line number Diff line number Diff line change
@@ -25,6 +25,7 @@
from .... import psdb
from ....profiler import EventGuard
from ....utils import (
ENV_SOT_EXPORT,
get_static_function,
is_break_graph_api,
is_break_graph_tensor_methods,
@@ -553,6 +554,9 @@ def main_info(self) -> dict[str, Any]:

@VariableFactory.register_from_value(successor="UserDefinedLayerVariable")
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
# TODO: @wuzhanfei, if we support create sub layer when export, remove this branch
if ENV_SOT_EXPORT.get() != "":
return None
# TODO(SigureMo): Add a more common way to check if a value is a paddle builtin layer.
if isinstance(value, paddle.nn.Layer):
# If there is a user-defined behavior, such as a container class layer
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@
from functools import reduce
from typing import TYPE_CHECKING, Any

from ....utils import ConstTypes
from ....utils.exceptions import FallbackError, InnerError
from ..dispatcher import Dispatcher
from ..guard import StringifyExpression, check_guard
@@ -32,7 +33,7 @@
GetIterTracker,
Tracker,
)
from .base import ConstTypes, VariableBase, VariableFactory
from .base import VariableBase, VariableFactory
from .basic import ConstantVariable
from .callable import BuiltinVariable, UserDefinedFunctionVariable

5 changes: 5 additions & 0 deletions python/paddle/jit/sot/symbolic/compile_cache.py
Original file line number Diff line number Diff line change
@@ -25,6 +25,7 @@
from ..infer_meta import convert_meta_to_input_spec
from ..profiler import EventGuard
from ..utils import (
ENV_SOT_EXPORT,
Cache,
GraphLogger,
Singleton,
@@ -33,6 +34,7 @@
log_do,
map_if,
)
from .export import export
from .interpreter import compile_sir

if TYPE_CHECKING:
@@ -143,6 +145,9 @@ def __call__(self, *args, **kwargs):
4,
lambda: print("[CompileCache] run sir forward success."),
)
if ENV_SOT_EXPORT.get() != "":
export(self.SIR, ENV_SOT_EXPORT.get())

return outputs


304 changes: 304 additions & 0 deletions python/paddle/jit/sot/symbolic/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from itertools import chain

from paddle.utils import flatten

from ..utils import ConstTypes, ExportError, NameGenerator
from .statement_ir import Symbol


class PyStatement:
tab = " " * 4

def __init__(self, *lines):
self.sub_statement = []
self.lines = lines

def get_lines(self, prefix=""):
lines = [prefix + line for line in self.lines]
for statment in self.sub_statement:
lines.extend(statment.get_lines(self.tab + prefix))
return lines

def add_sub(self, *lines):
sub = PyStatement(*lines)
self.sub_statement.append(sub)
return sub

def __str__(self):
return "\n".join(self.get_lines())


class PyFileGen:
def __init__(self, SIR):
self.SIR = SIR
self.roots = []

self.layer_name_map = {}
self.layer_name_generator = NameGenerator("_")
self.SIR_name = SIR.name.replace("_", "")

def new_root(self, *args):
stmt = PyStatement(*args)
self.roots.append(stmt)
return stmt

def roots_to_string(self):
lines = []
for root in self.roots:
lines.extend(root.get_lines())
return "\n".join(lines)

def gen_py_codes(self):
self.check_exportable()
self.create_header()
self.new_root("\n")
self.create_layer()
self.new_root("\n")
self.create_test()
self.new_root("\n")
self.create_tail()
return self.roots_to_string()

def check_exportable(self):
for stmt in self.SIR.statements:
for inp in flatten(stmt.inputs):
if not isinstance(inp, ConstTypes) and not isinstance(
inp, Symbol
):
raise ExportError(
f"Not support create python file with input: {inp}"
)
for out in flatten(stmt.outputs):
if not isinstance(out, ConstTypes) and not isinstance(
out, Symbol
):
raise ExportError(
f"Not support create python file with output: {out}"
)

def create_header(self):
self.new_root(
"import paddle",
"import unittest",
"import numpy as np",
)

def create_layer(self):
layer_class = self.new_root(f"class {self.SIR_name}(paddle.nn.Layer):")

init_fn = layer_class.add_sub("def __init__(self):")
init_fn.add_sub("super().__init__()")

for param in self.SIR.param_symbol:
meta = self.SIR.symbol_meta_map[param.name]
init_fn.add_sub(
f"self.{param.name} = self.create_parameter(",
f" shape={meta.shape},",
f" dtype={meta.dtype},",
")",
)

for stmt in self.SIR.statements:
if stmt.type == "layer":
layer = stmt.layer()
if id(layer) not in self.layer_name_map:
layer_name = (
layer.__class__.__name__
+ self.layer_name_generator.next()
)
self.layer_name_map[id(layer)] = layer_name
init_fn.add_sub(self.init_sub_layer(layer, layer_name))

forward_definition = ["def forward(", " self,"]

for inp in self.SIR.inputs:
if inp in self.SIR.non_param_symbol:
meta = self.SIR.symbol_meta_map[inp.name]
forward_definition.append(f" {inp.name}, # {str(meta)}")
forward_definition.append("):")

forward_fn = layer_class.add_sub(*forward_definition)

for stmt in self.SIR.statements:
forward_fn.add_sub(*self.create_stmt_line(stmt))

forward_fn.add_sub(
"return {}".format(
", ".join(self.true_name(out) for out in self.SIR.outputs)
)
)

def create_test(self):
test_class = self.new_root(
f"class Test{self.SIR_name}(unittest.TestCase):"
)

setup = test_class.add_sub("def setUp(self):")
test_inputs = [
"self.inputs = (",
]
for inp in self.SIR.inputs:
if inp in self.SIR.non_param_symbol:
meta = self.SIR.symbol_meta_map[inp.name]
test_inputs.append(
f" paddle.rand(shape={meta.shape}, dtype={meta.dtype}),"
)
test_inputs.append(")")
setup.add_sub(*test_inputs)

train = test_class.add_sub(
"def train(self, net, to_static, with_cinn=False):"
)
train.add_sub(
"if to_static:",
" if with_cinn:",
" build_strategy = paddle.static.BuildStrategy()",
" build_strategy.build_cinn_pass = True",
" net = paddle.jit.to_static(net, build_strategy=build_strategy, full_graph=True)",
" else:",
" net = paddle.jit.to_static(net, full_graph=True)",
"outs = net(*self.inputs)",
"return outs",
)

test_ast_static = test_class.add_sub("def test_ast_static(self):")
test_ast_static.add_sub(
"net = SIR0()",
"dy_out = self.train(net, to_static=False)",
"st_out = self.train(net, to_static=True, with_cinn=False)",
"for dy, st in zip(paddle.utils.flatten(dy_out), paddle.utils.flatten(st_out)):",
" np.testing.assert_allclose(dy.numpy(), st.numpy(), atol=1e-8)",
)

test_ast_cinn_static = test_class.add_sub(
"def test_ast_cinn_static(self):"
)
test_ast_cinn_static.add_sub(
"net = SIR0()",
"dy_out = self.train(net, to_static=False)",
"st_out = self.train(net, to_static=True, with_cinn=True)",
"for dy, st in zip(paddle.utils.flatten(dy_out), paddle.utils.flatten(st_out)):",
" np.testing.assert_allclose(dy.numpy(), st.numpy(), atol=1e-8)",
)

def create_tail(self):
self.new_root(
"if __name__ == '__main__':",
" unittest.main()",
)

def true_name(self, var):
if isinstance(var, Symbol):
if var in self.SIR.param_symbol:
return "self." + var.name
else:
return var.name
else:
return str(var)

def init_sub_layer(self, layer, layer_name):
# TODO @wuzhanfei need more effecient way to create a sub layer
# now, we just close call_Layer behavior
raise ExportError("Not support create sub layer now.")

def create_input_string(self, args, kwargs):
return ", ".join(
chain(
(self.true_name(arg) for arg in args),
(f"{k}={self.true_name(v)}" for k, v in kwargs.items()),
)
)

def create_unpack_output_string(self, outputs):
path = ["out"]
result = []

def search(outputs, path, result):
if isinstance(outputs, (list, tuple)):
search_sequnce(outputs, path, result)
elif isinstance(outputs, dict):
search_dict(outputs, path, result)
elif isinstance(outputs, Symbol):
result.append(self.true_name(outputs) + " = " + "".join(path))

def search_sequnce(outputs, path, result):
for idx, out in enumerate(outputs):
path.append(f"[{idx}]")
search(out, path, result)
path.pop()

def search_dict(outputs, path, result):
for k, out in outputs.items():
path.append(f"[{k}]")
search(out, path, result)
path.pop()

search(outputs, path, result)
return result

def create_stmt_line(self, stmt):
return getattr(self, "create_" + stmt.type + "_stmt")(stmt)

def create_api_stmt(self, stmt):
args, kwargs = stmt.inputs
input_str = self.create_input_string(args, kwargs)
api = stmt.api
api_str = api.__module__ + "." + api.__name__
if isinstance(stmt.outputs, Symbol):
return [f"{stmt.outputs.name} = {api_str}({input_str})"]
else:
compute_code = f"out = {api_str}({input_str})"
unpack_codes = self.create_unpack_output_string(stmt.outputs)
return [compute_code] + unpack_codes

def create_method_stmt(self, stmt):
args, kwargs = stmt.inputs
input_str = self.create_input_string(args[1:], kwargs)
method_str = args[0].name + "." + stmt.method
if isinstance(stmt.outputs, Symbol):
return [f"{stmt.outputs.name} = {method_str}({input_str})"]
else:
compute_code = f"out = {method_str}({input_str})"
unpack_codes = self.create_unpack_output_string(stmt.outputs)
return [compute_code] + unpack_codes

def create_layer_stmt(self, stmt):
args, kwargs = stmt.inputs
input_str = self.create_input_string(args, kwargs)
layer_str = "self." + self.layer_name_map[id(stmt.layer())]
if isinstance(stmt.outputs, Symbol):
return [f"{stmt.outputs.name} = {layer_str}({input_str})"]
else:
compute_code = f"out = {layer_str}({input_str})"
unpack_codes = self.create_unpack_output_string(stmt.outputs)
return [compute_code] + unpack_codes


def export(SIR, path):
try:
pygen = PyFileGen(SIR)
string = pygen.gen_py_codes()
except ExportError as e:
print("[SOT] Export SIR Failed:", e)
return

if not os.path.exists(path):
os.makedirs(path)

with open(os.path.join(path, f"{SIR.name}.py"), "w") as f:
f.write(string)
6 changes: 6 additions & 0 deletions python/paddle/jit/sot/symbolic/statement_ir.py
Original file line number Diff line number Diff line change
@@ -216,6 +216,8 @@ def __init__(self, name: str):
self.statements = [] # list of Statement

self.symbol_meta_map = {}
self.param_symbol = set()
self.non_param_symbol = set()

def __len__(self):
return len(self.statements)
@@ -228,6 +230,10 @@ def __deepcopy__(self, memo=None):
new_sir.symbol_meta_map = dict(self.symbol_meta_map.items())
return new_sir

def set_parameter_info(self, params, non_params):
self.param_symbol.update(params)
self.non_param_symbol.update(non_params)

def set_symbol_meta_map(self, meta_map):
# if the meta of a input symbol inplace changed, we should get the origin meta as input of SIR
meta_map.update(self.symbol_meta_map)
4 changes: 4 additions & 0 deletions python/paddle/jit/sot/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -18,16 +18,19 @@
ENV_COST_MODEL,
ENV_MIN_GRAPH_SIZE,
ENV_SHOW_TRACKERS,
ENV_SOT_EXPORT,
ENV_SOT_LOG_LEVEL,
ENV_SOT_WITH_CONTROL_FLOW,
ENV_STRICT_MODE,
cost_model_guard,
min_graph_size_guard,
strict_mode_guard,
with_control_flow_guard,
with_export_guard,
)
from .exceptions import ( # noqa: F401
BreakGraphError,
ExportError,
FallbackError,
InnerError,
inner_error_default_handler,
@@ -41,6 +44,7 @@
)
from .utils import ( # noqa: F401
Cache,
ConstTypes,
GraphLogger,
NameGenerator,
OrderedSet,
7 changes: 7 additions & 0 deletions python/paddle/jit/sot/utils/envs.py
Original file line number Diff line number Diff line change
@@ -32,6 +32,7 @@
ENV_SOT_WITH_CONTROL_FLOW = BooleanEnvironmentVariable(
"SOT_WITH_CONTROL_FLOW", True
)
ENV_SOT_EXPORT = StringEnvironmentVariable("SOT_EXPORT", "")


@contextmanager
@@ -56,3 +57,9 @@ def min_graph_size_guard(value: int):
def with_control_flow_guard(value: bool):
with EnvironmentVariableGuard(ENV_SOT_WITH_CONTROL_FLOW, value):
yield


@contextmanager
def with_export_guard(value: str):
with EnvironmentVariableGuard(ENV_SOT_EXPORT, value):
yield
4 changes: 4 additions & 0 deletions python/paddle/jit/sot/utils/exceptions.py
Original file line number Diff line number Diff line change
@@ -62,3 +62,7 @@ def impl(*args, **kwargs):
) from e

return impl


class ExportError(SotErrorBase):
pass
1 change: 1 addition & 0 deletions python/paddle/jit/sot/utils/utils.py
Original file line number Diff line number Diff line change
@@ -44,6 +44,7 @@
)

T = TypeVar("T")
ConstTypes = (int, float, str, bool, type(None))


class Singleton(Generic[T]):
46 changes: 46 additions & 0 deletions test/sot/test_sot_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import paddle
from paddle.jit.sot.utils import with_export_guard


class Net(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.linear = paddle.nn.Linear(2, 2)
self.bias = self.create_parameter(
shape=[2],
attr=None,
dtype="float32",
is_bias=True,
)

def forward(self, x):
a = self.linear(x)
return a + self.bias


class TestSotExport(unittest.TestCase):
@with_export_guard("/tmp")
def test_basic(self):
net = Net()
x = paddle.to_tensor([2, 3], dtype="float32", stop_gradient=True)
y = paddle.jit.to_static(net)(x)


if __name__ == "__main__":
unittest.main()

0 comments on commit b578350

Please sign in to comment.