1111GitHub: https://github.com/KoichiYasuoka
1212"""
1313import os
14+ from typing import List , Union
1415import numpy
1516import torch
1617import ufal .chu_liu_edmonds
@@ -48,7 +49,7 @@ def __init__(self, model: str="KoichiYasuoka/deberta-base-thai-ud-head") -> None
4849 tokenizer = self .tokenizer
4950 )
5051
51- def __call__ (self , text : str )-> str :
52+ def __call__ (self , text : str , tag : str = "str" )-> Union [ List [ List [ str ]], str ] :
5253 w = [(t ["start" ],t ["end" ],t ["entity_group" ]) for t in self .deprel (text )]
5354 z ,n = {t ["start" ]:t ["entity" ].split ("|" ) for t in self .tagger (text )},len (w )
5455 r ,m = [text [s :e ] for s ,e ,p in w ],numpy .full ((n + 1 ,n + 1 ),numpy .nan )
@@ -73,6 +74,12 @@ def __call__(self, text: str)->str:
7374 m [0 :j ,0 ]= m [j + 1 :,0 ]= numpy .nan
7475 h = ufal .chu_liu_edmonds .chu_liu_edmonds (m )[0 ]
7576 u = ""
77+ if tag == "list" :
78+ _tag_data = []
79+ for i ,(s ,e ,p ) in enumerate (w ,1 ):
80+ p = "root" if h [i ]== 0 else "dep" if p == "root" else p
81+ _tag_data .append ([str (i ),r [i - 1 ],"_" ,z [s ][0 ][2 :],"_" ,"|" .join (z [s ][1 :]),str (h [i ]),p ,"_" ,"_" if i < n and e < w [i ][0 ] else "SpaceAfter=No" ])
82+ return _tag_data
7683 for i ,(s ,e ,p ) in enumerate (w ,1 ):
7784 p = "root" if h [i ]== 0 else "dep" if p == "root" else p
7885 u += "\t " .join (
0 commit comments