Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
pufanyi committed Mar 29, 2024
1 parent c507389 commit c9f759e
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 61 deletions.
2 changes: 1 addition & 1 deletion lmms_eval/api/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def get_dataset(self) -> datasets.Dataset:
if self.fewshot_indices:
self.dataset = self.dataset.select(self.fewshot_indices)
return self.dataset

def sample(self, n, rnd):
indices = rnd.sample(range(len(self.get_dataset())), n)
return self.get_dataset().select(indices)
Expand Down
18 changes: 9 additions & 9 deletions lmms_eval/tasks/olympiadbench/cn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file

import logging

eval_logger = logging.getLogger("lmms-eval")
dir_name = os.path.dirname(os.path.abspath(__file__))

olympiadbench_evaluator = OlympiadBenchEvaluator()


def olympiadbench_doc_to_visual(doc):
return [image.convert("RGB") for image in doc["images"]]


def olympiadbench_doc_to_text(doc):
question = doc["question"]
subject = doc["subfield"]
Expand All @@ -36,28 +39,26 @@ def olympiadbench_doc_to_text(doc):
else:
post_prompt += '"所以最终答案是\\boxed{用英⽂逗号连接的多个答案}。"\n'

final_question = pre_prompt + question + '\n' + post_prompt
final_question = pre_prompt + question + "\n" + post_prompt
return final_question


def olympiadbench_process_results(doc, results):
precision = doc["error"]
is_proving = "TP" in doc["source"]
is_proving = "TP" in doc["source"]
if precision is None:
precision = 0
prediction = results[0].strip()

if is_proving:
return {
"submission": prediction
}
return {"submission": prediction}
else:
prediction = prediction.split("所以最终答案是")[-1]
prediction = prediction.replace('"', "").replace("\n", "").replace(" ", "").strip(".").strip("。")
accuracy = olympiadbench_evaluator.judge(prediction, doc["final_answer"][0], precision)
accuracy = int(accuracy)
return {
"exact_match": accuracy
}
return {"exact_match": accuracy}


def olympiadbench_aggregate_results(results, args):
now_date_time = datetime.datetime.now().strftime("%Y-%m%d-%H%M-%S")
Expand All @@ -66,4 +67,3 @@ def olympiadbench_aggregate_results(results, args):
with open(path, "w") as f:
json.dump(results, f, ensure_ascii=False)
print(f"Submission file saved to {path}")

24 changes: 13 additions & 11 deletions lmms_eval/tasks/olympiadbench/en_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file

import logging

eval_logger = logging.getLogger("lmms-eval")
dir_name = os.path.dirname(os.path.abspath(__file__))

olympiadbench_evaluator = OlympiadBenchEvaluator()


def olympiadbench_doc_to_visual(doc):
return [image.convert("RGB") for image in doc["images"]]


def olympiadbench_doc_to_text(doc):
question = doc["question"]
subject = doc["subfield"]
Expand All @@ -30,34 +33,34 @@ def olympiadbench_doc_to_text(doc):
post_prompt += f"The answer of the question should be {ans_type}.\n"
else:
post_prompt += f"The question has multiple answers, each of them should be {ans_type}.\n"
post_prompt += "Please calculate the answer according to the given requirements and the information provided. Please use LaTeX format to represent the variables and formulas used in the solution process and results. Please end your solution with "
post_prompt += (
"Please calculate the answer according to the given requirements and the information provided. Please use LaTeX format to represent the variables and formulas used in the solution process and results. Please end your solution with "
)
if not mul_ans:
post_prompt += '"So the final answer is \\boxed{answer}."\n'
else:
post_prompt += 'So the final answer is \\boxed{multiple answers connected with commas}.\n'
post_prompt += "So the final answer is \\boxed{multiple answers connected with commas}.\n"

final_question = pre_prompt + question + '\n' + post_prompt
final_question = pre_prompt + question + "\n" + post_prompt
return final_question


def olympiadbench_process_results(doc, results):
precision = doc["error"]
is_proving = "TP" in doc["source"]
is_proving = "TP" in doc["source"]
if precision is None:
precision = 0
prediction = results[0].strip()

if is_proving:
return {
"submission": prediction
}
return {"submission": prediction}
else:
prediction = prediction.split("final answer is")[-1]
prediction = prediction.replace('"', "").replace("\n", "").replace(" ", "").strip(".").strip("。")
accuracy = olympiadbench_evaluator.judge(prediction, doc["final_answer"][0], precision)
accuracy = int(accuracy)
return {
"exact_match": accuracy
}
return {"exact_match": accuracy}


def olympiadbench_aggregate_results(results, args):
now_date_time = datetime.datetime.now().strftime("%Y-%m%d-%H%M-%S")
Expand All @@ -66,4 +69,3 @@ def olympiadbench_aggregate_results(results, args):
with open(path, "w") as f:
json.dump(results, f, ensure_ascii=False)
print(f"Submission file saved to {path}")

80 changes: 40 additions & 40 deletions lmms_eval/tasks/olympiadbench/olympiadbench_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# precision = 1e-4
# res = scorer.judge(exp1, exp2, precision)


class OlympiadBenchEvaluator:
def __init__(self):
# Map of special symbols to their replacements
Expand Down Expand Up @@ -46,8 +47,8 @@ def split_by_comma(self, expr: str):
start_idx = i + 1

if start_idx < len(expr):
splitted_expr.append(expr[start_idx:].strip())
splitted_expr.append(expr[start_idx:].strip())

return splitted_expr

def trans_plus_minus_sign(self, expr_list: list):
Expand All @@ -59,9 +60,9 @@ def trans_plus_minus_sign(self, expr_list: list):
new_expr_list.append(expr.replace("\\pm", "-"))
else:
new_expr_list.append(expr)

return new_expr_list

def judge(self, expression1, expression2, precision=1e-8):
# Judge if two expressions are equal (expression1 is considered as the Ground Truth)
# Default precision is a list for supporting multiple expressions
Expand All @@ -74,11 +75,11 @@ def judge(self, expression1, expression2, precision=1e-8):
if expression1 == expression2:
# print("Exactly equal")
return True

# Remove Chinese characters from the string, as answers like "yes" or "no" in Chinese have been considered
expression1 = re.sub(r'[\u4e00-\u9fff]+', '', expression1)
expression2 = re.sub(r'[\u4e00-\u9fff]+', '', expression2)
expression1 = re.sub(r"[\u4e00-\u9fff]+", "", expression1)
expression2 = re.sub(r"[\u4e00-\u9fff]+", "", expression2)

expression1 = self.split_by_comma(expression1)
expression2 = self.split_by_comma(expression2)

Expand All @@ -88,7 +89,7 @@ def judge(self, expression1, expression2, precision=1e-8):
# Set up a list for allowed errors
if len(precision) <= 1:
precision = precision * len(temp_list1)

if len(temp_list1) != len(temp_list2):
return False

Expand All @@ -112,15 +113,15 @@ def judge(self, expression1, expression2, precision=1e-8):

# If all elements are matched, return True
return True

def is_interval(self, expr):
# Checks if an expression is an interval
return expr.startswith(("(", "[")) and expr.endswith((")", "]"))

def sympy_sub_pi(self, expression_sympy):
# Replaces the symbol for pi in sympy expressions with its numerical value
return expression_sympy.subs(self.pi, math.pi)

def is_equal(self, expression1, expression2):
# Default first expression is ground truth. Check if expressions are equal in different aspects
if expression1 == expression2 and expression1 != "" and expression2 != "":
Expand All @@ -143,41 +144,40 @@ def is_equal(self, expression1, expression2):
return True
except:
pass

# Then check if expressions are mathematically equal
try:
if self.expression_equal(expression1, expression2) and not ("=" in expression1 and "=" in expression2):
# print("Expression equivalent")
return True
except:
pass

# Lastly, check for equation equality
try:
if self.equation_equal(expression1, expression2):
# print("Equation equivalent")
return True
except:
pass

return False

def numerical_equal(self, expression1: str, expression2: str, include_percentage: bool = True):
# Check if two numerical values are equal within an allowed error range
# Includes possible percentage cases
reference = float(expression1)
prediction = float(expression2)

if include_percentage:
gt_result = [reference / 100, reference, reference * 100]
else:
gt_result = [reference]

for item in gt_result:
if abs(item - prediction) <= self.precision * 1.01:
return True
return False


def expression_equal(self, exp1, exp2):
# Check if two expressions are mathematically equivalent
Expand All @@ -186,7 +186,7 @@ def extract_expression(expression):
if "=" in expression:
expression = expression.split("=")[1]
return expression.strip()

exp1 = extract_expression(exp1)
exp2 = extract_expression(exp2)

Expand All @@ -204,7 +204,7 @@ def extract_expression(expression):
elif not expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol):
try:
if not (self.can_compute_power(expr1_sym) and self.can_compute_power(expr2_sym)):
print(f"These two numbers cannot be calculated by the current computer for: \"{str(expr1_sym)}\" and \"{str(expr2_sym)}\"")
print(f'These two numbers cannot be calculated by the current computer for: "{str(expr1_sym)}" and "{str(expr2_sym)}"')
return False

if abs(expr1_sym.evalf() - expr2_sym.evalf()) <= self.precision * 1.01:
Expand All @@ -218,7 +218,7 @@ def extract_expression(expression):
simplified_expr = simplify(expr1_sym - expr2_sym)

num_value = simplified_expr.evalf()

return abs(num_value) < 1e-3
except:
return False
Expand All @@ -227,7 +227,7 @@ def equation_equal(self, expression1, expression2):
# Check if two equations are mathematically equivalent
# Simplify equations and use sympy for equivalence checking
def simplify_equation(latex_eq):
lhs, rhs = latex_eq.split('=')
lhs, rhs = latex_eq.split("=")

lhs_expr = parse_latex(lhs)
rhs_expr = parse_latex(rhs)
Expand All @@ -254,18 +254,18 @@ def interval_equal(self, expression1, expression2):
def compare_two_interval(inter1, inter2):
if inter1[0] != inter2[0] or inter1[-1] != inter2[-1]:
return False

inter1 = inter1.strip('[]()')
inter2 = inter2.strip('[]()')

items_1 = inter1.split(',')
items_2 = inter2.split(',')
inter1 = inter1.strip("[]()")
inter2 = inter2.strip("[]()")

items_1 = inter1.split(",")
items_2 = inter2.split(",")

for item_1, item_2 in zip(items_1, items_2):
if not self.expression_equal(item_1, item_2):
return False
return True

interval1 = expression1
interval2 = expression2

Expand All @@ -274,7 +274,7 @@ def compare_two_interval(inter1, inter2):
else:
inter_list1 = interval1.split("\\cup")
inter_list2 = interval2.split("\\cup")

if len(inter_list1) != len(inter_list2):
return False
else:
Expand All @@ -286,7 +286,7 @@ def compare_two_interval(inter1, inter2):
def preprocess(self, expression1, expression2):
# Preprocess expressions to extract and replace special symbols
def extract_boxed_content(latex_str):
boxed_matches = re.finditer(r'\\boxed{', latex_str)
boxed_matches = re.finditer(r"\\boxed{", latex_str)
results = ""

for match in boxed_matches:
Expand All @@ -295,14 +295,14 @@ def extract_boxed_content(latex_str):
stack = 1

while stack > 0 and end_index < len(latex_str):
if latex_str[end_index] == '{':
if latex_str[end_index] == "{":
stack += 1
elif latex_str[end_index] == '}':
elif latex_str[end_index] == "}":
stack -= 1
end_index += 1

if stack == 0:
content = latex_str[start_index:end_index - 1]
content = latex_str[start_index : end_index - 1]
results += content + ","
else:
raise ValueError("Mismatched braces in LaTeX string.")
Expand All @@ -317,28 +317,28 @@ def extract_boxed_content(latex_str):
results += ans + ","
else:
results = latex_str

return results

def sepcial_symbol_replace(expression):
if "\\in " in expression:
expression = expression.split("\\in ")[1]

for signal in self.special_signal_map:
expression = expression.replace(signal, self.special_signal_map[signal])

expression = expression.strip("\n$,.:;^_=+`!@#$%^&*~,。")

pattern = r'\\(?:mathrm|mathbf)\{~?([^}]*)\}'
expression = re.sub(pattern, r'\1', expression)
pattern = r"\\(?:mathrm|mathbf)\{~?([^}]*)\}"
expression = re.sub(pattern, r"\1", expression)

return expression

exp1, exp2 = extract_boxed_content(expression1), extract_boxed_content(expression2)
exp1, exp2 = sepcial_symbol_replace(exp1), sepcial_symbol_replace(exp2)

return exp1, exp2

def can_compute_power(self, expr):
# Checks if a power expression can be computed
if isinstance(expr, Pow):
Expand All @@ -352,4 +352,4 @@ def can_compute_power(self, expr):
else:
return False
else:
return True # Not a power expression, can compute
return True # Not a power expression, can compute

0 comments on commit c9f759e

Please sign in to comment.