Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updated name info assets #149

Merged
merged 6 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 174 additions & 0 deletions assets/benchmark_v1/demography/name_info/NameInfo_BLOOMZ_ZeroShot.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some cases that we can handle easily I think:

{
  "input": "جورج بوش الابن",
  "label": "us",
  "model_output": "name: جورج بوش الأب country: united states of america</s>",
  "filtered_output": "جورج بوش الأب country: united states of america"
}
{
  "input": "جورج ووكر بوش",
  "label": "us",
  "model_output": "country_code: US</s>",
  "filtered_output": null
}
{
  "input": "George W. Bush",
  "label": "us",
  "model_output": "name: George W. Bush country:us</s>",
  "filtered_output": "george w. bush country:us"
}
{
  "input": "أوغستو بينوشيه",
  "label": "cl",
  "model_output": "name: أوغستو بينوشيه</s>",
  "filtered_output": "أوغستو بينوشيه"
}
{
  "input": "Augusto Pinochet",
  "label": "cl",
  "model_output": "country:CL</s>",
  "filtered_output": null
}

Also should we explicitly set to None if none of the labels are found? (e.g. when it outputs the original name back - sample 4 above)

Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import os
import re

from arabic_llm_benchmark.datasets import NameInfoDataset
from arabic_llm_benchmark.models import BLOOMPetalModel
from arabic_llm_benchmark.tasks import DemographyNameInfoTask


def config():
return {
"dataset": NameInfoDataset,
"dataset_args": {},
"task": DemographyNameInfoTask,
"task_args": {},
"model": BLOOMPetalModel,
"model_args": {
"api_url": os.environ["API_URL"],
"class_labels": [
"gb",
"us",
"cl",
"fr",
"ru",
"pl",
"in",
"it",
"kr",
"gh",
"ca",
"sa",
"at",
"de",
"cn",
"br",
"dk",
"se",
"bd",
"cu",
"jp",
"be",
"es",
"co",
"id",
"iq",
"pk",
"tr",
"il",
"ch",
"ar",
"ro",
"nl",
"ps",
"ug",
"ir",
"cg",
"do",
"ee",
"tn",
"gr",
"np",
"ie",
"sy",
"hu",
"eg",
"ma",
"ve",
"ph",
"no",
"bg",
"si",
"ke",
"au",
"et",
"py",
"af",
"pt",
"th",
"bo",
"mx",
"lb",
"za",
"fi",
"hr",
"vn",
"ly",
"nz",
"qa",
"kh",
"ci",
"ng",
"sg",
"cm",
"dz",
"tz",
"ae",
"pe",
"az",
"lu",
"ec",
"cz",
"ua",
"uy",
"sd",
"ao",
"my",
"lv",
"kw",
"tw",
"bh",
"lk",
"ye",
"cr",
"jo",
"pa",
"om",
"uz",
"by",
"kz",
],
"max_tries": 3,
},
"general_args": {
"data_path": "data/demographic_attributes/name_info/wikidata_test.txt"
},
}


def prompt(input_sample):
prompt_string = (
f"You are an expert annotator who can identify the country of a person based on name.\n"
f"Label the country of the following person 'name'. Write ONLY the country code in ISO 3166-1 alpha-2 format.\n"
f"Provide only label.\n\n"
f"name: {input_sample}\n"
f"country: \n"
)

return {
"prompt": prompt_string,
}


def post_process(response):
label = (
response["outputs"]
.strip()
.replace("<s>", "")
.replace("</s>", "")
.replace("ISO 3166-1:", "")
.replace("ISO 3166-1", "")
.lower()
)
label_list = config()["model_args"]["class_labels"]

# Regular expressions to catch the pattern
match = re.search(r"(country|country_code):\s*(.*)", label)
if match:
label = match.group(2).strip().lower()
if label in label_list:
label_fixed = label
elif (
"I'm sorry, but I cannot predict the country" in label
or "I cannot predict the country" in label
):
label_fixed = None
else:
label_fixed = None

# Consolidating the check for None or empty string
if not label_fixed:
label_fixed = None

return label_fixed

return label_fixed
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,16 @@ def config():


def few_shot_prompt(input_sample, base_prompt, examples):
out_prompt = base_prompt + "\n\n"
for example in examples:
out_prompt = base_prompt + "\n"
out_prompt = out_prompt + "Here are some examples:\n\n"

for index, example in enumerate(examples):
out_prompt = (
out_prompt
+ "Example "
+ str(index)
+ ":"
+ "\n"
+ "name: "
+ example["input"]
+ "\ncountry: "
Expand Down
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some cases like:

{
  "input": "جورج بوش الابن",
  "label": "us",
  "model_output": "country: US",
  "filtered_output": null
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The gpt4 fewshot asset also has similar outputs, but handles it, perhaps we just need to use the same postprocessing fn?

Original file line number Diff line number Diff line change
Expand Up @@ -120,27 +120,47 @@ def config():
"by",
"kz",
],
"max_tries": 3,
"max_tries": 30,
},
"general_args": {
"data_path": "data/demographic_attributes/name_info/wikidata_test.txt",
"data_path": "data/demographic_attributes/name_info/wikidata_test.txt"
},
}


def prompt(input_sample):
prompt_string = (
f"Label the country of the following person 'name'. Write ONLY the country code in ISO 3166-1 alpha-2 format.\n\n"
f"name: {input_sample}\n"
f"country: \n"
)
return [
{
"role": "system",
"content": "You are an AI assistant that helps people find information on locations.",
"content": "You are an expert annotator who can identify the country of a person based on name.",
},
{
"role": "user",
"content": f"Predict the country of citizenship of the following person name. Write ONLY the country code in ISO 3166-1 alpha-2 format without explananation.\n {input_sample}",
"content": prompt_string,
},
]


def post_process(response):
out = response["choices"][0]["message"]["content"]
return out.lower()
label = response["choices"][0]["message"]["content"]

label_list = config()["model_args"]["class_labels"]

if "name: " in label:
label_fixed = label.replace("name: ", "").lower()
elif label.lower() in label_list:
label_fixed = label.lower()
elif (
"I'm sorry, but I cannot predict the country" in label
or "I cannot predict the country" in label
):
label_fixed = None
else:
label_fixed = None

return label_fixed
Loading