Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
feifei-111 committed Jul 17, 2023
1 parent 8c729e3 commit 61c937e
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 114 deletions.
10 changes: 7 additions & 3 deletions sot/opcode_translator/executor/function_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def convert_to_meta(inputs: Any):
def func(x):
if isinstance(x, TensorVariable):
return x.meta
return x.get_value()
return x.get_py_value()

return map_variables(func, inputs)

Expand All @@ -57,7 +57,7 @@ def convert_to_symbol(inputs: Any):
def func(x):
if isinstance(x, (TensorVariable, PaddleLayerVariable)):
return x.get_symbol()
return x.get_value()
return x.get_py_value()

return map_variables(func, inputs)

Expand Down Expand Up @@ -253,7 +253,11 @@ def start_compile(self, *ret_vars: VariableBase):
self.pycode_gen.gen_store_fast(tensor_var.out_var_name)
# restore the outputs.
for ret_var in ret_vars:
ret_var.reconstruct(self.pycode_gen)
try:
ret_var.reconstruct(self.pycode_gen)
except:
breakpoint()
ret_var.reconstruct(self.pycode_gen)

# deal side effect
self.restore_side_effects(self.side_effects.variables)
Expand Down
6 changes: 3 additions & 3 deletions sot/opcode_translator/executor/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ def object_equal_stringify_guard(self) -> StringifyExpression:
frame_value_tracer = self.tracker.trace_value_from_frame()

obj_free_var_name = f"__{self.id}"
weak_ref_obj = self.get_value()
weak_ref_obj = self.get_py_value()
if support_weak_ref(weak_ref_obj):
weak_ref_obj = weakref.ref(self.get_value())
weak_ref_obj = weakref.ref(self.get_py_value())
return StringifyExpression(
f"{obj_free_var_name}() is not None and {frame_value_tracer.expr} == {obj_free_var_name}()",
union_free_vars(
Expand All @@ -116,6 +116,6 @@ def object_equal_stringify_guard(self) -> StringifyExpression:
f"{frame_value_tracer.expr} == {obj_free_var_name}",
union_free_vars(
frame_value_tracer.free_vars,
{obj_free_var_name: self.get_value()},
{obj_free_var_name: self.get_py_value()},
),
)
36 changes: 18 additions & 18 deletions sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def get_var(self, name: str):
elif name in self._builtins.keys():
return self._builtins[name]
elif name in self._cells.keys(): # in closure
return self._cells[name].get_value()
return self._cells[name].cell_content()
else:
raise InnerError(f'Can not get var: {name}')

Expand Down Expand Up @@ -769,7 +769,7 @@ def BINARY_SUBSCR(self, instr: Instruction):
self._graph.add_global_guarded_variable(key)
self.push(
BuiltinVariable(operator.getitem, self._graph, DanglingTracker())(
container, key.get_value()
container, key.get_py_value()
)
)

Expand Down Expand Up @@ -813,7 +813,7 @@ def LOAD_CLOSURE(self, instr):
def LOAD_DEREF(self, instr):
namemap = self._code.co_cellvars + self._code.co_freevars
name = namemap[instr.arg]
self.push(self._cells[name].get_value())
self.push(self._cells[name].cell_content())

def LOAD_FAST(self, instr: Instruction):
varname = self._code.co_varnames[instr.arg]
Expand Down Expand Up @@ -880,7 +880,7 @@ def STORE_SUBSCR(self, instr: Instruction):
f"Key is a TensorVariable in STORE_SUBSCR, {container}[{key}] = {value}"
)
# TODO(xiongkun): support tensor[tensor] = tensor, dy2static is not the same with dygraph.
container[key.get_value()] = value
container[key.get_py_value()] = value
value.debug_name = f"{container.debug_name}[{key.debug_name}]"

def DELETE_SUBSCR(self, instr: Instruction):
Expand Down Expand Up @@ -926,8 +926,8 @@ def BUILD_STRING(self, instr: Instruction):
str_list = self.pop_n(count)
new_str = ''
for s in str_list:
assert isinstance(s.get_value(), str)
new_str += s.get_value()
assert s.get_py_type() == str
new_str += s.get_py_value()
self.push(
VariableFactory.from_value(
new_str, self._graph, DummyTracker(str_list)
Expand All @@ -945,7 +945,7 @@ def BUILD_SLICE(self, instr: Instruction):

related_list = [start, stop, step] if step else [start, stop]

slice_ = slice(*(x.get_value() for x in related_list))
slice_ = slice(*(x.get_py_value() for x in related_list))

self.push(
VariableFactory.from_value(
Expand All @@ -961,7 +961,7 @@ def build_map(
assert isinstance(key, VariableBase)
# Add key to global guarded variable to avoid missing the key guard
self._graph.add_global_guarded_variable(key)
key = key.get_value()
key = key.get_py_value()
built_map[key] = value
return DictVariable(
built_map,
Expand Down Expand Up @@ -1027,7 +1027,7 @@ def BUILD_MAP_UNPACK(self, instr: Instruction):

retval = {}
for item in unpack_values:
assert isinstance(item.get_value(), dict)
assert item.get_py_type() == dict
retval.update(item.get_wrapped_items())

self.push(
Expand All @@ -1043,7 +1043,7 @@ def BUILD_MAP_UNPACK_WITH_CALL(self, instr: Instruction):

retval = {}
for item in unpack_values:
assert isinstance(item.get_value(), dict)
assert item.get_py_type() == dict
wrapped_item = item.get_wrapped_items()
if wrapped_item.items() & retval.items():
raise InnerError(
Expand Down Expand Up @@ -1074,8 +1074,8 @@ def CALL_FUNCTION_KW(self, instr: Instruction):
assert isinstance(kwargs_keys, TupleVariable)
assert len(kwargs_keys) > 0
kwargs_keys = [
x.get_value() if isinstance(x, VariableBase) else x
for x in kwargs_keys.get_value()
x.get_py_value() if isinstance(x, VariableBase) else x
for x in kwargs_keys.get_py_value()
]

# split arg_list to args and kwargs
Expand Down Expand Up @@ -1119,7 +1119,7 @@ def CALL_METHOD(self, instr: Instruction):

@call_break_graph_decorator(
push_n=1
) # call instance, in, not in may call TensorVariable.get_value, which raise BreakGraphError
) # call instance, in, not in may call TensorVariable.get_py_value, which raise BreakGraphError
def COMPARE_OP(self, instr: Instruction):
op = dis.cmp_op[instr.arg]
right, left = self.pop(), self.pop()
Expand Down Expand Up @@ -1186,9 +1186,9 @@ def g(z=x):
default_args = ()

new_fn = types.FunctionType(
codeobj.get_value(),
codeobj.get_py_value(),
global_dict,
fn_name.get_value(),
fn_name.get_py_value(),
default_args,
closure,
)
Expand Down Expand Up @@ -1308,7 +1308,7 @@ def UNPACK_SEQUENCE(self, instr: Instruction):
'''
TODO: To unpack iterator
To unpack is easy, just like:
seq = tuple(sequence.get_value())
seq = tuple(sequence.get_py_value())
But what is the `source` when iterator returned a value ?
'''
Expand All @@ -1334,7 +1334,7 @@ def FORMAT_VALUE(self, instr: Instruction):
which_conversion = flag & FV.FVC_MASK
have_fmt_spec = bool((flag & FV.FVS_MASK) == FV.FVS_HAVE_SPEC)

fmt_spec = self.pop().get_value() if have_fmt_spec else ""
fmt_spec = self.pop().get_py_value() if have_fmt_spec else ""
value = self.pop()

if which_conversion == FV.FVC_NONE:
Expand All @@ -1352,7 +1352,7 @@ def FORMAT_VALUE(self, instr: Instruction):

# different type will lead to different Tracker, so call self.push in different branch
if isinstance(value, ConstantVariable):
result = value.get_value()
result = value.get_py_value()
if convert_fn is not None:
result = getattr(result, convert_fn)(result)

Expand Down
2 changes: 1 addition & 1 deletion sot/opcode_translator/executor/opcode_inline_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def _prepare_closure(self):
"""
from .variables import VariableFactory

closure = self._fn_var.get_value().__closure__
closure = self._fn_var.get_py_value().__closure__
for name in self._code.co_cellvars + self._code.co_freevars:
# create a cell for each variable.
self._cells[name] = CellVariable() # put in cells.
Expand Down
Loading

0 comments on commit 61c937e

Please sign in to comment.