From 4ffb29c3b226bbaed78e18dc3fe8ad60f4cb0215 Mon Sep 17 00:00:00 2001 From: Golovneva <103262907+Golovneva@users.noreply.github.com> Date: Thu, 17 Nov 2022 11:48:39 -0500 Subject: [PATCH] fixing reasoning perturbation issues (#4883) * fixing reasoning perturbation issues * update looping and comments --- parlai/tasks/math_dataset/agents.py | 32 ++++++++++--- parlai/tasks/proof_writer/agents.py | 10 +++- .../reasoning/reason_types/step_by_step.py | 46 +++++++------------ projects/roscoe/baselines/scores.py | 4 +- 4 files changed, 53 insertions(+), 39 deletions(-) diff --git a/parlai/tasks/math_dataset/agents.py b/parlai/tasks/math_dataset/agents.py index d4c7b81a07c..1e5bb1d273f 100644 --- a/parlai/tasks/math_dataset/agents.py +++ b/parlai/tasks/math_dataset/agents.py @@ -20,7 +20,7 @@ from parlai.core.opt import Opt from parlai.core.params import ParlaiParser from parlai.utils.io import PathManager -from typing import Optional +from typing import List, Optional from parlai.tasks.reasoning.agents import MWPStepsReasoningTeacher @@ -72,7 +72,7 @@ def __init__(self, opt, shared=None): self.math_random = random.Random(42) super().__init__(opt, shared) - def load_data(self, domains): + def load_data(self, domains) -> List[str]: data = [] data_path = self.opt['datafile'] for domain in domains: @@ -106,10 +106,7 @@ def get_data_for_fold(self, fold): answer_blob = self._clean_steps(answer_blob) steps = answer_blob.split(". ") if extrinsic_step: - rand_steps = self._clean_steps( - self.math_random.choice(data)["solution"] - ).split(". ") - random_step = self.math_random.choice(rand_steps) + random_step = self._find_nonempty_random_step(data) if convert: question = self._latex_conversion(question) final_answer = self._latex_conversion(final_answer) @@ -225,6 +222,29 @@ def _latex_conversion(self, final_answer: str) -> str: return final_answer + def _find_nonempty_random_step(self, dataset: List[str]) -> str: + '''Here we *ASSUME* that the whole dataset contains at least one non-empty step + Otherwise it will go into infinite loop looking for the one + ''' + # what we call an empty step + empty_steps = ["", " "] + # first find chain with at least one non-empty step + rand_steps = self._clean_steps( + self.math_random.choice(dataset)["solution"] + ).split(". ") + # make sure this chain has at least one non-empty step + i = 0 + while i < len(rand_steps) and rand_steps[i] in empty_steps: + i += 1 + # if it doesn't, try again + if i == len(rand_steps): + return self._find_nonempty_random_step(dataset) + random_step = empty_steps[0] + # find non-empty random step (and we know it exists in this chain) + while random_step in empty_steps: + random_step = self.math_random.choice(rand_steps) + return random_step + def get_boxed_answer(self, answer): boxed_idx = answer.find("boxed{") final_answer = answer[boxed_idx:] diff --git a/parlai/tasks/proof_writer/agents.py b/parlai/tasks/proof_writer/agents.py index 83ab62e7b72..bb5e999809f 100644 --- a/parlai/tasks/proof_writer/agents.py +++ b/parlai/tasks/proof_writer/agents.py @@ -328,8 +328,14 @@ def get_data_for_fold(self, fold): for m in messages: if extrinsic_step: - rand_steps = self.proofwriter_random.choice(messages)["steps"] - random_step = self.proofwriter_random.choice(rand_steps) + random_step = None + # make sure new step is from a different context + # here we aasume that there is at least one step in the set + # with different context, otherwise it will go in the + # infinite loop + while not random_step or random_step in m["question"]: + rand_steps = self.proofwriter_random.choice(messages)["steps"] + random_step = self.proofwriter_random.choice(rand_steps) m["extrinsic_step"] = random_step yield m else: diff --git a/parlai/tasks/reasoning/reason_types/step_by_step.py b/parlai/tasks/reasoning/reason_types/step_by_step.py index d0a9e509946..382a91f1799 100644 --- a/parlai/tasks/reasoning/reason_types/step_by_step.py +++ b/parlai/tasks/reasoning/reason_types/step_by_step.py @@ -301,14 +301,7 @@ def __init__(self, opt: Opt, cache: Optional[Dict[str, List[List[str]]]] = None) super().__init__(opt, cache) self.lemmatizer = WordNetLemmatizer() - def lemmatize_step(self, step): - try: - words = nltk.word_tokenize(str(step)) - except IndexError: - print( - f"WARNING: could not lemmatize step {str(step)}. Proceeding to the next perturbation." - ) - return str(step) + def lemmatize_step(self, words): lemmatized_output = ' '.join([self.lemmatizer.lemmatize(w, 'v') for w in words]) # remove extraneous spaces after joining strings back clean_lemmatized_output = re.sub( @@ -316,14 +309,7 @@ def lemmatize_step(self, step): ) return clean_lemmatized_output - def drop_verb(self, step): - try: - words = nltk.word_tokenize(str(step)) - except IndexError: - print( - f"WARNING: could not lemmatize step {str(step)}. Proceeding to the next perturbation." - ) - return str(step) + def drop_verb(self, words): tags = nltk.pos_tag(words) verb_indices = [] for i, tag in enumerate(tags): @@ -338,14 +324,7 @@ def drop_verb(self, step): clean_result = re.sub(r'\s([?.!"](?:\s|$))', r'\1', result) return clean_result - def swap_words(self, step): - try: - tokenized_step = nltk.word_tokenize(str(step)) - except IndexError: - print( - f"WARNING: could not lemmatize step {str(step)}. Proceeding to next perturbation." - ) - return str(step) + def swap_words(self, tokenized_step): tags = nltk.pos_tag(tokenized_step) word_indices = [] for i, tag in enumerate(tags): @@ -374,15 +353,22 @@ def perturb(self, example_dict: Dict) -> Dict: grammatical_error_steps = [] for i, step in enumerate(steps): + try: + tok_step = nltk.word_tokenize(str(step)) + except IndexError: + print( + f"WARNING: could not tokenize step {str(step)}. Proceeding to next chain." + ) + return str(step) # perform all possible grammatical errors on each step, then randomly choose 1 - lemmatized_step = self.lemmatize_step(step) - if str(step) != lemmatized_step: + lemmatized_step = self.lemmatize_step(tok_step) + if tok_step != lemmatized_step: grammatical_error_steps.append((i, lemmatized_step)) - dropped_verb_step = self.drop_verb(step) - if dropped_verb_step != "" and str(step) != dropped_verb_step: + dropped_verb_step = self.drop_verb(tok_step) + if dropped_verb_step != "" and tok_step != dropped_verb_step: grammatical_error_steps.append((i, dropped_verb_step)) - swapped_word_step = self.swap_words(step) - if swapped_word_step != "" and str(step) != swapped_word_step: + swapped_word_step = self.swap_words(tok_step) + if swapped_word_step != "" and tok_step != swapped_word_step: grammatical_error_steps.append((i, swapped_word_step)) if not grammatical_error_steps: diff --git a/projects/roscoe/baselines/scores.py b/projects/roscoe/baselines/scores.py index b0fca6a1a2e..63753611adf 100644 --- a/projects/roscoe/baselines/scores.py +++ b/projects/roscoe/baselines/scores.py @@ -223,7 +223,9 @@ def load(self, path=None): ) # Path here to fine-tuend BART Model try: - self.scorer.load(BART_SCORE_REPO + "/train/reproduce/trained/bart_6000.pth") + self.scorer.load( + BART_SCORE_REPO + "/train/reproduce/trained/fine_tuned_bartscore.pth" + ) except FileNotFoundError: raise FileNotFoundError( f"Path here should be to fine tuned BART model from"