-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathcompute_pairwise_mi.py
executable file
·38 lines (29 loc) · 1.33 KB
/
compute_pairwise_mi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
"""
Executable to compute the Mutual Information across items variables
Theory: https://nlp.stanford.edu/IR-book/html/htmledition/mutual-information-1.html
Impl Docs: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mutual_info_score.html
"""
import argparse
import numpy as np
from sklearn.metrics import mutual_info_score
from aaerec.datasets import Bags
from aaerec.condition import ConditionList, CountCondition
from aaerec.utils import compute_mutual_info
PARSER = argparse.ArgumentParser()
PARSER.add_argument('dataset', type=str,
help='path to dataset')
PARSER.add_argument('-m', '--min-count', type=int,
help='Pruning parameter', default=None)
PARSER.add_argument('-M', '--max-features', type=int,
help='Max features', default=None)
ARGS = PARSER.parse_args()
# MI_CONDITIONS = ConditionList([('title', CountCondition(max_features=100000))])
MI_CONDITIONS = None
print("Computing Mutual Info with args")
print(ARGS)
# With no metadata or just titles
BAGS = Bags.load_tabcomma_format(ARGS.dataset, unique=True)\
.build_vocab(min_count=ARGS.min_count, max_features=ARGS.max_features)
mi = compute_mutual_info(BAGS, MI_CONDITIONS, include_labels=True, normalize=True)
with open('mi.csv', 'a') as mifile:
print('CITREC', ARGS.min_count, mi, sep=',', file=mifile)