Skip to content

Commit

Permalink
Bugfix: WebSRC should be token-level F1 NOT character-level
Browse files Browse the repository at this point in the history
  • Loading branch information
hunterheiden committed May 2, 2024
1 parent eef3aeb commit 626e8a9
Showing 1 changed file with 27 additions and 15 deletions.
42 changes: 27 additions & 15 deletions lmms_eval/tasks/websrc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def websrc_process_results(doc, results):
"websrc_squad_f1": websrc_ans,
"submission": {
websrc_ans['question_id']: pred,
},
} if 'question_id' in websrc_ans else None
}


Expand Down Expand Up @@ -122,27 +122,39 @@ def _normalize_str(string):
# lower it
string = string.lower()

# strip non-alphanumeric characters
string = re.sub(r"[^a-zA-Z0-9]", "", string)

# strip leading and trailing whitespaces
string = string.strip()

return string

def _tokenize(text):
# Regex pattern to match words and isolate punctuation
pattern = r'\w+|[^\w\s]'
tokens = re.findall(pattern, text)
return tokens

def _compute_f1(sa, sb):
sa = _normalize_str(sa)
sb = _normalize_str(sb)

sa = _tokenize(sa)
sb = _tokenize(sb)

sa = set(sa)
sb = set(sb)

if len(sa) == 0 or len(sb) == 0:
return 0.0

comm = sa.intersection(sb)
prec = len(comm) / len(sb)
rec = len(comm) / len(sa)
f1 = 2 * prec * rec / (prec + rec) if prec + rec > 0 else 0
return f1

judge_list = []
for sample in samples:
gold_i = set(_normalize_str(sample["answer"]))
pred_i = set(_normalize_str( sample["parsed_pred"]))
if len(pred_i) == 0:
judge_list.append(0.0)
continue

comm_i = gold_i.intersection(pred_i)
prec_i = len(comm_i) / len(pred_i)
rec_i = len(comm_i) / len(gold_i)
f1_i = 2 * prec_i * rec_i / (prec_i + rec_i) if prec_i + rec_i > 0 else 0
judge_list.append(f1_i)
judge_list.append(_compute_f1(sample["answer"], sample["parsed_pred"]))

f1 = np.mean(judge_list)
return judge_list, {"f1": f1}

0 comments on commit 626e8a9

Please sign in to comment.