Skip to content

Commit

Permalink
Merge pull request NVlabs#87 from yukang2017/main
Browse files Browse the repository at this point in the history
VILA Benchmark using GPT-4 for evaluation
  • Loading branch information
Efficient-Large-Language-Model authored Jun 6, 2024
2 parents 4507918 + 70588d2 commit 8dbee2c
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 30 deletions.
23 changes: 23 additions & 0 deletions llava/eval/video/convert_pred_to_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import os
import json

path_to_pred = "/home/ligengz/workspace/VILA-internal/vertex-ai-gemini-15_pexel_1k_new_prompt.json"
labels_json = json.load(open("/lustre/fs2/portfolios/nvr/users/yukangc/datasets/Video-Benchmark-Label-0605.json"))

videos = []
preds_dict = {}
for item in preds_json:
videos.append(item.split("/")[-1].split(".")[0])
preds_dict[item.split("/")[-1].split(".")[0]] = preds_json[item]

model = "Gemini"
pred_path = "./eval_output/%s"%model

output_json = []
for item in labels_json:
video_name = item['video_name']
item_output = item
item_output['pred'] = preds_dict[video_name]['output']
output_json.append(item_output)

json.dump(output_json, open(os.path.join(pred_path, "pred.json"), "w"))
28 changes: 18 additions & 10 deletions llava/eval/video/eval_benchmark_1_correctness.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
import openai
#import openai
import os
import argparse
import json
import ast
from multiprocessing.pool import Pool

from openai import AzureOpenAI

client = AzureOpenAI(
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version="2024-02-01",
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT")
)

def parse_args():
parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3")
parser.add_argument("--pred_path", required=True, help="The path to file containing prediction.")
parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.")
parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.")
parser.add_argument("--api_key", required=True, help="OpenAI API key.")
parser.add_argument("--api_key", help="OpenAI API key.")
parser.add_argument("--api_base", default="", type=str, help="OpenAI API base.")
parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.")
args = parser.parse_args()
Expand All @@ -24,9 +31,6 @@ def annotate(prediction_set, caption_files, output_dir, args):
Returns a score for correctness.
"""
# Set the OpenAI API key.
openai.api_key = args.api_key
if args.api_base is not None:
openai.api_base = args.api_base
for file in caption_files:
key = file[:-5] # Strip file extension
qa_set = prediction_set[key]
Expand All @@ -35,8 +39,8 @@ def annotate(prediction_set, caption_files, output_dir, args):
pred = qa_set['pred']
try:
# Compute the correctness score
completion = openai.chat.completions.create(
model="gpt-3.5-turbo",
completion = client.chat.completions.create(
model="gpt-4",
messages=[
{
"role": "system",
Expand Down Expand Up @@ -66,6 +70,7 @@ def annotate(prediction_set, caption_files, output_dir, args):
)
# Convert response to a Python dictionary.
response_message = completion.choices[0].message.content
#response_message = completion["choices"][0]["message"]["content"]
response_dict = ast.literal_eval(response_message)
result_qa_pair = [response_dict, qa_set]

Expand Down Expand Up @@ -124,7 +129,7 @@ def main():
prediction_set[id] = qa_set

# Set the OpenAI API key.
openai.api_key = args.api_key
#openai.api_key = args.api_key
num_tasks = args.num_tasks

# While loop to ensure that all captions are processed.
Expand All @@ -150,8 +155,11 @@ def main():
task_args = [(prediction_set, part, args.output_dir, args) for part in all_parts]

# Use a pool of workers to process the files in parallel.
with Pool() as pool:
pool.starmap(annotate, task_args)
#with Pool() as pool:
# pool.starmap(annotate, task_args)
from tqdm import tqdm
for task_arg in tqdm(task_args):
annotate(*task_arg)

except Exception as e:
print(f"Error: {e}")
Expand Down
27 changes: 17 additions & 10 deletions llava/eval/video/eval_benchmark_2_detailed_orientation.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
# This file is originated from: https://github.com/mbzuai-oryx/Video-ChatGPT

import openai
#import openai
import os
import argparse
import json
import ast
from multiprocessing.pool import Pool
from openai import AzureOpenAI

client = AzureOpenAI(
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version="2024-02-01",
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT")
)

def parse_args():
parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3")
parser.add_argument("--pred_path", required=True, help="The path to file containing prediction.")
parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.")
parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.")
parser.add_argument("--api_key", required=True, help="OpenAI API key.")
parser.add_argument("--api_key", help="OpenAI API key.")
parser.add_argument("--api_base", default="", type=str, help="OpenAI API base.")
parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.")
args = parser.parse_args()
Expand All @@ -26,9 +32,6 @@ def annotate(prediction_set, caption_files, output_dir, args):
returns a score for detailed orientation.
"""
# Set the OpenAI API key.
openai.api_key = args.api_key
if args.api_base is not None:
openai.api_base = args.api_base
for file in caption_files:
key = file[:-5] # Strip file extension
qa_set = prediction_set[key]
Expand All @@ -37,8 +40,8 @@ def annotate(prediction_set, caption_files, output_dir, args):
pred = qa_set['pred']
try:
# Compute the detailed-orientation score
completion = openai.chat.completions.create(
model="gpt-3.5-turbo",
completion = client.chat.completions.create(
model="gpt-4",
messages=[
{
"role": "system",
Expand Down Expand Up @@ -68,6 +71,7 @@ def annotate(prediction_set, caption_files, output_dir, args):
)
# Convert response to a Python dictionary.
response_message = completion.choices[0].message.content
#response_message = completion["choices"][0]["message"]["content"]
response_dict = ast.literal_eval(response_message)
result_qa_pair = [response_dict, qa_set]

Expand Down Expand Up @@ -126,7 +130,7 @@ def main():
prediction_set[id] = qa_set

# Set the OpenAI API key.
openai.api_key = args.api_key
#openai.api_key = args.api_key
num_tasks = args.num_tasks

# While loop to ensure that all captions are processed.
Expand All @@ -152,8 +156,11 @@ def main():
task_args = [(prediction_set, part, args.output_dir, args) for part in all_parts]

# Use a pool of workers to process the files in parallel.
with Pool() as pool:
pool.starmap(annotate, task_args)
#with Pool() as pool:
# pool.starmap(annotate, task_args)
from tqdm import tqdm
for task_arg in tqdm(task_args):
annotate(*task_arg)

except Exception as e:
print(f"Error: {e}")
Expand Down
30 changes: 20 additions & 10 deletions llava/eval/video/eval_benchmark_3_context.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
# This file is originated from: https://github.com/mbzuai-oryx/Video-ChatGPT

import openai
#import openai
import os
import argparse
import json
import ast
from multiprocessing.pool import Pool
from tqdm import tqdm

from openai import AzureOpenAI

client = AzureOpenAI(
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version="2024-02-01",
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT")
)


def parse_args():
parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3")
parser.add_argument("--pred_path", required=True, help="The path to file containing prediction.")
parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.")
parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.")
parser.add_argument("--api_key", required=True, help="OpenAI API key.")
parser.add_argument("--api_key", help="OpenAI API key.")
parser.add_argument("--api_base", default="", type=str, help="OpenAI API base.")
parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.")
args = parser.parse_args()
Expand All @@ -26,9 +35,6 @@ def annotate(prediction_set, caption_files, output_dir, args):
returns a score for contextual understanding.
"""
# Set the OpenAI API key.
openai.api_key = args.api_key
if args.api_base is not None:
openai.api_base = args.api_base
for file in caption_files:
key = file[:-5] # Strip file extension
qa_set = prediction_set[key]
Expand All @@ -37,8 +43,8 @@ def annotate(prediction_set, caption_files, output_dir, args):
pred = qa_set['pred']
try:
# Compute the contextual understanding score
completion = openai.chat.completions.create(
model="gpt-3.5-turbo",
completion = client.chat.completions.create(
model="gpt-4",
messages=[
{
"role": "system",
Expand Down Expand Up @@ -68,6 +74,7 @@ def annotate(prediction_set, caption_files, output_dir, args):
)
# Convert response to a Python dictionary.
response_message = completion.choices[0].message.content
#response_message = completion["choices"][0]["message"]["content"]
response_dict = ast.literal_eval(response_message)
result_qa_pair = [response_dict, qa_set]

Expand Down Expand Up @@ -126,7 +133,7 @@ def main():
prediction_set[id] = qa_set

# Set the OpenAI API key.
openai.api_key = args.api_key
#openai.api_key = args.api_key
num_tasks = args.num_tasks

# While loop to ensure that all captions are processed.
Expand All @@ -152,8 +159,11 @@ def main():
task_args = [(prediction_set, part, args.output_dir, args) for part in all_parts]

# Use a pool of workers to process the files in parallel.
with Pool() as pool:
pool.starmap(annotate, task_args)
#with Pool() as pool:
# pool.starmap(annotate, task_args)

for task_arg in tqdm(task_args):
annotate(*task_arg)

except Exception as e:
print(f"Error: {e}")
Expand Down

0 comments on commit 8dbee2c

Please sign in to comment.