Skip to content
This repository has been archived by the owner on May 8, 2024. It is now read-only.

Commit

Permalink
refactor: use Huggingface pipeline for resegmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
ninpnin committed Feb 9, 2024
1 parent 0d97284 commit b24d00c
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 34 deletions.
10 changes: 2 additions & 8 deletions pyriksdagen/refine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pyparlaclarin.read import element_hash
import dateparser
import pandas as pd
from .utils import elem_iter, infer_metadata, parse_date, XML_NS, get_formatted_uuid
from .utils import elem_iter, infer_metadata, parse_date, XML_NS, get_formatted_uuid, write_protocol
from .db import load_expressions, filter_db, load_patterns, load_metadata
from .segmentation import (
detect_mp,
Expand Down Expand Up @@ -81,13 +81,7 @@ def redetect_protocol(metadata, protocol):
unknown_variables=["gender", "party", "other"],
)

b = etree.tostring(
root, pretty_print=True, encoding="utf-8", xml_declaration=True
)

f = open(protocol, "wb")
f.write(b)
f.close()
write_protocol(root, protocol)
return unk


Expand Down
65 changes: 39 additions & 26 deletions scripts/resegment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,42 @@
from pyparlaclarin.refine import format_texts
from pyriksdagen.db import load_patterns
from pyriksdagen.refine import (
detect_mps,
find_introductions,
update_ids,
)
from pyriksdagen.utils import infer_metadata
from pyriksdagen.utils import protocol_iterators
from pyriksdagen.utils import protocol_iterators, elem_iter,XML_NS

from lxml import etree
import pandas as pd
import os, progressbar, argparse

from transformers import pipeline

intro_classifier = pipeline("text-classification", model="jesperjmb/parlaBERT", top_k=None)
LABEL_MAP = {"LABEL0": "u", "LABEL1": "intro"}
BATCH_SIZE = 16

def get_labels(texts):
labels = []
texts = [t if t is not None else "" for t in texts]
for i in progressbar.progressbar(range(0, len(texts), BATCH_SIZE)):
texts_i = texts[i:i+BATCH_SIZE]
raw_labels = intro_classifier(texts_i, truncation=True, max_length=512)
for t, label_dict in zip(texts_i, raw_labels):
label_dict = [l for l in label_dict if l["label"] == "LABEL_0"][0]
if label_dict["score"] >= 0.5:
labels.append("u")
else:
labels.append("intro")

return labels

def get_text(elem):
if elem.text is None:
return ""
else:
return " ".join(elem.text.split())

def main(args):
if args.protocol:
protocols = [args.protocol]
Expand All @@ -25,36 +50,24 @@ def main(args):
protocols = list(protocol_iterators("corpus/protocols/", start=args.start, end=args.end))

parser = etree.XMLParser(remove_blank_text=True)
intro_df = pd.read_csv('input/segmentation/intros.csv')

for protocol in progressbar.progressbar(protocols):
intro_ids = intro_df.loc[intro_df['file_path'] == protocol, 'id'].tolist()

metadata = infer_metadata(protocol)
protocol_id = protocol.split("/")[-1]
year = metadata["year"]
root = etree.parse(protocol, parser).getroot()
paragraphs = []
ids = []
for tag, elem in elem_iter(root):
if tag == "u":
for seg in elem:
paragraphs.append(get_text(seg))
ids.append(seg.attrib[f"{XML_NS}id"])
elif tag != "pb":
paragraphs.append(get_text(elem))
ids.append(elem.attrib[f"{XML_NS}id"])

years = [
int(elem.attrib.get("when").split("-")[0])
for elem in root.findall(
".//{http://www.tei-c.org/ns/1.0}docDate"
)
]

if not year in years:
year = years[0]

pattern_db = load_patterns()
pattern_db = pattern_db[
(pattern_db["start"] <= year) & (pattern_db["end"] >= year)
]
root = find_introductions(root, pattern_db, intro_ids, minister_db=None)
root = format_texts(root, padding=10)
labels = get_labels(paragraphs)
b = etree.tostring(
root, pretty_print=True, encoding="utf-8", xml_declaration=True
)

with open(protocol, "wb") as f:
f.write(b)

Expand Down

0 comments on commit b24d00c

Please sign in to comment.