-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Diacritization dataset/task/asset (#128)
* Add diacritizaton module * Update ArabicDiacritization.py Use undiacritized tokens as fall back for None results. * Format code * Add comments and minor fixes * More fixes to dataloader --------- Co-authored-by: Ahmed Abdelali <ahmed.abdelali@gmail.com>
- Loading branch information
Showing
5 changed files
with
234 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from arabic_llm_benchmark.datasets.dataset_base import DatasetBase | ||
|
||
|
||
class ArabicDiacritizationDataset(DatasetBase): | ||
def __init__(self, **kwargs): | ||
super(ArabicDiacritizationDataset, self).__init__(**kwargs) | ||
|
||
def citation(self): | ||
return """@article{10.1145/3434235, | ||
author = {Darwish, Kareem and Abdelali, Ahmed and Mubarak, Hamdy and Eldesouki, Mohamed}, | ||
title = {Arabic Diacritic Recovery Using a Feature-Rich BiLSTM Model}, | ||
year = {2021}, | ||
issue_date = {March 2021}, | ||
publisher = {Association for Computing Machinery}, | ||
address = {New York, NY, USA}, | ||
volume = {20}, | ||
number = {2}, | ||
issn = {2375-4699}, | ||
url = {https://doi.org/10.1145/3434235}, | ||
doi = {10.1145/3434235}, | ||
journal = {ACM Trans. Asian Low-Resour. Lang. Inf. Process.}, | ||
month = {apr}, | ||
articleno = {33}, | ||
numpages = {18}, | ||
}""" | ||
|
||
def get_data_sample(self): | ||
return { | ||
"input": "Original sentence", | ||
"label": "Sentence with diacritized words", | ||
} | ||
|
||
def load_data(self, data_path, no_labels=False): | ||
data = [] | ||
|
||
with open(data_path, "r") as fp: | ||
for line_idx, line in enumerate(fp): | ||
text, diacritized_text = line.split("\t") | ||
data.append( | ||
{ | ||
"input": text.strip(), | ||
"label": diacritized_text.strip(), | ||
"line_number": line_idx, | ||
} | ||
) | ||
|
||
return data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
import re | ||
|
||
from sklearn.metrics import f1_score | ||
|
||
from arabic_llm_benchmark.tasks.task_base import TaskBase | ||
|
||
|
||
# | ||
# repo: https://pyzone.dev/word-error-rate-in-python | ||
# | ||
def wer(ref, hyp, debug=True): | ||
r = ref | ||
h = hyp | ||
# costs will holds the costs, like in the Levenshtein distance algorithm | ||
costs = [[0 for inner in range(len(h) + 1)] for outer in range(len(r) + 1)] | ||
# backtrace will hold the operations we've done. | ||
# so we could later backtrace, like the WER algorithm requires us to. | ||
backtrace = [[0 for inner in range(len(h) + 1)] for outer in range(len(r) + 1)] | ||
|
||
OP_OK = 0 | ||
OP_SUB = 1 | ||
OP_INS = 2 | ||
OP_DEL = 3 | ||
DEL_PENALTY = 1 | ||
INS_PENALTY = 1 | ||
SUB_PENALTY = 1 | ||
|
||
# First column represents the case where we achieve zero | ||
# hypothesis words by deleting all reference words. | ||
for i in range(1, len(r) + 1): | ||
costs[i][0] = DEL_PENALTY * i | ||
backtrace[i][0] = OP_DEL | ||
|
||
# First row represents the case where we achieve the hypothesis | ||
# by inserting all hypothesis words into a zero-length reference. | ||
for j in range(1, len(h) + 1): | ||
costs[0][j] = INS_PENALTY * j | ||
backtrace[0][j] = OP_INS | ||
|
||
# computation | ||
for i in range(1, len(r) + 1): | ||
for j in range(1, len(h) + 1): | ||
if r[i - 1] == h[j - 1]: | ||
costs[i][j] = costs[i - 1][j - 1] | ||
backtrace[i][j] = OP_OK | ||
else: | ||
substitutionCost = ( | ||
costs[i - 1][j - 1] + SUB_PENALTY | ||
) # penalty is always 1 | ||
insertionCost = costs[i][j - 1] + INS_PENALTY # penalty is always 1 | ||
deletionCost = costs[i - 1][j] + DEL_PENALTY # penalty is always 1 | ||
|
||
costs[i][j] = min(substitutionCost, insertionCost, deletionCost) | ||
if costs[i][j] == substitutionCost: | ||
backtrace[i][j] = OP_SUB | ||
elif costs[i][j] == insertionCost: | ||
backtrace[i][j] = OP_INS | ||
else: | ||
backtrace[i][j] = OP_DEL | ||
|
||
# back trace though the best route: | ||
i = len(r) | ||
j = len(h) | ||
numSub = 0 | ||
numDel = 0 | ||
numIns = 0 | ||
numCor = 0 | ||
if debug: | ||
print("OP\tREF\tHYP") | ||
lines = [] | ||
while i > 0 or j > 0: | ||
if backtrace[i][j] == OP_OK: | ||
numCor += 1 | ||
i -= 1 | ||
j -= 1 | ||
if debug: | ||
lines.append("OK\t" + r[i] + "\t" + h[j]) | ||
elif backtrace[i][j] == OP_SUB: | ||
numSub += 1 | ||
i -= 1 | ||
j -= 1 | ||
if debug: | ||
lines.append("SUB\t" + r[i] + "\t" + h[j]) | ||
elif backtrace[i][j] == OP_INS: | ||
numIns += 1 | ||
j -= 1 | ||
if debug: | ||
lines.append("INS\t" + "****" + "\t" + h[j]) | ||
elif backtrace[i][j] == OP_DEL: | ||
numDel += 1 | ||
i -= 1 | ||
if debug: | ||
lines.append("DEL\t" + r[i] + "\t" + "****") | ||
if debug: | ||
lines = reversed(lines) | ||
for line in lines: | ||
print(line) | ||
print("#cor " + str(numCor)) | ||
print("#sub " + str(numSub)) | ||
print("#del " + str(numDel)) | ||
print("#ins " + str(numIns)) | ||
# return (numSub + numDel + numIns) / (float) (len(r)) | ||
wer_result = round((numSub + numDel + numIns) / (float)(len(r)), 3) | ||
if debug: | ||
return { | ||
"WER": wer_result, | ||
"numCor": numCor, | ||
"numSub": numSub, | ||
"numIns": numIns, | ||
"numDel": numDel, | ||
"numCount": len(r), | ||
} | ||
else: | ||
return {"WER": wer_result} | ||
|
||
|
||
class ArabicDiacritizationTask(TaskBase): | ||
def __init__(self, **kwargs): | ||
super(ArabicDiacritizationTask, self).__init__(**kwargs) | ||
|
||
def evaluate(self, true_labels, predicted_labels): | ||
# Flatten sentences into a long list of words | ||
hyp = [] | ||
ref = [] | ||
for t, p in zip(true_labels, predicted_labels): | ||
if p is None: | ||
# Use undiacritized word in case of prediction failiure | ||
p = re.sub(r"[ًٌٍَُِّْ]", "", t).split() | ||
else: | ||
p = p.split() | ||
|
||
t = t.split() | ||
|
||
# If prediction is missing tokens, pad with empty tokens | ||
if len(p) < len(t): | ||
for i in range(len(p) - len(t)): | ||
hyp.append("") | ||
|
||
# If prediction has extra tokens, only consider the first | ||
# N tokens, where N == number of gold tokens | ||
hyp += p[: len(t)] | ||
ref += t | ||
return wer(ref, hyp, False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
42 changes: 42 additions & 0 deletions
42
assets/benchmark_v1/sequence_tagging_ner_pos_etc/diacritization_ChatGPT_ZeroShot.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import os | ||
|
||
from arabic_llm_benchmark.datasets import ArabicDiacritizationDataset | ||
from arabic_llm_benchmark.models import GPTModel | ||
from arabic_llm_benchmark.tasks import ArabicDiacritizationTask | ||
|
||
|
||
def config(): | ||
return { | ||
"dataset": ArabicDiacritizationDataset, | ||
"dataset_args": {}, | ||
"task": ArabicDiacritizationTask, | ||
"task_args": {}, | ||
"model": GPTModel, | ||
"model_args": { | ||
"api_type": "azure", | ||
"api_version": "2023-03-15-preview", | ||
"api_base": os.environ["AZURE_API_URL"], | ||
"api_key": os.environ["AZURE_API_KEY"], | ||
"engine_name": os.environ["ENGINE_NAME"], | ||
"max_tries": 3, | ||
}, | ||
"general_args": { | ||
"data_path": "data/sequence_tagging_ner_pos_etc/diacritization/WikiNewsTruth.txt" | ||
}, | ||
} | ||
|
||
|
||
def prompt(input_sample): | ||
return { | ||
"system_message": "You are an AI assistant that helps people find information.", | ||
"messages": [ | ||
{ | ||
"sender": "user", | ||
"text": f"Diacritize fully the following Arabic sentence: {input_sample}", | ||
} | ||
], | ||
} | ||
|
||
|
||
def post_process(response): | ||
return response["choices"][0]["text"] |