diff --git a/common.py b/common.py index b6b4c0e1..97846c2b 100644 --- a/common.py +++ b/common.py @@ -25,7 +25,7 @@ ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer[ \t]*:[ \t]*\$?([A-D])\$?" ANSWER_PATTERN = r"(?i)Answer\s*:\s*([^\n]+)" MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = ( - "(?i){}[ \t]*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])" + "(?i)(?=(?:{}[ \t]*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])))" ) # All the different ways "Answer" is written in different languages MULTILINGUAL_ANSWER_REGEXES = [ diff --git a/mmlu_eval.py b/mmlu_eval.py index 9423c660..11a4044a 100644 --- a/mmlu_eval.py +++ b/mmlu_eval.py @@ -104,9 +104,10 @@ def fn(row: dict): extracted_answer = None for answer_regex in MULTILINGUAL_ANSWER_REGEXES: regex = MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(answer_regex) - match = re.search(regex, response_text) - if match: - extracted_answer = normalize_extracted_answer(match.group(1)) + matches = re.findall(regex, response_text) + if matches: + match = matches[-1] + extracted_answer = normalize_extracted_answer(match) break score = 1.0 if extracted_answer == row["Answer"] else 0.0 html = common.jinja_env.from_string(HTML_JINJA).render(