-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
145 lines (119 loc) · 5.54 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from typing import Dict
import os
import re
import json
import random
import argparse
from pathlib import Path
from tqdm import tqdm
import concurrent.futures
import openai
from openai import OpenAI
from rich.console import Console
from rich.theme import Theme
custom_theme = Theme({
"info": "bold dim cyan",
"warning": "bold magenta",
"danger": "bold red",
"debugging": "bold sandy_brown"
})
console = Console(theme=custom_theme)
PROJECT_HOME = Path(__file__).parent.resolve()
OUTPUT_DIR = os.path.join(PROJECT_HOME, 'output')
TEMPLATE = 'The following is a dialogue between {spk1} and {spk2}. The dialogue is provided line-by-line. In the given dialogue, select all utterances that are appropriate for sharing the image in the next turn, and write the speaker who will share the image after the selected utterance. You should also provide a rationale for your decision and describe the relevant image concisely.\n\nDialogue:\n{dialogue}\n\nRestrictions:\n(1) your answer should be in the format of "<UTTERANCE> | <SPEAKER> | <RATIONALE> | <IMAGE DESCRIPTION>".\n(2) you MUST select the utterance in the given dialogue, NOT generate a new utterance.\n(3) the rationale should be written starting with "To".\n\nAnswer:\n1.',
PATTERN = r'^(?:\d+\.\s+)?\"?(?P<utterance>.*?)\"?\s+\|\s+(?P<speaker>.*?)(?:\s+\|\s+(?P<rationale>.*?))?(?:\s+\|\s+(?P<description>.*?))?$'
class Runner():
def __init__(self, args):
self.args = args
self.output_base_dir = os.path.join(OUTPUT_DIR, self.args.run_id + ":{}".format(self.args.model))
os.makedirs(self.output_base_dir, exist_ok=True)
self.client = OpenAI(
api_key="<OPENAI_API_KEY>"
)
def run(self, dialogue, spk1, spk2):
prompt = TEMPLATE[0].format(spk1=spk1, spk2=spk2, dialogue=dialogue)
console.log(f'Prompt input: {prompt}', style='debugging')
output = self.interact(prompt)
console.log(f'Output: {output}', style='debugging')
self.dump_output(output, os.path.join(self.output_base_dir, 'test.jsonl'))
def generate(self, prompt):
while True:
try:
response = self.client.chat.completions.create(
model=self.args.model,
messages=[{"role": "user", "content": "{}".format(prompt)}],
temperature=self.args.temperature,
max_tokens=self.args.max_tokens,
top_p=self.args.top_p,
frequency_penalty=self.args.frequency_penalty,
presence_penalty=self.args.presence_penalty
)
break
except (RuntimeError, openai.RateLimitError, openai.APIStatusError, openai.APIConnectionError) as e:
print("Error: {}".format(e))
time.sleep(2)
continue
return response.choices[0].message.content.strip()
def parse(self, generation):
matches = re.finditer(PATTERN, generation, re.MULTILINE)
results = []
for match in matches:
utter = match.group('utterance')
speaker = match.group('speaker')
rationale = match.group('rationale')
description = match.group('description')
results.append({
'utterance': utter,
'speaker': speaker,
'rationale': rationale,
'description': description
})
return results
def interact(self, prompt_input):
generation = self.generate(prompt_input)
parsed_output = self.parse(generation)
return parsed_output
def dump_output(self, outputs, file_name=None):
f = open(file_name, 'w')
for output in outputs:
f.write(json.dumps(output) + '\n')
f.close()
def main(args):
runner = Runner(args)
with open('sample.txt', 'r') as f:
sample_dialogue = f.read()
sample_spk1 = 'Jennifer'
sample_spk2 = 'Keyana'
runner.run(sample_dialogue, sample_spk1, sample_spk2)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='arguments for generating multi-modal dialogues using LLM')
parser.add_argument('--run-id',
type=str,
default='vanilla',
help='the name of the directory where the output will be dumped')
parser.add_argument('--model',
type=str,
default='gpt-3.5-turbo',
help='which LLM to use')
parser.add_argument('--temperature',
type=float,
default=0.9,
help="control randomness: lowering results in less random completion")
parser.add_argument('--top-p',
type=float,
default=0.95,
help="nucleus sampling")
parser.add_argument('--frequency-penalty',
type=float,
default=1.0,
help="decreases the model's likelihood to repeat the same line verbatim")
parser.add_argument('--presence-penalty',
type=float,
default=0.6,
help="increases the model's likelihood to talk about new topics")
parser.add_argument('--max-tokens',
type=int,
default=1024,
help='maximum number of tokens to generate')
args = parser.parse_args()
main(args)