Skip to content

Commit

Permalink
Merge pull request #130 from lscpku/vitatecs
Browse files Browse the repository at this point in the history
Add task VITATECS
  • Loading branch information
Luodian authored Jul 1, 2024
2 parents 383e7fe + a752259 commit 11fd7e3
Show file tree
Hide file tree
Showing 9 changed files with 320 additions and 0 deletions.
9 changes: 9 additions & 0 deletions lmms_eval/tasks/vitatecs/_default_template_yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
dataset_path: lscpku/VITATECS
dataset_kwargs:
token: True
video: True
cache_dir: vitatecs
model_specific_prompt_kwargs:
default:
pre_prompt: ""
post_prompt: "\nPlease response with a single letter (A or B):"
8 changes: 8 additions & 0 deletions lmms_eval/tasks/vitatecs/_vitatecs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
group: vitatecs
task:
- vitatecs_direction
- vitatecs_intensity
- vitatecs_sequence
- vitatecs_compositionality
- vitatecs_localization
- vitatecs_type
225 changes: 225 additions & 0 deletions lmms_eval/tasks/vitatecs/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
from decord import VideoReader, cpu
import numpy as np
import os
import sys
import datetime
import lmms_eval.tasks._task_utils.file_utils as file_utils
import json
import logging
import yaml
from pathlib import Path

import requests
import openai
from openai import OpenAI
import time
import ast
from tqdm import tqdm
import random

import re

with open(Path(__file__).parent / "_default_template_yaml", "r") as f:
raw_data = f.readlines()
safe_data = []
for i, line in enumerate(raw_data):
# remove function definition since yaml load cannot handle it
if "!function" not in line:
safe_data.append(line)

config = yaml.safe_load("".join(safe_data))


API_TYPE = os.getenv("API_TYPE", "openai")

if API_TYPE == "openai":
API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions")
API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY")
headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json",
}

# We will unzip all the zip files
# To HF HOME cache dir
# And load it here
HF_HOME = os.environ["HF_HOME"]
cache_dir = config["dataset_kwargs"]["cache_dir"]
cache_dir = os.path.join(HF_HOME, cache_dir)

eval_logger = logging.getLogger("lmms-eval")


# Pass in video path here
# Can only work correctly with video llm
def vitatecs_doc_to_visual(doc):
video_path = os.path.join(cache_dir, doc["src_dataset"], doc["video_name"])
if os.path.exists(video_path):
video_path = video_path
else:
sys.exit(f"video path:{video_path} does not exist, please check")
return [video_path]


# This is the place where you format your question
def vitatecs_doc_to_text(doc, model_specific_prompt_kwargs=None):
if model_specific_prompt_kwargs is None:
model_specific_prompt_kwargs = {}
pre_prompt = ""
post_prompt = ""
if "pre_prompt" in model_specific_prompt_kwargs:
pre_prompt = model_specific_prompt_kwargs["pre_prompt"]
if "post_prompt" in model_specific_prompt_kwargs:
post_prompt = model_specific_prompt_kwargs["post_prompt"]

question, _, _ = format_question_and_answer(doc)
return f"{pre_prompt}{question}{post_prompt}"


def process_option_for_question(sent):
if not sent.endswith("."):
sent += "."
return sent.capitalize()


def process_option_for_matching(sent):
if sent.endswith("."):
sent = sent[:-1]
return sent.lower()


def format_question_and_answer(doc):
seed = sum(ord(c) for c in doc['caption'] + doc['counterfactual']) % 100
random.seed(seed)
if random.random() > 0.5:
option_a = process_option_for_question(doc['caption'])
option_b = process_option_for_question(doc['counterfactual'])
answer = "(A) " + option_a
else:
option_a = process_option_for_question(doc['counterfactual'])
option_b = process_option_for_question(doc['caption'])
answer = "(B) " + option_b
options = [process_option_for_matching(doc['caption']), process_option_for_matching(doc['counterfactual'])]

question = f"Which of the following best describes the content of the video: \n(A) {option_a} \n(B) {option_b}"
return question, answer, options


def vitatecs_doc_to_answer(doc):
_, answer, _ = format_question_and_answer(doc)
return answer


# Process result
def vitatecs_process_results(doc, result):
pred = result[0]
rating = 0
match_success = True
chatgpt_response = None
question, answer, options = format_question_and_answer(doc)

# Some hand-crafted matching rules
if options[0] in pred.lower() and options[1] not in pred.lower():
rating = 1
elif options[1] in pred.lower() and options[0] not in pred.lower():
rating = 0
elif pred in ["A", "B"]:
rating = 1 if pred == answer[1] else 0
elif any(pred.startswith(prefix) for prefix in ["A.", "B."]):
rating = 1 if pred.split(".")[0] == answer[1] else 0
elif any(pred.startswith(prefix) for prefix in ["A)", "B)"]):
rating = 1 if pred.split(")")[0] == answer[1] else 0
elif any(pred.startswith(prefix) for prefix in ["(A)", "(B)"]):
rating = 1 if pred.split(")")[1] == answer[1] else 0
else:
# Fail to match answer in the video-llm response. Use ChatGPT to evaluate.
match_success = False

base_prompt = """You will receive a caption matching question, the ground-truth answer and the prediction from a question answering (QA) model. Your task is to determine whether QA model prediction is correct, based on the question and ground-truth answer. If the prediction is correct, respond "Correct". If the prediction is incorrect, respond "Incorrect". """
prompt = f"""{base_prompt}\n\nCaption Matching Question: {question}\n\nGround-Truth Answer: {answer}\n\nModel Prediction: {pred}"""
chatgpt_response, rating = get_eval_result(prompt)

if not match_success:
return {
"accuracy": {
"src_dataset": doc["src_dataset"],
"video_id": doc["video_name"],
"question": question,
"gt-answer": answer,
"video-llm-prediction": pred,
"match_success": match_success,
"rating": rating,
# "chatgpt_prompt": prompt,
"chatgpt_response": chatgpt_response,
"aspect": doc["aspect"],
},
}
else:
return {
"accuracy": {
"src_dataset": doc["src_dataset"],
"video_id": doc["video_name"],
"question": question,
"gt-answer": answer,
"video-llm-prediction": pred,
"match_success": match_success,
"rating": rating,
"aspect": doc["aspect"],
},
}


# utils function for gpt_evaluation when rule-based matching is unsuccessful
def get_eval_result(prompt, maxtry=10, sys_prompt=None):
llm_output = None
while True:
try:
llm_output = get_llm_output(prompt, sys_prompt)
rating = llm_output_to_rating(llm_output)
return llm_output, rating
except:
if maxtry <= 0:
return llm_output, 0
maxtry -= 1
print(f"Not success! {maxtry} retries remaining...")
time.sleep(random.uniform(1, 2))


# utils function for gpt evaluation
def get_llm_output(prompt, sys_prompt, max_tokens=128):
if sys_prompt is None:
sys_prompt = "You are an AI assistant for question answering."
data = {"max_tokens": max_tokens, "model": "gpt-3.5-turbo-1106", "temperature": 1.0, "top_p": 1, "presence_penalty": 1, "messages": [{"role": "system", "content": sys_prompt}, {"role": "user", "content": prompt}]}
response = requests.post(API_URL, headers=headers, data=json.dumps(data).encode("utf-8"))
result = response.content.decode("utf-8")
dict_result = json.loads(result)
llm_output = dict_result["choices"][0]["message"]["content"].strip()
return llm_output


# utils function that converts gpt evaluation into rating
def llm_output_to_rating(llm_output):
assert "Correct" in llm_output or "Incorrect" in llm_output
if llm_output.startswith("Correct"):
rating = 1
elif llm_output.startswith("Incorrect"):
rating = 0
elif ("Correct" in llm_output) and ("Incorrect" not in llm_output):
rating = 1
elif "Incorrect" in llm_output:
rating = 0
return rating


# Factory into different aggregate
def vitatecs_aggregate_rating(results, args):
yes_count = 0

# results is a list of dict
for answer_dict in results:
if answer_dict["rating"] == 1:
yes_count += 1

accuracy = yes_count / len(results)

return accuracy * 100
13 changes: 13 additions & 0 deletions lmms_eval/tasks/vitatecs/vitatecs_compositionality.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
dataset_name: "Compositionality"
task: "vitatecs_compositionality"
test_split: test
output_type: generate_until
doc_to_visual: !function utils.vitatecs_doc_to_visual
doc_to_text: !function utils.vitatecs_doc_to_text
doc_to_target: !function utils.vitatecs_doc_to_answer
process_results: !function utils.vitatecs_process_results
metric_list:
- metric: accuracy
aggregation: !function utils.vitatecs_aggregate_rating
higher_is_better: true
include: _default_template_yaml
13 changes: 13 additions & 0 deletions lmms_eval/tasks/vitatecs/vitatecs_direction.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
dataset_name: "Direction"
task: "vitatecs_direction"
test_split: test
output_type: generate_until
doc_to_visual: !function utils.vitatecs_doc_to_visual
doc_to_text: !function utils.vitatecs_doc_to_text
doc_to_target: !function utils.vitatecs_doc_to_answer
process_results: !function utils.vitatecs_process_results
metric_list:
- metric: accuracy
aggregation: !function utils.vitatecs_aggregate_rating
higher_is_better: true
include: _default_template_yaml
13 changes: 13 additions & 0 deletions lmms_eval/tasks/vitatecs/vitatecs_intensity.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
dataset_name: "Intensity"
task: "vitatecs_intensity"
test_split: test
output_type: generate_until
doc_to_visual: !function utils.vitatecs_doc_to_visual
doc_to_text: !function utils.vitatecs_doc_to_text
doc_to_target: !function utils.vitatecs_doc_to_answer
process_results: !function utils.vitatecs_process_results
metric_list:
- metric: accuracy
aggregation: !function utils.vitatecs_aggregate_rating
higher_is_better: true
include: _default_template_yaml
13 changes: 13 additions & 0 deletions lmms_eval/tasks/vitatecs/vitatecs_localization.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
dataset_name: "Localization"
task: "vitatecs_localization"
test_split: test
output_type: generate_until
doc_to_visual: !function utils.vitatecs_doc_to_visual
doc_to_text: !function utils.vitatecs_doc_to_text
doc_to_target: !function utils.vitatecs_doc_to_answer
process_results: !function utils.vitatecs_process_results
metric_list:
- metric: accuracy
aggregation: !function utils.vitatecs_aggregate_rating
higher_is_better: true
include: _default_template_yaml
13 changes: 13 additions & 0 deletions lmms_eval/tasks/vitatecs/vitatecs_sequence.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
dataset_name: "Sequence"
task: "vitatecs_sequence"
test_split: test
output_type: generate_until
doc_to_visual: !function utils.vitatecs_doc_to_visual
doc_to_text: !function utils.vitatecs_doc_to_text
doc_to_target: !function utils.vitatecs_doc_to_answer
process_results: !function utils.vitatecs_process_results
metric_list:
- metric: accuracy
aggregation: !function utils.vitatecs_aggregate_rating
higher_is_better: true
include: _default_template_yaml
13 changes: 13 additions & 0 deletions lmms_eval/tasks/vitatecs/vitatecs_type.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
dataset_name: "Type"
task: "vitatecs_type"
test_split: test
output_type: generate_until
doc_to_visual: !function utils.vitatecs_doc_to_visual
doc_to_text: !function utils.vitatecs_doc_to_text
doc_to_target: !function utils.vitatecs_doc_to_answer
process_results: !function utils.vitatecs_process_results
metric_list:
- metric: accuracy
aggregation: !function utils.vitatecs_aggregate_rating
higher_is_better: true
include: _default_template_yaml

0 comments on commit 11fd7e3

Please sign in to comment.