diff --git a/pipelines/inference/base.py b/pipelines/inference/base.py index cc23bdd..9d6a204 100644 --- a/pipelines/inference/base.py +++ b/pipelines/inference/base.py @@ -13,6 +13,8 @@ class VidHalInferencePipeline: 1. Formatting of prompt to be provided to the model. 2. Generation of response from the given prompt and video. """ + system_prompt_instruction = "" + main_prompt_instruction = "" def __init__( self, model, @@ -65,27 +67,9 @@ def generate_response( response (str) : Response generated by the model. """ raise NotImplementedError - - def run(self, save_path=None): - raise NotImplementedError - -class VidHalMCQAInferencePipeline(VidHalInferencePipeline): - system_prompt_instruction = "You are provided with a video and a set of several captions. " \ - "Your task is to watch the video provided carefully, and select the caption that best describes the video. " \ - "Provide your answer only as a single letter representing the option whose caption that best describes the video, without any explanation." - main_prompt_instruction = "Watch the video provided, and choose the option whose caption describes the video most accurately." - def __init__(self, model, dataset, generation_config={}, *args, **kwargs): - super().__init__(model, dataset, generation_config, *args, **kwargs) def process_response(self, response): - """ - Parses the generated response to extract only the selected option. - """ - last_option = list(string.ascii_uppercase)[self.num_captions - 1] - match = re.search(fr"\b[a-{last_option.lower()}A-{last_option}]\b", response) - match = match.group(0).upper().strip(";:., ") if match else None - - return match if match else response # If no match, keep original response in case model replies with caption instead of option + return response def run(self, save_path=None): responses = {} @@ -112,7 +96,25 @@ def run(self, save_path=None): with open(save_path, "r") as f: json.dump(responses, f, indent=4) -class VidHalRelativeCOInferencePipeline(VidHalMCQAInferencePipeline): +class VidHalMCQAInferencePipeline(VidHalInferencePipeline): + system_prompt_instruction = "You are provided with a video and a set of several captions. " \ + "Your task is to watch the video provided carefully, and select the caption that best describes the video. " \ + "Provide your answer only as a single letter representing the option whose caption that best describes the video, without any explanation." + main_prompt_instruction = "Watch the video provided, and choose the option whose caption describes the video most accurately." + def __init__(self, model, dataset, generation_config={}, *args, **kwargs): + super().__init__(model, dataset, generation_config, *args, **kwargs) + + def process_response(self, response): + """ + Parses the generated response to extract only the selected option. + """ + last_option = list(string.ascii_uppercase)[self.num_captions - 1] + match = re.search(fr"\b[a-{last_option.lower()}A-{last_option}]\b", response) + match = match.group(0).upper().strip(";:., ") if match else None + + return match if match else response # If no match, keep original response in case model replies with caption instead of option + +class VidHalRelativeOrderingInferencePipeline(VidHalMCQAInferencePipeline): def reorder_options(self, captions, option_to_rank): """ Re-orders the option prefixes (A, B, C) if there are less then the total number of captions presented to the model @@ -132,10 +134,10 @@ def reorder_options(self, captions, option_to_rank): return option_to_rank def prompt_paired_question(self, video, captions, options, option_to_rank): - # Reorder keys (e.g. A, C -> A, B), track mapping + # Reorder keys (e.g. A, C -> A, B) and track the mapping display_options = list(string.ascii_uppercase)[:len(options)] - remapped_option_to_rank = {display_options[i] : rank for i, (_, rank) in enumerate(sorted(list(option_to_rank.items(), lambda x : x[0])))} - remapped_to_original = {display_options[i] : option for i, (option, _) in enumerate(sorted(list(option_to_rank.items(), lambda x : x[0])))} + remapped_option_to_rank = {display_options[i] : option_to_rank[option] for i, option in enumerate(options)} + remapped_to_original = {display_options[i] : option for i, option in enumerate(options)} # Format prompt and generate response options_prompt = self.format_options_prompt(captions=captions, option_to_rank=remapped_option_to_rank) @@ -153,7 +155,7 @@ def prompt_paired_question(self, video, captions, options, option_to_rank): return response def prompt_relative_ordering(self, video, video_id, captions): - order = [] + overall_order = [] # Transform from rank -> caption to option -> caption option_to_rank = self.option_display_order[video_id] rank_to_option = {v : k for k, v in option_to_rank.items()} @@ -163,11 +165,130 @@ def prompt_relative_ordering(self, video, video_id, captions): options = sorted(list(captions.keys())) for option_A, option_B in zip(options, options[1:]): response = self.prompt_paired_question(video, captions, [option_A, option_B], option_to_rank) + # Assign incorrect order if response is invalid or incorrect + correct_order = [x[0] for x in sorted([ + (option_A, option_to_rank[option_A]), (option_B, option_to_rank[option_B]) + ], lambda x : int(x[-1]))] + correct_answer = correct_order[0] + relative_order = correct_order if response == correct_answer else list(reversed(correct_order)) - return order + if len(overall_order) < 1: + overall_order = relative_order + elif overall_order[0] == relative_order[-1]: # Front prepend + overall_order = relative_order[:1] + overall_order + elif overall_order[-1] == relative_order[0]: # Back append + overall_order = overall_order + relative_order[1:] + # Intermediate insertion + else: + option_A, option_B = relative_order + # Determine start point of insertion based on position of which key is present + if option_A in overall_order: + index = overall_order.index(option_A) + elements_to_compare = overall_order[index + 1:] + else: + index = overall_order.index(option_B) + elements_to_compare = list(reversed(overall_order[:index])) + + target_option = option_B if option_A in overall_order else option_A + # Compare with candidates til unique ordering can be constructed + for i, candidate_option in enumerate(elements_to_compare): + response = self.prompt_paired_question(video, captions, sorted([target_option, candidate_option]), option_to_rank) + if not response: # Select wrong answer if invalid one provided + response = sorted([ + (target_option, option_to_rank[target_option]), (candidate_option, option_to_rank[candidate_option]) + ], lambda x : -int(x[-1]))[0][0] + + if (target_option == option_A and response != target_option) or (target_option == option_B and response == target_option): + new_subsequence = elements_to_compare[:i] + [target_option] + elements_to_compare[i:] + if target_option == option_B: + overall_order = overall_order[:index + 1] + new_subsequence + else: + overall_order = list(reversed(new_subsequence)) + overall_order[index:] + break + + # Insert at ends of list if not inserted + if target_option not in overall_order: + overall_order = [target_option] + overall_order if target_option == option_A else overall_order + [target_option] + + return overall_order def run(self, save_path=None): + responses = {} with torch.inference_mode(), torch.no_grad(): for i in tqdm(range(len(self.dataset))): example = self.dataset[i] video, video_id, captions = example["video"], example["video_id"], example["captions"] + predicted_order = self.prompt_relative_ordering(video, video_id, captions) + responses[video_id] = predicted_order + + if save_path is not None: + with open(save_path, "r") as f: + json.dump(responses, f, indent=4) + +class VidHalNaiveOrderingInferencePipeline(VidHalInferencePipeline): + def process_response(self, response): + def condense_sequence(sequence): + """ + Reduces consecutively repeating options, in cases where model explains option chosen + """ + condensed_sequence = [] + + for letter in sequence: + if not condensed_sequence or condensed_sequence[-1] != letter: + condensed_sequence.append(letter) + + return condensed_sequence + + # Insert commas if letter sequence doesn't have + response = re.sub(r'(?<=[A-Z])(?=[A-Z])', ', ', response) + + matches = re.findall(r"\b[A-Z]\b", response) + + # Convert matches to uppercase and remove duplicates while preserving order + matches = [match.upper().strip(";:., ") for match in matches] + matches = condense_sequence(matches) # Remove repeated consecutive letters due to explanations or descriptions + # Handle more options than expected (e.g A, B, C, D, E, F, ...) + valid_options = list(string.ascii_uppercase)[:self.num_captions] + matches = [x for x in matches if x in valid_options] + if len(matches) == self.num_captions: + return matches + else: + initial_match = matches + + # Handle response with more constraints + matches = re.findall(r"\b[A-Z][:\.,]", response) + matches = condense_sequence([match.upper().strip(";:., ") for match in matches]) + matches = [x for x in matches if x in valid_options] + if matches and len(matches) <= self.num_captions: + return matches + + # Capture more than 3 letters - Response contains descriptory/explanatory elements + if len(matches) > self.num_captions: + # Break down by paragraph-level parsing + sentences = response.split("\n") + matches = [re.findall(r"(? 1 and len(x) <= self.num_captions] + + # Break down into sentence-level parsing + sentences = response.split(".") + matches.extend([ + re.findall(r"(? 1 + ) + ]) + matches = [[x for x in match if x in valid_options] for match in matches] + + # Condense duplicate orderings and get ordering with most + matches = sorted( + list(set([tuple(x) for x in matches])), key=lambda x: -len(x) + ) + # Handle no valid ordering at the end + try: + matches = list(matches[0]) + except: + matches = [] + + if len(matches) <= self.num_captions and len(initial_match) > len(matches): + return initial_match + + return matches diff --git a/pipelines/inference/random.py b/pipelines/inference/random.py new file mode 100644 index 0000000..55f9fd9 --- /dev/null +++ b/pipelines/inference/random.py @@ -0,0 +1,35 @@ +import re +import random +import numpy as np +from .base import ( + VidHalInferencePipeline, + VidHalMCQAInferencePipeline, + VidHalNaiveOrderingInferencePipeline, + VidHalRelativeOrderingInferencePipeline +) + +class RandomInferencePipeline(VidHalInferencePipeline): + def __init__(self, dataset, model=None, option_display_order = None, generation_config=..., *args, **kwargs): + super().__init__(model, dataset, option_display_order, generation_config, *args, **kwargs) + + def format_prompt(self, main_prompt, options_prompt, system_prompt=None, *args, **kwargs): + return f"{main_prompt}\n\n{options_prompt}" + + def generate_response(self, model, video, main_prompt, system_prompt=None, generation_config=..., *args, **kwargs): + if "order" in main_prompt: + return ", ".join(np.random.permutation(["A", "B", "C"]).tolist()) + else: + options = re.findall(r'\b[A-Z]\b', main_prompt) + return random.choice(options) + +class RandomMCQAInferencePipeline(RandomInferencePipeline, VidHalMCQAInferencePipeline): + def __init__(self, dataset, model=None, option_display_order=None, generation_config=..., *args, **kwargs): + super().__init__(dataset, model, option_display_order, generation_config, *args, **kwargs) + +class RandomNaiveOrderingInferencePipeline(RandomInferencePipeline, VidHalNaiveOrderingInferencePipeline): + def __init__(self, dataset, model=None, option_display_order=None, generation_config=..., *args, **kwargs): + super().__init__(dataset, model, option_display_order, generation_config, *args, **kwargs) + +class RandomRelativeOrderingInferencePipeline(RandomInferencePipeline, VidHalRelativeOrderingInferencePipeline): + def __init__(self, dataset, model=None, option_display_order=None, generation_config=..., *args, **kwargs): + super().__init__(dataset, model, option_display_order, generation_config, *args, **kwargs)