From e496b775ff5d6a360b91d1516297d0f26e872016 Mon Sep 17 00:00:00 2001 From: xiogkun Date: Tue, 11 Jul 2023 08:38:48 +0000 Subject: [PATCH 01/14] fix unpack sequence bugs when introduce getvalue() --- .../executor/opcode_executor.py | 34 +++++++++++-------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index c08b71e36..fb6f7c493 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -895,8 +895,8 @@ def BUILD_STRING(self, instr: Instruction): str_list = self.pop_n(count) new_str = '' for s in str_list: - assert isinstance(s.value, str) - new_str += s.value + assert isinstance(s.get_value(), str) + new_str += s.get_value() self.push(ConstantVariable.wrap_literal(new_str, self._graph)) def BUILD_SLICE(self, instr: Instruction): @@ -909,7 +909,7 @@ def BUILD_SLICE(self, instr: Instruction): related_list = [start, stop, step] if step else [start, stop] - slice_ = slice(*(x.value for x in related_list)) + slice_ = slice(*(x.get_value() for x in related_list)) self.push( VariableFactory.from_value( @@ -925,7 +925,7 @@ def build_map( assert isinstance(key, VariableBase) # Add key to global guarded variable to avoid missing the key guard self._graph.add_global_guarded_variable(key) - key = key.value + key = key.get_value() built_map[key] = value return DictVariable( built_map, @@ -991,7 +991,7 @@ def BUILD_MAP_UNPACK(self, instr: Instruction): retval = {} for item in unpack_values: - assert isinstance(item.value, dict) + assert isinstance(item.get_value(), dict) retval.update(item.get_wrapped_items()) self.push( @@ -1007,7 +1007,7 @@ def BUILD_MAP_UNPACK_WITH_CALL(self, instr: Instruction): retval = {} for item in unpack_values: - assert isinstance(item.value, dict) + assert isinstance(item.get_value(), dict) wrapped_item = item.get_wrapped_items() if wrapped_item.items() & retval.items(): raise InnerError( @@ -1038,8 +1038,8 @@ def CALL_FUNCTION_KW(self, instr: Instruction): assert isinstance(kwargs_keys, TupleVariable) assert len(kwargs_keys) > 0 kwargs_keys = [ - x.value if isinstance(x, VariableBase) else x - for x in kwargs_keys.value + x.get_value() if isinstance(x, VariableBase) else x + for x in kwargs_keys.get_value() ] # split arg_list to args and kwargs @@ -1147,7 +1147,11 @@ def g(z=x): default_args = () new_fn = types.FunctionType( - codeobj.value, global_dict, fn_name.value, default_args, closure + codeobj.get_value(), + global_dict, + fn_name.get_value(), + default_args, + closure, ) self.push( UserDefinedFunctionVariable( @@ -1265,15 +1269,17 @@ def UNPACK_SEQUENCE(self, instr: Instruction): ''' TODO: To unpack iterator To unpack is easy, just like: - seq = tuple(sequence.value) + seq = tuple(sequence.get_value()) But what is the `source` when iterator returned a value ? ''' if isinstance(sequence, TensorVariable): # TODO: If need to unpack a Tensor, should have different logic. - raise NotImplementException("Unpack a iterator is not implemented.") + raise NotImplementException( + "Unpack a tensor variable is not implemented." + ) elif isinstance(sequence, (ListVariable, TupleVariable)): - seq = sequence.value + seq = sequence.get_value() else: raise NotImplementException( f"Unpack {sequence} is not implemented." @@ -1297,7 +1303,7 @@ def FORMAT_VALUE(self, instr: Instruction): which_conversion = flag & FV.FVC_MASK have_fmt_spec = bool((flag & FV.FVS_MASK) == FV.FVS_HAVE_SPEC) - fmt_spec = self.pop().value if have_fmt_spec else "" + fmt_spec = self.pop().get_value() if have_fmt_spec else "" value = self.pop() if which_conversion == FV.FVC_NONE: @@ -1315,7 +1321,7 @@ def FORMAT_VALUE(self, instr: Instruction): # different type will lead to different Tracker, so call self.push in different branch if isinstance(value, ConstantVariable): - result = value.value + result = value.get_value() if convert_fn is not None: result = getattr(result, convert_fn)(result) From 2179b3f82e925377cbcddc4e515a5da806707cf9 Mon Sep 17 00:00:00 2001 From: xiogkun Date: Tue, 11 Jul 2023 10:00:14 +0000 Subject: [PATCH 02/14] fix ci erros --- .../executor/opcode_executor.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index fb6f7c493..6183d1343 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -36,7 +36,6 @@ ConstTracker, DanglingTracker, DummyTracker, - GetItemTracker, GetIterTracker, GlobalTracker, LocalTracker, @@ -1278,25 +1277,17 @@ def UNPACK_SEQUENCE(self, instr: Instruction): raise NotImplementException( "Unpack a tensor variable is not implemented." ) - elif isinstance(sequence, (ListVariable, TupleVariable)): - seq = sequence.get_value() - else: + if not isinstance(sequence, (ListVariable, TupleVariable)): raise NotImplementException( f"Unpack {sequence} is not implemented." ) assert ( - len(seq) == instr.arg - ), f"Want unpack {seq} to {instr.arg}, but the len is {len(seq)}." + len(sequence) == instr.arg + ), f"Want unpack {sequence} to {instr.arg}, but the len is {len(sequence)}." for i in range(instr.arg - 1, -1, -1): - self.push( - VariableFactory.from_value( - seq[i], - graph=self._graph, - tracker=GetItemTracker(sequence, i), - ) - ) + self.push(sequence[i]) def FORMAT_VALUE(self, instr: Instruction): flag = instr.arg From 546311d441960d7a03df73a5f082a71f36cc4502 Mon Sep 17 00:00:00 2001 From: xiogkun Date: Wed, 12 Jul 2023 02:56:38 +0000 Subject: [PATCH 03/14] Fix more bugs --- sot/infer_meta.py | 9 ++++---- .../executor/function_graph.py | 6 +++++- .../executor/opcode_executor.py | 21 +++++++++++++++++++ .../executor/pycode_generator.py | 3 ++- .../executor/variables/basic.py | 5 +---- sot/opcode_translator/transform.py | 11 ++++++++++ 6 files changed, 45 insertions(+), 10 deletions(-) diff --git a/sot/infer_meta.py b/sot/infer_meta.py index f1752625a..af0a07c86 100644 --- a/sot/infer_meta.py +++ b/sot/infer_meta.py @@ -142,7 +142,10 @@ def convert_meta_to_input_spec(args): args, pred=lambda x: isinstance(x, MetaInfo), true_fn=lambda x: x.to_input_spec(), - false_fn=lambda x: paddle.static.InputSpec.from_tensor(x), + # TODO(xiongkun): can x be tensor ? + false_fn=lambda x: paddle.static.InputSpec.from_tensor(x) + if isinstance(x, paddle.Tensor) + else x, ) @@ -168,9 +171,7 @@ def infer_meta_for_layer(layer, *args, **kwargs): ), f"Expect a Layer, but got {layer}." layer = paddle.jit.to_static(layer, enable_fallback=False) - args_, kwargs_ = convert_meta_to_input_spec( - args - ), convert_meta_to_input_spec(kwargs) + args_, kwargs_ = convert_meta_to_input_spec((args, kwargs)) ( concrete_program, diff --git a/sot/opcode_translator/executor/function_graph.py b/sot/opcode_translator/executor/function_graph.py index 6ece0c374..79d5af3cf 100644 --- a/sot/opcode_translator/executor/function_graph.py +++ b/sot/opcode_translator/executor/function_graph.py @@ -185,6 +185,10 @@ def start_compile(self, *ret_vars: VariableBase): - Restore the output - Return the top of the stack """ + from ..breakpoint import BreakpointManager + + BreakpointManager().on_event("start_compile") + ret_items = [ ret_item for ret_var in ret_vars @@ -347,7 +351,7 @@ def message_handler(*args, **kwargs): return f"Call paddle layer error: {layer}, may be not a valid paddle layer ?" return inner_error_default_handler(self.symbolic_call, message_handler)( - infer_meta_fn, compute_fn, layer, *[layer, *args] + infer_meta_fn, compute_fn, layer, *[layer, *args], **kwargs ) def _put_inner(self, var: VariableBase): diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index 6183d1343..7f8c15b93 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -743,6 +743,22 @@ def BINARY_SUBSCR(self, instr: Instruction): key = self.pop() container = self.pop() assert isinstance(key, VariableBase) + # TODO(xiongkun): getitem / getattr support key and attr as variable. + if isinstance(key, TensorVariable) and isinstance( + container, TensorVariable + ): + # NOTE(xiongkun): tensor[tensor] should support. + output = self._graph.call_tensor_method( + "__getitem__", container, key + ) + self.push(output) + return + + if isinstance(key, TensorVariable): + raise BreakGraphError( + f"Key is a TensorVariable in BINARY_SUBSCR, {container}[{key}]" + ) + self._graph.add_global_guarded_variable(key) self.push( BuiltinVariable(operator.getitem, self._graph, DanglingTracker())( @@ -848,6 +864,11 @@ def STORE_SUBSCR(self, instr: Instruction): value = self.pop() assert isinstance(key, VariableBase) self._graph.add_global_guarded_variable(key) + if isinstance(key, TensorVariable): + raise BreakGraphError( + f"Key is a TensorVariable in STORE_SUBSCR, {container}[{key}] = {value}" + ) + # TODO(xiongkun): support tensor[tensor] = tensor, dy2static is not the same with dygraph. container[key.get_value()] = value value.debug_name = f"{container.debug_name}[{key.debug_name}]" diff --git a/sot/opcode_translator/executor/pycode_generator.py b/sot/opcode_translator/executor/pycode_generator.py index bdd40d71f..7f2a63d61 100644 --- a/sot/opcode_translator/executor/pycode_generator.py +++ b/sot/opcode_translator/executor/pycode_generator.py @@ -361,7 +361,8 @@ def gen_loop_body_between(self, for_iter, start, end): instr.jump_to = nop_for_break # outputs is the same as inputs - return self._gen_fn(inputs, inputs), inputs + generated_fn = self._gen_fn(inputs, inputs) + return generated_fn, inputs def gen_for_loop_fn_between(self, iterator, start, end, exist_names): origin_instrs = get_instructions(self._origin_code) diff --git a/sot/opcode_translator/executor/variables/basic.py b/sot/opcode_translator/executor/variables/basic.py index 37bce98dc..df8608c5d 100644 --- a/sot/opcode_translator/executor/variables/basic.py +++ b/sot/opcode_translator/executor/variables/basic.py @@ -225,10 +225,7 @@ def __init__( super().__init__(tracker) if isinstance(tensor, paddle.Tensor): self.value = tensor - try: - self.meta = MetaInfo.from_tensor(tensor) - except: - breakpoint() + self.meta = MetaInfo.from_tensor(tensor) elif isinstance(tensor, MetaInfo): self.value = None self.meta = tensor diff --git a/sot/opcode_translator/transform.py b/sot/opcode_translator/transform.py index 65a7aff07..d559de957 100644 --- a/sot/opcode_translator/transform.py +++ b/sot/opcode_translator/transform.py @@ -17,6 +17,17 @@ def eval_frame_callback(frame, **kwargs): + str(frame.f_code) + "\n", ) + local_key = [ + key for key in frame.f_locals.keys() if not key.startswith("__") + ] + log( + 4, + f"[eval_frame_callback] {frame.f_code.co_name} with locals {local_key} \n", + ) + log( + 4, + f"[eval_frame_callback] {' ' * len(frame.f_code.co_name)} with cellvars + freevars: {frame.f_code.co_cellvars + frame.f_code.co_freevars} \n", + ) log(8, "[transform_opcode] old_opcode: " + frame.f_code.co_name + "\n") log_do(8, lambda: dis.dis(frame.f_code)) From b8978456c4dcb62bbdbe9ecd945b93468cbd878f Mon Sep 17 00:00:00 2001 From: xiogkun Date: Wed, 12 Jul 2023 06:30:22 +0000 Subject: [PATCH 04/14] fix --- sot/opcode_translator/executor/guard.py | 21 +++++- .../executor/opcode_executor.py | 3 +- .../executor/variables/basic.py | 3 + .../executor/variables/callable.py | 25 ++++--- sot/opcode_translator/skip_files.py | 4 ++ sot/opcode_translator/transform.py | 69 +++++++++---------- tests/test_str_format.py | 21 ++++++ 7 files changed, 97 insertions(+), 49 deletions(-) create mode 100644 tests/test_str_format.py diff --git a/sot/opcode_translator/executor/guard.py b/sot/opcode_translator/executor/guard.py index 0939bed0f..8d05ab0c6 100644 --- a/sot/opcode_translator/executor/guard.py +++ b/sot/opcode_translator/executor/guard.py @@ -65,6 +65,12 @@ def make_guard(stringify_guards: list[StringifyExpression]) -> Guard: return guard +def support_weak_ref(obj): + if isinstance(obj, types.FunctionType): + return True + return False + + def object_equal_stringify_guard(self) -> StringifyExpression: assert ( self.tracker.is_traceable() @@ -78,11 +84,20 @@ def object_equal_stringify_guard(self) -> StringifyExpression: ), ) obj_free_var_name = f"__{self.id}" - weak_ref_obj = weakref.ref(self.get_value()) + weak_ref_obj = self.get_value() + if support_weak_ref(weak_ref_obj): + weak_ref_obj = weakref.ref(self.get_value()) + return StringifyExpression( + f"{obj_free_var_name}() is not None and {frame_value_tracer.expr} == {obj_free_var_name}()", + union_free_vars( + frame_value_tracer.free_vars, + {obj_free_var_name: weak_ref_obj}, + ), + ) return StringifyExpression( - f"{obj_free_var_name}() is not None and {frame_value_tracer.expr} == {obj_free_var_name}()", + f"{frame_value_tracer.expr} == {obj_free_var_name}", union_free_vars( frame_value_tracer.free_vars, - {obj_free_var_name: weak_ref_obj}, + {obj_free_var_name: self.get_value()}, ), ) diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index c97247359..3d417b566 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -1674,10 +1674,11 @@ def _break_graph_in_for_loop( self.indexof(for_iter.jump_to), len(self._stack) ) + total_inputs = set(list(fn_inputs) + list(loop_inputs)) # 1. part before for-loop, start compile ret_names = [ name - for name in loop_inputs[:-1] + for name in total_inputs if name in chain(self._locals, self._cells) ] # the last one is _break_flag ret_vars = [self.get_var(name) for name in ret_names] diff --git a/sot/opcode_translator/executor/variables/basic.py b/sot/opcode_translator/executor/variables/basic.py index df8608c5d..c1ba9114c 100644 --- a/sot/opcode_translator/executor/variables/basic.py +++ b/sot/opcode_translator/executor/variables/basic.py @@ -516,6 +516,9 @@ def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): return ModuleVariable(value, graph, tracker) return None + # Happened in a inline import statement. + make_stringify_guard = object_equal_stringify_guard + class DygraphTracerVariable(VariableBase): # TODO(SigureMo): Remove this trick after we add CompareTracker diff --git a/sot/opcode_translator/executor/variables/callable.py b/sot/opcode_translator/executor/variables/callable.py index 9cacfa2d5..34b5b1eb3 100644 --- a/sot/opcode_translator/executor/variables/callable.py +++ b/sot/opcode_translator/executor/variables/callable.py @@ -44,10 +44,15 @@ def __init__(self, graph: FunctionGraph, tracker: Tracker): super().__init__(tracker) self.graph = graph - def __call__(self, *args, **kwargs) -> VariableBase: + def __call__(self, /, *args, **kwargs) -> VariableBase: + """Why we need '/' to make self positional only? + + If kwargs have {'self': xxx}, this function call raise a error. + See: test_str_format.py for details. + """ return self.call_function(*args, **kwargs) - def call_function(self, *args, **kwargs): + def call_function(self, /, *args, **kwargs): raise NotImplementedError("call_function is not implemented.") @@ -88,7 +93,7 @@ def __init__( ): super().__init__(fn, graph, tracker) - def call_function(self, *args, **kwargs) -> VariableBase: + def call_function(self, /, *args, **kwargs) -> VariableBase: from ..opcode_inline_executor import OpcodeInlineExecutor # special function for inner debug. @@ -134,7 +139,7 @@ def __init__( ): super().__init__(fn, graph, tracker) - def call_function(self, *args, **kwargs): + def call_function(self, /, *args, **kwargs): if is_break_graph_api(self.value): raise BreakGraphError( f"breakgraph by unsupport function: {self.value.__name__}" @@ -166,7 +171,7 @@ def __init__( super().__init__(fn, graph, tracker) self.method_name = method_name - def call_function(self, *args, **kwargs): + def call_function(self, /, *args, **kwargs): if is_break_graph_tensor_methods(self.method_name): raise BreakGraphError() return self.graph.call_tensor_method(self.method_name, *args, **kwargs) @@ -205,7 +210,7 @@ def _reconstruct(self, pycode_gen): self.tensor.reconstruct(pycode_gen) pycode_gen.gen_load_attr(self.method_name) - def call_function(self, *args, **kwargs): + def call_function(self, /, *args, **kwargs): return self.fn(*(self.bound_instance, *args), **kwargs) @staticmethod @@ -299,7 +304,7 @@ def __init__( ): super().__init__(layer, graph, tracker) - def call_function(self, *args, **kwargs): + def call_function(self, /, *args, **kwargs): fn_var = UserDefinedFunctionVariable( self.value.__class__.__call__, self.graph, @@ -330,7 +335,7 @@ def __init__( super().__init__(fn, graph, tracker) self.value = fn - def call_function(self, *args, **kwargs): + def call_function(self, /, *args, **kwargs): # Lookup the handler from dispatcher handler = Dispatcher.dispatch(self.value, *args, **kwargs) if handler is not None: @@ -382,7 +387,7 @@ def __init__( ): super().__init__(fn, graph, tracker) - def call_function(self, *args, **kwargs) -> VariableBase: + def call_function(self, /, *args, **kwargs): iter_ = self.value() return VariableFactory.from_value( iter_, self.graph, DummyTracker([self]) @@ -413,7 +418,7 @@ def __init__( def get_symbol(self) -> Symbol: return Symbol(self.name) - def call_function(self, *args, **kwargs): + def call_function(self, /, *args, **kwargs): # TODO: Remove this trick after we support for-loop. if isinstance(self.value, paddle.nn.Sequential): assert len(args) == 1, "Sequential only accept one input" diff --git a/sot/opcode_translator/skip_files.py b/sot/opcode_translator/skip_files.py index ecbac3514..ac27a13d2 100644 --- a/sot/opcode_translator/skip_files.py +++ b/sot/opcode_translator/skip_files.py @@ -5,6 +5,7 @@ import copy import copyreg import dataclasses +import distutils import enum import functools import importlib @@ -37,6 +38,7 @@ import decorator import google.protobuf import numpy +import setuptools from ..utils import log @@ -89,6 +91,8 @@ def _module_dir(m: types.ModuleType): decorator, codecs, uuid, + setuptools, + distutils, ) } diff --git a/sot/opcode_translator/transform.py b/sot/opcode_translator/transform.py index d559de957..97662be5a 100644 --- a/sot/opcode_translator/transform.py +++ b/sot/opcode_translator/transform.py @@ -10,38 +10,37 @@ def eval_frame_callback(frame, **kwargs): if frame.f_code.co_flags & 0x20 > 0: return None - if not need_skip(frame.f_code): - log( - 2, - "[eval_frame_callback] start to translate: " - + str(frame.f_code) - + "\n", - ) - local_key = [ - key for key in frame.f_locals.keys() if not key.startswith("__") - ] - log( - 4, - f"[eval_frame_callback] {frame.f_code.co_name} with locals {local_key} \n", - ) - log( - 4, - f"[eval_frame_callback] {' ' * len(frame.f_code.co_name)} with cellvars + freevars: {frame.f_code.co_cellvars + frame.f_code.co_freevars} \n", - ) - - log(8, "[transform_opcode] old_opcode: " + frame.f_code.co_name + "\n") - log_do(8, lambda: dis.dis(frame.f_code)) - - new_code = InstructionTranslatorCache()(frame, **kwargs) - - log( - 7, - "\n[transform_opcode] new_opcode: " + frame.f_code.co_name + "\n", - ) - if new_code is not None: - log_do(7, lambda: dis.dis(new_code.code)) - else: - log(7, f"Skip frame: {frame.f_code.co_name}") - - return new_code - return None + if need_skip(frame.f_code): + return None + + log( + 2, + "[eval_frame_callback] start to translate: " + str(frame.f_code) + "\n", + ) + local_key = [ + key for key in frame.f_locals.keys() if not key.startswith("__") + ] + log( + 4, + f"[eval_frame_callback] {frame.f_code.co_name} with locals {local_key} \n", + ) + log( + 4, + f"[eval_frame_callback] {' ' * len(frame.f_code.co_name)} with cellvars + freevars: {frame.f_code.co_cellvars + frame.f_code.co_freevars} \n", + ) + + log(8, "[transform_opcode] old_opcode: " + frame.f_code.co_name + "\n") + log_do(8, lambda: dis.dis(frame.f_code)) + + new_code = InstructionTranslatorCache()(frame, **kwargs) + + log( + 7, + "\n[transform_opcode] new_opcode: " + frame.f_code.co_name + "\n", + ) + if new_code is not None: + log_do(7, lambda: dis.dis(new_code.code)) + else: + log(7, f"Skip frame: {frame.f_code.co_name}") + + return new_code diff --git a/tests/test_str_format.py b/tests/test_str_format.py new file mode 100644 index 000000000..348067ed9 --- /dev/null +++ b/tests/test_str_format.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + + +# copy from python library _distutils_hack/__init__.py +def find_spec(self, fullname, path, target=None): + method_name = 'spec_for_{fullname}'.format(**locals()) + method = getattr(self, method_name, lambda: None) + return method() + + +class TestExecutor(TestCaseBase): + def test_simple(self): + self.assert_results(find_spec, "self", "fullname", "path", None) + + +if __name__ == "__main__": + unittest.main() From d8fb904041227d1c18e86ce7bf24070fda5d8aa8 Mon Sep 17 00:00:00 2001 From: xiogkun Date: Wed, 12 Jul 2023 08:18:41 +0000 Subject: [PATCH 05/14] fix gen load error and add test cast. --- .../executor/opcode_executor.py | 9 ++++-- .../executor/pycode_generator.py | 10 +++---- tests/test_paddle_cast.py | 28 +++++++++++++++++++ 3 files changed, 39 insertions(+), 8 deletions(-) create mode 100644 tests/test_paddle_cast.py diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index 55a555564..d65ac878c 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -1422,7 +1422,10 @@ def _prepare_virtual_env(self): Prepare the virtual environment for execution by adding variables from locals, globals, builtins, and constants. """ - log(3, f"[Executor] code options: {self._frame.f_code.co_cellvars}\n") + log( + 3, + f"[Executor] code options: co_cellvars={self._frame.f_code.co_cellvars}\n", + ) free_or_cell_vars = ( self._frame.f_code.co_cellvars + self._frame.f_code.co_freevars ) @@ -1718,7 +1721,7 @@ def _break_graph_in_for_loop( # 5.2 load loop body inputs for name in loop_inputs[:-1]: - self._graph.pycode_gen.gen_load(name, self._code) + self._graph.pycode_gen.gen_load(name) # 5.3 load break flag self._graph.pycode_gen.gen_load_const(True) @@ -1751,7 +1754,7 @@ def _break_graph_in_for_loop( for stack_arg in self._stack: stack_arg.reconstruct(self._graph.pycode_gen) for name in fn_inputs: - self._graph.pycode_gen.gen_load(name, self._code) + self._graph.pycode_gen.gen_load(name) self._graph.pycode_gen.gen_call_function( argc=after_loop_fn.__code__.co_argcount, with_eval_frame=True diff --git a/sot/opcode_translator/executor/pycode_generator.py b/sot/opcode_translator/executor/pycode_generator.py index 74ce165b8..c932fa4ff 100644 --- a/sot/opcode_translator/executor/pycode_generator.py +++ b/sot/opcode_translator/executor/pycode_generator.py @@ -304,7 +304,7 @@ def create_fn_with_specific_io(self, inputs, outputs): the main codes should be created before call create_fn_with_specific_io ''' for name in outputs: - self.gen_load_fast(name) + self.gen_load(name) self.gen_build_tuple(len(outputs)) self._code_options['co_argcount'] = len(inputs) self._code_options['co_varnames'] = list( @@ -427,12 +427,12 @@ def dbg_func(): self.gen_call_function(1) self.gen_pop_top() - def gen_load(self, name, code): - if name in code.co_cellvars: + def gen_load(self, name): + if name in self._code_options["co_cellvars"]: self.gen_load_deref(name) - elif name in code.co_varnames: + elif name in self._code_options["co_varnames"]: self.gen_load_fast(name) - elif name in code.co_names: + elif name in self._code_options["co_names"]: self.gen_load_global(name) else: raise InnerError( diff --git a/tests/test_paddle_cast.py b/tests/test_paddle_cast.py new file mode 100644 index 000000000..512aa5eee --- /dev/null +++ b/tests/test_paddle_cast.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def test_paddle_cast(x): + y = x + 1 + return y.cast("int") + + +def test_paddle_cast2(x): + y = x + 1 + return paddle.cast(y, "int") + + +class TestExecutor(TestCaseBase): + def test_simple(self): + a = paddle.to_tensor(1) + self.assert_results(test_paddle_cast, a) + self.assert_results(test_paddle_cast2, a) + + +if __name__ == "__main__": + unittest.main() From a0ecb16f6016713c0564f6a029f06f3c40dde8ce Mon Sep 17 00:00:00 2001 From: xiogkun Date: Wed, 12 Jul 2023 09:11:12 +0000 Subject: [PATCH 06/14] support delete fast and add unittest --- .../executor/opcode_executor.py | 4 +++ .../instruction_utils/opcode_analysis.py | 34 ++++++++++++++++--- tests/test_delete_fast.py | 24 +++++++++++++ 3 files changed, 58 insertions(+), 4 deletions(-) create mode 100644 tests/test_delete_fast.py diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index d65ac878c..8976734a7 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -818,6 +818,10 @@ def LOAD_FAST(self, instr: Instruction): var = self._locals[varname] self.push(var) + def DELETE_FAST(self, instr: Instruction): + varname = self._code.co_varnames[instr.arg] + del self._locals[varname] + def LOAD_GLOBAL(self, instr: Instruction): name = self._code.co_names[instr.arg] if name in self._globals.keys(): diff --git a/sot/opcode_translator/instruction_utils/opcode_analysis.py b/sot/opcode_translator/instruction_utils/opcode_analysis.py index f8b9b8939..82830cf97 100644 --- a/sot/opcode_translator/instruction_utils/opcode_analysis.py +++ b/sot/opcode_translator/instruction_utils/opcode_analysis.py @@ -13,6 +13,32 @@ class State: visited: set[int] +def is_read_opcode(opname): + if opname in ["LOAD_FAST", "LOAD_DEREF", "LOAD_NAME", "LOAD_GLOBAL"]: + return True + if opname in ( + "DELETE_FAST", + "DELETE_DEREF", + "DELETE_NAME", + "DELETE_GLOBAL", + ): + return True + return False + + +def is_write_opcode(opname): + if opname in ["STORE_FAST", "STORE_NAME", "STORE_DEREF", "STORE_GLOBAL"]: + return True + if opname in ( + "DELETE_FAST", + "DELETE_DEREF", + "DELETE_NAME", + "DELETE_GLOBAL", + ): + return True + return False + + def analysis_inputs( instructions: list[Instruction], current_instr_idx: int, @@ -38,11 +64,11 @@ def walk(state: State, start: int) -> set[str]: instr = instructions[i] if instr.opname in HAS_LOCAL | HAS_FREE: - if instr.opname.startswith("LOAD") and instr.argval not in ( + if is_read_opcode(instr.opname) and instr.argval not in ( state.writes ): state.reads.add(instr.argval) - elif instr.opname.startswith("STORE"): + elif is_write_opcode(instr.opname): state.writes.add(instr.argval) elif instr.opname in ALL_JUMP: assert instr.jump_to is not None @@ -87,11 +113,11 @@ def walk(state: State, start: int) -> set[str]: instr = instructions[i] if instr.opname in HAS_LOCAL | HAS_FREE: - if instr.opname.startswith("LOAD") and instr.argval not in ( + if is_read_opcode(instr.opname) and instr.argval not in ( state.writes ): state.reads.add(instr.argval) - elif instr.opname.startswith("STORE"): + elif is_write_opcode(instr.opname): state.writes.add(instr.argval) elif instr.opname in ALL_JUMP: assert instr.jump_to is not None diff --git a/tests/test_delete_fast.py b/tests/test_delete_fast.py new file mode 100644 index 000000000..decad8328 --- /dev/null +++ b/tests/test_delete_fast.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def test_delete_fast(a): + a = a + 2 + t = a * 3 + del t + return a + + +class TestExecutor(TestCaseBase): + def test_simple(self): + a = paddle.to_tensor(1) + self.assert_results(test_delete_fast, a) + + +if __name__ == "__main__": + unittest.main() From ee118b863a4e62db66bd643c2d6f23e71106d427 Mon Sep 17 00:00:00 2001 From: xiogkun Date: Thu, 13 Jul 2023 07:46:12 +0000 Subject: [PATCH 07/14] fix side effect problem --- .../executor/function_graph.py | 21 +++++++- .../executor/opcode_executor.py | 52 ++++++++++++++----- sot/symbolic/compile_cache.py | 9 ++++ 3 files changed, 69 insertions(+), 13 deletions(-) diff --git a/sot/opcode_translator/executor/function_graph.py b/sot/opcode_translator/executor/function_graph.py index 79d5af3cf..69e1b25e1 100644 --- a/sot/opcode_translator/executor/function_graph.py +++ b/sot/opcode_translator/executor/function_graph.py @@ -11,6 +11,7 @@ from ...symbolic.statement_ir import Symbol from ...symbolic.symbolic_context import SymbolicTraceContext from ...utils import ( + NameGenerator, inner_error_default_handler, is_paddle_api, log, @@ -173,6 +174,25 @@ def guard_fn(self) -> Guard: return make_guard(guards) + def start_compile_with_name_store(self, ret_vars, to_store_vars): + class VariableLoader: + def __init__(self, index_for_load, pycode_gen): + self._index_for_load = index_for_load + self._pycode_gen = pycode_gen + + def load(self, var): + self._pycode_gen.gen_load_fast(self._index_for_load[var.id]) + + # var_id -> local_name mapping + index_for_load = {} + self.start_compile(*(ret_vars + to_store_vars)) + name_gen = NameGenerator("__start_compile_saved_") + for var in to_store_vars: + index_for_load[var.id] = name_gen.next() + for var in to_store_vars[::-1]: + self.pycode_gen.gen_store_fast(index_for_load[var.id]) + return VariableLoader(index_for_load, self.pycode_gen) + def start_compile(self, *ret_vars: VariableBase): """ Generate bytecode based on the information collected by the simulation execution. @@ -203,7 +223,6 @@ def start_compile(self, *ret_vars: VariableBase): compiled_fn_name = f"__compiled_fn_{statment_ir.name}" # prepare function and inputs self.pycode_gen.gen_load_object(compiled_fn, compiled_fn_name) - for name in input_names: found = False for variable in self.input_variables: diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index 8976734a7..e1ac233c6 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -1091,7 +1091,7 @@ def CALL_FUNCTION_EX(self, instr: Instruction): kwargs = {} args_variable = self.pop() - assert isinstance(args_variable, TupleVariable) + assert isinstance(args_variable, (TupleVariable, ListVariable)) args = args_variable.get_wrapped_items() fn = self.pop() @@ -1511,7 +1511,16 @@ def _break_graph_in_jump(self, result: VariableBase, instr: Instruction): ret_vars = [ result, ] + inputs_var - self._graph.start_compile(*ret_vars) + # Collect all the to store variables. + store_vars = [] + for stack_arg in self._stack: + store_vars.append(stack_arg) + for name in if_inputs + else_inputs: + store_vars.append(self.get_var(name)) + + var_loader = self._graph.start_compile_with_name_store( + ret_vars, store_vars + ) # only pop the input of if/else resume fn, and keep the bool tensor result on the stack for _ in inputs_var: self._graph.pycode_gen.gen_pop_top() @@ -1523,9 +1532,9 @@ def _break_graph_in_jump(self, result: VariableBase, instr: Instruction): ) insert_index = len(self._graph.pycode_gen._instructions) - 1 for stack_arg in self._stack: - stack_arg.reconstruct(self._graph.pycode_gen) + var_loader.load(stack_arg) for name in if_inputs: - self.get_var(name).reconstruct(self._graph.pycode_gen) + var_loader.load(self.get_var(name)) self._graph.pycode_gen.gen_call_function( argc=if_fn.__code__.co_argcount, with_eval_frame=True, @@ -1541,9 +1550,9 @@ def _break_graph_in_jump(self, result: VariableBase, instr: Instruction): ) jump_to = self._graph.pycode_gen._instructions[-1] for stack_arg in self._stack: - stack_arg.reconstruct(self._graph.pycode_gen) + var_loader.load(stack_arg) for name in else_inputs: - self.get_var(name).reconstruct(self._graph.pycode_gen) + var_loader.load(self.get_var(name)) self._graph.pycode_gen.gen_call_function( argc=else_fn.__code__.co_argcount, with_eval_frame=True, @@ -1589,7 +1598,17 @@ def _break_graph_in_call( for name in resume_input_name if self.get_var(name) not in ret_vars ] - self._graph.start_compile(*ret_vars) + + # Collect all the to store variables. + store_vars = [] + for stack_arg in self._stack: + store_vars.append(stack_arg) + for name in resume_input_name: + store_vars.append(self.get_var(name)) + var_loader = self._graph.start_compile_with_name_store( + ret_vars, store_vars + ) + for _ in ret_vars: self._graph.pycode_gen.gen_pop_top() @@ -1606,7 +1625,7 @@ def _break_graph_in_call( DummyVariable(), f'dummy_var{i}' ) else: - stack_arg.reconstruct(self._graph.pycode_gen) + var_loader.load(stack_arg) self._graph.pycode_gen.add_pure_instructions([instr]) # gen call resume fn opcode @@ -1619,7 +1638,7 @@ def _break_graph_in_call( ) self._graph.pycode_gen.gen_rot_n(stack_size + 1) for name in resume_input_name: - self._locals[name].reconstruct(self._graph.pycode_gen) + var_loader.load(self.get_var(name)) self._graph.pycode_gen.gen_call_function( argc=resume_fn.__code__.co_argcount, with_eval_frame=True, @@ -1695,13 +1714,22 @@ def _break_graph_in_for_loop( if name in chain(self._locals, self._cells) ] # the last one is _break_flag ret_vars = [self.get_var(name) for name in ret_names] - self._graph.start_compile(*ret_vars) + # Collect all the to store variables. + store_vars = [] + for idx in range(len(ret_names)): + store_vars.append(ret_vars[idx]) + for stack_arg in self._stack: + store_vars.append(stack_arg) + var_loader = self._graph.start_compile_with_name_store( + ret_vars, store_vars + ) + for _ in ret_vars: self._graph.pycode_gen.gen_pop_top() # 2. restore vars for idx in range(len(ret_names)): - ret_vars[idx].reconstruct(self._graph.pycode_gen) + var_loader.load(ret_vars[idx]) self._graph.pycode_gen.gen_store(ret_names[idx], self._code) # 3. setup vars which is created in loop @@ -1756,7 +1784,7 @@ def _break_graph_in_for_loop( ) for stack_arg in self._stack: - stack_arg.reconstruct(self._graph.pycode_gen) + var_loader.load(stack_arg) for name in fn_inputs: self._graph.pycode_gen.gen_load(name) diff --git a/sot/symbolic/compile_cache.py b/sot/symbolic/compile_cache.py index 7d1773c9d..27cde4f59 100644 --- a/sot/symbolic/compile_cache.py +++ b/sot/symbolic/compile_cache.py @@ -32,6 +32,14 @@ def __call__(self, *args, **kwargs): log_do( 2, lambda: print("[FallbackWrapper] start run SIR: \n", self.SIR) ) + log_do( + 4, + lambda: print( + self.compiled_fn.get_concrete_program(*args, **kwargs)[ + 1 + ].train_program + ), + ) if self.partial_program is None or True: outputs = self.compiled_fn(*args, **kwargs) ( @@ -41,6 +49,7 @@ def __call__(self, *args, **kwargs): else: # Speed up Resnet from 0.0068 --> 0.0057 outputs = self.partial_program(*args, **kwargs) + clear_eager_tensor_name(outputs) log_do( 1, From 322a123ec4527d74223dca061a6dfc5f9f41630b Mon Sep 17 00:00:00 2001 From: xiogkun Date: Thu, 13 Jul 2023 08:19:21 +0000 Subject: [PATCH 08/14] fix side effect error. --- sot/opcode_translator/executor/function_graph.py | 7 +++++++ sot/opcode_translator/executor/opcode_executor.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/sot/opcode_translator/executor/function_graph.py b/sot/opcode_translator/executor/function_graph.py index 2987c3bcf..f8624472f 100644 --- a/sot/opcode_translator/executor/function_graph.py +++ b/sot/opcode_translator/executor/function_graph.py @@ -25,6 +25,7 @@ from .variables import ( ContainerVariable, DictVariable, + DummyVariable, ListVariable, PaddleLayerVariable, TensorVariable, @@ -181,10 +182,16 @@ def __init__(self, index_for_load, pycode_gen): self._pycode_gen = pycode_gen def load(self, var): + if isinstance(var, DummyVariable): + var.reconstruct(self._pycode_gen) + return self._pycode_gen.gen_load_fast(self._index_for_load[var.id]) # var_id -> local_name mapping index_for_load = {} + to_store_vars = list( + filter(lambda x: not isinstance(x, DummyVariable), to_store_vars) + ) self.start_compile(*(ret_vars + to_store_vars)) name_gen = NameGenerator("__start_compile_saved_") for var in to_store_vars: diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index 09b267a9b..5e901f54f 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -1519,7 +1519,7 @@ def _break_graph_in_jump(self, result: VariableBase, instr: Instruction): store_vars = [] for stack_arg in self._stack: store_vars.append(stack_arg) - for name in if_inputs + else_inputs: + for name in inputs_name: store_vars.append(self.get_var(name)) var_loader = self._graph.start_compile_with_name_store( From 9e823d7e0b10385a71051471f503b30f84903aef Mon Sep 17 00:00:00 2001 From: xiogkun Date: Thu, 13 Jul 2023 09:56:01 +0000 Subject: [PATCH 09/14] fix --- .../instruction_utils/opcode_analysis.py | 8 +++++- tests/error_test_paddle_cast.py | 0 tests/test_paddle_cast.py | 28 ------------------- 3 files changed, 7 insertions(+), 29 deletions(-) create mode 100644 tests/error_test_paddle_cast.py delete mode 100644 tests/test_paddle_cast.py diff --git a/sot/opcode_translator/instruction_utils/opcode_analysis.py b/sot/opcode_translator/instruction_utils/opcode_analysis.py index 82830cf97..9a7e835d9 100644 --- a/sot/opcode_translator/instruction_utils/opcode_analysis.py +++ b/sot/opcode_translator/instruction_utils/opcode_analysis.py @@ -14,7 +14,13 @@ class State: def is_read_opcode(opname): - if opname in ["LOAD_FAST", "LOAD_DEREF", "LOAD_NAME", "LOAD_GLOBAL"]: + if opname in [ + "LOAD_FAST", + "LOAD_DEREF", + "LOAD_NAME", + "LOAD_GLOBAL", + "LOAD_CLOSURE", + ]: return True if opname in ( "DELETE_FAST", diff --git a/tests/error_test_paddle_cast.py b/tests/error_test_paddle_cast.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_paddle_cast.py b/tests/test_paddle_cast.py deleted file mode 100644 index 512aa5eee..000000000 --- a/tests/test_paddle_cast.py +++ /dev/null @@ -1,28 +0,0 @@ -from __future__ import annotations - -import unittest - -from test_case_base import TestCaseBase - -import paddle - - -def test_paddle_cast(x): - y = x + 1 - return y.cast("int") - - -def test_paddle_cast2(x): - y = x + 1 - return paddle.cast(y, "int") - - -class TestExecutor(TestCaseBase): - def test_simple(self): - a = paddle.to_tensor(1) - self.assert_results(test_paddle_cast, a) - self.assert_results(test_paddle_cast2, a) - - -if __name__ == "__main__": - unittest.main() From 7039ebc69f4b890a1b409f1a349ee506b9440ca5 Mon Sep 17 00:00:00 2001 From: xiogkun Date: Thu, 13 Jul 2023 12:40:14 +0000 Subject: [PATCH 10/14] fix ci error --- sot/opcode_translator/executor/variables/base.py | 4 ++-- tests/test_str_format.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/sot/opcode_translator/executor/variables/base.py b/sot/opcode_translator/executor/variables/base.py index b555dcde5..f11aed4d9 100644 --- a/sot/opcode_translator/executor/variables/base.py +++ b/sot/opcode_translator/executor/variables/base.py @@ -375,7 +375,7 @@ def get_traceable_inputs(self) -> list[VariableBase]: filter(lambda x: x.tracker.is_traceable(), self.tracker.inputs) ) - def call_function(self, *args, **kwargs): + def call_function(self, /, *args, **kwargs): pass def getattr(self, name: str, default=None): @@ -463,7 +463,7 @@ def getitem(self, item): output = fn_var(self, item) return output - def __call__(self, *args, **kwargs): + def __call__(self, /, *args, **kwargs): """ Call the object represented by this variable with the given arguments. diff --git a/tests/test_str_format.py b/tests/test_str_format.py index 348067ed9..6eaf6fa79 100644 --- a/tests/test_str_format.py +++ b/tests/test_str_format.py @@ -7,7 +7,9 @@ # copy from python library _distutils_hack/__init__.py def find_spec(self, fullname, path, target=None): - method_name = 'spec_for_{fullname}'.format(**locals()) + method_name = 'spec_for_{fullname}'.format( + **{'self': self, 'fullname': fullname} + ) method = getattr(self, method_name, lambda: None) return method() From d202d2f6152a16aa106c9bb90cfd0491759b7dc6 Mon Sep 17 00:00:00 2001 From: xiogkun Date: Fri, 14 Jul 2023 07:40:57 +0000 Subject: [PATCH 11/14] add some log for sot --- .../executor/function_graph.py | 12 +++++- sot/opcode_translator/transform.py | 40 +++++++++++++------ 2 files changed, 39 insertions(+), 13 deletions(-) diff --git a/sot/opcode_translator/executor/function_graph.py b/sot/opcode_translator/executor/function_graph.py index f8624472f..b4925006e 100644 --- a/sot/opcode_translator/executor/function_graph.py +++ b/sot/opcode_translator/executor/function_graph.py @@ -15,6 +15,7 @@ inner_error_default_handler, is_paddle_api, log, + log_do, map_if, show_trackers, ) @@ -196,6 +197,15 @@ def load(self, var): name_gen = NameGenerator("__start_compile_saved_") for var in to_store_vars: index_for_load[var.id] = name_gen.next() + + def _log_fn(): + print( + f"[StartCompile] saved var: {index_for_load[var.id]} = ", + var, + ) + + log_do(4, _log_fn) + for var in to_store_vars[::-1]: self.pycode_gen.gen_store_fast(index_for_load[var.id]) return VariableLoader(index_for_load, self.pycode_gen) @@ -256,8 +266,8 @@ def start_compile(self, *ret_vars: VariableBase): ret_var.reconstruct(self.pycode_gen) # deal side effect - self.restore_side_effects(self.side_effects.variables) self.restore_print_stmts(self._print_variables) + self.restore_side_effects(self.side_effects.variables) tracker_output_path = show_trackers() if tracker_output_path: diff --git a/sot/opcode_translator/transform.py b/sot/opcode_translator/transform.py index 97662be5a..7a959c92c 100644 --- a/sot/opcode_translator/transform.py +++ b/sot/opcode_translator/transform.py @@ -1,10 +1,37 @@ import dis +from functools import partial from ..utils import log, log_do from .executor.opcode_executor import InstructionTranslatorCache from .skip_files import need_skip +def print_locals(frame): + local_key = [ + key for key in frame.f_locals.keys() if not key.startswith("__") + ] + print( + f"[eval_frame_callback] {frame.f_code.co_name} with locals {local_key}" + ) + print( + f"[eval_frame_callback] {' ' * len(frame.f_code.co_name)} with cellvars + freevars: {frame.f_code.co_cellvars + frame.f_code.co_freevars}" + ) + + def convert_obj(obj): + import paddle + + if isinstance(obj, paddle.Tensor): + return obj.shape + if isinstance(obj, list): + return [convert_obj(i) for i in obj] + return obj + + for key in local_key: + print( + f"[eval_frame_callback] {' ' * len(frame.f_code.co_name)} {key} = {convert_obj(frame.f_locals[key])}" + ) + + def eval_frame_callback(frame, **kwargs): # is generator if frame.f_code.co_flags & 0x20 > 0: @@ -17,18 +44,7 @@ def eval_frame_callback(frame, **kwargs): 2, "[eval_frame_callback] start to translate: " + str(frame.f_code) + "\n", ) - local_key = [ - key for key in frame.f_locals.keys() if not key.startswith("__") - ] - log( - 4, - f"[eval_frame_callback] {frame.f_code.co_name} with locals {local_key} \n", - ) - log( - 4, - f"[eval_frame_callback] {' ' * len(frame.f_code.co_name)} with cellvars + freevars: {frame.f_code.co_cellvars + frame.f_code.co_freevars} \n", - ) - + log_do(4, partial(print_locals, frame)) log(8, "[transform_opcode] old_opcode: " + frame.f_code.co_name + "\n") log_do(8, lambda: dis.dis(frame.f_code)) From d06c68767b7400bf7e3f36cde9cbb1be18a14caf Mon Sep 17 00:00:00 2001 From: xiogkun Date: Wed, 19 Jul 2023 03:19:19 +0000 Subject: [PATCH 12/14] Fix some error --- sot/__init__.py | 3 ++- sot/opcode_translator/executor/guard.py | 4 +-- .../executor/opcode_executor.py | 2 +- .../executor/pycode_generator.py | 26 ++++++++++++------- .../executor/variables/callable.py | 17 +++++++++--- sot/opcode_translator/transform.py | 2 +- sot/symbolic/compile_cache.py | 4 +++ sot/utils/__init__.py | 2 ++ sot/utils/exceptions.py | 4 ++- sot/utils/utils.py | 8 ++++++ 10 files changed, 54 insertions(+), 18 deletions(-) diff --git a/sot/__init__.py b/sot/__init__.py index 2bd54d405..a20f914c2 100644 --- a/sot/__init__.py +++ b/sot/__init__.py @@ -1,7 +1,7 @@ from .opcode_translator.breakpoint import BM, add_breakpoint, add_event from .opcode_translator.skip_files import skip_function from .translate import symbolic_translate -from .utils import psdb_print +from .utils import psdb_breakpoint, psdb_print __all__ = [ "symbolic_translate", @@ -10,4 +10,5 @@ "BM", "skip_function", "psdb_print", + "psdb_breakpoint", ] diff --git a/sot/opcode_translator/executor/guard.py b/sot/opcode_translator/executor/guard.py index 6dc6780a9..489b85adc 100644 --- a/sot/opcode_translator/executor/guard.py +++ b/sot/opcode_translator/executor/guard.py @@ -1,6 +1,5 @@ from __future__ import annotations -import ast import types import weakref from dataclasses import dataclass @@ -35,7 +34,8 @@ def __post_init__(self): def check_expr(self, expr: str): try: - ast.parse(expr) + pass + # ast.parse(expr) # TODO(xiongkun): too slow except SyntaxError as e: raise InnerError(f"Invalid expression: {expr}") from e diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index 511260207..9687b00ee 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -244,7 +244,7 @@ def start_translate(frame: types.FrameType, **kwargs) -> GuardedFunction | None: log( 2, f"Unsupport Frame is {frame.f_code}, error message is: \n" - + '\n'.join(traceback.format_exception_only(type(e), e)), + + "".join(traceback.format_exception(type(e), e, e.__traceback__)), ) # NOTE: If resume fn need fallback, we should replace DummyVariable using NULL otherwise will fail to run diff --git a/sot/opcode_translator/executor/pycode_generator.py b/sot/opcode_translator/executor/pycode_generator.py index c932fa4ff..dc58d23c5 100644 --- a/sot/opcode_translator/executor/pycode_generator.py +++ b/sot/opcode_translator/executor/pycode_generator.py @@ -187,23 +187,31 @@ def stacksize(instructions): # Two list below shows the possible stack size before opcode is called # The stack size might be different in different branch, so it has max and min max_stack = [float("-inf")] * len(instructions) - min_stack = [float("inf")] * len(instructions) max_stack[0] = 0 - min_stack[0] = 0 + + queue = [] + queue.append(0) def update_stacksize(lasti, nexti, stack_effect): + old_max = max_stack[nexti] max_stack[nexti] = max( max_stack[nexti], max_stack[lasti] + stack_effect ) - min_stack[nexti] = min( - min_stack[nexti], max_stack[lasti] + stack_effect - ) + if old_max != max_stack[nexti]: + if nexti not in queue: # may be slow, we can use a flag. + queue.append(nexti) - for idx in range(len(instructions)): + while len(queue) > 0: + idx = queue[0] + del queue[0] instr = instructions[idx] - - if idx + 1 < len(instructions): + opname = instr.opname + if idx + 1 < len(instructions) and instr.opname not in [ + 'JUMP_ABSOLUTE', + "JUMP_FORWARD", + "JUMP_BACKWRAD", + ]: stack_effect = dis.stack_effect(instr.opcode, instr.arg, jump=False) update_stacksize(idx, idx + 1, stack_effect) @@ -212,7 +220,7 @@ def update_stacksize(lasti, nexti, stack_effect): target_idx = instructions.index(instr.jump_to) update_stacksize(idx, target_idx, stack_effect) - assert min(min_stack) >= 0 + # assert min(min_stack) >= 0 # min_stack may be a negative number when try: except is got. return max(max_stack) diff --git a/sot/opcode_translator/executor/variables/callable.py b/sot/opcode_translator/executor/variables/callable.py index e6bd77920..6c1c2090f 100644 --- a/sot/opcode_translator/executor/variables/callable.py +++ b/sot/opcode_translator/executor/variables/callable.py @@ -15,6 +15,7 @@ is_builtin_fn, is_paddle_api, magic_method_builtin_dispatch, + psdb_breakpoint, psdb_print, ) from ....utils.exceptions import BreakGraphError, FallbackErrorBase @@ -93,9 +94,7 @@ def __init__( ): super().__init__(fn, graph, tracker) - def call_function(self, /, *args, **kwargs) -> VariableBase: - from ..opcode_inline_executor import OpcodeInlineExecutor - + def handle_psdb_function(self, /, *args, **kwargs): # special function for inner debug. if self.value is ASSERT: # TODO: add comptime check mechanism @@ -109,6 +108,18 @@ def call_function(self, /, *args, **kwargs) -> VariableBase: ) return ConstantVariable.wrap_literal(None, self.graph) + if self.value is psdb_breakpoint: + # do nothing. just return None. + return ConstantVariable.wrap_literal(None, self.graph) + return None + + def call_function(self, /, *args, **kwargs) -> VariableBase: + from ..opcode_inline_executor import OpcodeInlineExecutor + + result = self.handle_psdb_function(*args, **kwargs) + if result is not None: + return result + checkpoint = self.graph.save_memo() try: inline_executor = OpcodeInlineExecutor(self, *args, **kwargs) diff --git a/sot/opcode_translator/transform.py b/sot/opcode_translator/transform.py index 08717bc21..a8954881c 100644 --- a/sot/opcode_translator/transform.py +++ b/sot/opcode_translator/transform.py @@ -27,7 +27,7 @@ def convert_obj(obj): import paddle if isinstance(obj, paddle.Tensor): - return obj.shape + return "Tensor(" + str(obj.shape) + ")" if isinstance(obj, list): return [convert_obj(i) for i in obj] return obj diff --git a/sot/symbolic/compile_cache.py b/sot/symbolic/compile_cache.py index c9fdff889..322987c06 100644 --- a/sot/symbolic/compile_cache.py +++ b/sot/symbolic/compile_cache.py @@ -57,6 +57,10 @@ def __call__(self, *args, **kwargs): self.concrete_program.main_program ), ) + log_do( + 4, + lambda: print("[CompileCache] run sir forward success."), + ) return outputs diff --git a/sot/utils/__init__.py b/sot/utils/__init__.py index d7e097496..03a70d747 100644 --- a/sot/utils/__init__.py +++ b/sot/utils/__init__.py @@ -32,6 +32,7 @@ map_if, meta_str, no_eval_frame, + psdb_breakpoint, psdb_print, show_trackers, ) @@ -61,6 +62,7 @@ "paddle_tensor_methods", "ASSERT", "psdb_print", + "psdb_breakpoint", "ResumeFnNameFactory", "list_contain_by_id", "list_find_index_by_id", diff --git a/sot/utils/exceptions.py b/sot/utils/exceptions.py index 20f660fe3..74c7c6429 100644 --- a/sot/utils/exceptions.py +++ b/sot/utils/exceptions.py @@ -34,6 +34,8 @@ def impl(*args, **kwargs): return func(*args, **kwargs) except Exception as e: message = message_fn(*args, **kwargs) - raise InnerError(f"{message}.\nOrigin Exception is : \n {e}") from e + raise InnerError( + f"{message}.\nOrigin Exception is : \n {traceback.format_exception(type(e), e, e.__traceback__)}" + ) from e return impl diff --git a/sot/utils/utils.py b/sot/utils/utils.py index 8e3ed5dce..25442f302 100644 --- a/sot/utils/utils.py +++ b/sot/utils/utils.py @@ -203,6 +203,14 @@ def psdb_print(*args, **kwargs): print("[Dygraph]", *args, **kwargs) +def psdb_breakpoint(): + import paddle + + old = paddle.fluid.core.set_eval_frame(None) + breakpoint() + paddle.fluid.core.set_eval_frame(old) + + def list_find_index_by_id(li: list[Any], item: Any) -> int: return [id(it) for it in li].index(id(item)) From 91ba8fffc7fd970eba21b76acf668e129b9f84d8 Mon Sep 17 00:00:00 2001 From: xiogkun Date: Wed, 19 Jul 2023 07:17:20 +0000 Subject: [PATCH 13/14] fix dispatch meets unhashable function errors. --- sot/opcode_translator/executor/dispatcher.py | 4 ++-- sot/opcode_translator/executor/opcode_executor.py | 5 +++-- sot/utils/__init__.py | 2 ++ sot/utils/magic_methods.py | 4 ++++ sot/utils/utils.py | 8 ++++++++ 5 files changed, 19 insertions(+), 4 deletions(-) diff --git a/sot/opcode_translator/executor/dispatcher.py b/sot/opcode_translator/executor/dispatcher.py index add169daf..11b42fc4f 100644 --- a/sot/opcode_translator/executor/dispatcher.py +++ b/sot/opcode_translator/executor/dispatcher.py @@ -5,7 +5,7 @@ from functools import cached_property, reduce from typing import TYPE_CHECKING, Any, Callable, Dict, Tuple, TypeVar -from ...utils import InnerError +from ...utils import InnerError, hashable if TYPE_CHECKING: T = TypeVar("T") @@ -209,7 +209,7 @@ def dispatch( args: The args of the function. kwargs: The kwargs of the function. """ - if fn not in cls.handlers: + if not hashable(fn) or fn not in cls.handlers: return None for pattern, handler in cls.handlers[fn]: if pattern.match_inputs(*args, **kwargs): diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index 9687b00ee..f2561d8bf 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -553,10 +553,11 @@ def error_message_summary(original_error: Exception) -> str: message_lines.append( f"{indent} {lines[current_line-start].rstrip()}" ) - error_message = traceback.format_exception_only( - type(original_error), original_error + error_message = traceback.format_exception( + type(original_error), original_error, original_error.__traceback__ ) for line in error_message: + line = line.rstrip() message_lines.append(f"{indent} {line}") return "\n".join(message_lines) diff --git a/sot/utils/__init__.py b/sot/utils/__init__.py index 03a70d747..51e505e7e 100644 --- a/sot/utils/__init__.py +++ b/sot/utils/__init__.py @@ -20,6 +20,7 @@ count_if, execute_time, get_unbound_method, + hashable, in_paddle_module, is_break_graph_api, is_builtin_fn, @@ -70,4 +71,5 @@ "get_unbound_method", "GraphLogger", "UndefinedVar", + "hashable", ] diff --git a/sot/utils/magic_methods.py b/sot/utils/magic_methods.py index bca42ea46..19c78520c 100644 --- a/sot/utils/magic_methods.py +++ b/sot/utils/magic_methods.py @@ -4,6 +4,8 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable +from .utils import hashable + if TYPE_CHECKING: BinaryOp = Callable[[Any, Any], Any] UnaryOp = Callable[[Any], Any] @@ -89,6 +91,8 @@ class MagicMethod: def magic_method_builtin_dispatch(fn: BinaryOp | UnaryOp) -> list[MagicMethod]: + if not hashable(fn): + return [] if fn in INPLACE_BINARY_OPS: inplace_magic_name, non_inplace_op = INPLACE_BINARY_OPS_TO_MAGIC_NAMES[ fn diff --git a/sot/utils/utils.py b/sot/utils/utils.py index 25442f302..4227bf83e 100644 --- a/sot/utils/utils.py +++ b/sot/utils/utils.py @@ -289,3 +289,11 @@ def print_info(self): @Singleton class UndefinedVar: pass + + +def hashable(obj): + try: + hash(obj) + return True + except TypeError as e: + return False From e1a94db3b1b087c1ebc51227ff6f8e71119a131a Mon Sep 17 00:00:00 2001 From: xiogkun Date: Wed, 19 Jul 2023 07:46:15 +0000 Subject: [PATCH 14/14] fix --- sot/opcode_translator/executor/opcode_executor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index 022436c2a..5c697aeee 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -553,8 +553,8 @@ def error_message_summary(original_error: Exception) -> str: message_lines.append( f"{indent} {lines[current_line-start].rstrip()}" ) - error_message = traceback.format_exception( - type(original_error), original_error, original_error.__traceback__ + error_message = traceback.format_exception_only( + type(original_error), original_error ) for line in error_message: line = line.rstrip()