Skip to content

Commit

Permalink
Add subtitle evaluation to VideoMME (NVlabs#124)
Browse files Browse the repository at this point in the history
* update

* update

---------

Co-authored-by: Efficient-Large-Language-Model <156256291+Efficient-Large-Language-Model@users.noreply.github.com>
  • Loading branch information
Lyken17 and Efficient-Large-Language-Model authored Jul 10, 2024
1 parent e7677f4 commit 9950a97
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 21 deletions.
34 changes: 33 additions & 1 deletion llava/eval/video_mme/eval.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,32 @@ done
ckpt=${1:-/home/jasonlu/workspace/VILA-Internal/checkpoints/vila-yi-34b-intern-6b-stage2_5_r620_sft_more_r2}
mname=$(echo $ckpt | rev | cut -d "/" -f 1 | rev)
# convert the checkpoints


eval_output/VILA1.5-40b/video_mme/frames--1.json

python llava/eval/video_mme/video_eval.py \
--model-path Efficient-Large-Model/VILA1.5-40b \
--output_dir eval_output/VILA1.5-40b/video_mme \
--output_name frames--1.json \
-c

YOUR_RESULTS_FILE=eval_output/VILA1.5-40b/video_mme/frames--1_converted.json
VIDEO_DURATION_TYPE=short,medium,long
python llava/eval/video_mme/mme_calc.py \
--results_file $YOUR_RESULTS_FILE \
--video_duration_type $VIDEO_DURATION_TYPE \
--your_answer_key response_w/_sub

python llava/eval/video_mme/mme_calc.py \
--results_file $YOUR_RESULTS_FILE \
--video_duration_type $VIDEO_DURATION_TYPE \
--your_answer_key response_w/o_sub


ckpt=${1:-Efficient-Large-Model/VILA1.5-40b}
ckpt=${1:-/home/jasonlu/workspace/VILA-Internal/checkpoints/vila-yi-34b-intern-6b-stage2_5_r620_sft_more_r2}
mname=$(echo $ckpt | rev | cut -d "/" -f 1 | rev)
python llava/eval/video_mme/video_eval.py \
--model-path $ckpt \
--output_dir eval_output/$mname/video_mme \
Expand All @@ -34,7 +60,13 @@ YOUR_RESULTS_FILE=eval_output/$mname/video_mme/${mname}_converted.json
VIDEO_DURATION_TYPE=short,medium,long
python llava/eval/video_mme/mme_calc.py \
--results_file $YOUR_RESULTS_FILE \
--video_duration_type $VIDEO_DURATION_TYPE
--video_duration_type $VIDEO_DURATION_TYPE \
--your_answer_key response_w/_sub

python llava/eval/video_mme/mme_calc.py \
--results_file $YOUR_RESULTS_FILE \
--video_duration_type $VIDEO_DURATION_TYPE \
--your_answer_key response_w/o_sub

exit 0
python llava/eval/video_mme/mme_calc.py \
Expand Down
9 changes: 6 additions & 3 deletions llava/eval/video_mme/mme_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def eval_your_results(
- gt_answer_key (Optional[str]): Key to access the ground truth answer in the results file.
- your_answer_key (Optional[str]): Key to access your answer in the results file.
"""
key_name = your_answer_key.replace("response_", "")
ckpt_name = osp.basename(your_results_path).replace("_converted.json", "")
wandb_project = os.environ.get("WANDB_PROJECT", "VILA-evaluation")
wandb_name = os.environ.get("WANDB_NAME", ckpt_name)
Expand All @@ -120,7 +121,7 @@ def hash_path(fpath):
sha.update(fpath.encode())
return sha.hexdigest()[:8]

wandb_id = hash_path(osp.realpath(your_results_path))
wandb_id = hash_path(osp.realpath(your_results_path) + "2024")
# wandb.require("core")
wandb.init(
project=wandb_project,
Expand Down Expand Up @@ -233,7 +234,7 @@ def hash_path(fpath):

overall_acc = 100 * total_correct / total_answered if total_answered > 0 else 0

wandb.log({f"videomme/{video_type}": overall_acc})
wandb.log({f"videomme/{video_type}-{key_name}": overall_acc})
print(f"Overall: {100 * total_correct / total_answered if total_answered > 0 else 0 : .1f}%")

print("\n")
Expand Down Expand Up @@ -278,7 +279,7 @@ def hash_path(fpath):
total_answered = sum([sum([q_type_dict[video_type][q_type]["answered"] for q_type in TASK_CATEGORIES]) for video_type in video_types])
overall_acc = 100 * total_correct / total_answered if total_answered > 0 else 0

wandb.log({"videomme/entire": overall_acc})
wandb.log({f"videomme/entire-{key_name}": overall_acc})

print(f"Overall: {100 * total_correct / total_answered if total_answered > 0 else 0 : .1f}%")

Expand All @@ -287,6 +288,7 @@ def hash_path(fpath):
parser = argparse.ArgumentParser()
parser.add_argument("--results_file", type=str, required=True)
parser.add_argument("--video_duration_type", type=str, required=True)
parser.add_argument("--your_answer_key", type=str, default="response_w/o_sub")
parser.add_argument("--return_categories_accuracy", action="store_true")
parser.add_argument("--return_sub_categories_accuracy", action="store_true")
parser.add_argument("--return_task_types_accuracy", action="store_true")
Expand All @@ -296,6 +298,7 @@ def hash_path(fpath):

eval_your_results(
args.results_file,
your_answer_key=args.your_answer_key,
video_types=args.video_duration_type,
return_categories_accuracy=args.return_categories_accuracy,
return_sub_categories_accuracy=args.return_sub_categories_accuracy,
Expand Down
38 changes: 28 additions & 10 deletions llava/eval/video_mme/sbatch_eval.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@ idx=$SLURM_ARRAY_TASK_ID
total=$SLURM_ARRAY_TASK_COUNT
jname=seval-$idx-of-$total-random


ckpt=${1:-"Efficient-Large-Model/VILA1.5-3b"}
_model_name=$(echo $ckpt | rev | cut -d "/" -f 1 | rev)
# llava_v1
# hermes-2
model_name=${2:-"$_model_name"}
conv_mode=${3:-"hermes-2"}
temperature=${4:-"0.0"}
num_beams=${5:-1}
temperature=${temperature:-"0.0"}
num_beams=${num_beams:-1}

num_video_frames=${num_video_frames:-"-1"}

OUTDIR=slurm-logs/$ckpt
#_$wname
Expand All @@ -40,27 +43,42 @@ srun \
-e $OUTDIR/$jname.err -o $OUTDIR/$jname.out \
python llava/eval/video_mme/video_eval.py \
--model-path $ckpt --shard $idx --total $total --conv-mode $conv_mode \
--output_dir=./eval_output/$model_name/video_mme/ --output_name=$model_name --temperature $temperature --num-beams $num_beams
--output_dir=./eval_output/$model_name/video_mme/ --output_name=frames-$num_video_frames \
--num_video_frames $num_video_frames \
--temperature $temperature --num-beams $num_beams

exit 0

# usage examples

# debuging usage
python llava/data_aug/video_eval.py --model-path Efficient-Large-Model/VILA1.5-3b
python llava/eval/video_mme/video_eval.py --model-path Efficient-Large-Model/VILA1.5-3b

# sbatch launch
tmp=0
beam=1
sbatch -A nvr_elm_llm -p interactive,$SLURM_PARTITION -J videomme:vila-15-13b \
llava/eval/video_mme/sbatch_eval.sh Efficient-Large-Model/VILA1.5-3b VILA1.5-3b llava_v1 $tmp $beam
export temperature=0
export num_beams=1
# export num_video_frames=12
sbatch -A nvr_elm_llm -p interactive,$SLURM_PARTITION -J videomme:VILA1.5-40b \
llava/eval/video_mme/sbatch_eval.sh \
Efficient-Large-Model/VILA1.5-40b \
VILA1.5-40b \
hermes-2


sbatch llava/eval/video_mme/sbatch_eval.sh \
sbatch -A nvr_elm_llm -p interactive,$SLURM_PARTITION -J videomme:VILA1.5-3b \
llava/eval/video_mme/sbatch_eval.sh \
Efficient-Large-Model/VILA1.5-3b \
VILA1.5-3b \
llava_v1

sbatch -A nvr_elm_llm -p interactive,$SLURM_PARTITION -J videomme:vila-yi-34b-intern-6b-stage2_5_r620_sft_more_r2 \

sbatch -A nvr_elm_llm -p interactive,$SLURM_PARTITION -J videomme:VILA1.5-40b \
llava/eval/video_mme/sbatch_eval.sh \
Efficient-Large-Model/VILA1.5-40b \
VILA1.5-40b \
hermes-2

sbatch -A llmservice_nlp_fm -p interactive,$SLURM_PARTITION -J videomme:vila-yi-34b-intern-6b-stage2_5_r620_sft_more_r2 \
llava/eval/video_mme/sbatch_eval.sh \
/home/jasonlu/workspace/VILA-Internal/checkpoints/vila-yi-34b-intern-6b-stage2_5_r620_sft_more_r2 \
vila-yi-34b-intern-6b-stage2_5_r620_sft_more_r2 \
Expand Down
45 changes: 39 additions & 6 deletions llava/eval/video_mme/video_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,14 @@ def get_model_output(
The best answer is:
"""

template_wsub = r""" This video's subtitles are listed below:
{subtitle}
Select the best answer to the following multiple-choice question based on the video. Respond with only the letter (A, B, C, or D) of the correct option.
{question}
The best answer is:
"""



def get_path(root_path):
from huggingface_hub import repo_exists, snapshot_download
Expand Down Expand Up @@ -152,7 +160,7 @@ def eval_model(args):
if not args.output_name.endswith(".json"):
args.output_name += ".json"

if args.num_video_frames is None:
if args.num_video_frames is None or args.num_video_frames < 0:
root_path = osp.join(get_path(model_path))
args.num_video_frames = json.load(open(osp.join(root_path, "config.json")))["num_video_frames"]
print(
Expand All @@ -174,19 +182,24 @@ def eval_model(args):
jinfo = json.load(open("/home/ligengz/workspace/video-mme/Video-MME.json"))
folder = "/home/ligengz/workspace/video-mme/ytb_videos"

# videomme v2, released in june 20 2024
# videomme v2, updated in june 20 2024
jinfo = json.load(open("/lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/dataset/Video-MME/qa_old_format.json"))
folder = "/lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/dataset/Video-MME/videos"
subtitle_folder = "/home/ligengz/nvr_elm_llm/dataset/Video-MME/subtitle"

if args.convert:
for vmeta in jinfo:
for question in vmeta["questions"]:
qid = question["question_id"]
if qid in labeled_key:
question["response"] = labeled_key[qid]["response"]
# question["response"] = labeled_key[qid]["response"]
question["response_w/o_sub"] = labeled_key[qid]["response_w/o_sub"]
question["response_w/_sub"] = labeled_key[qid]["response_w/_sub"]
else:
# if not answered, using "C" as the default answer.
print("missing", qid)
question["response"] = "C"
question["response_w/o_sub"] = "C"
question["response_w/_sub"] = "C"
with open(answers_file.replace(".json", "_converted.json"), "w") as fp:
json.dump(jinfo, fp, indent=2)
return 0
Expand All @@ -212,18 +225,23 @@ def eval_model(args):
url = vmeta["url"]
video_id = vmeta["video_id"]
uid = osp.basename(url).split("?v=")[-1]

vpath = osp.join(folder, f"{uid}.mp4")
subpath = osp.join(subtitle_folder, f"{uid}.srt")

from llava.eval.video_mme.w_sub_eval import slice_frames
video_frames, video_subtitles = slice_frames(vpath, subpath, num_frames=args.num_video_frames)
if not osp.exists(vpath):
print("[video not downloaded] Skip", vpath)
continue

for questions in vmeta["questions"]:
qid = questions["question_id"]
if qid in labeled_key:
print("[question id answered] Skip", qid, url)
continue
qa = questions["question"] + "\n" + "Answer the question by only outputing the choice.\n" + "\n".join(questions["choices"])

qs = template.format(question=qa)
output = get_model_output(
model,
Expand All @@ -236,7 +254,21 @@ def eval_model(args):
num_beams=args.num_beams,
num_video_frames=args.num_video_frames,
)
questions["response"] = output
questions["response_w/o_sub"] = output

qs = template_wsub.format(question=qa, subtitle=video_subtitles)
output = get_model_output(
model,
image_processor,
tokenizer,
vpath,
qs,
conv_mode=args.conv_mode,
temperature=args.temperature,
num_beams=args.num_beams,
num_video_frames=args.num_video_frames,
)
questions["response_w/_sub"] = output
labeled_key[questions["question_id"]] = questions
# break
# output_json.append(vmeta)
Expand All @@ -249,6 +281,7 @@ def eval_model(args):
parser = argparse.ArgumentParser()
parser.add_argument("--num-beams", type=int, default=1)
parser.add_argument("-c", "--convert", action="store_true")
parser.add_argument("--with-sub", action="store_true")
parser.add_argument("--shard", type=int, default=0)
parser.add_argument("--total", type=int, default=-1)
parser.add_argument("--model-path", type=str, default="Efficient-Large-Model/VILA1.5-3b")
Expand Down
112 changes: 112 additions & 0 deletions llava/eval/video_mme/w_sub_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import pysubs2
import cv2
import numpy as np
import os
import shutil
from PIL import Image

def get_seq_frames(total_num_frames, desired_num_frames):
"""
Calculate the indices of frames to extract from a video.
Parameters:
total_num_frames (int): Total number of frames in the video.
desired_num_frames (int): Desired number of frames to extract.
Returns:
list: List of indices of frames to extract.
"""

# Calculate the size of each segment from which a frame will be extracted
seg_size = float(total_num_frames - 1) / desired_num_frames

seq = []
for i in range(desired_num_frames):
# Calculate the start and end indices of each segment
start = int(np.round(seg_size * i))
end = int(np.round(seg_size * (i + 1)))

# Append the middle index of the segment to the list
seq.append((start + end) // 2)

return seq

def create_frame_output_dir(output_dir):
"""
Create the output directory for storing the extracted frames.
Parameters:
output_dir (str): Path to the output directory.
"""

if not os.path.exists(output_dir):
os.makedirs(output_dir)
else:
shutil.rmtree(output_dir)
os.makedirs(output_dir)

def slice_frames(video_path, srt_path, num_frames=8):
"""
Extract frames from a video and save them to the output directory.
Parameters:
video_file_name (str): Path to the video file.
num_frames (int): Number of frames to extract.
res_path (str): Path to the output directory.
subtitles_file_name (str): Path to the subtitles file.
"""
# print(f"Extracting video: {video_path}")
# create_frame_output_dir(os.path.join(output_path, "frames"))

cv2_vr = cv2.VideoCapture(video_path)
duration = int(cv2_vr.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cv2_vr.get(cv2.CAP_PROP_FPS)
selected_frame_ids = get_seq_frames(duration, num_frames)

output_file_prefix = os.path.basename(video_path).replace(".", "_")

count = 0

pil_frames = []
# Note(ligeng): frames are loaded using VILA's function, here we only load subtitles
# while cv2_vr.isOpened():
# success, frame = cv2_vr.read()
# if not success:
# break
# if count in selected_frame_ids:
# min = int(count / fps) // 60
# sec = int(count / fps) % 60
# img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# im_pil = Image.fromarray(img)
# pil_frames.append(im_pil)
# # time_string = f"{min:02d}:{sec:02d}"
# # image_name = f"{output_file_prefix}_frame_{time_string}.jpg"
# # cv2.imwrite(f"{output_path}/frames/{image_name}", frame)
# count += 1

subtitles = ""
if srt_path and os.path.exists(srt_path):
subs = pysubs2.load(srt_path, encoding="utf-8")
subtitles = []

for seleced_frame_id in selected_frame_ids:
sub_text = ""
cur_time = pysubs2.make_time(fps=fps, frames=seleced_frame_id)
for sub in subs:
if sub.start < cur_time and sub.end > cur_time:
sub_text = sub.text.replace("\\N", " ")
break
if sub_text.strip():
subtitles.append(sub_text)
subtitles = "\n".join(subtitles)

return pil_frames, subtitles

if __name__ == "__main__":
print(slice_frames(
"/home/ligengz/nvr_elm_llm/dataset/Video-MME/videos/_8lBR0E_Tx8.mp4",
"/home/ligengz/nvr_elm_llm/dataset/Video-MME/subtitle/_8lBR0E_Tx8.srt",
num_frames=12
))
Loading

0 comments on commit 9950a97

Please sign in to comment.