From ed5462b861e381ff0f3d5c827c5e01bfdd58098c Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Mon, 27 Feb 2023 23:57:01 +0100 Subject: [PATCH] move Tensor layer_dict to NameCtx This is for preparation of RETURNN Tensor usage. #252 --- nn/array_.py | 4 ++-- nn/base.py | 60 +++++++++++++++++++++++++++------------------------- nn/cond.py | 2 +- nn/loop.py | 56 +++++++++++++++++++++++++----------------------- nn/module.py | 8 +++---- nn/naming.py | 24 +++++++++++---------- 6 files changed, 81 insertions(+), 73 deletions(-) diff --git a/nn/array_.py b/nn/array_.py index 04e63344..ce6415c7 100644 --- a/nn/array_.py +++ b/nn/array_.py @@ -20,8 +20,8 @@ def constant_value(x: nn.Tensor) -> Optional[Union[int, float, complex, bool, st """ If the tensor is a constant, return its value. """ - if x.layer_dict and x.layer_dict["class"] == "constant": - return x.layer_dict["value"] + if x.raw_tensor.layer_dict and x.raw_tensor.layer_dict["class"] == "constant": + return x.raw_tensor.layer_dict["value"] return None diff --git a/nn/base.py b/nn/base.py index 91e3908d..45a5efbf 100644 --- a/nn/base.py +++ b/nn/base.py @@ -113,6 +113,7 @@ def __init__( assert name_ctx.layer_ref is None assert name_ctx.layer is None self.debug_layer = None + self.extra_dependencies = [] # type: List[Tensor] if is_ref: assert layer_dict is None @@ -139,12 +140,11 @@ def __init__( data.batch = name_ctx.root.global_batch self.data = data - self.layer_dict = layer_dict + self.raw_tensor.layer_dict = layer_dict name_ctx.layer_ref = self if not is_ref: name_ctx.layer = self self.is_ref = is_ref - self.extra_dependencies = [] # type: List[Tensor] self.remove_unused_cleanup_hooks = [] # type: List[Callable[[nn.Tensor], None]] def __repr__(self): @@ -160,7 +160,10 @@ def __repr__(self): parts.append(repr(self.data.placeholder)) if not self.is_ref: parts.append( - f"via {self.raw_tensor.module if self.raw_tensor.module else self.layer_dict.get('class', '?')!r}" + f"via " + + repr( + self.raw_tensor.module if self.raw_tensor.module else self.raw_tensor.layer_dict.get("class", "?") + ) ) if self.data and self.data.control_flow_ctx: parts.append(f"ctx={self.data.control_flow_ctx.repr_inner()}") @@ -352,10 +355,10 @@ def mark_as_loss( """ root_scope = self.raw_tensor.root res = nn.copy(self, name=root_scope.get_new_child(suggested_name=name)) - res.layer_dict["loss"] = "as_is" + res.raw_tensor.layer_dict["loss"] = "as_is" loss_opts = {} if scale is not None and scale != 1: - assert "loss_scale" not in res.layer_dict + assert "loss_scale" not in res.raw_tensor.layer_dict loss_opts["scale"] = scale if as_error: loss_opts["as_error"] = True @@ -366,7 +369,7 @@ def mark_as_loss( if custom_inv_norm_factor is not None: loss_opts["custom_inv_norm_factor"] = custom_inv_norm_factor if loss_opts: - res.layer_dict["loss_opts"] = loss_opts + res.raw_tensor.layer_dict["loss_opts"] = loss_opts # Add it to the root name scope marked_losses list. # Note that this logic might change. root_scope.marked_losses.append(res) @@ -388,11 +391,11 @@ def mark_as_output(self, *, _scope: Optional[nn.NameCtx] = None) -> Tensor: pass # not needed elif self.raw_tensor.parent is not scope: res = nn.copy(self, name=scope.get_new_child(suggested_name=self.raw_tensor.get_abs_name(join_str="_"))) - res.layer_dict["is_output_layer"] = True + res.raw_tensor.layer_dict["is_output_layer"] = True else: assert self.raw_tensor.parent is scope assert not self.is_ref - self.layer_dict["is_output_layer"] = True + self.raw_tensor.layer_dict["is_output_layer"] = True scope.marked_outputs.append(res) return res @@ -427,13 +430,13 @@ def _maybe_add_dep(x): if _extra_layer_dict: nest.map_structure(_maybe_add_dep, _extra_layer_dict) - if hasattr(self, "layer_dict") and self.layer_dict: # hasattr to be able to run this function early - nest.map_structure(_maybe_add_dep, self.layer_dict) + if self.raw_tensor.layer_dict: + nest.map_structure(_maybe_add_dep, self.raw_tensor.layer_dict) if self.raw_tensor.children and "output" in self.raw_tensor.children: _maybe_add_dep(self.raw_tensor.children["output"].layer_ref) if self.raw_tensor.parent and self.raw_tensor.parent.layer_ref: _maybe_add_dep(self.raw_tensor.parent.layer_ref) - if getattr(self, "extra_dependencies", None): + if self.extra_dependencies: dep_list.extend(self.extra_dependencies) return dep_list @@ -444,9 +447,8 @@ def _replace_by(self, tensor: nn.Tensor): """ assert isinstance(tensor, nn.Tensor) self.parent_modules = tensor.parent_modules - self.raw_tensor = tensor.raw_tensor + self.raw_tensor = tensor.raw_tensor # type: nn.NameCtx self.data = tensor.data - self.layer_dict = tensor.layer_dict self.is_ref = tensor.is_ref self.extra_dependencies = tensor.extra_dependencies self.remove_unused_cleanup_hooks.clear() @@ -456,7 +458,7 @@ def _sis_hash(self): if self.is_ref: return sis_hash_helper(self.raw_tensor.get_abs_name()) - return sis_hash_helper(self.layer_dict) + return sis_hash_helper(self.raw_tensor.layer_dict) def __add__(self, other: Union[RawTensorTypes, Tensor]) -> Tensor: if isinstance(other, (int, float, numpy.number)) and other == 0: @@ -671,10 +673,10 @@ def initial(self, value: Optional[nn.init.ParamInitType]): if isinstance(value, nn.init.ParamInit): value = value(shape=self.dims, dtype=self.dtype) if value is None: - self.layer_dict.pop("init", None) - self.layer_dict.pop("init_by_layer", None) + self.raw_tensor.layer_dict.pop("init", None) + self.raw_tensor.layer_dict.pop("init_by_layer", None) elif isinstance(value, nn.Tensor): - self.layer_dict.pop("init", None) + self.raw_tensor.layer_dict.pop("init", None) if not value.raw_tensor.parent.can_access_children_from_root: accessible_parent = value.raw_tensor.parent while not accessible_parent.can_access_children_from_root: @@ -686,10 +688,10 @@ def initial(self, value: Optional[nn.init.ParamInitType]): assert ( dep.raw_tensor.parent.can_access_children_from_root ), f"dep {dep} of moved value {value} is not accessible" - self.layer_dict["init_by_layer"] = value + self.raw_tensor.layer_dict["init_by_layer"] = value else: - self.layer_dict.pop("init_by_layer", None) - self.layer_dict["init"] = value + self.raw_tensor.layer_dict.pop("init_by_layer", None) + self.raw_tensor.layer_dict["init"] = value if nn.is_debug_eager_mode_enabled(): shape = [d.get_dim_value() for d in self.dims] if isinstance(value, nn.Tensor): @@ -710,9 +712,9 @@ def initial_value(self) -> Optional[Union[nn.Tensor, nn.RawTensorTypes]]: """ In case initial is a ParamInit, this will return the actual value. """ - if self.layer_dict.get("init_by_layer", None) is not None: - return self.layer_dict["init_by_layer"] - return self.layer_dict.get("init", None) + if self.raw_tensor.layer_dict.get("init_by_layer", None) is not None: + return self.raw_tensor.layer_dict["init_by_layer"] + return self.raw_tensor.layer_dict.get("init", None) @property def weight_decay(self) -> float: @@ -722,26 +724,26 @@ def weight_decay(self) -> float: can be controlled via the ``decouple_constraints`` config option. https://github.com/rwth-i6/returnn_common/issues/59#issuecomment-1073913421 """ - return self.layer_dict.get("L2", 0.0) + return self.raw_tensor.layer_dict.get("L2", 0.0) @weight_decay.setter def weight_decay(self, value: Optional[float]): if value: - self.layer_dict["L2"] = value + self.raw_tensor.layer_dict["L2"] = value else: - self.layer_dict.pop("L2", None) + self.raw_tensor.layer_dict.pop("L2", None) @property def trainable(self) -> Optional[bool]: """trainable""" - return self.layer_dict.get("trainable", None) + return self.raw_tensor.layer_dict.get("trainable", None) @trainable.setter def trainable(self, value: Optional[bool]): if value is not None: - self.layer_dict["trainable"] = value + self.raw_tensor.layer_dict["trainable"] = value else: - self.layer_dict.pop("trainable", None) + self.raw_tensor.layer_dict.pop("trainable", None) class LayerState(dict): diff --git a/nn/cond.py b/nn/cond.py index cba11c52..0af9af50 100644 --- a/nn/cond.py +++ b/nn/cond.py @@ -167,7 +167,7 @@ def false(self, false_value): name = true_v.raw_tensor.name false_values_flat[i] = nn.copy(false_v, name=self.false_branch_name_ctx.get_child(name)) if name != "output": - false_values_flat[i].layer_dict["is_output_layer"] = True + false_values_flat[i].raw_tensor.layer_dict["is_output_layer"] = True false_value = nest.pack_sequence_as(false_value, false_values_flat) self.false_branch_name_ctx.__exit__(None, None, None) self._false_value = false_value diff --git a/nn/loop.py b/nn/loop.py index e1e02d23..a02607dc 100644 --- a/nn/loop.py +++ b/nn/loop.py @@ -181,7 +181,7 @@ def stack(self, source: nn.Tensor, *, name: Optional[str] = None) -> nn.Tensor: res = copy(source, name=name) assert isinstance(res, nn.Tensor) if res.raw_tensor.name != "output": - res.layer_dict["is_output_layer"] = True + res.raw_tensor.layer_dict["is_output_layer"] = True # We access the returned layer-ref from outside, thus fix the data template. res.data = res.data.copy_add_dim_by_tag(dim_tag=self.axis, unbroadcast=True, axis=0) res.data.time_dim_axis = 0 @@ -208,7 +208,7 @@ def last(self, source: nn.Tensor, *, name: Optional[str] = None) -> nn.Tensor: sub_layer_name = source.raw_tensor.get_name_in_ctx(self.name_ctx).replace("/", ".") source = nn.copy(source, name=self.name_ctx.get_new_child(sub_layer_name)) assert source.raw_tensor.parent is self.name_ctx - source.layer_dict["need_last"] = True + source.raw_tensor.layer_dict["need_last"] = True sub_layer_name = source.raw_tensor.get_name_in_ctx(self.name_ctx) with self.name_ctx.parent: # need to be outside the loop res = nn.make_layer( @@ -216,7 +216,7 @@ def last(self, source: nn.Tensor, *, name: Optional[str] = None) -> nn.Tensor: predefined_out_data=source.data, name=name or sub_layer_name.replace("/", "_"), ) - res.remove_unused_cleanup_hooks.append(lambda _: source.layer_dict.pop("need_last")) + res.remove_unused_cleanup_hooks.append(lambda _: source.raw_tensor.layer_dict.pop("need_last")) res.extra_dependencies.append(source) self._last_frames[source.raw_tensor] = res return res @@ -456,21 +456,21 @@ def _map_ref_to_name_ctx(layer_ref: nn.Tensor, name_ctx: nn.NameCtx, initial: nn ctx, ctx_ = layer_ctx_list[i : i + 2] assert isinstance(ctx, nn.NameCtx) and isinstance(ctx_, nn.NameCtx) if isinstance(ctx.module, nn.MaskedComputationModule): - ctx_.layer.layer_dict["is_output_layer"] = True + ctx_.layer.raw_tensor.layer_dict["is_output_layer"] = True break # Potential optimization for RETURNN layers. # See ReturnnWrappedLayerBase._get_recurrent_state. - if layer_ref.layer_dict: + if layer_ref.raw_tensor.layer_dict: _do_const_initial_value_opt = False _const_initial_value_opt_layer_white_list = {"cum_concat", "rec"} - if layer_ref.layer_dict["class"] in _const_initial_value_opt_layer_white_list: + if layer_ref.raw_tensor.layer_dict["class"] in _const_initial_value_opt_layer_white_list: _do_const_initial_value_opt = True - elif layer_ref.layer_dict["class"] == "get_last_hidden_state": - src = layer_ref.layer_dict["from"] + elif layer_ref.raw_tensor.layer_dict["class"] == "get_last_hidden_state": + src = layer_ref.raw_tensor.layer_dict["from"] assert isinstance(src, nn.Tensor) - if src.layer_dict: - if src.layer_dict["class"] in _const_initial_value_opt_layer_white_list: + if src.raw_tensor.layer_dict: + if src.raw_tensor.layer_dict["class"] in _const_initial_value_opt_layer_white_list: _do_const_initial_value_opt = True if _do_const_initial_value_opt: # Note: Only do this optimization for some layers because otherwise @@ -479,12 +479,12 @@ def _map_ref_to_name_ctx(layer_ref: nn.Tensor, name_ctx: nn.NameCtx, initial: nn if initial_const is not None: initial = initial_const - if layer_ref.layer_dict["class"] == "get_last_hidden_state": + if layer_ref.raw_tensor.layer_dict["class"] == "get_last_hidden_state": used_state_eliminate_optimization = False - key = layer_ref.layer_dict.get("key", "state") - src = layer_ref.layer_dict["from"] + key = layer_ref.raw_tensor.layer_dict.get("key", "state") + src = layer_ref.raw_tensor.layer_dict["from"] assert isinstance(src, nn.Tensor) - src_state_opt = src.layer_dict.get("state") if src.layer_dict else None + src_state_opt = src.raw_tensor.layer_dict.get("state") if src.raw_tensor.layer_dict else None if isinstance(src_state_opt, nn.LayerState): src_state_for_key = src_state_opt.get(key) if isinstance(src_state_for_key, PrevTensorRef): @@ -494,9 +494,11 @@ def _map_ref_to_name_ctx(layer_ref: nn.Tensor, name_ctx: nn.NameCtx, initial: nn used_state_eliminate_optimization = True src_state_opt[key] = None if all(opt is None for opt in nest.flatten(src_state_opt)): - del src.layer_dict["state"] + del src.raw_tensor.layer_dict["state"] # We need to pass the initial_state instead though. - src_initial_state_opt = src.layer_dict.setdefault("initial_state", nn.LayerState()) + src_initial_state_opt = src.raw_tensor.layer_dict.setdefault( + "initial_state", nn.LayerState() + ) src_initial_state_opt[key] = initial # If there is any other code which refers to this state, it can access the passed layer. # So anyway pass through. @@ -510,18 +512,18 @@ def _map_ref_to_name_ctx(layer_ref: nn.Tensor, name_ctx: nn.NameCtx, initial: nn else: # class != get_last_hidden_state - if layer_ref.layer_dict["class"] == "cum_concat": - layer_state_opt = layer_ref.layer_dict.get("state") + if layer_ref.raw_tensor.layer_dict["class"] == "cum_concat": + layer_state_opt = layer_ref.raw_tensor.layer_dict.get("state") if isinstance(layer_state_opt, nn.LayerState) and set(layer_state_opt.keys()) == {"state"}: layer_state = layer_state_opt.state if isinstance(layer_state, PrevTensorRef) and layer_state.cur_layer_name_ctx is name_ctx: # The 'state' argument refers to "prev:..." of itself. # This is redundant, so we don't need to pass it. - layer_ref.layer_dict.pop("state") + layer_ref.raw_tensor.layer_dict.pop("state") - assert "initial_state" not in layer_ref.layer_dict - assert "initial_output" not in layer_ref.layer_dict - layer_ref.layer_dict["initial_output"] = initial + assert "initial_state" not in layer_ref.raw_tensor.layer_dict + assert "initial_output" not in layer_ref.raw_tensor.layer_dict + layer_ref.raw_tensor.layer_dict["initial_output"] = initial else: # layer_ref not Layer raise NotImplementedError(f"{self}.assign to {layer_ref} but layer expected") @@ -535,10 +537,12 @@ def _map_ref_to_name_ctx(layer_ref: nn.Tensor, name_ctx: nn.NameCtx, initial: nn # Currently, RETURNN does not properly support a state in a subnet. # So we copy the layer to the loop root under the reserved existing name. nn.copy(layer_ref, name=name_ctx) - if layer_ref.layer_dict: - assert "initial_state" not in layer_ref.layer_dict # not supported/implemented - if "initial_output" in layer_ref.layer_dict: - name_ctx.layer.layer_dict["initial_output"] = layer_ref.layer_dict.pop("initial_output") + if layer_ref.raw_tensor.layer_dict: + assert "initial_state" not in layer_ref.raw_tensor.layer_dict # not supported/implemented + if "initial_output" in layer_ref.raw_tensor.layer_dict: + name_ctx.layer.raw_tensor.layer_dict[ + "initial_output" + ] = layer_ref.raw_tensor.layer_dict.pop("initial_output") else: prev_ref.assign_new_cur_layer_name_ctx(layer_ref.raw_tensor) diff --git a/nn/module.py b/nn/module.py index c335e9c1..984af7c0 100644 --- a/nn/module.py +++ b/nn/module.py @@ -365,15 +365,15 @@ def returnn_layer_get_recurrent_state(layer: nn.Tensor) -> nn.LayerState: # Note that this is actually layer specific. # We try to use a number of heuristics to get it right for the common cases. name = f"{layer.raw_tensor.name}_state" - layer_class = layer.layer_dict["class"] + layer_class = layer.raw_tensor.layer_dict["class"] if layer_class in {"cum_concat", "cumsum"}: return nn.LayerState(layer) # the layer output itself is its state if layer_class == "window": return nn.LayerState(_get_last_hidden_state(layer, out_dim=layer.feature_dim, name=name)) # This is some very generic fallback code, which probably does not work correctly in some cases. - out_dim = layer.layer_dict["out_dim"] - if layer_class == "rec" and isinstance(layer.layer_dict["unit"], str): - if "lstm" in layer.layer_dict["unit"].lower(): + out_dim = layer.raw_tensor.layer_dict["out_dim"] + if layer_class == "rec" and isinstance(layer.raw_tensor.layer_dict["unit"], str): + if "lstm" in layer.raw_tensor.layer_dict["unit"].lower(): h = _get_last_hidden_state(layer, out_dim=out_dim, key="h", name=f"{name}_h") c = _get_last_hidden_state(layer, out_dim=out_dim, key="c", name=f"{name}_c") return nn.LayerState(h=h, c=c) diff --git a/nn/naming.py b/nn/naming.py index 6393ffb7..9f661e55 100644 --- a/nn/naming.py +++ b/nn/naming.py @@ -191,6 +191,7 @@ def __init__( self.module = module self.layer_ref = None # type: Optional[nn.Tensor] self.layer = None # type: Optional[nn.Tensor] + self.layer_dict = None # type: Optional[nn.LayerDictRaw] self._enter_stack_frames = None # type: Optional[Set[types.FrameType]] self.is_subnet = False # it says whether it can have children self._subnet_main_output = None # type: Optional[nn.Tensor] # when this is via SubnetworkLayer @@ -276,7 +277,8 @@ def move_layer_ref_here(self: NameCtx, layer_ref: nn.Tensor): # Now reassign. layer_ref.raw_tensor = self self.layer_ref = layer_ref - self.layer = layer_ref if layer_ref.layer_dict else None + self.layer = layer_ref if old_name_ctx.layer_dict else None + self.layer_dict = old_name_ctx.layer_dict self.module = old_name_ctx.module self.is_subnet = old_name_ctx.is_subnet self._subnet_main_output = old_name_ctx._subnet_main_output @@ -294,14 +296,14 @@ def move_layer_ref_here(self: NameCtx, layer_ref: nn.Tensor): self.children[name] = child old_name_ctx.children = self.children # just in case there is some other reference to the old name ctx - if layer_ref.layer_dict: + if old_name_ctx.layer_dict: def _check_layer_opt_value(v): if isinstance(v, nn.Net): assert v.name_ctx is old_name_ctx v.name_ctx = self - nest.map_structure(_check_layer_opt_value, layer_ref.layer_dict) + nest.map_structure(_check_layer_opt_value, old_name_ctx.layer_dict) @property def root(self) -> NameCtx: @@ -360,8 +362,9 @@ def _remove_unused_and_assign_parents_and_handle_subnets(self): ] # type: List[Tuple[nn.Tensor,List[nn.Tensor]]] while queue: tensor, src = queue.pop(0) - if tensor.raw_tensor is used_names: + if tensor.raw_tensor in used_names: continue + used_names.add(tensor.raw_tensor) src_ = src + [tensor] for dep in tensor.get_dependencies(): if dep.raw_tensor not in used_names: @@ -373,7 +376,7 @@ def _remove_unused_and_assign_parents_and_handle_subnets(self): tensor._assign_parent_name_ctx(ref_ctx=root) # Handle subnetworks: Flatten away if just a single entry. Create layer if not created yet. - ctx = tensor.raw_tensor + ctx = tensor.raw_tensor # type: nn.NameCtx ctx.make_all_sub_networks_and_optimize() # Add tensor name including all parents. @@ -382,7 +385,6 @@ def _remove_unused_and_assign_parents_and_handle_subnets(self): for ctx in tensor.raw_tensor.get_abs_name_ctx_list(): if ctx in used_names: continue # skip early, to skip the extra checks below - used_names.add(ctx) if ctx.layer_ref is not None and ctx.layer_ref is not tensor: queue.append((ctx.layer_ref, src_)) @@ -432,10 +434,10 @@ def prepare_for_config_serialization(self, root_module: nn.Module): # root_mod_call.layer might be None if the subnet is not yet initialized. if root_mod_call.layer_ref is not None: assert not self.layer_ref # not sure. maybe just reset? - assert root_mod_call.layer.layer_dict["class"] == "subnetwork" + assert root_mod_call.layer_dict["class"] == "subnetwork" sub_out = root_mod_call.children.pop("output") - assert sub_out.layer.layer_dict["class"] == "copy" - sub_real_out = sub_out.layer.layer_dict["from"] + assert sub_out.layer_dict["class"] == "copy" + sub_real_out = sub_out.layer_dict["from"] assert isinstance(sub_real_out, nn.Tensor) # noinspection PyProtectedMember sub_out.layer._replace_by(sub_real_out) @@ -556,7 +558,7 @@ def get_name_in_ctx(self, ctx: NameCtx, *, middle_prefix: str = "", shorten_subn while len(self_name_abs) >= 2: ctx_, ctx__ = self_name_abs[-2:] assert isinstance(ctx_, NameCtx) and isinstance(ctx__, NameCtx) - if ctx_.layer is not None and ctx_.layer.layer_dict["class"] == "subnetwork": + if ctx_.layer is not None and ctx_.layer.raw_tensor.layer_dict["class"] == "subnetwork": if ctx_._subnet_main_output is ctx__.layer_ref or ctx_.children.get("output") is ctx__: self_name_abs.pop(-1) continue # check again @@ -1017,7 +1019,7 @@ def make_net_dict_raw(self, net: Net, *, _stack: Optional[_StackInfo] = None) -> if not sub_name_ctx.layer: continue - layer_dict = sub_name_ctx.layer.layer_dict.copy() + layer_dict = sub_name_ctx.layer.raw_tensor.layer_dict.copy() assert "class" in layer_dict data_template = sub_name_ctx.layer_ref.data.copy_template()