Skip to content

Commit

Permalink
move Tensor layer_dict to NameCtx
Browse files Browse the repository at this point in the history
This is for preparation of RETURNN Tensor usage.
#252
  • Loading branch information
albertz committed Feb 27, 2023
1 parent 9e58dd0 commit ed5462b
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 73 deletions.
4 changes: 2 additions & 2 deletions nn/array_.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ def constant_value(x: nn.Tensor) -> Optional[Union[int, float, complex, bool, st
"""
If the tensor is a constant, return its value.
"""
if x.layer_dict and x.layer_dict["class"] == "constant":
return x.layer_dict["value"]
if x.raw_tensor.layer_dict and x.raw_tensor.layer_dict["class"] == "constant":
return x.raw_tensor.layer_dict["value"]
return None


Expand Down
60 changes: 31 additions & 29 deletions nn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def __init__(
assert name_ctx.layer_ref is None
assert name_ctx.layer is None
self.debug_layer = None
self.extra_dependencies = [] # type: List[Tensor]

if is_ref:
assert layer_dict is None
Expand All @@ -139,12 +140,11 @@ def __init__(
data.batch = name_ctx.root.global_batch

self.data = data
self.layer_dict = layer_dict
self.raw_tensor.layer_dict = layer_dict
name_ctx.layer_ref = self
if not is_ref:
name_ctx.layer = self
self.is_ref = is_ref
self.extra_dependencies = [] # type: List[Tensor]
self.remove_unused_cleanup_hooks = [] # type: List[Callable[[nn.Tensor], None]]

def __repr__(self):
Expand All @@ -160,7 +160,10 @@ def __repr__(self):
parts.append(repr(self.data.placeholder))
if not self.is_ref:
parts.append(
f"via {self.raw_tensor.module if self.raw_tensor.module else self.layer_dict.get('class', '?')!r}"
f"via "
+ repr(
self.raw_tensor.module if self.raw_tensor.module else self.raw_tensor.layer_dict.get("class", "?")
)
)
if self.data and self.data.control_flow_ctx:
parts.append(f"ctx={self.data.control_flow_ctx.repr_inner()}")
Expand Down Expand Up @@ -352,10 +355,10 @@ def mark_as_loss(
"""
root_scope = self.raw_tensor.root
res = nn.copy(self, name=root_scope.get_new_child(suggested_name=name))
res.layer_dict["loss"] = "as_is"
res.raw_tensor.layer_dict["loss"] = "as_is"
loss_opts = {}
if scale is not None and scale != 1:
assert "loss_scale" not in res.layer_dict
assert "loss_scale" not in res.raw_tensor.layer_dict
loss_opts["scale"] = scale
if as_error:
loss_opts["as_error"] = True
Expand All @@ -366,7 +369,7 @@ def mark_as_loss(
if custom_inv_norm_factor is not None:
loss_opts["custom_inv_norm_factor"] = custom_inv_norm_factor
if loss_opts:
res.layer_dict["loss_opts"] = loss_opts
res.raw_tensor.layer_dict["loss_opts"] = loss_opts
# Add it to the root name scope marked_losses list.
# Note that this logic might change.
root_scope.marked_losses.append(res)
Expand All @@ -388,11 +391,11 @@ def mark_as_output(self, *, _scope: Optional[nn.NameCtx] = None) -> Tensor:
pass # not needed
elif self.raw_tensor.parent is not scope:
res = nn.copy(self, name=scope.get_new_child(suggested_name=self.raw_tensor.get_abs_name(join_str="_")))
res.layer_dict["is_output_layer"] = True
res.raw_tensor.layer_dict["is_output_layer"] = True
else:
assert self.raw_tensor.parent is scope
assert not self.is_ref
self.layer_dict["is_output_layer"] = True
self.raw_tensor.layer_dict["is_output_layer"] = True
scope.marked_outputs.append(res)
return res

Expand Down Expand Up @@ -427,13 +430,13 @@ def _maybe_add_dep(x):

if _extra_layer_dict:
nest.map_structure(_maybe_add_dep, _extra_layer_dict)
if hasattr(self, "layer_dict") and self.layer_dict: # hasattr to be able to run this function early
nest.map_structure(_maybe_add_dep, self.layer_dict)
if self.raw_tensor.layer_dict:
nest.map_structure(_maybe_add_dep, self.raw_tensor.layer_dict)
if self.raw_tensor.children and "output" in self.raw_tensor.children:
_maybe_add_dep(self.raw_tensor.children["output"].layer_ref)
if self.raw_tensor.parent and self.raw_tensor.parent.layer_ref:
_maybe_add_dep(self.raw_tensor.parent.layer_ref)
if getattr(self, "extra_dependencies", None):
if self.extra_dependencies:
dep_list.extend(self.extra_dependencies)
return dep_list

Expand All @@ -444,9 +447,8 @@ def _replace_by(self, tensor: nn.Tensor):
"""
assert isinstance(tensor, nn.Tensor)
self.parent_modules = tensor.parent_modules
self.raw_tensor = tensor.raw_tensor
self.raw_tensor = tensor.raw_tensor # type: nn.NameCtx
self.data = tensor.data
self.layer_dict = tensor.layer_dict
self.is_ref = tensor.is_ref
self.extra_dependencies = tensor.extra_dependencies
self.remove_unused_cleanup_hooks.clear()
Expand All @@ -456,7 +458,7 @@ def _sis_hash(self):

if self.is_ref:
return sis_hash_helper(self.raw_tensor.get_abs_name())
return sis_hash_helper(self.layer_dict)
return sis_hash_helper(self.raw_tensor.layer_dict)

def __add__(self, other: Union[RawTensorTypes, Tensor]) -> Tensor:
if isinstance(other, (int, float, numpy.number)) and other == 0:
Expand Down Expand Up @@ -671,10 +673,10 @@ def initial(self, value: Optional[nn.init.ParamInitType]):
if isinstance(value, nn.init.ParamInit):
value = value(shape=self.dims, dtype=self.dtype)
if value is None:
self.layer_dict.pop("init", None)
self.layer_dict.pop("init_by_layer", None)
self.raw_tensor.layer_dict.pop("init", None)
self.raw_tensor.layer_dict.pop("init_by_layer", None)
elif isinstance(value, nn.Tensor):
self.layer_dict.pop("init", None)
self.raw_tensor.layer_dict.pop("init", None)
if not value.raw_tensor.parent.can_access_children_from_root:
accessible_parent = value.raw_tensor.parent
while not accessible_parent.can_access_children_from_root:
Expand All @@ -686,10 +688,10 @@ def initial(self, value: Optional[nn.init.ParamInitType]):
assert (
dep.raw_tensor.parent.can_access_children_from_root
), f"dep {dep} of moved value {value} is not accessible"
self.layer_dict["init_by_layer"] = value
self.raw_tensor.layer_dict["init_by_layer"] = value
else:
self.layer_dict.pop("init_by_layer", None)
self.layer_dict["init"] = value
self.raw_tensor.layer_dict.pop("init_by_layer", None)
self.raw_tensor.layer_dict["init"] = value
if nn.is_debug_eager_mode_enabled():
shape = [d.get_dim_value() for d in self.dims]
if isinstance(value, nn.Tensor):
Expand All @@ -710,9 +712,9 @@ def initial_value(self) -> Optional[Union[nn.Tensor, nn.RawTensorTypes]]:
"""
In case initial is a ParamInit, this will return the actual value.
"""
if self.layer_dict.get("init_by_layer", None) is not None:
return self.layer_dict["init_by_layer"]
return self.layer_dict.get("init", None)
if self.raw_tensor.layer_dict.get("init_by_layer", None) is not None:
return self.raw_tensor.layer_dict["init_by_layer"]
return self.raw_tensor.layer_dict.get("init", None)

@property
def weight_decay(self) -> float:
Expand All @@ -722,26 +724,26 @@ def weight_decay(self) -> float:
can be controlled via the ``decouple_constraints`` config option.
https://github.com/rwth-i6/returnn_common/issues/59#issuecomment-1073913421
"""
return self.layer_dict.get("L2", 0.0)
return self.raw_tensor.layer_dict.get("L2", 0.0)

@weight_decay.setter
def weight_decay(self, value: Optional[float]):
if value:
self.layer_dict["L2"] = value
self.raw_tensor.layer_dict["L2"] = value
else:
self.layer_dict.pop("L2", None)
self.raw_tensor.layer_dict.pop("L2", None)

@property
def trainable(self) -> Optional[bool]:
"""trainable"""
return self.layer_dict.get("trainable", None)
return self.raw_tensor.layer_dict.get("trainable", None)

@trainable.setter
def trainable(self, value: Optional[bool]):
if value is not None:
self.layer_dict["trainable"] = value
self.raw_tensor.layer_dict["trainable"] = value
else:
self.layer_dict.pop("trainable", None)
self.raw_tensor.layer_dict.pop("trainable", None)


class LayerState(dict):
Expand Down
2 changes: 1 addition & 1 deletion nn/cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def false(self, false_value):
name = true_v.raw_tensor.name
false_values_flat[i] = nn.copy(false_v, name=self.false_branch_name_ctx.get_child(name))
if name != "output":
false_values_flat[i].layer_dict["is_output_layer"] = True
false_values_flat[i].raw_tensor.layer_dict["is_output_layer"] = True
false_value = nest.pack_sequence_as(false_value, false_values_flat)
self.false_branch_name_ctx.__exit__(None, None, None)
self._false_value = false_value
Expand Down
56 changes: 30 additions & 26 deletions nn/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def stack(self, source: nn.Tensor, *, name: Optional[str] = None) -> nn.Tensor:
res = copy(source, name=name)
assert isinstance(res, nn.Tensor)
if res.raw_tensor.name != "output":
res.layer_dict["is_output_layer"] = True
res.raw_tensor.layer_dict["is_output_layer"] = True
# We access the returned layer-ref from outside, thus fix the data template.
res.data = res.data.copy_add_dim_by_tag(dim_tag=self.axis, unbroadcast=True, axis=0)
res.data.time_dim_axis = 0
Expand All @@ -208,15 +208,15 @@ def last(self, source: nn.Tensor, *, name: Optional[str] = None) -> nn.Tensor:
sub_layer_name = source.raw_tensor.get_name_in_ctx(self.name_ctx).replace("/", ".")
source = nn.copy(source, name=self.name_ctx.get_new_child(sub_layer_name))
assert source.raw_tensor.parent is self.name_ctx
source.layer_dict["need_last"] = True
source.raw_tensor.layer_dict["need_last"] = True
sub_layer_name = source.raw_tensor.get_name_in_ctx(self.name_ctx)
with self.name_ctx.parent: # need to be outside the loop
res = nn.make_layer(
{"class": "rec_last_output", "rec_layer": self.name_ctx.layer_ref, "sub_layer_name": sub_layer_name},
predefined_out_data=source.data,
name=name or sub_layer_name.replace("/", "_"),
)
res.remove_unused_cleanup_hooks.append(lambda _: source.layer_dict.pop("need_last"))
res.remove_unused_cleanup_hooks.append(lambda _: source.raw_tensor.layer_dict.pop("need_last"))
res.extra_dependencies.append(source)
self._last_frames[source.raw_tensor] = res
return res
Expand Down Expand Up @@ -456,21 +456,21 @@ def _map_ref_to_name_ctx(layer_ref: nn.Tensor, name_ctx: nn.NameCtx, initial: nn
ctx, ctx_ = layer_ctx_list[i : i + 2]
assert isinstance(ctx, nn.NameCtx) and isinstance(ctx_, nn.NameCtx)
if isinstance(ctx.module, nn.MaskedComputationModule):
ctx_.layer.layer_dict["is_output_layer"] = True
ctx_.layer.raw_tensor.layer_dict["is_output_layer"] = True
break

# Potential optimization for RETURNN layers.
# See ReturnnWrappedLayerBase._get_recurrent_state.
if layer_ref.layer_dict:
if layer_ref.raw_tensor.layer_dict:
_do_const_initial_value_opt = False
_const_initial_value_opt_layer_white_list = {"cum_concat", "rec"}
if layer_ref.layer_dict["class"] in _const_initial_value_opt_layer_white_list:
if layer_ref.raw_tensor.layer_dict["class"] in _const_initial_value_opt_layer_white_list:
_do_const_initial_value_opt = True
elif layer_ref.layer_dict["class"] == "get_last_hidden_state":
src = layer_ref.layer_dict["from"]
elif layer_ref.raw_tensor.layer_dict["class"] == "get_last_hidden_state":
src = layer_ref.raw_tensor.layer_dict["from"]
assert isinstance(src, nn.Tensor)
if src.layer_dict:
if src.layer_dict["class"] in _const_initial_value_opt_layer_white_list:
if src.raw_tensor.layer_dict:
if src.raw_tensor.layer_dict["class"] in _const_initial_value_opt_layer_white_list:
_do_const_initial_value_opt = True
if _do_const_initial_value_opt:
# Note: Only do this optimization for some layers because otherwise
Expand All @@ -479,12 +479,12 @@ def _map_ref_to_name_ctx(layer_ref: nn.Tensor, name_ctx: nn.NameCtx, initial: nn
if initial_const is not None:
initial = initial_const

if layer_ref.layer_dict["class"] == "get_last_hidden_state":
if layer_ref.raw_tensor.layer_dict["class"] == "get_last_hidden_state":
used_state_eliminate_optimization = False
key = layer_ref.layer_dict.get("key", "state")
src = layer_ref.layer_dict["from"]
key = layer_ref.raw_tensor.layer_dict.get("key", "state")
src = layer_ref.raw_tensor.layer_dict["from"]
assert isinstance(src, nn.Tensor)
src_state_opt = src.layer_dict.get("state") if src.layer_dict else None
src_state_opt = src.raw_tensor.layer_dict.get("state") if src.raw_tensor.layer_dict else None
if isinstance(src_state_opt, nn.LayerState):
src_state_for_key = src_state_opt.get(key)
if isinstance(src_state_for_key, PrevTensorRef):
Expand All @@ -494,9 +494,11 @@ def _map_ref_to_name_ctx(layer_ref: nn.Tensor, name_ctx: nn.NameCtx, initial: nn
used_state_eliminate_optimization = True
src_state_opt[key] = None
if all(opt is None for opt in nest.flatten(src_state_opt)):
del src.layer_dict["state"]
del src.raw_tensor.layer_dict["state"]
# We need to pass the initial_state instead though.
src_initial_state_opt = src.layer_dict.setdefault("initial_state", nn.LayerState())
src_initial_state_opt = src.raw_tensor.layer_dict.setdefault(
"initial_state", nn.LayerState()
)
src_initial_state_opt[key] = initial
# If there is any other code which refers to this state, it can access the passed layer.
# So anyway pass through.
Expand All @@ -510,18 +512,18 @@ def _map_ref_to_name_ctx(layer_ref: nn.Tensor, name_ctx: nn.NameCtx, initial: nn

else: # class != get_last_hidden_state

if layer_ref.layer_dict["class"] == "cum_concat":
layer_state_opt = layer_ref.layer_dict.get("state")
if layer_ref.raw_tensor.layer_dict["class"] == "cum_concat":
layer_state_opt = layer_ref.raw_tensor.layer_dict.get("state")
if isinstance(layer_state_opt, nn.LayerState) and set(layer_state_opt.keys()) == {"state"}:
layer_state = layer_state_opt.state
if isinstance(layer_state, PrevTensorRef) and layer_state.cur_layer_name_ctx is name_ctx:
# The 'state' argument refers to "prev:..." of itself.
# This is redundant, so we don't need to pass it.
layer_ref.layer_dict.pop("state")
layer_ref.raw_tensor.layer_dict.pop("state")

assert "initial_state" not in layer_ref.layer_dict
assert "initial_output" not in layer_ref.layer_dict
layer_ref.layer_dict["initial_output"] = initial
assert "initial_state" not in layer_ref.raw_tensor.layer_dict
assert "initial_output" not in layer_ref.raw_tensor.layer_dict
layer_ref.raw_tensor.layer_dict["initial_output"] = initial

else: # layer_ref not Layer
raise NotImplementedError(f"{self}.assign to {layer_ref} but layer expected")
Expand All @@ -535,10 +537,12 @@ def _map_ref_to_name_ctx(layer_ref: nn.Tensor, name_ctx: nn.NameCtx, initial: nn
# Currently, RETURNN does not properly support a state in a subnet.
# So we copy the layer to the loop root under the reserved existing name.
nn.copy(layer_ref, name=name_ctx)
if layer_ref.layer_dict:
assert "initial_state" not in layer_ref.layer_dict # not supported/implemented
if "initial_output" in layer_ref.layer_dict:
name_ctx.layer.layer_dict["initial_output"] = layer_ref.layer_dict.pop("initial_output")
if layer_ref.raw_tensor.layer_dict:
assert "initial_state" not in layer_ref.raw_tensor.layer_dict # not supported/implemented
if "initial_output" in layer_ref.raw_tensor.layer_dict:
name_ctx.layer.raw_tensor.layer_dict[
"initial_output"
] = layer_ref.raw_tensor.layer_dict.pop("initial_output")
else:
prev_ref.assign_new_cur_layer_name_ctx(layer_ref.raw_tensor)

Expand Down
8 changes: 4 additions & 4 deletions nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,15 +365,15 @@ def returnn_layer_get_recurrent_state(layer: nn.Tensor) -> nn.LayerState:
# Note that this is actually layer specific.
# We try to use a number of heuristics to get it right for the common cases.
name = f"{layer.raw_tensor.name}_state"
layer_class = layer.layer_dict["class"]
layer_class = layer.raw_tensor.layer_dict["class"]
if layer_class in {"cum_concat", "cumsum"}:
return nn.LayerState(layer) # the layer output itself is its state
if layer_class == "window":
return nn.LayerState(_get_last_hidden_state(layer, out_dim=layer.feature_dim, name=name))
# This is some very generic fallback code, which probably does not work correctly in some cases.
out_dim = layer.layer_dict["out_dim"]
if layer_class == "rec" and isinstance(layer.layer_dict["unit"], str):
if "lstm" in layer.layer_dict["unit"].lower():
out_dim = layer.raw_tensor.layer_dict["out_dim"]
if layer_class == "rec" and isinstance(layer.raw_tensor.layer_dict["unit"], str):
if "lstm" in layer.raw_tensor.layer_dict["unit"].lower():
h = _get_last_hidden_state(layer, out_dim=out_dim, key="h", name=f"{name}_h")
c = _get_last_hidden_state(layer, out_dim=out_dim, key="c", name=f"{name}_c")
return nn.LayerState(h=h, c=c)
Expand Down
Loading

0 comments on commit ed5462b

Please sign in to comment.