diff --git a/assets/benchmark_v1/sequence_tagging_ner_pos_etc/DialectADI_BLOOMZ_ZeroShot.py b/assets/benchmark_v1/sequence_tagging_ner_pos_etc/DialectADI_BLOOMZ_ZeroShot.py new file mode 100644 index 00000000..602253da --- /dev/null +++ b/assets/benchmark_v1/sequence_tagging_ner_pos_etc/DialectADI_BLOOMZ_ZeroShot.py @@ -0,0 +1,102 @@ +import os + +from arabic_llm_benchmark.datasets import DialectADIDataset +from arabic_llm_benchmark.models import BLOOMPetalModel +from arabic_llm_benchmark.tasks import DialectIDTask + + +def config(): + return { + "dataset": DialectADIDataset, + "dataset_args": {}, + "task": DialectIDTask, + "task_args": {}, + "model": BLOOMPetalModel, + "model_args": { + "api_url": os.environ["API_URL"], + "class_labels": [ + "EGY", + "IRA", + "JOR", + "KSA", + "KUW", + "LEB", + "LIB", + "MOR", + "MSA", + "PAL", + "QAT", + "SUD", + "SYR", + "UAE", + "YEM", + ], + "max_tries": 3, + }, + "general_args": { + "data_path": "data/sequence_tagging_ner_pos_etc/dialect_identification/all_v2.tsv", + }, + } + + +def prompt(input_sample): + arr = input_sample.split() + if len(arr) > 500: + input_sample = arr[:500] + + prompt_string = ( + f'Classify the following "text" into one of the following categories: "EGY", "IRA", "JOR", "KSA", "KUW", "LEB", "LIB", "MOR", "MSA", "PAL", "QAT", "SUD", "SYR", "UAE", "YEM"\n' + f"Please provide only the label.\n\n" + f"text: {input_sample}\n" + f"label: \n" + ) + + return { + "prompt": prompt_string, + } + + +def post_process(response): + label = response["outputs"].strip() + label = label.replace("", "") + label = label.replace("", "") + label = label.lower() + + # label_list = config()["model_args"]["class_labels"] + # label_list = [lab.lower() for lab in label_list] + # + # if "label: " in label: + # label_fixed = label.replace("label: ", "").lower() + # elif label.lower() in label_list: + # label_fixed = label.lower() + # else: + # label_fixed = None + label_list = config()["model_args"]["class_labels"] + label_list = [dialect.lower() for dialect in label_list] + + label = label.replace("label:", "").strip() + + if label in label_list: + label_fixed = label + elif "\n msa" in label: + label_fixed = "msa" + elif "\n ksa" in label: + label_fixed = "ksa" + elif "\n pal" in label: + label_fixed = "pal" + elif "\n egy" in label: + label_fixed = "egy" + elif "\n yem" in label: + label_fixed = "yem" + elif "\n syr" in label: + label_fixed = "syr" + elif "\n jor" in label: + label_fixed = "jor" + elif "\n ira" in label: + label_fixed = "ira" + elif "\n kuw" in label: + label_fixed = "kuw" + else: + label_fixed = None + + return label_fixed diff --git a/assets/benchmark_v1/sequence_tagging_ner_pos_etc/DialectADI_ChatGPT_ZeroShot.py b/assets/benchmark_v1/sequence_tagging_ner_pos_etc/DialectADI_ChatGPT_ZeroShot.py index bbc87c58..f3ba850c 100644 --- a/assets/benchmark_v1/sequence_tagging_ner_pos_etc/DialectADI_ChatGPT_ZeroShot.py +++ b/assets/benchmark_v1/sequence_tagging_ner_pos_etc/DialectADI_ChatGPT_ZeroShot.py @@ -35,7 +35,7 @@ def config(): "max_tries": 30, }, "general_args": { - "data_path": "data/sequence_tagging_ner_pos_etc/dialect_identification/dialect_12_test_merged.tsv" + "data_path": "data/sequence_tagging_ner_pos_etc/dialect_identification/all_v2.tsv" }, } @@ -63,11 +63,29 @@ def post_process(response): label = response["choices"][0]["text"].lower() label_list = config()["model_args"]["class_labels"] label_list = [dialect.lower() for dialect in label_list] - label = label.replace("label: ", "") + + label = label.replace("label:", "").strip() if label in label_list: label_fixed = label + elif "\n msa" in label: + label_fixed = "msa" + elif "\n ksa" in label: + label_fixed = "ksa" + elif "\n pal" in label: + label_fixed = "pal" + elif "\n egy" in label: + label_fixed = "egy" + elif "\n yem" in label: + label_fixed = "yem" + elif "\n syr" in label: + label_fixed = "syr" + elif "\n jor" in label: + label_fixed = "jor" + elif "\n ira" in label: + label_fixed = "ira" + elif "\n kuw" in label: + label_fixed = "kuw" else: label_fixed = None - return label_fixed diff --git a/assets/benchmark_v1/sequence_tagging_ner_pos_etc/DialectADI_GPTChatCompletion_FewShot.py b/assets/benchmark_v1/sequence_tagging_ner_pos_etc/DialectADI_GPTChatCompletion_FewShot.py index b775101a..705db4fa 100644 --- a/assets/benchmark_v1/sequence_tagging_ner_pos_etc/DialectADI_GPTChatCompletion_FewShot.py +++ b/assets/benchmark_v1/sequence_tagging_ner_pos_etc/DialectADI_GPTChatCompletion_FewShot.py @@ -35,9 +35,9 @@ def config(): "max_tries": 30, }, "general_args": { - "data_path": "data/sequence_tagging_ner_pos_etc/dialect_identification/dialect_12_test_merged.tsv", + "data_path": "data/sequence_tagging_ner_pos_etc/dialect_identification/all_v2.tsv", "fewshot": { - "train_data_path": "data/sequence_tagging_ner_pos_etc/dialect_identification/dialect_12_test_merged.tsv", # TODO update + "train_data_path": "data/sequence_tagging_ner_pos_etc/dialect_identification/fewshot_dev.tsv", # TODO update "deduplicate": False, }, }, @@ -90,6 +90,24 @@ def post_process(response): if label in label_list: label_fixed = label + elif "\n msa" in label: + label_fixed = "msa" + elif "\n ksa" in label: + label_fixed = "ksa" + elif "\n pal" in label: + label_fixed = "pal" + elif "\n egy" in label: + label_fixed = "egy" + elif "\n yem" in label: + label_fixed = "yem" + elif "\n syr" in label: + label_fixed = "syr" + elif "\n jor" in label: + label_fixed = "jor" + elif "\n ira" in label: + label_fixed = "ira" + elif "\n kuw" in label: + label_fixed = "kuw" else: label_fixed = None diff --git a/assets/benchmark_v1/sequence_tagging_ner_pos_etc/DialectADI_GPTChatCompletion_ZeroShot.py b/assets/benchmark_v1/sequence_tagging_ner_pos_etc/DialectADI_GPTChatCompletion_ZeroShot.py index 01404f87..a4c12bb1 100644 --- a/assets/benchmark_v1/sequence_tagging_ner_pos_etc/DialectADI_GPTChatCompletion_ZeroShot.py +++ b/assets/benchmark_v1/sequence_tagging_ner_pos_etc/DialectADI_GPTChatCompletion_ZeroShot.py @@ -35,7 +35,7 @@ def config(): "max_tries": 30, }, "general_args": { - "data_path": "data/sequence_tagging_ner_pos_etc/dialect_identification/dialect_12_test_merged.tsv" + "data_path": "data/sequence_tagging_ner_pos_etc/dialect_identification/all_v2.tsv" }, } @@ -69,6 +69,24 @@ def post_process(response): if label in label_list: label_fixed = label + elif "\n msa" in label: + label_fixed = "msa" + elif "\n ksa" in label: + label_fixed = "ksa" + elif "\n pal" in label: + label_fixed = "pal" + elif "\n egy" in label: + label_fixed = "egy" + elif "\n yem" in label: + label_fixed = "yem" + elif "\n syr" in label: + label_fixed = "syr" + elif "\n jor" in label: + label_fixed = "jor" + elif "\n ira" in label: + label_fixed = "ira" + elif "\n kuw" in label: + label_fixed = "kuw" else: label_fixed = None diff --git a/assets/benchmark_v1/sequence_tagging_ner_pos_etc/DialectID_QADI_ChatGPT_ZeroShot.py b/assets/benchmark_v1/sequence_tagging_ner_pos_etc/DialectID_QADI_ChatGPT_ZeroShot.py index 68e98387..cc4b3138 100644 --- a/assets/benchmark_v1/sequence_tagging_ner_pos_etc/DialectID_QADI_ChatGPT_ZeroShot.py +++ b/assets/benchmark_v1/sequence_tagging_ner_pos_etc/DialectID_QADI_ChatGPT_ZeroShot.py @@ -45,20 +45,39 @@ def config(): def prompt(input_sample): + prompt_string = ( + f'Write only the country code of the Arabic country in which this sentence is written in its dialect without any explanation. Write only the country code in ISO 3166-1 alpha-2 format without explanation. Write "MSA" if the sentence is written in Modern Standard Arabic.\n' + f"Please provide only the label.\n\n" + f"text: {input_sample}\n" + f"label: \n" + ) + return { "system_message": "You are an AI assistant that helps people find information.", "messages": [ { "sender": "user", - "text": f"Write only the country code of the Arabic country in which this sentence is written in its dialect without any explanation. Write only the country code in ISO 3166-1 alpha-2 format without explanation. Write 'MSA' if the sentence is written in Modern Standard Arabic.\n {input_sample}", + "text": prompt_string, } ], } def post_process(response): - out = response["choices"][0]["text"] - j = out.find(".") - if j > 0: - out = out[0:j] - return out + label = response["choices"][0]["text"] + + label_list = config()["model_args"]["class_labels"] + label_list = [dialect for dialect in label_list] + + label = label.replace("label:", "").strip() + + # j = out.find(".") + # if j > 0: + # out = out[0:j] + + if label in label_list: + label_fixed = label + else: + label_fixed = None + + return label_fixed