Skip to content

Commit

Permalink
Added caption ordering inference logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Lookuz committed Nov 21, 2024
1 parent 63c1c41 commit 97711e4
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 25 deletions.
171 changes: 146 additions & 25 deletions pipelines/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class VidHalInferencePipeline:
1. Formatting of prompt to be provided to the model.
2. Generation of response from the given prompt and video.
"""
system_prompt_instruction = ""
main_prompt_instruction = ""
def __init__(
self,
model,
Expand Down Expand Up @@ -65,27 +67,9 @@ def generate_response(
response (str) : Response generated by the model.
"""
raise NotImplementedError

def run(self, save_path=None):
raise NotImplementedError

class VidHalMCQAInferencePipeline(VidHalInferencePipeline):
system_prompt_instruction = "You are provided with a video and a set of several captions. " \
"Your task is to watch the video provided carefully, and select the caption that best describes the video. " \
"Provide your answer only as a single letter representing the option whose caption that best describes the video, without any explanation."
main_prompt_instruction = "Watch the video provided, and choose the option whose caption describes the video most accurately."
def __init__(self, model, dataset, generation_config={}, *args, **kwargs):
super().__init__(model, dataset, generation_config, *args, **kwargs)

def process_response(self, response):
"""
Parses the generated response to extract only the selected option.
"""
last_option = list(string.ascii_uppercase)[self.num_captions - 1]
match = re.search(fr"\b[a-{last_option.lower()}A-{last_option}]\b", response)
match = match.group(0).upper().strip(";:., ") if match else None

return match if match else response # If no match, keep original response in case model replies with caption instead of option
return response

def run(self, save_path=None):
responses = {}
Expand All @@ -112,7 +96,25 @@ def run(self, save_path=None):
with open(save_path, "r") as f:
json.dump(responses, f, indent=4)

class VidHalRelativeCOInferencePipeline(VidHalMCQAInferencePipeline):
class VidHalMCQAInferencePipeline(VidHalInferencePipeline):
system_prompt_instruction = "You are provided with a video and a set of several captions. " \
"Your task is to watch the video provided carefully, and select the caption that best describes the video. " \
"Provide your answer only as a single letter representing the option whose caption that best describes the video, without any explanation."
main_prompt_instruction = "Watch the video provided, and choose the option whose caption describes the video most accurately."
def __init__(self, model, dataset, generation_config={}, *args, **kwargs):
super().__init__(model, dataset, generation_config, *args, **kwargs)

def process_response(self, response):
"""
Parses the generated response to extract only the selected option.
"""
last_option = list(string.ascii_uppercase)[self.num_captions - 1]
match = re.search(fr"\b[a-{last_option.lower()}A-{last_option}]\b", response)
match = match.group(0).upper().strip(";:., ") if match else None

return match if match else response # If no match, keep original response in case model replies with caption instead of option

class VidHalRelativeOrderingInferencePipeline(VidHalMCQAInferencePipeline):
def reorder_options(self, captions, option_to_rank):
"""
Re-orders the option prefixes (A, B, C) if there are less then the total number of captions presented to the model
Expand All @@ -132,10 +134,10 @@ def reorder_options(self, captions, option_to_rank):
return option_to_rank

def prompt_paired_question(self, video, captions, options, option_to_rank):
# Reorder keys (e.g. A, C -> A, B), track mapping
# Reorder keys (e.g. A, C -> A, B) and track the mapping
display_options = list(string.ascii_uppercase)[:len(options)]
remapped_option_to_rank = {display_options[i] : rank for i, (_, rank) in enumerate(sorted(list(option_to_rank.items(), lambda x : x[0])))}
remapped_to_original = {display_options[i] : option for i, (option, _) in enumerate(sorted(list(option_to_rank.items(), lambda x : x[0])))}
remapped_option_to_rank = {display_options[i] : option_to_rank[option] for i, option in enumerate(options)}
remapped_to_original = {display_options[i] : option for i, option in enumerate(options)}

# Format prompt and generate response
options_prompt = self.format_options_prompt(captions=captions, option_to_rank=remapped_option_to_rank)
Expand All @@ -153,7 +155,7 @@ def prompt_paired_question(self, video, captions, options, option_to_rank):
return response

def prompt_relative_ordering(self, video, video_id, captions):
order = []
overall_order = []
# Transform from rank -> caption to option -> caption
option_to_rank = self.option_display_order[video_id]
rank_to_option = {v : k for k, v in option_to_rank.items()}
Expand All @@ -163,11 +165,130 @@ def prompt_relative_ordering(self, video, video_id, captions):
options = sorted(list(captions.keys()))
for option_A, option_B in zip(options, options[1:]):
response = self.prompt_paired_question(video, captions, [option_A, option_B], option_to_rank)
# Assign incorrect order if response is invalid or incorrect
correct_order = [x[0] for x in sorted([
(option_A, option_to_rank[option_A]), (option_B, option_to_rank[option_B])
], lambda x : int(x[-1]))]
correct_answer = correct_order[0]
relative_order = correct_order if response == correct_answer else list(reversed(correct_order))

return order
if len(overall_order) < 1:
overall_order = relative_order
elif overall_order[0] == relative_order[-1]: # Front prepend
overall_order = relative_order[:1] + overall_order
elif overall_order[-1] == relative_order[0]: # Back append
overall_order = overall_order + relative_order[1:]
# Intermediate insertion
else:
option_A, option_B = relative_order
# Determine start point of insertion based on position of which key is present
if option_A in overall_order:
index = overall_order.index(option_A)
elements_to_compare = overall_order[index + 1:]
else:
index = overall_order.index(option_B)
elements_to_compare = list(reversed(overall_order[:index]))

target_option = option_B if option_A in overall_order else option_A
# Compare with candidates til unique ordering can be constructed
for i, candidate_option in enumerate(elements_to_compare):
response = self.prompt_paired_question(video, captions, sorted([target_option, candidate_option]), option_to_rank)
if not response: # Select wrong answer if invalid one provided
response = sorted([
(target_option, option_to_rank[target_option]), (candidate_option, option_to_rank[candidate_option])
], lambda x : -int(x[-1]))[0][0]

if (target_option == option_A and response != target_option) or (target_option == option_B and response == target_option):
new_subsequence = elements_to_compare[:i] + [target_option] + elements_to_compare[i:]
if target_option == option_B:
overall_order = overall_order[:index + 1] + new_subsequence
else:
overall_order = list(reversed(new_subsequence)) + overall_order[index:]
break

# Insert at ends of list if not inserted
if target_option not in overall_order:
overall_order = [target_option] + overall_order if target_option == option_A else overall_order + [target_option]

return overall_order

def run(self, save_path=None):
responses = {}
with torch.inference_mode(), torch.no_grad():
for i in tqdm(range(len(self.dataset))):
example = self.dataset[i]
video, video_id, captions = example["video"], example["video_id"], example["captions"]
predicted_order = self.prompt_relative_ordering(video, video_id, captions)
responses[video_id] = predicted_order

if save_path is not None:
with open(save_path, "r") as f:
json.dump(responses, f, indent=4)

class VidHalNaiveOrderingInferencePipeline(VidHalInferencePipeline):
def process_response(self, response):
def condense_sequence(sequence):
"""
Reduces consecutively repeating options, in cases where model explains option chosen
"""
condensed_sequence = []

for letter in sequence:
if not condensed_sequence or condensed_sequence[-1] != letter:
condensed_sequence.append(letter)

return condensed_sequence

# Insert commas if letter sequence doesn't have
response = re.sub(r'(?<=[A-Z])(?=[A-Z])', ', ', response)

matches = re.findall(r"\b[A-Z]\b", response)

# Convert matches to uppercase and remove duplicates while preserving order
matches = [match.upper().strip(";:., ") for match in matches]
matches = condense_sequence(matches) # Remove repeated consecutive letters due to explanations or descriptions
# Handle more options than expected (e.g A, B, C, D, E, F, ...)
valid_options = list(string.ascii_uppercase)[:self.num_captions]
matches = [x for x in matches if x in valid_options]
if len(matches) == self.num_captions:
return matches
else:
initial_match = matches

# Handle response with more constraints
matches = re.findall(r"\b[A-Z][:\.,]", response)
matches = condense_sequence([match.upper().strip(";:., ") for match in matches])
matches = [x for x in matches if x in valid_options]
if matches and len(matches) <= self.num_captions:
return matches

# Capture more than 3 letters - Response contains descriptory/explanatory elements
if len(matches) > self.num_captions:
# Break down by paragraph-level parsing
sentences = response.split("\n")
matches = [re.findall(r"(?<![a-zA-Z'])[A-Z]\b", x) for x in sentences]
matches = [x for x in matches if len(x) > 1 and len(x) <= self.num_captions]

# Break down into sentence-level parsing
sentences = response.split(".")
matches.extend([
re.findall(r"(?<![a-zA-Z'])[A-Z]\b", x) for x in sentences if (
len(re.findall(r"(?<![a-zA-Z'])[A-Z]\b", x)) > 1
)
])
matches = [[x for x in match if x in valid_options] for match in matches]

# Condense duplicate orderings and get ordering with most
matches = sorted(
list(set([tuple(x) for x in matches])), key=lambda x: -len(x)
)
# Handle no valid ordering at the end
try:
matches = list(matches[0])
except:
matches = []

if len(matches) <= self.num_captions and len(initial_match) > len(matches):
return initial_match

return matches
35 changes: 35 additions & 0 deletions pipelines/inference/random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import re
import random
import numpy as np
from .base import (
VidHalInferencePipeline,
VidHalMCQAInferencePipeline,
VidHalNaiveOrderingInferencePipeline,
VidHalRelativeOrderingInferencePipeline
)

class RandomInferencePipeline(VidHalInferencePipeline):
def __init__(self, dataset, model=None, option_display_order = None, generation_config=..., *args, **kwargs):
super().__init__(model, dataset, option_display_order, generation_config, *args, **kwargs)

def format_prompt(self, main_prompt, options_prompt, system_prompt=None, *args, **kwargs):
return f"{main_prompt}\n\n{options_prompt}"

def generate_response(self, model, video, main_prompt, system_prompt=None, generation_config=..., *args, **kwargs):
if "order" in main_prompt:
return ", ".join(np.random.permutation(["A", "B", "C"]).tolist())
else:
options = re.findall(r'\b[A-Z]\b', main_prompt)
return random.choice(options)

class RandomMCQAInferencePipeline(RandomInferencePipeline, VidHalMCQAInferencePipeline):
def __init__(self, dataset, model=None, option_display_order=None, generation_config=..., *args, **kwargs):
super().__init__(dataset, model, option_display_order, generation_config, *args, **kwargs)

class RandomNaiveOrderingInferencePipeline(RandomInferencePipeline, VidHalNaiveOrderingInferencePipeline):
def __init__(self, dataset, model=None, option_display_order=None, generation_config=..., *args, **kwargs):
super().__init__(dataset, model, option_display_order, generation_config, *args, **kwargs)

class RandomRelativeOrderingInferencePipeline(RandomInferencePipeline, VidHalRelativeOrderingInferencePipeline):
def __init__(self, dataset, model=None, option_display_order=None, generation_config=..., *args, **kwargs):
super().__init__(dataset, model, option_display_order, generation_config, *args, **kwargs)

0 comments on commit 97711e4

Please sign in to comment.