Skip to content

Commit

Permalink
fix old internal frontend code
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Mar 16, 2023
1 parent 2e4758b commit f013baa
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
16 changes: 9 additions & 7 deletions returnn/tensor/_tensor_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
from .dim import Dim, batch_dim, VerifyOutShapeException
import returnn.tensor.tensor as _t
import returnn.tensor.marked_dim as _m

# noinspection PyProtectedMember
import returnn.frontend._api as _frontend_api
import returnn._internal_frontend_api as _internal_frontend_api
from ._tensor_mixin_base import _TensorMixinBase


Expand Down Expand Up @@ -214,9 +215,9 @@ def _handle_extra_kwargs(
# even in graph-based backends.
# Rather, the logic to create placeholders should be done elsewhere.
# noinspection PyProtectedMember
from returnn.tf.frontend_low_level._internal_frontend import TFInternalFrontend
from returnn.tf.frontend_low_level._frontend import TFFrontend

self.raw_tensor = TFInternalFrontend.create_placeholder(self)
self.raw_tensor = TFFrontend.create_placeholder(self)
# Do that after same_dim_tags_as handling.
_auto_create_size_placeholders_on_dim_tags(name=self.name, dim_tags=self._dims)

Expand All @@ -233,13 +234,13 @@ def raw_frontend(self) -> Optional[Type[_frontend_api.Frontend]]:
return _frontend_api.get_frontend_by_raw_tensor_type(type(self._raw_tensor))

@property
def _raw_internal_frontend(self) -> Optional[Type[_internal_frontend_api.InternalFrontend]]:
def _raw_internal_frontend(self) -> Optional[Type[_frontend_api.Frontend]]:
"""
:return: the internal backend for the raw tensor
"""
if self._raw_tensor is None:
return None
return _internal_frontend_api.get_internal_frontend_by_tensor_type(type(self._raw_tensor))
return _frontend_api.get_frontend_by_raw_tensor_type(type(self._raw_tensor))

@property
def control_flow_ctx(self) -> Optional[ControlFlowContext]:
Expand Down Expand Up @@ -2645,12 +2646,13 @@ def get_sequence_mask(self):
:return: seq mask of shape (batch,time) if we are batch-major, else (time,batch) if we are time-major
:rtype: tf.Tensor
"""
from returnn._internal_frontend_api import get_internal_frontend_by_tensor_type
# noinspection PyProtectedMember
from returnn.frontend._api import get_frontend_by_raw_tensor_type

assert self.time_dim_axis is not None
assert self.batch_dim_axis is not None
dyn_seq_len = self.get_sequence_lengths()
rf = get_internal_frontend_by_tensor_type(type(dyn_seq_len))
rf = get_frontend_by_raw_tensor_type(type(dyn_seq_len))

if self.is_time_major:
assert self.batch_dim_axis == 1
Expand Down
1 change: 1 addition & 0 deletions returnn/tensor/_tensor_op_overloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
if TYPE_CHECKING:
from .tensor import Tensor # just for type hints; otherwise use _t.Tensor

# noinspection PyProtectedMember
import returnn.frontend._api as _frontend_api
from ._tensor_mixin_base import _TensorMixinBase

Expand Down

0 comments on commit f013baa

Please sign in to comment.