-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclassify_marvl_gpt.py
114 lines (102 loc) · 3.48 KB
/
classify_marvl_gpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import requests
from datasets import load_dataset
from PIL import Image
import io
import argparse
from tqdm import tqdm
import os
from openai import AzureOpenAI
import json
import pandas as pd
import re
import math
import base64
from mimetypes import guess_type
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("--country", type=str, default="id")
args = arg_parser.parse_args()
COUNTRY = args.country
df = pd.read_csv("marvl/marvl_images.csv")
df = df[df["language"] == COUNTRY]
dataset = df.to_dict(orient="records")
print(f"Number of images for {COUNTRY}: {len(dataset)}")
with open("secrets.json", "r") as f:
secrets = json.load(f)
api_base = secrets["GPT4V_OPENAI_ENDPOINT"]
api_key=secrets["GPT4V_OPENAI_API_KEY"]
deployment_name = "gpt4v"
api_version = "2024-02-15-preview"
# Function to encode the image
def encode_image(image):
# Encoding image to base64
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
return img_str
client = AzureOpenAI(
api_key=api_key,
api_version=api_version,
base_url=api_base,
)
results = []
for example in tqdm(dataset):
path = example["image_path"]
concept = example["concept"]
country = example["language"]
id = path
image = Image.open(path)
if image.mode in ("RGBA", "P", "L", "LA"):
image = image.convert("RGB")
# resize the image to 512x512
image = image.resize((512, 512))
# Encode the image
base64_image = encode_image(image)
try:
response = client.chat.completions.create(
model=deployment_name,
messages=[
{ "role": "system", "content": "You are a helpful assistant. Your answer should include strictly only a valid geographical subregion according to the United Nations geoscheme developed by UNSD." },
{
"role": "user",
"content": [
{
"type": "text",
"text": "Strictly follow the United Nations geoscheme for subregions. Which geographical subregion of the United Nations geoscheme is this image from? Make an educated guess. Answer in one to three words."
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
"detail": "low"
}
}
]
}
],
max_tokens=300
)
answer = response.json()
answer = json.loads(answer)
results.append({
"id": id,
"model": "gpt-4-turbo-vision-preview",
"split": COUNTRY,
"response": answer["choices"][0]["message"]["content"],
"true_country": example["language"],
"concept": example["concept"]
})
except Exception as e:
results.append({
"id": id,
"model": "gpt-4-turbo-vision-preview",
"split": COUNTRY,
"response": "ResponsibleAIPolicyViolation",
"true_country": example["language"],
"concept": example["concept"]
})
image.close()
df = pd.DataFrame(results)
print(df.head())
df.to_csv(f"results/marvl/{COUNTRY}.csv", index=False)