From fa4d2bc24a54db1576ffe49f44a9ce64dfd40cb7 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 28 Feb 2023 10:34:19 +0100 Subject: [PATCH] move Tensor remove_unused_cleanup_hooks to NameCtx This is for preparation of RETURNN Tensor usage. #252 --- nn/base.py | 2 -- nn/loop.py | 4 +++- nn/naming.py | 5 +++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/nn/base.py b/nn/base.py index 3ca3b9f8..0b8ce35c 100644 --- a/nn/base.py +++ b/nn/base.py @@ -152,7 +152,6 @@ def __init__( self.data = data name_ctx.layer_dict = layer_dict name_ctx.tensor = self - self.remove_unused_cleanup_hooks = [] # type: List[Callable[[nn.Tensor], None]] def __repr__(self): parts = [self.__class__.__name__, self.raw_tensor.get_abs_name_repr()] @@ -455,7 +454,6 @@ def _replace_by(self, tensor: nn.Tensor): assert isinstance(tensor, nn.Tensor) self.raw_tensor = tensor.raw_tensor # type: nn.NameCtx self.data = tensor.data - self.remove_unused_cleanup_hooks.clear() def _sis_hash(self): from sisyphus.hash import sis_hash_helper # noqa diff --git a/nn/loop.py b/nn/loop.py index 9d5d7cf3..2996f26b 100644 --- a/nn/loop.py +++ b/nn/loop.py @@ -214,7 +214,9 @@ def last(self, source: nn.Tensor, *, name: Optional[str] = None) -> nn.Tensor: predefined_out_data=source.data, name=name or sub_layer_name.replace("/", "_"), ) - res.remove_unused_cleanup_hooks.append(lambda _: source.raw_tensor.layer_dict.pop("need_last")) + res.raw_tensor.tensor_remove_unused_cleanup_hooks.append( + lambda _: source.raw_tensor.layer_dict.pop("need_last") + ) res.raw_tensor.layer_extra_dependencies.append(source) self._last_frames[source.raw_tensor] = res return res diff --git a/nn/naming.py b/nn/naming.py index 4fc881ae..add0c28a 100644 --- a/nn/naming.py +++ b/nn/naming.py @@ -190,9 +190,10 @@ def __init__( """ self.module = module self.tensor = None # type: Optional[nn.Tensor] + self.tensor_parent_modules = [] # type: List[Tuple[nn.Module, str]] # via parent module attrib + self.tensor_remove_unused_cleanup_hooks = [] # type: List[Callable[[nn.Tensor], None]] self.layer_dict = None # type: Optional[nn.LayerDictRaw] self.layer_extra_dependencies = [] # type: List[nn.Tensor] - self.tensor_parent_modules = [] # type: List[Tuple[nn.Module, str]] # via parent module attrib self.debug_layer = None # type: Optional[nn.LayerBase] self._enter_stack_frames = None # type: Optional[Set[types.FrameType]] self.is_subnet = False # it says whether it can have children @@ -405,7 +406,7 @@ def _remove_unused_and_assign_parents_and_handle_subnets(self): assert name_ctx.parent name_ctx.parent.children.pop(name_ctx.name) if name_ctx.tensor is not None: - for hook in name_ctx.tensor.remove_unused_cleanup_hooks: + for hook in name_ctx.tensor_remove_unused_cleanup_hooks: hook(name_ctx.tensor) else: for name, child in name_ctx.children.items():