From 56e21210bd0d451f654051002c81cdfae714ea19 Mon Sep 17 00:00:00 2001 From: chenzehui Date: Thu, 22 Feb 2024 03:48:12 +0000 Subject: [PATCH] update evaluator protocal --- .../reason_retrieve_understand_evaluator.py | 56 +++++++++---------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/teval/evaluators/reason_retrieve_understand_evaluator.py b/teval/evaluators/reason_retrieve_understand_evaluator.py index 981a1cc2..244674ae 100644 --- a/teval/evaluators/reason_retrieve_understand_evaluator.py +++ b/teval/evaluators/reason_retrieve_understand_evaluator.py @@ -211,29 +211,29 @@ def _post_process(self, results_list): metrics_results[id]['name'] = 0 if 'args' in data.pred and 'args' in data.gt: - batch_arg_data.extend([str(data.pred['args']), str(data.gt['args'])]) - batch_arg_id.extend([id]) - if len(batch_arg_data) >= BATCH_LIMIT: - pred_emb = self.sentence_model.encode(batch_arg_data, convert_to_tensor=True) - for i in range(0, len(batch_arg_data), 2): - cosine_score = np.maximum(util.cos_sim(pred_emb[i], pred_emb[i+1]).cpu().numpy(), 0) - metrics_results[batch_arg_id[i // 2]]['args'] = cosine_score[0, 0] - batch_arg_data = [] - batch_arg_id = [] + # batch_arg_data.extend([str(data.pred['args']), str(data.gt['args'])]) + # batch_arg_id.extend([id]) + # if len(batch_arg_data) >= BATCH_LIMIT: + # pred_emb = self.sentence_model.encode(batch_arg_data, convert_to_tensor=True) + # for i in range(0, len(batch_arg_data), 2): + # cosine_score = np.maximum(util.cos_sim(pred_emb[i], pred_emb[i+1]).cpu().numpy(), 0) + # metrics_results[batch_arg_id[i // 2]]['args'] = cosine_score[0, 0] + # batch_arg_data = [] + # batch_arg_id = [] # NOTE we adopt a more strict evaluation protocal in v2 - # if isinstance(data.gt['args'], dict): - # for gt_arg_name in data.gt['args']: - # if gt_arg_name in data.pred['args'] and str(data.pred['args'][gt_arg_name]) == str(data.gt['args'][gt_arg_name]): - # metrics_results[id]['args'] += 1 - # metrics_results[id]['args'] /= (len(data.gt['args']) + 1e-5) - # if len(data.gt['args']) == 0 and len(data.pred['args']) == 0: - # metrics_results[id]['args'] = 1 - # if len(data.gt['args']) == 0 and len(data.pred['args']) != 0: - # metrics_results[id]['args'] = 0 - # else: - # data.pred['args'] = data.pred['args'].strip("'").strip('"') - # metrics_results[id]['args'] = float(data.gt['args'] == data.pred['args']) + if isinstance(data.gt['args'], dict): + for gt_arg_name in data.gt['args']: + if gt_arg_name in data.pred['args'] and str(data.pred['args'][gt_arg_name]) == str(data.gt['args'][gt_arg_name]): + metrics_results[id]['args'] += 1 + metrics_results[id]['args'] /= (len(data.gt['args']) + 1e-5) + if len(data.gt['args']) == 0 and len(data.pred['args']) == 0: + metrics_results[id]['args'] = 1 + if len(data.gt['args']) == 0 and len(data.pred['args']) != 0: + metrics_results[id]['args'] = 0 + else: + data.pred['args'] = data.pred['args'].strip("'").strip('"') + metrics_results[id]['args'] = float(data.gt['args'] == data.pred['args']) if len(batch_data) > 0: pred_emb = self.sentence_model.encode(batch_data, convert_to_tensor=True) @@ -243,13 +243,13 @@ def _post_process(self, results_list): batch_data = [] batch_id = [] - if len(batch_arg_data) > 0: - pred_emb = self.sentence_model.encode(batch_arg_data, convert_to_tensor=True) - for i in range(0, len(batch_arg_data), 2): - cosine_score = np.maximum(util.cos_sim(pred_emb[i], pred_emb[i+1]).cpu().numpy(), 0) - metrics_results[batch_arg_id[i // 2]]['args'] = cosine_score[0, 0] - batch_arg_data = [] - batch_arg_id = [] + # if len(batch_arg_data) > 0: + # pred_emb = self.sentence_model.encode(batch_arg_data, convert_to_tensor=True) + # for i in range(0, len(batch_arg_data), 2): + # cosine_score = np.maximum(util.cos_sim(pred_emb[i], pred_emb[i+1]).cpu().numpy(), 0) + # metrics_results[batch_arg_id[i // 2]]['args'] = cosine_score[0, 0] + # batch_arg_data = [] + # batch_arg_id = [] results = dict() for key in metric_keys: