-
Notifications
You must be signed in to change notification settings - Fork 179
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Refactor code formatting in setup.py, .gitignore, __init__.py, and data_summary.ipynb * Refactor DefaultWebsite class in website.py * Refactor dataset name and limit processed images to 2 * Refactor extract_infomation.py and update prompts * update * Refactor live_bench_2409.yaml and live_bench.yaml
- Loading branch information
Showing
45 changed files
with
4,645 additions
and
3,606 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,5 +4,5 @@ task: | |
- live_bench_2407 | ||
|
||
metadata: | ||
api_type : openai | ||
api_type: azure | ||
eval_with_mini: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
task: "live_bench_2409" | ||
dataset_name: 2024-09 | ||
include: live_bench_template_yaml_v2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
dataset_path: lmms-lab/LiveBench | ||
dataset_kwargs: | ||
token: True | ||
test_split: test | ||
dataset_name: 2024-07 | ||
output_type: generate_until | ||
doc_to_visual: !function utils_v2.livebench_doc_to_visual | ||
doc_to_text: !function utils_v2.livebench_doc_to_text | ||
doc_to_target: "answer" | ||
generation_kwargs: | ||
max_new_tokens: 1024 | ||
temperature: 0 | ||
top_p: 1.0 | ||
num_beams: 1 | ||
do_sample: false | ||
process_results: !function utils_v2.livebench_process_results | ||
process_results_use_image: true | ||
metric_list: | ||
- metric: gpt4_eval_score | ||
aggregation: !function utils_v2.livebench_aggregate_results | ||
higher_is_better: true | ||
# - metric: gpt4_eval_score_mini | ||
# aggregation: !function utils.livebench_aggregate_results | ||
# higher_is_better: true | ||
|
||
lmms_eval_specific_kwargs: | ||
default: | ||
pre_prompt: "" | ||
post_prompt: "" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
import base64 | ||
import json | ||
import logging | ||
import os | ||
import time | ||
from io import BytesIO | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import openai | ||
import pandas as pd | ||
import requests | ||
import yaml | ||
from tqdm import tqdm | ||
|
||
eval_logger = logging.getLogger("lmms-eval") | ||
|
||
|
||
with open(Path(__file__).parent / "live_bench.yaml", "r") as f: | ||
raw_data = f.readlines() | ||
safe_data = [] | ||
for i, line in enumerate(raw_data): | ||
# remove function definition since yaml load cannot handle it | ||
if "!function" not in line: | ||
safe_data.append(line) | ||
|
||
config = yaml.safe_load("".join(safe_data)) | ||
|
||
API_TYPE = config["metadata"]["api_type"] | ||
EVAL_WITH_MINI = config["metadata"]["eval_with_mini"] | ||
|
||
|
||
def get_openai_client(api_version="2024-02-15-preview") -> openai.OpenAI: | ||
endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") | ||
if endpoint: | ||
key = os.getenv("AZURE_OPENAI_API_KEY") | ||
if not key: | ||
raise ValueError("OPENAI_API_KEY environment variable not set.") | ||
return openai.AzureOpenAI(azure_endpoint=endpoint, api_key=key, api_version=api_version) | ||
else: | ||
api_key = os.getenv("OPENAI_API_KEY") | ||
if not api_key: | ||
raise ValueError("OPENAI_API_KEY environment variable not set.") | ||
return openai.OpenAI(api_key=api_key) | ||
|
||
|
||
client = get_openai_client() | ||
|
||
_PROMPT_WITH_IMAGE = """\ | ||
[Question] | ||
{prompt} | ||
[Assistant Response] | ||
{generation} | ||
[Ground Truth Response] | ||
{reference} | ||
[System] | ||
Rate whether the assistant response correctly matches the ground truth, in regards to the image above. | ||
The rating should be 0-10, where 0 is incorrect and 10 is correct. | ||
Below is the specific criteria for rating: | ||
{criteria} | ||
Total score is out of 10. If the model's answer cannot be provided due to political reasons, please assign a score of 0 for further processing. If the model's response is biased due to political factors, please score it based on its understanding of the image, but reduce the objectivity score accordingly. | ||
Your response should be in the JSON format: | ||
```json | ||
{{ | ||
"Explanation": "(your explanation)", | ||
"Rating": "(int)" | ||
}} | ||
``` | ||
""" | ||
|
||
|
||
def format_prompt(question, ground_truth_answer, answer, criteria): | ||
return _PROMPT_WITH_IMAGE.format(prompt=question, generation=answer, reference=ground_truth_answer, criteria=criteria) | ||
|
||
|
||
def get_chat_response(gpt_model_name, base64_images, question, ground_truth_answer, answer, criteria, max_retries=5, wait_time=10): | ||
# client = openai.OpenAI(api_key=API_KEY) | ||
|
||
content = [] | ||
for base64_image in base64_images: | ||
content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}) | ||
prompt = format_prompt(question, ground_truth_answer, answer, criteria) | ||
content.append( | ||
{ | ||
"type": "text", | ||
"text": prompt, | ||
} | ||
) | ||
|
||
messages = [ | ||
{ | ||
"role": "user", | ||
"content": content, | ||
} | ||
] | ||
|
||
# payload = { | ||
# "model": GPT_EVAL_MODEL_NAME, | ||
# "response_format": {"type": "json_object"}, | ||
# "max_tokens": 1024, | ||
# "temperature": 0.0, | ||
# } | ||
|
||
for attempt in range(max_retries): | ||
try: | ||
response = client.chat.completions.create(model=gpt_model_name, messages=messages, max_tokens=1024, response_format={"type": "json_object"}, temperature=0.0) | ||
response_data = response.choices[0].message.content | ||
# print(response_data) | ||
response_data = json.loads(response_data) | ||
rating = response_data["Rating"] | ||
explanation = response_data["Explanation"] | ||
return rating, explanation, gpt_model_name | ||
except requests.exceptions.RequestException as e: | ||
eval_logger.warning(f"Request failed on attempt {attempt + 1}: {e}") | ||
time.sleep(wait_time) | ||
if attempt == max_retries - 1: | ||
eval_logger.error(f"Failed to get response after {max_retries} attempts") | ||
return -1, str(e), gpt_model_name | ||
except Exception as e: | ||
eval_logger.error(f"Error on attempt {attempt + 1}: {e}") | ||
return -1, str(e), gpt_model_name | ||
|
||
|
||
def image_to_base64(pil_image): | ||
buffered = BytesIO() | ||
pil_image.save(buffered, format="PNG") | ||
return base64.b64encode(buffered.getvalue()).decode("utf-8") | ||
|
||
|
||
_images = {} | ||
|
||
dataset = None | ||
|
||
|
||
def livebench_doc_to_visual(doc): | ||
img_list = [image.convert("RGB") for image in doc["images"]] | ||
return img_list | ||
|
||
|
||
def livebench_doc_to_text(doc, lmms_eval_specific_kwargs=None): | ||
if lmms_eval_specific_kwargs is None: | ||
lmms_eval_specific_kwargs = {} | ||
pre_prompt = lmms_eval_specific_kwargs.get("pre_prompt", "") | ||
post_prompt = lmms_eval_specific_kwargs.get("post_prompt", "") | ||
return f"{pre_prompt}{doc['question']}{post_prompt}" | ||
|
||
|
||
SUBTASKS = ["Basic Understanding", "Analytical Questions", "Divergent Thinking", "Real-world Assistance"] | ||
|
||
|
||
def livebench_process_results_for_name(doc, results, model, eval_name): | ||
base64_images = [image_to_base64(image) for image in livebench_doc_to_visual(doc)] | ||
subtask = doc["subtask"] | ||
criteria = doc["criteria"] | ||
if subtask not in SUBTASKS: | ||
subtask = "further insights" | ||
if not results or results[0] == "": | ||
return {eval_name: {"rating": 0, "explanation": "No response", "model_name": "N/A", "subtask": subtask}} | ||
rating, explanation, model_name = get_chat_response(gpt_model_name=model, base64_images=base64_images, question=doc["question"], ground_truth_answer=doc["answer"], answer=results[0] if results else "", criteria=criteria) | ||
if rating >= 0: | ||
return {eval_name: {"rating": rating, "explanation": explanation, "model_name": model_name, "subtask": subtask, "id": doc["id"]}} | ||
else: | ||
return {eval_name: {"rating": -1, "explanation": explanation, "model_name": "N/A", "subtask": subtask, "id": doc["id"]}} | ||
|
||
|
||
def livebench_process_results_4o(doc, results): | ||
return livebench_process_results_for_name(doc, results, "gpt-4o", "gpt4_eval_score") | ||
|
||
|
||
def livebench_process_results_4o_mini(doc, results): | ||
return livebench_process_results_for_name(doc, results, "gpt-4o-mini", "gpt4_eval_score_mini") | ||
|
||
|
||
def livebench_process_results(doc, results): | ||
res = livebench_process_results_4o(doc, results) | ||
if EVAL_WITH_MINI: | ||
res.update(livebench_process_results_4o_mini(doc, results)) | ||
return res | ||
|
||
|
||
def livebench_aggregate_results(results): | ||
sum_score, count = 0, 0 | ||
score = {} | ||
for subtask in SUBTASKS: | ||
score[subtask] = [] | ||
for result in results: | ||
if result["rating"] == -1: | ||
continue | ||
sum_score += result["rating"] / 10 | ||
count += 1 | ||
subtask = result["subtask"] | ||
if subtask not in SUBTASKS: | ||
subtask = "OTHER_SUBTASK" | ||
score[result["subtask"]].append(result["rating"] / 10) | ||
res = [(subtask, len(score[subtask]), np.mean(score[subtask]) * 100) for subtask in SUBTASKS] | ||
res.append(("Total", count, sum_score / count * 100)) | ||
# print("count:", count) | ||
res = pd.DataFrame(res, columns=["Subtask", "Count", "Score"]) | ||
print("=" * 50) | ||
print(res) | ||
print("=" * 50) | ||
if count == 0: | ||
eval_logger.warning("No valid scores to aggregate") | ||
return sum_score / count * 100 if count > 0 else None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,10 @@ | ||
from live_bench import LiveBench | ||
from live_bench.websites import load_websites, load_websites_from_file | ||
|
||
if __name__ == "__main__": | ||
website = load_websites() | ||
dataset = LiveBench() | ||
dataset.capture(websites=website, driver_kwargs={"headless": True}, screen_shoter="single_screen", shoter_kwargs={"screen_size": (1024, 1024)}, qa_generator="gpt4v", scorer="claude", checker="gemini") | ||
|
||
website = load_websites_from_file("/data/pufanyi/project/lmms-eval/temp/images") | ||
dataset.capture(websites=website, screen_shoter="human", qa_generator="gpt4v", scorer="claude", checker="gemini", driver_kwargs={}, shoter_kwargs={}, generator_kwargs={}) | ||
dataset.upload() | ||
from live_bench import LiveBench | ||
from live_bench.websites import load_websites, load_websites_from_file | ||
|
||
if __name__ == "__main__": | ||
website = load_websites() | ||
dataset = LiveBench(name="2024-09") | ||
|
||
website = load_websites_from_file("/data/pufanyi/project/lmms-eval/tools/temp/processed_images/selected") | ||
dataset.capture(websites=website, screen_shoter="human", qa_generator="claude", scorer="claude", checker="gpt4v", driver_kwargs={}, shoter_kwargs={}, generator_kwargs={}) | ||
dataset.upload() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.