Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Diacritization dataset/task/asset #128

Merged
merged 5 commits into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions arabic_llm_benchmark/datasets/ArabicDiacritization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from arabic_llm_benchmark.datasets.dataset_base import DatasetBase


class ArabicDiacritizationDataset(DatasetBase):
def __init__(self, **kwargs):
super(ArabicDiacritizationDataset, self).__init__(**kwargs)

def citation(self):
return """@article{10.1145/3434235,
author = {Darwish, Kareem and Abdelali, Ahmed and Mubarak, Hamdy and Eldesouki, Mohamed},
title = {Arabic Diacritic Recovery Using a Feature-Rich BiLSTM Model},
year = {2021},
issue_date = {March 2021},
publisher = {Association for Computing Machinery},
address = {New York, NY, USA},
volume = {20},
number = {2},
issn = {2375-4699},
url = {https://doi.org/10.1145/3434235},
doi = {10.1145/3434235},
journal = {ACM Trans. Asian Low-Resour. Lang. Inf. Process.},
month = {apr},
articleno = {33},
numpages = {18},
}"""

def get_data_sample(self):
return {
"input": "Original sentence",
"label": "Sentence with diacritized words",
}

def load_data(self, data_path, no_labels=False):
data = []

with open(data_path, "r") as fp:
for line_idx, line in enumerate(fp):
text, diacritized_text = line.split("\t")
data.append(
{
"input": text.strip(),
"label": diacritized_text.strip(),
"line_number": line_idx,
}
)

return data
1 change: 1 addition & 0 deletions arabic_llm_benchmark/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .Aqmar import AqmarDataset
from .AraBench import AraBenchDataset
from .ArabGend import ArabGendDataset
from .ArabicDiacritization import ArabicDiacritizationDataset
from .ArabicSegmentation import ArabicSegmentationDataset
from .ArapTweet import ArapTweetDataset
from .ARCD import ARCDDataset
Expand Down
143 changes: 143 additions & 0 deletions arabic_llm_benchmark/tasks/ArabicDiacritization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import re

from sklearn.metrics import f1_score

from arabic_llm_benchmark.tasks.task_base import TaskBase


#
# repo: https://pyzone.dev/word-error-rate-in-python
#
def wer(ref, hyp, debug=True):
r = ref
h = hyp
# costs will holds the costs, like in the Levenshtein distance algorithm
costs = [[0 for inner in range(len(h) + 1)] for outer in range(len(r) + 1)]
# backtrace will hold the operations we've done.
# so we could later backtrace, like the WER algorithm requires us to.
backtrace = [[0 for inner in range(len(h) + 1)] for outer in range(len(r) + 1)]

OP_OK = 0
OP_SUB = 1
OP_INS = 2
OP_DEL = 3
DEL_PENALTY = 1
INS_PENALTY = 1
SUB_PENALTY = 1

# First column represents the case where we achieve zero
# hypothesis words by deleting all reference words.
for i in range(1, len(r) + 1):
costs[i][0] = DEL_PENALTY * i
backtrace[i][0] = OP_DEL

# First row represents the case where we achieve the hypothesis
# by inserting all hypothesis words into a zero-length reference.
for j in range(1, len(h) + 1):
costs[0][j] = INS_PENALTY * j
backtrace[0][j] = OP_INS

# computation
for i in range(1, len(r) + 1):
for j in range(1, len(h) + 1):
if r[i - 1] == h[j - 1]:
costs[i][j] = costs[i - 1][j - 1]
backtrace[i][j] = OP_OK
else:
substitutionCost = (
costs[i - 1][j - 1] + SUB_PENALTY
) # penalty is always 1
insertionCost = costs[i][j - 1] + INS_PENALTY # penalty is always 1
deletionCost = costs[i - 1][j] + DEL_PENALTY # penalty is always 1

costs[i][j] = min(substitutionCost, insertionCost, deletionCost)
if costs[i][j] == substitutionCost:
backtrace[i][j] = OP_SUB
elif costs[i][j] == insertionCost:
backtrace[i][j] = OP_INS
else:
backtrace[i][j] = OP_DEL

# back trace though the best route:
i = len(r)
j = len(h)
numSub = 0
numDel = 0
numIns = 0
numCor = 0
if debug:
print("OP\tREF\tHYP")
lines = []
while i > 0 or j > 0:
if backtrace[i][j] == OP_OK:
numCor += 1
i -= 1
j -= 1
if debug:
lines.append("OK\t" + r[i] + "\t" + h[j])
elif backtrace[i][j] == OP_SUB:
numSub += 1
i -= 1
j -= 1
if debug:
lines.append("SUB\t" + r[i] + "\t" + h[j])
elif backtrace[i][j] == OP_INS:
numIns += 1
j -= 1
if debug:
lines.append("INS\t" + "****" + "\t" + h[j])
elif backtrace[i][j] == OP_DEL:
numDel += 1
i -= 1
if debug:
lines.append("DEL\t" + r[i] + "\t" + "****")
if debug:
lines = reversed(lines)
for line in lines:
print(line)
print("#cor " + str(numCor))
print("#sub " + str(numSub))
print("#del " + str(numDel))
print("#ins " + str(numIns))
# return (numSub + numDel + numIns) / (float) (len(r))
wer_result = round((numSub + numDel + numIns) / (float)(len(r)), 3)
if debug:
return {
"WER": wer_result,
"numCor": numCor,
"numSub": numSub,
"numIns": numIns,
"numDel": numDel,
"numCount": len(r),
}
else:
return {"WER": wer_result}


class ArabicDiacritizationTask(TaskBase):
def __init__(self, **kwargs):
super(ArabicDiacritizationTask, self).__init__(**kwargs)

def evaluate(self, true_labels, predicted_labels):
# Flatten sentences into a long list of words
hyp = []
ref = []
for t, p in zip(true_labels, predicted_labels):
if p is None:
# Use undiacritized word in case of prediction failiure
p = re.sub(r"[ًٌٍَُِّْ]", "", t).split()
else:
p = p.split()

t = t.split()

# If prediction is missing tokens, pad with empty tokens
if len(p) < len(t):
for i in range(len(p) - len(t)):
hyp.append("")

# If prediction has extra tokens, only consider the first
# N tokens, where N == number of gold tokens
hyp += p[: len(t)]
ref += t
return wer(ref, hyp, False)
1 change: 1 addition & 0 deletions arabic_llm_benchmark/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .Adult import AdultTask
from .ArabicDiacritization import ArabicDiacritizationTask
from .ArabicSegmentation import ArabicSegmentationTask
from .Attentionworthy import AttentionworthyTask
from .Checkworthiness import CheckworthinessTask
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os

from arabic_llm_benchmark.datasets import ArabicDiacritizationDataset
from arabic_llm_benchmark.models import GPTModel
from arabic_llm_benchmark.tasks import ArabicDiacritizationTask


def config():
return {
"dataset": ArabicDiacritizationDataset,
"dataset_args": {},
"task": ArabicDiacritizationTask,
"task_args": {},
"model": GPTModel,
"model_args": {
"api_type": "azure",
"api_version": "2023-03-15-preview",
"api_base": os.environ["AZURE_API_URL"],
"api_key": os.environ["AZURE_API_KEY"],
"engine_name": os.environ["ENGINE_NAME"],
"max_tries": 3,
},
"general_args": {
"data_path": "data/sequence_tagging_ner_pos_etc/diacritization/WikiNewsTruth.txt"
},
}


def prompt(input_sample):
return {
"system_message": "You are an AI assistant that helps people find information.",
"messages": [
{
"sender": "user",
"text": f"Diacritize fully the following Arabic sentence: {input_sample}",
}
],
}


def post_process(response):
return response["choices"][0]["text"]