Skip to content

Commit

Permalink
fix: DEV-2087: Fix agreement calculation for Taxonomy with label weights
Browse files Browse the repository at this point in the history
  • Loading branch information
KonstantinKorotaev committed Apr 7, 2022
1 parent 3384b01 commit b9b95d7
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
14 changes: 13 additions & 1 deletion evalme/tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,4 +584,16 @@ def test_htmltags_migration_per_label():
html_tags1 = HTMLTagsEvalItem(raw_data=item_old, shape_key="hypertextlabels")
html_tags2 = HTMLTagsEvalItem(raw_data=item_new, shape_key="hypertextlabels")
assert html_tags1.intersection(html_tags2, per_label=True) == {'Title': 0.9285714285714286}
assert html_tags2.intersection(html_tags1, per_label=True) == {'Title': 0.9285714285714286}
assert html_tags2.intersection(html_tags1, per_label=True) == {'Title': 0.9285714285714286}


def test_taxonomy_control_weights():
"""
Test taxonomy with label weights
:return:
"""
label_weights = {"B": 0.5, "B_A": 0.2}
pred = intersection_taxonomy(tree_subview_1, tree_subview_2, label_config=label_config_subview, label_weights=label_weights)
assert pred == 0.2
pred_vice = intersection_taxonomy(tree_subview_2, tree_subview_1, label_config=label_config_subview, label_weights=label_weights)
assert pred_vice == 0.05
10 changes: 9 additions & 1 deletion evalme/text/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def spans_iou(self, prediction, per_label=False, label_config=None, label_weight
taxonomy_gt_list.extend(TaxonomyEvalItem._transform_tree(master_tree, item_gt_tx))
for item in taxonomy_pred_list:
if item in taxonomy_gt_list:
temp += 1
temp += label_weights.get(item[-1], 1)
matches += (temp / max(len(taxonomy_gt_list), 1))
tasks += 1
return matches / max(tasks, 1)
Expand Down Expand Up @@ -287,6 +287,8 @@ def path_matches(self, prediction, per_label=False, label_weights=dict()):
def _tree(label_config):
"""
Creating Tree from label_config
Example for default config:
{'Archaea': {}, 'Bacteria': {}, 'Eukarya': {'Human': {}, 'Oppossum': {}, 'Extraterrestial': {}}}
"""

def recursive_lookup(d, k='Taxonomy'):
Expand Down Expand Up @@ -366,6 +368,12 @@ def paths(self):

@staticmethod
def _compare_list(gt, pred):
"""
Compare taxonomy path in depth
:param gt: ground truth value
:param pred: predicted value
:return: score [0..1]
"""
score = 0
for p, g in zip(pred, gt):
if p == g:
Expand Down

0 comments on commit b9b95d7

Please sign in to comment.