-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcaptions.py
95 lines (82 loc) · 3.06 KB
/
captions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import queue
import utils
from keys import keys
from openai import AzureOpenAI, OpenAI
import typing
import multiprocessing
from io import BytesIO
import base64
import PIL.Image
import os
from tqdm.contrib.concurrent import process_map
import functools
import json
import argparse
available_keys = multiprocessing.Queue()
def default_argument_parser():
parser = argparse.ArgumentParser(description="convert json to spreadsheet")
parser.add_argument(
"--format", choices=["svg", "tikz", "graphviz"], default="", required=True, help="the format of the vector graphics")
parser.add_argument("--png-path", required=True)
return parser
def init_client(model: typing.Literal["gpt-4", "gpt-4v", "Mixtral-8x7B-Instruct-v0.1"]):
for key in keys[model]:
available_keys.put(key)
def caption_img(path: str, vformat: str) -> str:
caption_file = os.path.join(f"data/{vformat}/tmp_captions","%s.txt"%os.path.basename(path))
if os.path.exists(caption_file):
with open(caption_file) as file:
caption = file.read()
else:
try:
buffered = BytesIO()
with PIL.Image.open(path) as img:
if img.size[0] > 1024 or img.size[1] > 1024:
img = img.copy()
img = utils.scale_image(img, 1024)
img.save(buffered, format="PNG")
image_base64 = base64.b64encode(buffered.getvalue()).decode()
messages = [
{
"role": "system",
"content": "Generate a detailed caption for the given image. The reader of your caption should be able to replicate this picture."
},
{
"role": "user",
"content": [{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,%s" % image_base64},
}
]
}
]
caption = utils.multi_ask(available_keys, messages, model="gpt-4v")
# print(caption)
if caption is not None:
with open(caption_file, "w") as file:
file.write(caption)
except:
return None
return caption
def main():
args = default_argument_parser().parse_args()
init_client("gpt-4v")
in_dir = args.png_path
out_file = "data/%s/captions.json" % args.format
file_list = os.listdir(in_dir)[:2000]
file_list_complete_path = []
n = len(file_list)
for file in file_list:
file_list_complete_path.append(os.path.join(in_dir, file))
captions = []
# for file in file_list_complete_path:
# captions.append(caption_img(file))
captions = process_map(functools.partial(caption_img, vformat=args.format), file_list_complete_path, max_workers = 8)
assert (len(captions) == n)
result = {}
for i in range(n):
if captions[i] is not None:
result[file_list[i]] = captions[i]
json.dump(result, open(out_file, "w"))
if __name__ == '__main__':
main()