diff --git a/returnn/tensor/_tensor_extra.py b/returnn/tensor/_tensor_extra.py index da3751c690..fae81301db 100644 --- a/returnn/tensor/_tensor_extra.py +++ b/returnn/tensor/_tensor_extra.py @@ -17,8 +17,9 @@ from .dim import Dim, batch_dim, VerifyOutShapeException import returnn.tensor.tensor as _t import returnn.tensor.marked_dim as _m + +# noinspection PyProtectedMember import returnn.frontend._api as _frontend_api -import returnn._internal_frontend_api as _internal_frontend_api from ._tensor_mixin_base import _TensorMixinBase @@ -214,9 +215,9 @@ def _handle_extra_kwargs( # even in graph-based backends. # Rather, the logic to create placeholders should be done elsewhere. # noinspection PyProtectedMember - from returnn.tf.frontend_low_level._internal_frontend import TFInternalFrontend + from returnn.tf.frontend_low_level._frontend import TFFrontend - self.raw_tensor = TFInternalFrontend.create_placeholder(self) + self.raw_tensor = TFFrontend.create_placeholder(self) # Do that after same_dim_tags_as handling. _auto_create_size_placeholders_on_dim_tags(name=self.name, dim_tags=self._dims) @@ -233,13 +234,13 @@ def raw_frontend(self) -> Optional[Type[_frontend_api.Frontend]]: return _frontend_api.get_frontend_by_raw_tensor_type(type(self._raw_tensor)) @property - def _raw_internal_frontend(self) -> Optional[Type[_internal_frontend_api.InternalFrontend]]: + def _raw_internal_frontend(self) -> Optional[Type[_frontend_api.Frontend]]: """ :return: the internal backend for the raw tensor """ if self._raw_tensor is None: return None - return _internal_frontend_api.get_internal_frontend_by_tensor_type(type(self._raw_tensor)) + return _frontend_api.get_frontend_by_raw_tensor_type(type(self._raw_tensor)) @property def control_flow_ctx(self) -> Optional[ControlFlowContext]: @@ -2645,12 +2646,13 @@ def get_sequence_mask(self): :return: seq mask of shape (batch,time) if we are batch-major, else (time,batch) if we are time-major :rtype: tf.Tensor """ - from returnn._internal_frontend_api import get_internal_frontend_by_tensor_type + # noinspection PyProtectedMember + from returnn.frontend._api import get_frontend_by_raw_tensor_type assert self.time_dim_axis is not None assert self.batch_dim_axis is not None dyn_seq_len = self.get_sequence_lengths() - rf = get_internal_frontend_by_tensor_type(type(dyn_seq_len)) + rf = get_frontend_by_raw_tensor_type(type(dyn_seq_len)) if self.is_time_major: assert self.batch_dim_axis == 1 diff --git a/returnn/tensor/_tensor_op_overloads.py b/returnn/tensor/_tensor_op_overloads.py index 3807be0456..d30db2c14f 100644 --- a/returnn/tensor/_tensor_op_overloads.py +++ b/returnn/tensor/_tensor_op_overloads.py @@ -8,6 +8,7 @@ if TYPE_CHECKING: from .tensor import Tensor # just for type hints; otherwise use _t.Tensor +# noinspection PyProtectedMember import returnn.frontend._api as _frontend_api from ._tensor_mixin_base import _TensorMixinBase