diff --git a/convlab2/dst/dstc9/eval_file.py b/convlab2/dst/dstc9/eval_file.py index b53b46a2..a89c2e3e 100644 --- a/convlab2/dst/dstc9/eval_file.py +++ b/convlab2/dst/dstc9/eval_file.py @@ -4,6 +4,7 @@ import os import json +from copy import deepcopy from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_subdir @@ -16,21 +17,35 @@ def evaluate(model_dir, subtask, gt): if not os.path.exists(filepath): continue pred = json.load(open(filepath)) - results[i] = eval_states(gt, pred) + results[i] = eval_states(gt, pred, subtask) + print(json.dumps(results, indent=4)) json.dump(results, open(os.path.join(model_dir, subdir, 'file-results.json'), 'w'), indent=4, ensure_ascii=False) +# generate submission examples def dump_example(subtask, split): test_data = prepare_data(subtask, split) - gt = extract_gt(test_data) - json.dump(gt, open(os.path.join('example', get_subdir(subtask), 'submission1.json'), 'w'), ensure_ascii=False, indent=4) - for dialog_id, states in gt.items(): + pred = extract_gt(test_data) + json.dump(pred, open(os.path.join('example', get_subdir(subtask), 'submission1.json'), 'w'), ensure_ascii=False, indent=4) + import random + for dialog_id, states in pred.items(): + for state in states: + for domain in state.values(): + for slot, value in domain.items(): + if value: + if random.randint(0, 2) == 0: + domain[slot] = "" + else: + if random.randint(0, 4) == 0: + domain[slot] = "2333" + json.dump(pred, open(os.path.join('example', get_subdir(subtask), 'submission2.json'), 'w'), ensure_ascii=False, indent=4) + for dialog_id, states in pred.items(): for state in states: for domain in state.values(): for slot in domain: domain[slot] = "" - json.dump(gt, open(os.path.join('example', get_subdir(subtask), 'submission2.json'), 'w'), ensure_ascii=False, indent=4) + json.dump(pred, open(os.path.join('example', get_subdir(subtask), 'submission3.json'), 'w'), ensure_ascii=False, indent=4) if __name__ == '__main__': diff --git a/convlab2/dst/dstc9/eval_model.py b/convlab2/dst/dstc9/eval_model.py index 1b84ea05..377ecdd3 100644 --- a/convlab2/dst/dstc9/eval_model.py +++ b/convlab2/dst/dstc9/eval_model.py @@ -22,8 +22,8 @@ def evaluate(model_dir, subtask, test_data, gt): for dialog_id, turns in test_data.items(): model.init_session() pred[dialog_id] = [model.update_turn(sys_utt, user_utt) for sys_utt, user_utt, gt_turn in turns] - result = eval_states(gt, pred) - print(result) + result = eval_states(gt, pred, subtask) + print(json.dumps(result, indent=4)) json.dump(result, open(os.path.join(model_dir, subdir, 'model-result.json'), 'w'), indent=4, ensure_ascii=False) diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py index 80cd25d9..b378fd2e 100644 --- a/convlab2/dst/dstc9/utils.py +++ b/convlab2/dst/dstc9/utils.py @@ -5,8 +5,13 @@ from convlab2 import DATA_ROOT +def get_subdir(subtask): + subdir = 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en' + return subdir + + def prepare_data(subtask, split, data_root=DATA_ROOT): - data_dir = os.path.join(data_root, 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en') + data_dir = os.path.join(data_root, get_subdir(subtask)) zip_filename = os.path.join(data_dir, f'{split}.json.zip') test_data = json.load(zipfile.ZipFile(zip_filename).open(f'{split}.json')) data = {} @@ -57,7 +62,20 @@ def extract_gt(test_data): return gt -def eval_states(gt, pred): +def eval_states(gt, pred, subtask): + # for unifying values with the same meaning to the same expression + value_unifier = { + 'multiwoz': { + + }, + 'crosswoz': { + '未提及': '', + } + }[subtask] + + def unify_value(value): + return value_unifier.get(value, value) + def exception(description, **kargs): ret = { 'status': 'exception', @@ -89,35 +107,32 @@ def exception(description, **kargs): for slot_name, gt_value in gt_domain.items(): if slot_name not in pred_domain: return exception('slot missing', dialog_id=dialog_id, turn_id=turn_id, domain=domain_name, slot=slot_name) - pred_value = pred_domain[slot_name] + gt_value = unify_value(gt_value) + pred_value = unify_value(pred_domain[slot_name]) slot_tot += 1 if gt_value == pred_value: slot_acc += 1 - tp += 1 + if gt_value: + tp += 1 else: turn_result = False - # for class of gt_value - fn += 1 - # for class of pred_value - fp += 1 + if gt_value: + fn += 1 + if pred_value: + fp += 1 joint_acc += turn_result - precision = tp / (tp + fp) - recall = tp / (tp + fn) - f1 = 2 * tp / (2 * tp + fp + fn) + precision = tp / (tp + fp) if tp + fp else 1 + recall = tp / (tp + fn) if tp + fn else 1 + f1 = 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) else 1 return { 'status': 'ok', 'joint accuracy': joint_acc / joint_tot, - 'slot accuracy': slot_acc / slot_tot, - # 'slot': { - # 'accuracy': slot_acc / slot_tot, - # 'precision': precision, - # 'recall': recall, - # 'f1': f1, - # } + # 'slot accuracy': slot_acc / slot_tot, + 'slot': { + 'accuracy': slot_acc / slot_tot, + 'precision': precision, + 'recall': recall, + 'f1': f1, + } } - - -def get_subdir(subtask): - subdir = 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en' - return subdir