Skip to content

Commit

Permalink
move Tensor remove_unused_cleanup_hooks to NameCtx
Browse files Browse the repository at this point in the history
This is for preparation of RETURNN Tensor usage.
#252
  • Loading branch information
albertz committed Feb 28, 2023
1 parent 5c3a4b6 commit fa4d2bc
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
2 changes: 0 additions & 2 deletions nn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ def __init__(
self.data = data
name_ctx.layer_dict = layer_dict
name_ctx.tensor = self
self.remove_unused_cleanup_hooks = [] # type: List[Callable[[nn.Tensor], None]]

def __repr__(self):
parts = [self.__class__.__name__, self.raw_tensor.get_abs_name_repr()]
Expand Down Expand Up @@ -455,7 +454,6 @@ def _replace_by(self, tensor: nn.Tensor):
assert isinstance(tensor, nn.Tensor)
self.raw_tensor = tensor.raw_tensor # type: nn.NameCtx
self.data = tensor.data
self.remove_unused_cleanup_hooks.clear()

def _sis_hash(self):
from sisyphus.hash import sis_hash_helper # noqa
Expand Down
4 changes: 3 additions & 1 deletion nn/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,9 @@ def last(self, source: nn.Tensor, *, name: Optional[str] = None) -> nn.Tensor:
predefined_out_data=source.data,
name=name or sub_layer_name.replace("/", "_"),
)
res.remove_unused_cleanup_hooks.append(lambda _: source.raw_tensor.layer_dict.pop("need_last"))
res.raw_tensor.tensor_remove_unused_cleanup_hooks.append(
lambda _: source.raw_tensor.layer_dict.pop("need_last")
)
res.raw_tensor.layer_extra_dependencies.append(source)
self._last_frames[source.raw_tensor] = res
return res
Expand Down
5 changes: 3 additions & 2 deletions nn/naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,10 @@ def __init__(
"""
self.module = module
self.tensor = None # type: Optional[nn.Tensor]
self.tensor_parent_modules = [] # type: List[Tuple[nn.Module, str]] # via parent module attrib
self.tensor_remove_unused_cleanup_hooks = [] # type: List[Callable[[nn.Tensor], None]]
self.layer_dict = None # type: Optional[nn.LayerDictRaw]
self.layer_extra_dependencies = [] # type: List[nn.Tensor]
self.tensor_parent_modules = [] # type: List[Tuple[nn.Module, str]] # via parent module attrib
self.debug_layer = None # type: Optional[nn.LayerBase]
self._enter_stack_frames = None # type: Optional[Set[types.FrameType]]
self.is_subnet = False # it says whether it can have children
Expand Down Expand Up @@ -405,7 +406,7 @@ def _remove_unused_and_assign_parents_and_handle_subnets(self):
assert name_ctx.parent
name_ctx.parent.children.pop(name_ctx.name)
if name_ctx.tensor is not None:
for hook in name_ctx.tensor.remove_unused_cleanup_hooks:
for hook in name_ctx.tensor_remove_unused_cleanup_hooks:
hook(name_ctx.tensor)
else:
for name, child in name_ctx.children.items():
Expand Down

0 comments on commit fa4d2bc

Please sign in to comment.