Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
feifei-111 committed Sep 25, 2023
1 parent bc351d2 commit e4b2b1d
Showing 1 changed file with 21 additions and 21 deletions.
42 changes: 21 additions & 21 deletions sot/opcode_translator/executor/variables/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,17 +436,15 @@ def call_function(self, /, *args, **kwargs):

return fn_var(*(self, *args), **kwargs)

@VariableFactory.register_from_value(successor="PaddleApiVariable")
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
if isinstance(value, paddle.nn.Layer):
return UserDefinedLayerVariable(value, graph, tracker)
return None

@property
def main_info(self) -> dict[str, Any]:
return {
"name": self.value.__class__.__name__,
}
def getitem(self, key):
if isinstance(self.value, paddle.nn.LayerList) and isinstance(
key, SliceVariable
):
raise BreakGraphError(
"call LayerList.__getitem__ with slice as key"
)
else:
return super().getitem(key)

def get_iter(self):
if isinstance(self.value, PD_SEQ_CONTAINERS):
Expand All @@ -456,6 +454,18 @@ def get_iter(self):
else:
return super().get_iter()

@property
def main_info(self) -> dict[str, Any]:
return {
"name": self.value.__class__.__name__,
}

@VariableFactory.register_from_value(successor="PaddleApiVariable")
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
if isinstance(value, paddle.nn.Layer):
return UserDefinedLayerVariable(value, graph, tracker)
return None


class BuiltinVariable(FunctionVariable):
"""
Expand Down Expand Up @@ -625,16 +635,6 @@ def make_stringify_guard(self) -> list[StringifyExpression]:
else:
return super().make_stringify_guard()

def getitem(self, key):
if isinstance(self.value, paddle.nn.LayerList) and isinstance(
key, SliceVariable
):
raise BreakGraphError(
"call LayerList.__getitem__ with slice as key"
)
else:
return super().getitem(key)


class ClassVariable(CallableVariable):
def __init__(self, class_: type, graph: FunctionGraph, tracker: Tracker):
Expand Down

0 comments on commit e4b2b1d

Please sign in to comment.