Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

[Log Improvment] add some log for sot #272

Merged
merged 22 commits into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .opcode_translator.breakpoint import BM, add_breakpoint, add_event
from .opcode_translator.skip_files import skip_function
from .translate import symbolic_translate
from .utils import psdb_print
from .utils import psdb_breakpoint, psdb_print

__all__ = [
"symbolic_translate",
Expand All @@ -10,4 +10,5 @@
"BM",
"skip_function",
"psdb_print",
"psdb_breakpoint",
]
4 changes: 2 additions & 2 deletions sot/opcode_translator/executor/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from functools import cached_property, reduce
from typing import TYPE_CHECKING, Any, Callable, Dict, Tuple, TypeVar

from ...utils import InnerError, NameGenerator
from ...utils import InnerError, NameGenerator, hashable

if TYPE_CHECKING:
T = TypeVar("T")
Expand Down Expand Up @@ -241,7 +241,7 @@ def dispatch(
args: The args of the function.
kwargs: The kwargs of the function.
"""
if fn not in cls.handlers:
if not hashable(fn) or fn not in cls.handlers:
return None
for pattern, handler in cls.handlers[fn]:
if pattern.match_inputs(*args, **kwargs):
Expand Down
12 changes: 11 additions & 1 deletion sot/opcode_translator/executor/function_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
inner_error_default_handler,
is_paddle_api,
log,
log_do,
map_if,
show_trackers,
)
Expand Down Expand Up @@ -196,6 +197,15 @@ def load(self, var):
name_gen = NameGenerator("__start_compile_saved_")
for var in to_store_vars:
index_for_load[var.id] = name_gen.next()

def _log_fn():
print(
f"[StartCompile] saved var: {index_for_load[var.id]} = ",
var,
)

log_do(4, _log_fn)

for var in to_store_vars[::-1]:
self.pycode_gen.gen_store_fast(index_for_load[var.id])
return VariableLoader(index_for_load, self.pycode_gen)
Expand Down Expand Up @@ -256,8 +266,8 @@ def start_compile(self, *ret_vars: VariableBase):
ret_var.reconstruct(self.pycode_gen)

# deal side effect
self.restore_side_effects(self.side_effects.variables)
self.restore_print_stmts(self._print_variables)
self.restore_side_effects(self.side_effects.variables)

tracker_output_path = show_trackers()
if tracker_output_path:
Expand Down
4 changes: 2 additions & 2 deletions sot/opcode_translator/executor/guard.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import ast
import types
import weakref
from dataclasses import dataclass
Expand Down Expand Up @@ -35,7 +34,8 @@ def __post_init__(self):

def check_expr(self, expr: str):
try:
ast.parse(expr)
pass
# ast.parse(expr) # TODO(xiongkun): too slow
except SyntaxError as e:
raise InnerError(f"Invalid expression: {expr}") from e

Expand Down
3 changes: 2 additions & 1 deletion sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def start_translate(frame: types.FrameType, **kwargs) -> GuardedFunction | None:
log(
2,
f"Unsupport Frame is {frame.f_code}, error message is: \n"
+ '\n'.join(traceback.format_exception_only(type(e), e)),
+ "".join(traceback.format_exception(type(e), e, e.__traceback__)),
)

# NOTE: If resume fn need fallback, we should replace DummyVariable using NULL otherwise will fail to run
Expand Down Expand Up @@ -557,6 +557,7 @@ def error_message_summary(original_error: Exception) -> str:
type(original_error), original_error
)
for line in error_message:
line = line.rstrip()
message_lines.append(f"{indent} {line}")
return "\n".join(message_lines)

Expand Down
26 changes: 17 additions & 9 deletions sot/opcode_translator/executor/pycode_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,23 +187,31 @@ def stacksize(instructions):
# Two list below shows the possible stack size before opcode is called
# The stack size might be different in different branch, so it has max and min
max_stack = [float("-inf")] * len(instructions)
min_stack = [float("inf")] * len(instructions)

max_stack[0] = 0
min_stack[0] = 0

queue = []
queue.append(0)

def update_stacksize(lasti, nexti, stack_effect):
old_max = max_stack[nexti]
max_stack[nexti] = max(
max_stack[nexti], max_stack[lasti] + stack_effect
)
min_stack[nexti] = min(
min_stack[nexti], max_stack[lasti] + stack_effect
)
if old_max != max_stack[nexti]:
if nexti not in queue: # may be slow, we can use a flag.
queue.append(nexti)

for idx in range(len(instructions)):
while len(queue) > 0:
idx = queue[0]
del queue[0]
instr = instructions[idx]

if idx + 1 < len(instructions):
opname = instr.opname
if idx + 1 < len(instructions) and instr.opname not in [
'JUMP_ABSOLUTE',
"JUMP_FORWARD",
"JUMP_BACKWRAD",
]:
stack_effect = dis.stack_effect(instr.opcode, instr.arg, jump=False)
update_stacksize(idx, idx + 1, stack_effect)

Expand All @@ -212,7 +220,7 @@ def update_stacksize(lasti, nexti, stack_effect):
target_idx = instructions.index(instr.jump_to)
update_stacksize(idx, target_idx, stack_effect)

assert min(min_stack) >= 0
# assert min(min_stack) >= 0 # min_stack may be a negative number when try: except is got.
return max(max_stack)


Expand Down
17 changes: 14 additions & 3 deletions sot/opcode_translator/executor/variables/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
is_builtin_fn,
is_paddle_api,
magic_method_builtin_dispatch,
psdb_breakpoint,
psdb_print,
)
from ....utils.exceptions import BreakGraphError, FallbackErrorBase
Expand Down Expand Up @@ -87,9 +88,7 @@ def __init__(
):
super().__init__(fn, graph, tracker)

def call_function(self, /, *args, **kwargs) -> VariableBase:
from ..opcode_inline_executor import OpcodeInlineExecutor

def handle_psdb_function(self, /, *args, **kwargs):
# special function for inner debug.
if self.value is ASSERT:
# TODO: add comptime check mechanism
Expand All @@ -103,6 +102,18 @@ def call_function(self, /, *args, **kwargs) -> VariableBase:
)
return ConstantVariable.wrap_literal(None, self.graph)

if self.value is psdb_breakpoint:
# do nothing. just return None.
return ConstantVariable.wrap_literal(None, self.graph)
return None

def call_function(self, /, *args, **kwargs) -> VariableBase:
from ..opcode_inline_executor import OpcodeInlineExecutor

result = self.handle_psdb_function(*args, **kwargs)
if result is not None:
return result

checkpoint = self.graph.save_memo()
try:
inline_executor = OpcodeInlineExecutor(self, *args, **kwargs)
Expand Down
51 changes: 27 additions & 24 deletions sot/opcode_translator/transform.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,44 @@
from __future__ import annotations

import dis
import types
from functools import partial
from typing import TYPE_CHECKING

from ..utils import log, log_do
from .executor.opcode_executor import InstructionTranslatorCache
from .skip_files import need_skip

if TYPE_CHECKING:
from .executor.opcode_executor import CustomCode
pass


def eval_frame_callback(frame: types.FrameType, **kwargs) -> CustomCode | None:
"""
Callback function for the frame evaluation process.
It will be executed before the frame is to be performed.
def print_locals(frame):
local_key = [
key for key in frame.f_locals.keys() if not key.startswith("__")
]
print(
f"[eval_frame_callback] {frame.f_code.co_name} with locals {local_key}"
)
print(
f"[eval_frame_callback] {' ' * len(frame.f_code.co_name)} with cellvars + freevars: {frame.f_code.co_cellvars + frame.f_code.co_freevars}"
)

def convert_obj(obj):
import paddle

Args:
frame (types.FrameType): The frame object that will be translate.
kwargs: The arguments of ``to_static``.
if isinstance(obj, paddle.Tensor):
return "Tensor(" + str(obj.shape) + ")"
if isinstance(obj, list):
return [convert_obj(i) for i in obj]
return obj

Returns:
new_code: The new instruction code object, or None if unable to be translated into a new code object.
"""
for key in local_key:
print(
f"[eval_frame_callback] {' ' * len(frame.f_code.co_name)} {key} = {convert_obj(frame.f_locals[key])}"
)


def eval_frame_callback(frame, **kwargs):
# is generator
if frame.f_code.co_flags & 0x20 > 0:
return None
Expand All @@ -36,18 +50,7 @@ def eval_frame_callback(frame: types.FrameType, **kwargs) -> CustomCode | None:
2,
"[eval_frame_callback] start to translate: " + str(frame.f_code) + "\n",
)
local_key = [
key for key in frame.f_locals.keys() if not key.startswith("__")
]
log(
4,
f"[eval_frame_callback] {frame.f_code.co_name} with locals {local_key} \n",
)
log(
4,
f"[eval_frame_callback] {' ' * len(frame.f_code.co_name)} with cellvars + freevars: {frame.f_code.co_cellvars + frame.f_code.co_freevars} \n",
)

log_do(4, partial(print_locals, frame))
log(8, "[transform_opcode] old_opcode: " + frame.f_code.co_name + "\n")
log_do(8, lambda: dis.dis(frame.f_code))

Expand Down
4 changes: 4 additions & 0 deletions sot/symbolic/compile_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def __call__(self, *args, **kwargs):
self.concrete_program.main_program
),
)
log_do(
4,
lambda: print("[CompileCache] run sir forward success."),
)
return outputs


Expand Down
4 changes: 4 additions & 0 deletions sot/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
count_if,
execute_time,
get_unbound_method,
hashable,
in_paddle_module,
is_break_graph_api,
is_builtin_fn,
Expand All @@ -32,6 +33,7 @@
map_if,
meta_str,
no_eval_frame,
psdb_breakpoint,
psdb_print,
show_trackers,
)
Expand Down Expand Up @@ -61,11 +63,13 @@
"paddle_tensor_methods",
"ASSERT",
"psdb_print",
"psdb_breakpoint",
"ResumeFnNameFactory",
"list_contain_by_id",
"list_find_index_by_id",
"show_trackers",
"get_unbound_method",
"GraphLogger",
"UndefinedVar",
"hashable",
]
4 changes: 3 additions & 1 deletion sot/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def impl(*args, **kwargs):
return func(*args, **kwargs)
except Exception as e:
message = message_fn(*args, **kwargs)
raise InnerError(f"{message}.\nOrigin Exception is : \n {e}") from e
raise InnerError(
f"{message}.\nOrigin Exception is : \n {traceback.format_exception(type(e), e, e.__traceback__)}"
) from e

return impl
4 changes: 4 additions & 0 deletions sot/utils/magic_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable

from .utils import hashable

if TYPE_CHECKING:
BinaryOp = Callable[[Any, Any], Any]
UnaryOp = Callable[[Any], Any]
Expand Down Expand Up @@ -89,6 +91,8 @@ class MagicMethod:


def magic_method_builtin_dispatch(fn: BinaryOp | UnaryOp) -> list[MagicMethod]:
if not hashable(fn):
return []
if fn in INPLACE_BINARY_OPS:
inplace_magic_name, non_inplace_op = INPLACE_BINARY_OPS_TO_MAGIC_NAMES[
fn
Expand Down
16 changes: 16 additions & 0 deletions sot/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,14 @@ def psdb_print(*args, **kwargs):
print("[Dygraph]", *args, **kwargs)


def psdb_breakpoint():
import paddle

old = paddle.fluid.core.set_eval_frame(None)
breakpoint()
paddle.fluid.core.set_eval_frame(old)


def list_find_index_by_id(li: list[Any], item: Any) -> int:
return [id(it) for it in li].index(id(item))

Expand Down Expand Up @@ -280,3 +288,11 @@ def print_info(self):
@Singleton
class UndefinedVar:
pass


def hashable(obj):
try:
hash(obj)
return True
except TypeError as e:
return False