Skip to content

Commit

Permalink
update evaluator protocal
Browse files Browse the repository at this point in the history
  • Loading branch information
zehuichen123 committed Feb 22, 2024
1 parent 35851bc commit 56e2121
Showing 1 changed file with 28 additions and 28 deletions.
56 changes: 28 additions & 28 deletions teval/evaluators/reason_retrieve_understand_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,29 +211,29 @@ def _post_process(self, results_list):
metrics_results[id]['name'] = 0

if 'args' in data.pred and 'args' in data.gt:
batch_arg_data.extend([str(data.pred['args']), str(data.gt['args'])])
batch_arg_id.extend([id])
if len(batch_arg_data) >= BATCH_LIMIT:
pred_emb = self.sentence_model.encode(batch_arg_data, convert_to_tensor=True)
for i in range(0, len(batch_arg_data), 2):
cosine_score = np.maximum(util.cos_sim(pred_emb[i], pred_emb[i+1]).cpu().numpy(), 0)
metrics_results[batch_arg_id[i // 2]]['args'] = cosine_score[0, 0]
batch_arg_data = []
batch_arg_id = []
# batch_arg_data.extend([str(data.pred['args']), str(data.gt['args'])])
# batch_arg_id.extend([id])
# if len(batch_arg_data) >= BATCH_LIMIT:
# pred_emb = self.sentence_model.encode(batch_arg_data, convert_to_tensor=True)
# for i in range(0, len(batch_arg_data), 2):
# cosine_score = np.maximum(util.cos_sim(pred_emb[i], pred_emb[i+1]).cpu().numpy(), 0)
# metrics_results[batch_arg_id[i // 2]]['args'] = cosine_score[0, 0]
# batch_arg_data = []
# batch_arg_id = []

# NOTE we adopt a more strict evaluation protocal in v2
# if isinstance(data.gt['args'], dict):
# for gt_arg_name in data.gt['args']:
# if gt_arg_name in data.pred['args'] and str(data.pred['args'][gt_arg_name]) == str(data.gt['args'][gt_arg_name]):
# metrics_results[id]['args'] += 1
# metrics_results[id]['args'] /= (len(data.gt['args']) + 1e-5)
# if len(data.gt['args']) == 0 and len(data.pred['args']) == 0:
# metrics_results[id]['args'] = 1
# if len(data.gt['args']) == 0 and len(data.pred['args']) != 0:
# metrics_results[id]['args'] = 0
# else:
# data.pred['args'] = data.pred['args'].strip("'").strip('"')
# metrics_results[id]['args'] = float(data.gt['args'] == data.pred['args'])
if isinstance(data.gt['args'], dict):
for gt_arg_name in data.gt['args']:
if gt_arg_name in data.pred['args'] and str(data.pred['args'][gt_arg_name]) == str(data.gt['args'][gt_arg_name]):
metrics_results[id]['args'] += 1
metrics_results[id]['args'] /= (len(data.gt['args']) + 1e-5)
if len(data.gt['args']) == 0 and len(data.pred['args']) == 0:
metrics_results[id]['args'] = 1
if len(data.gt['args']) == 0 and len(data.pred['args']) != 0:
metrics_results[id]['args'] = 0
else:
data.pred['args'] = data.pred['args'].strip("'").strip('"')
metrics_results[id]['args'] = float(data.gt['args'] == data.pred['args'])

if len(batch_data) > 0:
pred_emb = self.sentence_model.encode(batch_data, convert_to_tensor=True)
Expand All @@ -243,13 +243,13 @@ def _post_process(self, results_list):
batch_data = []
batch_id = []

if len(batch_arg_data) > 0:
pred_emb = self.sentence_model.encode(batch_arg_data, convert_to_tensor=True)
for i in range(0, len(batch_arg_data), 2):
cosine_score = np.maximum(util.cos_sim(pred_emb[i], pred_emb[i+1]).cpu().numpy(), 0)
metrics_results[batch_arg_id[i // 2]]['args'] = cosine_score[0, 0]
batch_arg_data = []
batch_arg_id = []
# if len(batch_arg_data) > 0:
# pred_emb = self.sentence_model.encode(batch_arg_data, convert_to_tensor=True)
# for i in range(0, len(batch_arg_data), 2):
# cosine_score = np.maximum(util.cos_sim(pred_emb[i], pred_emb[i+1]).cpu().numpy(), 0)
# metrics_results[batch_arg_id[i // 2]]['args'] = cosine_score[0, 0]
# batch_arg_data = []
# batch_arg_id = []

results = dict()
for key in metric_keys:
Expand Down

0 comments on commit 56e2121

Please sign in to comment.