diff --git a/sot/opcode_translator/executor/function_graph.py b/sot/opcode_translator/executor/function_graph.py index f8624472f..e4241374f 100644 --- a/sot/opcode_translator/executor/function_graph.py +++ b/sot/opcode_translator/executor/function_graph.py @@ -44,7 +44,7 @@ def convert_to_meta(inputs: Any): def func(x): if isinstance(x, TensorVariable): return x.meta - return x.get_value() + return x.get_py_value() return map_variables(func, inputs) @@ -57,7 +57,7 @@ def convert_to_symbol(inputs: Any): def func(x): if isinstance(x, (TensorVariable, PaddleLayerVariable)): return x.get_symbol() - return x.get_value() + return x.get_py_value() return map_variables(func, inputs) @@ -253,7 +253,11 @@ def start_compile(self, *ret_vars: VariableBase): self.pycode_gen.gen_store_fast(tensor_var.out_var_name) # restore the outputs. for ret_var in ret_vars: - ret_var.reconstruct(self.pycode_gen) + try: + ret_var.reconstruct(self.pycode_gen) + except: + breakpoint() + ret_var.reconstruct(self.pycode_gen) # deal side effect self.restore_side_effects(self.side_effects.variables) diff --git a/sot/opcode_translator/executor/guard.py b/sot/opcode_translator/executor/guard.py index 6dc6780a9..5c2722fda 100644 --- a/sot/opcode_translator/executor/guard.py +++ b/sot/opcode_translator/executor/guard.py @@ -102,9 +102,9 @@ def object_equal_stringify_guard(self) -> StringifyExpression: frame_value_tracer = self.tracker.trace_value_from_frame() obj_free_var_name = f"__{self.id}" - weak_ref_obj = self.get_value() + weak_ref_obj = self.get_py_value() if support_weak_ref(weak_ref_obj): - weak_ref_obj = weakref.ref(self.get_value()) + weak_ref_obj = weakref.ref(self.get_py_value()) return StringifyExpression( f"{obj_free_var_name}() is not None and {frame_value_tracer.expr} == {obj_free_var_name}()", union_free_vars( @@ -116,6 +116,6 @@ def object_equal_stringify_guard(self) -> StringifyExpression: f"{frame_value_tracer.expr} == {obj_free_var_name}", union_free_vars( frame_value_tracer.free_vars, - {obj_free_var_name: self.get_value()}, + {obj_free_var_name: self.get_py_value()}, ), ) diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index 97f829470..e4aed62ec 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -502,7 +502,7 @@ def get_var(self, name: str): elif name in self._builtins.keys(): return self._builtins[name] elif name in self._cells.keys(): # in closure - return self._cells[name].get_value() + return self._cells[name].cell_content() else: raise InnerError(f'Can not get var: {name}') @@ -769,7 +769,7 @@ def BINARY_SUBSCR(self, instr: Instruction): self._graph.add_global_guarded_variable(key) self.push( BuiltinVariable(operator.getitem, self._graph, DanglingTracker())( - container, key.get_value() + container, key.get_py_value() ) ) @@ -813,7 +813,7 @@ def LOAD_CLOSURE(self, instr): def LOAD_DEREF(self, instr): namemap = self._code.co_cellvars + self._code.co_freevars name = namemap[instr.arg] - self.push(self._cells[name].get_value()) + self.push(self._cells[name].cell_content()) def LOAD_FAST(self, instr: Instruction): varname = self._code.co_varnames[instr.arg] @@ -880,7 +880,7 @@ def STORE_SUBSCR(self, instr: Instruction): f"Key is a TensorVariable in STORE_SUBSCR, {container}[{key}] = {value}" ) # TODO(xiongkun): support tensor[tensor] = tensor, dy2static is not the same with dygraph. - container[key.get_value()] = value + container[key.get_py_value()] = value value.debug_name = f"{container.debug_name}[{key.debug_name}]" def DELETE_SUBSCR(self, instr: Instruction): @@ -926,8 +926,8 @@ def BUILD_STRING(self, instr: Instruction): str_list = self.pop_n(count) new_str = '' for s in str_list: - assert isinstance(s.get_value(), str) - new_str += s.get_value() + assert s.get_py_type() == str + new_str += s.get_py_value() self.push( VariableFactory.from_value( new_str, self._graph, DummyTracker(str_list) @@ -945,7 +945,7 @@ def BUILD_SLICE(self, instr: Instruction): related_list = [start, stop, step] if step else [start, stop] - slice_ = slice(*(x.get_value() for x in related_list)) + slice_ = slice(*(x.get_py_value() for x in related_list)) self.push( VariableFactory.from_value( @@ -961,7 +961,7 @@ def build_map( assert isinstance(key, VariableBase) # Add key to global guarded variable to avoid missing the key guard self._graph.add_global_guarded_variable(key) - key = key.get_value() + key = key.get_py_value() built_map[key] = value return DictVariable( built_map, @@ -1027,7 +1027,7 @@ def BUILD_MAP_UNPACK(self, instr: Instruction): retval = {} for item in unpack_values: - assert isinstance(item.get_value(), dict) + assert item.get_py_type() == dict retval.update(item.get_wrapped_items()) self.push( @@ -1043,7 +1043,7 @@ def BUILD_MAP_UNPACK_WITH_CALL(self, instr: Instruction): retval = {} for item in unpack_values: - assert isinstance(item.get_value(), dict) + assert item.get_py_type() == dict wrapped_item = item.get_wrapped_items() if wrapped_item.items() & retval.items(): raise InnerError( @@ -1074,8 +1074,8 @@ def CALL_FUNCTION_KW(self, instr: Instruction): assert isinstance(kwargs_keys, TupleVariable) assert len(kwargs_keys) > 0 kwargs_keys = [ - x.get_value() if isinstance(x, VariableBase) else x - for x in kwargs_keys.get_value() + x.get_py_value() if isinstance(x, VariableBase) else x + for x in kwargs_keys.get_py_value() ] # split arg_list to args and kwargs @@ -1119,7 +1119,7 @@ def CALL_METHOD(self, instr: Instruction): @call_break_graph_decorator( push_n=1 - ) # call instance, in, not in may call TensorVariable.get_value, which raise BreakGraphError + ) # call instance, in, not in may call TensorVariable.get_py_value, which raise BreakGraphError def COMPARE_OP(self, instr: Instruction): op = dis.cmp_op[instr.arg] right, left = self.pop(), self.pop() @@ -1186,9 +1186,9 @@ def g(z=x): default_args = () new_fn = types.FunctionType( - codeobj.get_value(), + codeobj.get_py_value(), global_dict, - fn_name.get_value(), + fn_name.get_py_value(), default_args, closure, ) @@ -1308,7 +1308,7 @@ def UNPACK_SEQUENCE(self, instr: Instruction): ''' TODO: To unpack iterator To unpack is easy, just like: - seq = tuple(sequence.get_value()) + seq = tuple(sequence.get_py_value()) But what is the `source` when iterator returned a value ? ''' @@ -1334,7 +1334,7 @@ def FORMAT_VALUE(self, instr: Instruction): which_conversion = flag & FV.FVC_MASK have_fmt_spec = bool((flag & FV.FVS_MASK) == FV.FVS_HAVE_SPEC) - fmt_spec = self.pop().get_value() if have_fmt_spec else "" + fmt_spec = self.pop().get_py_value() if have_fmt_spec else "" value = self.pop() if which_conversion == FV.FVC_NONE: @@ -1352,7 +1352,7 @@ def FORMAT_VALUE(self, instr: Instruction): # different type will lead to different Tracker, so call self.push in different branch if isinstance(value, ConstantVariable): - result = value.get_value() + result = value.get_py_value() if convert_fn is not None: result = getattr(result, convert_fn)(result) diff --git a/sot/opcode_translator/executor/opcode_inline_executor.py b/sot/opcode_translator/executor/opcode_inline_executor.py index dbb2313f9..7c451a39a 100644 --- a/sot/opcode_translator/executor/opcode_inline_executor.py +++ b/sot/opcode_translator/executor/opcode_inline_executor.py @@ -206,7 +206,7 @@ def _prepare_closure(self): """ from .variables import VariableFactory - closure = self._fn_var.get_value().__closure__ + closure = self._fn_var.get_py_value().__closure__ for name in self._code.co_cellvars + self._code.co_freevars: # create a cell for each variable. self._cells[name] = CellVariable() # put in cells. diff --git a/sot/opcode_translator/executor/variable_dispatch.py b/sot/opcode_translator/executor/variable_dispatch.py index fb1ee2f52..756db890c 100644 --- a/sot/opcode_translator/executor/variable_dispatch.py +++ b/sot/opcode_translator/executor/variable_dispatch.py @@ -44,25 +44,24 @@ lambda var, value: var.index(value), ) -# dict Dispatcher.register( operator_in, ("VariableBase", "VariableBase"), {}, + # if left is a tensor, will raise err at left.get_py_value() lambda left, right: VariableFactory.from_value( - left.get_value() in right.get_value(), + left.get_py_value() in right.get_py_value(allow_tensor=True), left.graph, tracker=DummyTracker([left, right]), ), ) -# dict Dispatcher.register( operator_not_in, ("VariableBase", "VariableBase"), {}, lambda left, right: VariableFactory.from_value( - left.get_value() not in right.get_value(), + left.get_py_value() not in right.get_py_value(allow_tensor=True), left.graph, tracker=DummyTracker([left, right]), ), @@ -73,13 +72,13 @@ dict.get, ("DictVariable", "ConstantVariable", "VariableBase"), {}, - lambda var, key, default: var.get(key.get_value(), default), + lambda var, key, default: var.get(key.get_py_value(), default), ) Dispatcher.register( dict.get, ("DictVariable", "ConstantVariable"), {}, - lambda var, key: var.get(key.get_value()), + lambda var, key: var.get(key.get_py_value()), ) Dispatcher.register( dict.keys, @@ -104,13 +103,13 @@ dict.setdefault, ("DictVariable", "ConstantVariable", "VariableBase"), {}, - lambda var, key, default: var.setdefault(key.get_value(), default), + lambda var, key, default: var.setdefault(key.get_py_value(), default), ) Dispatcher.register( dict.setdefault, ("DictVariable", "ConstantVariable"), {}, - lambda var, key: var.setdefault(key.get_value()), + lambda var, key: var.setdefault(key.get_py_value()), ) Dispatcher.register( dict.update, @@ -134,13 +133,13 @@ dict.pop, ("DictVariable", "ConstantVariable"), {}, - lambda var, key: var.pop(key.get_value()), + lambda var, key: var.pop(key.get_py_value()), ) Dispatcher.register( dict.pop, ("DictVariable", "ConstantVariable", "VariableBase"), {}, - lambda var, key, default: var.pop(key.get_value(), default), + lambda var, key, default: var.pop(key.get_py_value(), default), ) Dispatcher.register( dict.popitem, @@ -188,7 +187,7 @@ list.insert, ("ListVariable", "ConstantVariable", "VariableBase"), {}, - lambda var, index, obj: var.insert(index.get_value(), obj), + lambda var, index, obj: var.insert(index.get_py_value(), obj), ) Dispatcher.register( list.remove, @@ -286,7 +285,7 @@ {}, lambda var, name: ( var.graph.add_global_guarded_variable(name), - var.getattr(name.get_value()), + var.getattr(name.get_py_value()), )[1], ) Dispatcher.register( @@ -295,7 +294,7 @@ {}, lambda var, name, default: ( var.graph.add_global_guarded_variable(name), - var.getattr(name.get_value(), default), + var.getattr(name.get_py_value(), default), )[1], ) # len @@ -314,7 +313,9 @@ ("ConstantVariable",), {}, lambda stop: VariableFactory.from_value( - range(stop.get_value()), graph=stop.graph, tracker=DummyTracker([stop]) + range(stop.get_py_value()), + graph=stop.graph, + tracker=DummyTracker([stop]), ), ) @@ -324,7 +325,7 @@ ("ConstantVariable", "ConstantVariable"), {}, lambda start, stop: VariableFactory.from_value( - range(start.get_value(), stop.get_value()), + range(start.get_py_value(), stop.get_py_value()), graph=stop.graph, tracker=DummyTracker([start, stop]), ), @@ -335,7 +336,7 @@ ("ConstantVariable", "ConstantVariable", "ConstantVariable"), {}, lambda start, stop, step: VariableFactory.from_value( - range(start.get_value(), stop.get_value(), step.get_value()), + range(start.get_py_value(), stop.get_py_value(), step.get_py_value()), graph=stop.graph, tracker=DummyTracker([start, stop, step]), ), @@ -359,7 +360,7 @@ ("VariableBase", "VariableBase"), {}, lambda left, right: ConstantVariable.wrap_literal( - isinstance(left.get_value(), right.get_value()), left.graph + isinstance(left.get_py_value(), right.get_py_value()), left.graph ), ) @@ -425,7 +426,7 @@ "ConstantVariable | SliceVariable", ), {}, - lambda var, key: var.getitem(key.get_value()), + lambda var, key: var.getitem(key.get_py_value()), ) # setitem @@ -437,7 +438,7 @@ "int | str | ConstantVariable | TensorVariable", ), {}, - lambda var, key, value: var.setitem(key.get_value(), value), + lambda var, key, value: var.setitem(key.get_py_value(), value), ) # delitem @@ -457,7 +458,7 @@ "ConstantVariable", ), {}, - lambda var, key: var.delitem(key.get_value()), + lambda var, key: var.delitem(key.get_py_value()), ) @@ -532,7 +533,7 @@ ("VariableBase", "VariableBase"), {}, lambda var, other: VariableFactory.from_value( - var.get_value() is other.get_value(), + var.get_py_value() is other.get_py_value(), var.graph, tracker=DummyTracker([var, other]), ), @@ -575,7 +576,9 @@ def is_not_func(var: VariableBase, other: VariableBase): {}, partial( lambda fn, var: VariableFactory.from_value( - fn(var.get_value()), var.graph, tracker=DummyTracker([var]) + fn(var.get_py_value()), + var.graph, + tracker=DummyTracker([var]), ), unary_fn, ), @@ -588,7 +591,7 @@ def is_not_func(var: VariableBase, other: VariableBase): {}, partial( lambda fn, var, other: VariableFactory.from_value( - fn(var.get_value(), other.get_value()), + fn(var.get_py_value(), other.get_py_value()), var.graph, tracker=DummyTracker([var, other]), ), @@ -665,7 +668,7 @@ def is_not_func(var: VariableBase, other: VariableBase): def tensor_mod_dispatcher( var: ConstantVariable, other: TensorVariable ): - if isinstance(var.get_value(), str): + if var.get_py_type() == str: raise BreakGraphError( "(ConstantVariable % TensorVariable) raise a callback. " ) @@ -712,7 +715,7 @@ def numpy_binary_dispatcher(var: NumpyVariable, other: NumpyVariable): # Register dispatch for DataVariable: directy call and return a wrapped variable. def data_variable_binary_dispatcher(var, other, operator): return VariableFactory.from_value( - operator(var.get_value(), other.get_value()), + operator(var.get_py_value(), other.get_py_value()), var.graph, DummyTracker([var, other]), ) @@ -738,7 +741,7 @@ def data_variable_binary_dispatcher(var, other, operator): def data_variable_unary_dispatcher(var: DataVariable, fn): return VariableFactory.from_value( - fn(var.get_value()), + fn(var.get_py_value()), var.graph, DummyTracker([var]), ) diff --git a/sot/opcode_translator/executor/variables/base.py b/sot/opcode_translator/executor/variables/base.py index bba189622..c672133ff 100644 --- a/sot/opcode_translator/executor/variables/base.py +++ b/sot/opcode_translator/executor/variables/base.py @@ -308,21 +308,21 @@ def make_stringify_guard(self) -> StringifyExpression: frame_value_tracer = self.tracker.trace_value_from_frame() return StringifyExpression( - f"{frame_value_tracer.expr} == {self.get_value()!r}", + f"{frame_value_tracer.expr} == {self.get_py_value()!r}", union_free_vars(frame_value_tracer.free_vars), ) - def get_value(self) -> Any: + def get_py_value(self, allow_tensor=False) -> Any: """ Abstract method to get the value of the variable """ raise NotImplementedError() - def get_type(self): + def get_py_type(self): """ Method to get the type of the variable's value """ - return type(self.get_value()) + return type(self.get_py_value()) def reconstruct(self, codegen: PyCodeGen): if ( @@ -411,7 +411,7 @@ def getattr(self, name: str, default=None): getattr(attr.__self__.__class__, name, None) ): class_var = VariableFactory.from_value( - self.get_type(), + self.get_py_type(), self.graph, GetAttrTracker(self, "__class__"), ) @@ -452,12 +452,12 @@ def __getitem__(self, idx): def getitem(self, item): class_var = VariableFactory.from_value( - self.get_value().__class__, + self.get_py_value().__class__, self.graph, GetAttrTracker(self, '__class__'), ) fn_var = VariableFactory.from_value( - get_unbound_method(self.get_value(), '__getitem__'), + get_unbound_method(self.get_py_value(), '__getitem__'), self.graph, GetAttrTracker(class_var, '__getitem__'), ) @@ -478,15 +478,15 @@ def __call__(self, /, *args, **kwargs): from .callable import BuiltinVariable, UserDefinedFunctionVariable class_var = VariableFactory.from_value( - self.get_value().__class__, + self.get_py_value().__class__, self.graph, GetAttrTracker(self, '__class__'), ) assert class_var is not None # if __call__ is a method, we should add self to arguments. - if inspect.ismethod(self.get_value().__call__): + if inspect.ismethod(self.get_py_value().__call__): args = (self,) + args - unbound_method = get_unbound_method(self.get_value(), '__call__') + unbound_method = get_unbound_method(self.get_py_value(), '__call__') if hasattr(unbound_method, "__code__"): fn_var = UserDefinedFunctionVariable( unbound_method, diff --git a/sot/opcode_translator/executor/variables/basic.py b/sot/opcode_translator/executor/variables/basic.py index 7d23b6a29..28e0ca41f 100644 --- a/sot/opcode_translator/executor/variables/basic.py +++ b/sot/opcode_translator/executor/variables/basic.py @@ -85,7 +85,7 @@ def __init__( super().__init__(graph, tracker) self.value = value - def get_value(self): + def get_py_value(self, allow_tensor=False): return self.value @property @@ -113,10 +113,10 @@ def bool(self): def bool_not(self): assert isinstance( - self.get_value(), bool + self.get_py_value(), bool ), "Bool_not can only be applied to a bool variable." return VariableFactory.from_value( - not bool(self.get_value()), self.graph, DummyTracker([self]) + not bool(self.get_py_value()), self.graph, DummyTracker([self]) ) def str(self): @@ -197,7 +197,7 @@ def __init__( super().__init__(graph, tracker) self.value = value - def get_value(self): + def get_py_value(self, allow_tensor=False): return self.value make_stringify_guard = object_equal_stringify_guard @@ -253,12 +253,15 @@ def bool(self): bool(self.value), self.graph, DummyTracker([self]) ) - def get_value(self): + def get_py_value(self, allow_tensor=False): + if allow_tensor: + return self.var_name + breakpoint() raise BreakGraphError( - "Called TensorVariable.get_value. Should not use Tensor's value in simulating." + "Called TensorVariable.get_py_value. Should not use Tensor's value in simulating." ) - def get_type(self): + def get_py_type(self): return paddle.Tensor def get_symbol(self) -> Symbol: @@ -455,7 +458,7 @@ def __init__(self, obj, graph, tracker): def main_info(self) -> dict[str, Any]: return {"value": self.value} - def get_value(self) -> Any: + def get_py_value(self, allow_tensor=False) -> Any: return self.value @@ -491,7 +494,7 @@ def debug_name(self, name): def main_info(self) -> dict[str, Any]: return {"value": self.value} - def get_value(self): + def get_py_value(self, allow_tensor=False): return self.value @VariableFactory.register_from_value() @@ -515,7 +518,7 @@ def __init__(self, func, graph, tracker): super().__init__(graph, tracker) self.value = func - def get_value(self): + def get_py_value(self, allow_tensor=False): return self.value @property @@ -538,7 +541,7 @@ def __init__(self, value, graph, tracker): super().__init__(graph, tracker) self.value = value - def get_value(self): + def get_py_value(self, allow_tensor=False): return self.value @check_guard @@ -576,12 +579,12 @@ def __init__(self, value, graph, tracker): def main_info(self) -> dict[str, Any]: return {"value": self.value} - def get_value(self) -> Any: + def get_py_value(self, allow_tensor=False) -> Any: return self.value @check_guard def make_stringify_guard(self) -> StringifyExpression: - if isinstance(self.get_value(), np.number): + if isinstance(self.get_py_value(), np.number): frame_value_tracer = self.tracker.trace_value_from_frame() def format_dtype(dtype: np.dtype): @@ -591,10 +594,10 @@ def format_number(number: np.number): return f"{format_dtype(number.dtype)}({str(number.item())})" return StringifyExpression( - f"{frame_value_tracer.expr} == {format_number(self.get_value())}", + f"{frame_value_tracer.expr} == {format_number(self.get_py_value())}", union_free_vars(frame_value_tracer.free_vars, {"np": np}), ) & StringifyExpression( - f"{frame_value_tracer.expr}.dtype == {format_dtype(self.get_value().dtype)}", + f"{frame_value_tracer.expr}.dtype == {format_dtype(self.get_py_value().dtype)}", union_free_vars(frame_value_tracer.free_vars, {"np": np}), ) else: @@ -631,7 +634,7 @@ def __init__(self, value=None): assert isinstance(value, (VariableBase, type(None))) self.set_value(value) - def get_value(self): + def cell_content(self): return self.value def set_value(self, value): diff --git a/sot/opcode_translator/executor/variables/callable.py b/sot/opcode_translator/executor/variables/callable.py index e6bd77920..69fcab0a9 100644 --- a/sot/opcode_translator/executor/variables/callable.py +++ b/sot/opcode_translator/executor/variables/callable.py @@ -62,7 +62,7 @@ def __init__( super().__init__(graph, tracker) self.value = fn - def get_value(self): + def get_py_value(self, allow_tensor=False): return self.value def get_code(self) -> types.CodeType: @@ -76,7 +76,7 @@ def bind(self, instance: VariableBase, name: str): tracker=GetAttrTracker(instance, name), ) class_var = VariableFactory.from_value( - instance.get_type(), + instance.get_py_type(), graph=self.graph, tracker=GetAttrTracker(instance, "__class__"), ) @@ -198,10 +198,10 @@ def __init__( self.fn = fn self.method_name = method_name - def get_value(self): - return self.fn.get_value().__get__( - self.bound_instance.get_value(), - self.bound_instance.get_value().__class__, + def get_py_value(self, allow_tensor=False): + return self.fn.get_py_value().__get__( + self.bound_instance.get_py_value(allow_tensor), + self.bound_instance.get_py_value(allow_tensor).__class__, ) def _reconstruct(self, pycode_gen): @@ -273,17 +273,17 @@ def __init__( super().__init__(graph, tracker) self.value = layer - def get_value(self): + def get_py_value(self, allow_tensor=False): return self.value @check_guard def make_stringify_guard(self) -> StringifyExpression: frame_value_tracer = self.tracker.trace_value_from_frame() return StringifyExpression( - f"id({frame_value_tracer.expr}) == {id(self.get_value())}", + f"id({frame_value_tracer.expr}) == {id(self.get_py_value())}", union_free_vars(frame_value_tracer.free_vars), ) & StringifyExpression( - f"{frame_value_tracer.expr}.training == {self.get_value().training}", + f"{frame_value_tracer.expr}.training == {self.get_py_value().training}", union_free_vars(frame_value_tracer.free_vars), ) @@ -337,7 +337,7 @@ def call_function(self, /, *args, **kwargs): sorted_args = args if magic_method.is_reverse: sorted_args = sorted_args[::-1] - arg_type = sorted_args[0].get_type() + arg_type = sorted_args[0].get_py_type() if hasattr(arg_type, magic_method.name): class_fn = getattr(arg_type, magic_method.name) class_var = VariableFactory.from_value( diff --git a/sot/opcode_translator/executor/variables/container.py b/sot/opcode_translator/executor/variables/container.py index 73a692602..3e041d933 100644 --- a/sot/opcode_translator/executor/variables/container.py +++ b/sot/opcode_translator/executor/variables/container.py @@ -16,7 +16,7 @@ Tracker, ) from .base import ConstTypes, VariableBase, VariableFactory -from .basic import ConstantVariable +from .basic import ConstantVariable, TensorVariable from .callable import BuiltinVariable if TYPE_CHECKING: @@ -98,11 +98,11 @@ def proxy_getter(self, data, key): data[key], self.graph, tracker=GetItemTracker(self, key) ) - def get_value(self): + def get_py_value(self, allow_tensor=False): items = self.proxy.get_all() - return [item.get_value() for item in items] + return [item.get_py_value(allow_tensor) for item in items] - def get_type(self): + def get_py_type(self): return list def _reconstruct(self, codegen: PyCodeGen): @@ -227,8 +227,8 @@ def repeat(self, length): def pop(self, index: ConstantVariable | None = None): if index is None: index = ConstantVariable.wrap_literal(-1, self.graph) - res = self.proxy.get(index.get_value()) - self.proxy.delete(index.get_value()) + res = self.proxy.get(index.get_py_value()) + self.proxy.delete(index.get_py_value()) self.graph.side_effects.record_variable(self) return res @@ -246,12 +246,20 @@ def clear(self): return ConstantVariable.wrap_literal(None, self.graph) def remove(self, value): - for idx in range(self.proxy.length): - if self[idx].get_value() == value.get_value(): - self.delitem(idx) - break + if isinstance(value, TensorVariable): + for idx in range(self.proxy.length): + if self[idx].id == value.id: + self.delitem(idx) + break + else: + raise InnerError(f"List {self} does not contain {value}") else: - raise InnerError(f"List {self} does not contain {value}") + for idx in range(self.proxy.length): + if self[idx].get_py_value() == value.get_py_value(): + self.delitem(idx) + break + else: + raise InnerError(f"List {self} does not contain {value}") self.graph.side_effects.record_variable(self) return ConstantVariable.wrap_literal(None, self.graph) @@ -259,7 +267,7 @@ def sort(self, key=None, reverse=None): if ( key is None or isinstance(key, ConstantVariable) - and key.get_value() is None + and key.get_py_value() is None ): key = VariableFactory.from_value( lambda x: x, self.graph, DanglingTracker() @@ -270,8 +278,8 @@ def sort(self, key=None, reverse=None): permutation = list(range(self.proxy.length)) permutation.sort( - key=lambda x: key.get_value()(self.getitem(x).value), - reverse=reverse.get_value(), + key=lambda x: key.get_py_value()(self.getitem(x).value), + reverse=reverse.get_py_value(), ) self.proxy.permutate(permutation) self.graph.side_effects.record_variable(self) @@ -297,7 +305,7 @@ def count(self, value: VariableBase): assert isinstance( eq_bool, ConstantVariable ), "bool should return ConstantVariable" - if eq.get_value() is True: + if eq.get_py_value() is True: count += 1 continue @@ -319,7 +327,7 @@ def index(self, value: VariableBase): assert isinstance( eq_bool, ConstantVariable ), "bool should return ConstantVariable" - if eq.get_value() is True: + if eq.get_py_value() is True: return VariableFactory.from_value( res, self.graph, DummyTracker([self, value]) ) @@ -384,10 +392,12 @@ def proxy_getter(self, data, key): data[key], self.graph, tracker=GetItemTracker(self, key) ) - def get_value(self): - return tuple(self[idx].get_value() for idx in range(len(self))) + def get_py_value(self, allow_tensor=False): + return tuple( + self[idx].get_py_value(allow_tensor) for idx in range(len(self)) + ) - def get_type(self): + def get_py_type(self): return tuple def _reconstruct(self, codegen: PyCodeGen): @@ -474,7 +484,7 @@ def count(self, value: VariableBase): assert isinstance( eq_bool, ConstantVariable ), "bool should return ConstantVariable" - if eq.get_value() is True: + if eq.get_py_value() is True: count += 1 continue @@ -496,7 +506,7 @@ def index(self, value: VariableBase): assert isinstance( eq_bool, ConstantVariable ), "bool should return ConstantVariable" - if eq.get_value() is True: + if eq.get_py_value() is True: return VariableFactory.from_value( res, self.graph, DummyTracker([self, value]) ) @@ -523,10 +533,10 @@ def __init__( super().__init__(graph, tracker) self.value = val_range - def get_type(self): + def get_py_type(self): return range - def get_value(self): + def get_py_value(self, allow_tensor=False): return self.value def getitem(self, key): @@ -615,13 +625,13 @@ def proxy_getter(self, data, key): data[key], self.graph, tracker=GetItemTracker(self, key) ) - def get_value(self): + def get_py_value(self, allow_tensor=False): return { - key: value.get_value() + key: value.get_py_value(allow_tensor) for key, value in self.proxy.get_all().items() } - def get_type(self): + def get_py_type(self): return dict def _reconstruct(self, codegen: PyCodeGen): @@ -809,7 +819,7 @@ def pop(self, key, default=None): return temp_value def popitem(self): - key = self.keys().hold.get_value()[-1] + key = self.keys().hold.get_py_value()[-1] value = self.getitem(key) # TODO: key, value should be VariableBase but key maybe a int # assert isinstance(key, VariableBase), key