Skip to content

Commit

Permalink
add metrics in XLDST evaluation (#126)
Browse files Browse the repository at this point in the history
* update sumbt translation train result with evaluation mode set

* update extract values

* automatically download sumbt model

* dstc9 eval

* dstc9 xldst evaluation

* modify example

* add .gitignore

* remove precision, recall, f1

* release 250 test data

* revise evaluation

* fix file submission example

* update precision, recall, f1 calculation

* minor change
  • Loading branch information
罗崚骁 authored Sep 24, 2020
1 parent e8ae888 commit d70c0fc
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 30 deletions.
25 changes: 20 additions & 5 deletions convlab2/dst/dstc9/eval_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import os
import json
from copy import deepcopy

from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_subdir

Expand All @@ -16,21 +17,35 @@ def evaluate(model_dir, subtask, gt):
if not os.path.exists(filepath):
continue
pred = json.load(open(filepath))
results[i] = eval_states(gt, pred)
results[i] = eval_states(gt, pred, subtask)

print(json.dumps(results, indent=4))
json.dump(results, open(os.path.join(model_dir, subdir, 'file-results.json'), 'w'), indent=4, ensure_ascii=False)


# generate submission examples
def dump_example(subtask, split):
test_data = prepare_data(subtask, split)
gt = extract_gt(test_data)
json.dump(gt, open(os.path.join('example', get_subdir(subtask), 'submission1.json'), 'w'), ensure_ascii=False, indent=4)
for dialog_id, states in gt.items():
pred = extract_gt(test_data)
json.dump(pred, open(os.path.join('example', get_subdir(subtask), 'submission1.json'), 'w'), ensure_ascii=False, indent=4)
import random
for dialog_id, states in pred.items():
for state in states:
for domain in state.values():
for slot, value in domain.items():
if value:
if random.randint(0, 2) == 0:
domain[slot] = ""
else:
if random.randint(0, 4) == 0:
domain[slot] = "2333"
json.dump(pred, open(os.path.join('example', get_subdir(subtask), 'submission2.json'), 'w'), ensure_ascii=False, indent=4)
for dialog_id, states in pred.items():
for state in states:
for domain in state.values():
for slot in domain:
domain[slot] = ""
json.dump(gt, open(os.path.join('example', get_subdir(subtask), 'submission2.json'), 'w'), ensure_ascii=False, indent=4)
json.dump(pred, open(os.path.join('example', get_subdir(subtask), 'submission3.json'), 'w'), ensure_ascii=False, indent=4)


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions convlab2/dst/dstc9/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def evaluate(model_dir, subtask, test_data, gt):
for dialog_id, turns in test_data.items():
model.init_session()
pred[dialog_id] = [model.update_turn(sys_utt, user_utt) for sys_utt, user_utt, gt_turn in turns]
result = eval_states(gt, pred)
print(result)
result = eval_states(gt, pred, subtask)
print(json.dumps(result, indent=4))
json.dump(result, open(os.path.join(model_dir, subdir, 'model-result.json'), 'w'), indent=4, ensure_ascii=False)


Expand Down
61 changes: 38 additions & 23 deletions convlab2/dst/dstc9/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@
from convlab2 import DATA_ROOT


def get_subdir(subtask):
subdir = 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en'
return subdir


def prepare_data(subtask, split, data_root=DATA_ROOT):
data_dir = os.path.join(data_root, 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en')
data_dir = os.path.join(data_root, get_subdir(subtask))
zip_filename = os.path.join(data_dir, f'{split}.json.zip')
test_data = json.load(zipfile.ZipFile(zip_filename).open(f'{split}.json'))
data = {}
Expand Down Expand Up @@ -57,7 +62,20 @@ def extract_gt(test_data):
return gt


def eval_states(gt, pred):
def eval_states(gt, pred, subtask):
# for unifying values with the same meaning to the same expression
value_unifier = {
'multiwoz': {

},
'crosswoz': {
'未提及': '',
}
}[subtask]

def unify_value(value):
return value_unifier.get(value, value)

def exception(description, **kargs):
ret = {
'status': 'exception',
Expand Down Expand Up @@ -89,35 +107,32 @@ def exception(description, **kargs):
for slot_name, gt_value in gt_domain.items():
if slot_name not in pred_domain:
return exception('slot missing', dialog_id=dialog_id, turn_id=turn_id, domain=domain_name, slot=slot_name)
pred_value = pred_domain[slot_name]
gt_value = unify_value(gt_value)
pred_value = unify_value(pred_domain[slot_name])
slot_tot += 1
if gt_value == pred_value:
slot_acc += 1
tp += 1
if gt_value:
tp += 1
else:
turn_result = False
# for class of gt_value
fn += 1
# for class of pred_value
fp += 1
if gt_value:
fn += 1
if pred_value:
fp += 1
joint_acc += turn_result

precision = tp / (tp + fp)
recall = tp / (tp + fn)
f1 = 2 * tp / (2 * tp + fp + fn)
precision = tp / (tp + fp) if tp + fp else 1
recall = tp / (tp + fn) if tp + fn else 1
f1 = 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) else 1
return {
'status': 'ok',
'joint accuracy': joint_acc / joint_tot,
'slot accuracy': slot_acc / slot_tot,
# 'slot': {
# 'accuracy': slot_acc / slot_tot,
# 'precision': precision,
# 'recall': recall,
# 'f1': f1,
# }
# 'slot accuracy': slot_acc / slot_tot,
'slot': {
'accuracy': slot_acc / slot_tot,
'precision': precision,
'recall': recall,
'f1': f1,
}
}


def get_subdir(subtask):
subdir = 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en'
return subdir

0 comments on commit d70c0fc

Please sign in to comment.