Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
Fix Sequential and hook error
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f committed Jul 17, 2023
1 parent 3b475ac commit 94b568f
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 23 deletions.
2 changes: 1 addition & 1 deletion sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1875,7 +1875,7 @@ def STORE_ATTR(self, instr):
)
else:
raise NotImplementException(
f"SETATTR don't support {obj}.{key}={val}"
f"STORE_ATTR don't support {obj}.{key}={val}"
)

def FOR_ITER(self, instr):
Expand Down
35 changes: 13 additions & 22 deletions sot/opcode_translator/executor/variables/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,7 @@
object_equal_stringify_guard,
union_free_vars,
)
from ..tracker import (
DanglingTracker,
DummyTracker,
GetAttrTracker,
GetItemTracker,
Tracker,
)
from ..tracker import DanglingTracker, DummyTracker, GetAttrTracker, Tracker
from .base import VariableBase, VariableFactory
from .basic import ConstantVariable, PrintStmtVariable

Expand Down Expand Up @@ -417,26 +411,23 @@ def get_symbol(self) -> Symbol:
return Symbol(self.name)

def call_function(self, /, *args, **kwargs):
# TODO: Remove this trick after we support for-loop.
if isinstance(self.value, paddle.nn.Sequential):
assert len(args) == 1, "Sequential only accept one input"
input = args[0]
for i, layer in enumerate(self.value._sub_layers.values()):
layer_var = VariableFactory.from_value(
layer, self.graph, tracker=GetItemTracker(self, i)
)
assert isinstance(layer_var, LayerVariable)
input = layer_var(input)
return input
return self.graph.call_layer(self, *args, **kwargs)

@VariableFactory.register_from_value(successor="UserDefinedLayerVariable")
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
# TODO(SigureMo): Add a more common way to check if a value is a paddle builtin layer.
if isinstance(value, paddle.nn.Layer) and value.__module__.startswith(
"paddle.nn."
):
return PaddleLayerVariable(value, graph, tracker)
if isinstance(value, paddle.nn.Layer):
# If there is a user-defined behavior, such as a container class layer
# or a hook on the layer, it needs to be converted to UserDefinedLayerVariable,
# otherwise converted to PaddleLayerVariable
if (
value.__module__.startswith("paddle.nn.Sequential")
or value._forward_pre_hooks
or value._forward_post_hooks
):
return None
if value.__module__.startswith("paddle.nn."):
return PaddleLayerVariable(value, graph, tracker)
return None

@property
Expand Down
5 changes: 5 additions & 0 deletions sot/opcode_translator/skip_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ def _module_dir(m: types.ModuleType):
f"^({'|'.join(map(re.escape, skip_file_names))})"
)

no_skip_file_names = {paddle_path + 'nn/layer/container.py'}


customed_skip_code = set()


Expand All @@ -124,6 +127,8 @@ def need_skip_path(filepath: str) -> bool:
Returns:
bool: True if the file should be skipped.
"""
if filepath in no_skip_file_names:
return False
if not filepath.startswith("<"):
filepath = os.path.abspath(filepath)
return bool(skip_file_name_re.match(filepath))
Expand Down

0 comments on commit 94b568f

Please sign in to comment.