From 0f606d90a4ed24616e03831d562f7b120bf0ecb4 Mon Sep 17 00:00:00 2001 From: enkilee Date: Wed, 10 Jul 2024 17:16:14 +0800 Subject: [PATCH 1/3] fix --- python/paddle/text/viterbi_decode.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/python/paddle/text/viterbi_decode.py b/python/paddle/text/viterbi_decode.py index a1f2fa5192f97..4199e36f9509a 100644 --- a/python/paddle/text/viterbi_decode.py +++ b/python/paddle/text/viterbi_decode.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING from paddle import _C_ops @@ -19,12 +22,19 @@ from ..base.layer_helper import LayerHelper from ..nn import Layer +if TYPE_CHECKING: + from paddle import Tensor + __all__ = ['viterbi_decode', 'ViterbiDecoder'] def viterbi_decode( - potentials, transition_params, lengths, include_bos_eos_tag=True, name=None -): + potentials: Tensor, + transition_params: Tensor, + lengths: Tensor, + include_bos_eos_tag: bool = True, + name: str | None = None, +) -> Tensor: """ Decode the highest scoring sequence of tags computed by transitions and potentials and get the viterbi path. @@ -141,13 +151,18 @@ class ViterbiDecoder(Layer): [1, 1]]) """ - def __init__(self, transitions, include_bos_eos_tag=True, name=None): + def __init__( + self, + transitions: Tensor, + include_bos_eos_tag: bool = True, + name: str | None = None, + ) -> None: super().__init__() self.transitions = transitions self.include_bos_eos_tag = include_bos_eos_tag self.name = name - def forward(self, potentials, lengths): + def forward(self, potentials: Tensor, lengths: Tensor) -> Tensor: return viterbi_decode( potentials, self.transitions, From a4c29f0fa65f050262c9daa672998886591516d4 Mon Sep 17 00:00:00 2001 From: enkilee Date: Wed, 10 Jul 2024 17:17:53 +0800 Subject: [PATCH 2/3] fix --- python/paddle/text/viterbi_decode.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/python/paddle/text/viterbi_decode.py b/python/paddle/text/viterbi_decode.py index 4199e36f9509a..daec99ffe0a35 100644 --- a/python/paddle/text/viterbi_decode.py +++ b/python/paddle/text/viterbi_decode.py @@ -46,7 +46,7 @@ def viterbi_decode( lengths (Tensor): The input tensor of length of each sequence. This is a 1-D tensor with shape of [batch_size]. The data type is int64. include_bos_eos_tag (`bool`, optional): If set to True, the last row and the last column of transitions will be considered as start tag, the second to last row and the second to last column of transitions will be considered as stop tag. Defaults to ``True``. - name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please + name (str|None, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -115,7 +115,7 @@ class ViterbiDecoder(Layer): transitions (`Tensor`): The transition matrix. Its dtype is float32 and has a shape of `[num_tags, num_tags]`. include_bos_eos_tag (`bool`, optional): If set to True, the last row and the last column of transitions will be considered as start tag, the second to last row and the second to last column of transitions will be considered as stop tag. Defaults to ``True``. - name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please + name (str|None, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Shape: @@ -151,6 +151,12 @@ class ViterbiDecoder(Layer): [1, 1]]) """ + transitions: Tensor + include_bos_eos_tag: bool + name: str | None + potentials: Tensor + lengths: Tensor + def __init__( self, transitions: Tensor, From 8d43635c056567a082de7f69d255bfef87dab489 Mon Sep 17 00:00:00 2001 From: enkilee Date: Fri, 12 Jul 2024 10:33:58 +0800 Subject: [PATCH 3/3] fix --- python/paddle/text/viterbi_decode.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/paddle/text/viterbi_decode.py b/python/paddle/text/viterbi_decode.py index daec99ffe0a35..6d6533b73df33 100644 --- a/python/paddle/text/viterbi_decode.py +++ b/python/paddle/text/viterbi_decode.py @@ -154,8 +154,6 @@ class ViterbiDecoder(Layer): transitions: Tensor include_bos_eos_tag: bool name: str | None - potentials: Tensor - lengths: Tensor def __init__( self,