diff --git a/evalme/tests/test_text.py b/evalme/tests/test_text.py index 1bbd6e4..1170483 100644 --- a/evalme/tests/test_text.py +++ b/evalme/tests/test_text.py @@ -584,4 +584,16 @@ def test_htmltags_migration_per_label(): html_tags1 = HTMLTagsEvalItem(raw_data=item_old, shape_key="hypertextlabels") html_tags2 = HTMLTagsEvalItem(raw_data=item_new, shape_key="hypertextlabels") assert html_tags1.intersection(html_tags2, per_label=True) == {'Title': 0.9285714285714286} - assert html_tags2.intersection(html_tags1, per_label=True) == {'Title': 0.9285714285714286} \ No newline at end of file + assert html_tags2.intersection(html_tags1, per_label=True) == {'Title': 0.9285714285714286} + + +def test_taxonomy_control_weights(): + """ + Test taxonomy with label weights + :return: + """ + label_weights = {"B": 0.5, "B_A": 0.2} + pred = intersection_taxonomy(tree_subview_1, tree_subview_2, label_config=label_config_subview, label_weights=label_weights) + assert pred == 0.2 + pred_vice = intersection_taxonomy(tree_subview_2, tree_subview_1, label_config=label_config_subview, label_weights=label_weights) + assert pred_vice == 0.05 diff --git a/evalme/text/text.py b/evalme/text/text.py index 3b42b60..8e58ef6 100644 --- a/evalme/text/text.py +++ b/evalme/text/text.py @@ -251,7 +251,7 @@ def spans_iou(self, prediction, per_label=False, label_config=None, label_weight taxonomy_gt_list.extend(TaxonomyEvalItem._transform_tree(master_tree, item_gt_tx)) for item in taxonomy_pred_list: if item in taxonomy_gt_list: - temp += 1 + temp += label_weights.get(item[-1], 1) matches += (temp / max(len(taxonomy_gt_list), 1)) tasks += 1 return matches / max(tasks, 1) @@ -287,6 +287,8 @@ def path_matches(self, prediction, per_label=False, label_weights=dict()): def _tree(label_config): """ Creating Tree from label_config + Example for default config: + {'Archaea': {}, 'Bacteria': {}, 'Eukarya': {'Human': {}, 'Oppossum': {}, 'Extraterrestial': {}}} """ def recursive_lookup(d, k='Taxonomy'): @@ -366,6 +368,12 @@ def paths(self): @staticmethod def _compare_list(gt, pred): + """ + Compare taxonomy path in depth + :param gt: ground truth value + :param pred: predicted value + :return: score [0..1] + """ score = 0 for p, g in zip(pred, gt): if p == g: