|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +TransformersUD |
| 4 | +
|
| 5 | +Author: Prof. Koichi Yasuoka |
| 6 | +
|
| 7 | +This tagger is provided under the terms of the apache-2.0 License. |
| 8 | +
|
| 9 | +The source: https://huggingface.co/KoichiYasuoka/deberta-base-thai-ud-head |
| 10 | +
|
| 11 | +GitHub: https://github.com/KoichiYasuoka |
| 12 | +""" |
| 13 | +import os |
| 14 | +import numpy |
| 15 | +import torch |
| 16 | +import ufal.chu_liu_edmonds |
| 17 | +from transformers import ( |
| 18 | + AutoTokenizer, |
| 19 | + AutoModelForQuestionAnswering, |
| 20 | + AutoModelForTokenClassification, |
| 21 | + AutoConfig, |
| 22 | + TokenClassificationPipeline |
| 23 | +) |
| 24 | +from transformers.utils import cached_file |
| 25 | + |
| 26 | + |
| 27 | +class Parse: |
| 28 | + def __init__(self, model: str="KoichiYasuoka/deberta-base-thai-ud-head") -> None: |
| 29 | + if model == None: |
| 30 | + model = "KoichiYasuoka/deberta-base-thai-ud-head" |
| 31 | + self.tokenizer=AutoTokenizer.from_pretrained(model) |
| 32 | + self.model=AutoModelForQuestionAnswering.from_pretrained(model) |
| 33 | + x=AutoModelForTokenClassification.from_pretrained |
| 34 | + if os.path.isdir(model): |
| 35 | + d,t=x(os.path.join(model,"deprel")),x(os.path.join(model,"tagger")) |
| 36 | + else: |
| 37 | + c=AutoConfig.from_pretrained(cached_file(model,"deprel/config.json")) |
| 38 | + d=x(cached_file(model,"deprel/pytorch_model.bin"),config=c) |
| 39 | + s=AutoConfig.from_pretrained(cached_file(model,"tagger/config.json")) |
| 40 | + t=x(cached_file(model,"tagger/pytorch_model.bin"),config=s) |
| 41 | + self.deprel=TokenClassificationPipeline( |
| 42 | + model=d, |
| 43 | + tokenizer=self.tokenizer, |
| 44 | + aggregation_strategy="simple" |
| 45 | + ) |
| 46 | + self.tagger=TokenClassificationPipeline( |
| 47 | + model=t, |
| 48 | + tokenizer=self.tokenizer |
| 49 | + ) |
| 50 | + |
| 51 | + def __call__(self, text: str)->str: |
| 52 | + w=[(t["start"],t["end"],t["entity_group"]) for t in self.deprel(text)] |
| 53 | + z,n={t["start"]:t["entity"].split("|") for t in self.tagger(text)},len(w) |
| 54 | + r,m=[text[s:e] for s,e,p in w],numpy.full((n+1,n+1),numpy.nan) |
| 55 | + v,c=self.tokenizer(r,add_special_tokens=False)["input_ids"],[] |
| 56 | + for i,t in enumerate(v): |
| 57 | + q=[self.tokenizer.cls_token_id]+t+[self.tokenizer.sep_token_id] |
| 58 | + c.append([q]+v[0:i]+[[self.tokenizer.mask_token_id]]+v[i+1:]+[[q[-1]]]) |
| 59 | + b=[[len(sum(x[0:j+1],[])) for j in range(len(x))] for x in c] |
| 60 | + with torch.no_grad(): |
| 61 | + d=self.model( |
| 62 | + input_ids=torch.tensor([sum(x,[]) for x in c]), |
| 63 | + token_type_ids=torch.tensor([[0]*x[0]+[1]*(x[-1]-x[0]) for x in b]) |
| 64 | + ) |
| 65 | + s,e=d.start_logits.tolist(),d.end_logits.tolist() |
| 66 | + for i in range(n): |
| 67 | + for j in range(n): |
| 68 | + m[i+1,0 if i==j else j+1]=s[i][b[i][j]]+e[i][b[i][j+1]-1] |
| 69 | + h=ufal.chu_liu_edmonds.chu_liu_edmonds(m)[0] |
| 70 | + if [0 for i in h if i==0]!=[0]: |
| 71 | + i=([p for s,e,p in w]+["root"]).index("root") |
| 72 | + j=i+1 if i<n else numpy.nanargmax(m[:,0]) |
| 73 | + m[0:j,0]=m[j+1:,0]=numpy.nan |
| 74 | + h=ufal.chu_liu_edmonds.chu_liu_edmonds(m)[0] |
| 75 | + u="" |
| 76 | + for i,(s,e,p) in enumerate(w,1): |
| 77 | + p="root" if h[i]==0 else "dep" if p=="root" else p |
| 78 | + u+="\t".join( |
| 79 | + [str(i),r[i-1],"_",z[s][0][2:],"_","|".join(z[s][1:]),str(h[i]),p,"_","_" if i<n and e<w[i][0] else "SpaceAfter=No"] |
| 80 | + )+"\n" |
| 81 | + return u+"\n" |
0 commit comments