Skip to content

Commit 7782b4c

Browse files
committed
Fix compare for sqlite
1 parent e4ef1f8 commit 7782b4c

File tree

1 file changed

+118
-111
lines changed

1 file changed

+118
-111
lines changed

bootstrap/compare.py

Lines changed: 118 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -1,121 +1,128 @@
1-
import json
2-
import numpy as np
31
import argparse
42
from os import path as osp
53
from tabulate import tabulate
6-
7-
8-
def load_values(dir_logs, metrics, nb_epochs=-1, best=None):
9-
json_files = {}
10-
values = {}
11-
12-
# load argsup of best
13-
if best:
14-
if best['json'] not in json_files:
15-
with open(osp.join(dir_logs, f'{best["json"]}.json')) as f:
16-
json_files[best['json']] = json.load(f)
17-
18-
jfile = json_files[best['json']]
19-
vals = jfile[best['name']]
20-
end = len(vals) if nb_epochs == -1 else nb_epochs
21-
argsup = np.__dict__[f'arg{best["order"]}'](vals[:end])
22-
23-
# load logs
24-
for _key, metric in metrics.items():
25-
# open json_files
26-
if metric['json'] not in json_files:
27-
with open(osp.join(dir_logs, f'{metric["json"]}.json')) as f:
28-
json_files[metric['json']] = json.load(f)
29-
30-
jfile = json_files[metric['json']]
31-
32-
if 'train' in metric['name']:
33-
epoch_key = 'train_epoch.epoch'
34-
else:
35-
epoch_key = 'eval_epoch.epoch'
36-
37-
if epoch_key in jfile:
38-
epochs = jfile[epoch_key]
39-
else:
40-
epochs = jfile['epoch']
41-
42-
vals = jfile[metric['name']]
43-
if not best:
44-
end = len(vals) if nb_epochs == -1 else nb_epochs
45-
argsup = np.__dict__[f'arg{metric["order"]}'](vals[:end])
46-
47-
try:
48-
values[metric['name']] = epochs[argsup], vals[argsup]
49-
except IndexError:
50-
values[metric['name']] = epochs[argsup - 1], vals[argsup - 1]
51-
return values
52-
53-
54-
def main(args):
55-
dir_logs = {}
56-
for raw in args.dir_logs:
57-
tmp = raw.split(':')
58-
if len(tmp) == 2:
59-
key, path = tmp
60-
elif len(tmp) == 1:
61-
path = tmp[0]
62-
key = osp.basename(osp.normpath(path))
63-
else:
64-
raise ValueError(raw)
65-
dir_logs[key] = path
66-
67-
metrics = {}
68-
for json_obj, name, order in args.metrics:
69-
metrics[f'{json_obj}_{name}'] = {
70-
'json': json_obj,
71-
'name': name,
72-
'order': order
73-
}
74-
75-
if args.best:
76-
json_obj, name, order = args.best
77-
best = {
78-
'json': json_obj,
79-
'name': name,
80-
'order': order
81-
}
82-
else:
83-
best = None
84-
85-
logs = {}
86-
for name, dir_log in dir_logs.items():
87-
logs[name] = load_values(dir_log, metrics,
88-
nb_epochs=args.nb_epochs,
89-
best=best)
90-
91-
for _key, metric in metrics.items():
92-
names = []
93-
values = []
94-
epochs = []
95-
for name, vals in logs.items():
96-
if metric['name'] in vals:
97-
names.append(name)
98-
epoch, value = vals[metric['name']]
99-
epochs.append(epoch)
100-
values.append(value)
101-
if values:
102-
values_names = sorted(zip(values, names, epochs), reverse=metric['order'] == 'max')
103-
values_names = [[i + 1, name, value, epoch] for i, (value, name, epoch) in enumerate(values_names)]
104-
print('\n\n## {}\n'.format(metric['name']))
105-
print(tabulate(values_names, headers=['Place', 'Method', 'Score', 'Epoch']))
4+
import sqlite3
5+
from contextlib import closing
6+
7+
# def get_internal_table_name(table_name):
8+
# return f'_{table_name}'
9+
10+
# def run_query(conn, query, parameters=None, cursor=None):
11+
# return execute(conn, query, parameters, commit=False, cursor=cursor)
12+
13+
# def list_columns(conn, table_name):
14+
# table_name = get_internal_table_name(table_name)
15+
# query = "SELECT name FROM PRAGMA_TABLE_INFO(?)"
16+
# with closing(conn.cursor()) as cursor:
17+
# qry_cur = run_query(conn, query, (table_name,), cursor=cursor)
18+
# columns = (res[0] for res in qry_cur)
19+
# # remove __id and __timestamp columns
20+
# columns = [c for c in columns if not c.startswith('__')]
21+
# return columns
22+
23+
# def select(conn, group, columns=None, where=None):
24+
# table_name = get_internal_table_name(group)
25+
# table_columns = list_columns(conn, group)
26+
# if columns is None:
27+
# column_string = '*'
28+
# else:
29+
# for c in columns:
30+
# if c not in table_columns:
31+
# Logger()(f'Unknown column "{c}"', log_level=Logger.ERROR)
32+
# column_string = ', '.join([f'"{c}"' for c in columns])
33+
# statement = f'SELECT {column_string} FROM {table_name}'
34+
# with closing(conn.cursor()) as cursor:
35+
# return execute(conn, statement, cursor=cursor, commit=False).fetchall()
36+
37+
38+
def execute(conn, statement, parameters=None, commit=True, cursor=None):
39+
assert parameters is None or isinstance(parameters, tuple)
40+
parameters = parameters or ()
41+
return_value = cursor.execute(statement, parameters)
42+
if commit:
43+
conn.commit()
44+
return return_value
45+
46+
47+
def load_table(list_dir, metric, nb_epochs=None, best=None):
48+
table = []
49+
for dir_logs in list_dir:
50+
# if metric['fname'] == best['fname']:
51+
# path_sql = osp.join(dir_logs, f'{metric["fname"]}.sqlite')
52+
# conn = sqlite3.connect(path_sql, check_same_thread=False, isolation_level='IMMEDIATE')
53+
# statement = f'SELECT m.{metric["column"]}, m.epoch FROM _{metric["group"]} AS m, _{best["group"]} AS b'
54+
# if nb_epochs:
55+
# statement += f' WHERE m.epoch < {nb_epochs}'
56+
# if best['order'] == 'max':
57+
# order = 'DESC'
58+
# elif best['order'] == 'min':
59+
# order = 'ASC'
60+
# statement += f' ORDER BY b.{best["column"]} {order} LIMIT 1'
61+
# with closing(conn.cursor()) as cursor:
62+
# score, epoch = execute(conn, statement, cursor=cursor).fetchone()
63+
# else:
64+
path_sql = osp.join(dir_logs, f'{best["fname"]}.sqlite')
65+
conn = sqlite3.connect(path_sql, check_same_thread=False, isolation_level='IMMEDIATE')
66+
statement = f'SELECT {best["column"]}, epoch FROM _{best["group"]}'
67+
if nb_epochs:
68+
statement += f' WHERE epoch < {nb_epochs}'
69+
if best['order'] == 'max':
70+
order = 'DESC'
71+
elif best['order'] == 'min':
72+
order = 'ASC'
73+
statement += f' ORDER BY {best["column"]} {order} LIMIT 1'
74+
with closing(conn.cursor()) as cursor:
75+
best_score, best_epoch = execute(conn, statement, cursor=cursor).fetchone()
76+
77+
path_sql = osp.join(dir_logs, f'{metric["fname"]}.sqlite')
78+
conn = sqlite3.connect(path_sql, check_same_thread=False, isolation_level='IMMEDIATE')
79+
statement = f'SELECT {metric["column"]}, epoch FROM _{metric["group"]}'
80+
statement += f' WHERE epoch == {best_epoch}'
81+
with closing(conn.cursor()) as cursor:
82+
score, epoch = execute(conn, statement, cursor=cursor).fetchone()
83+
84+
table.append([dir_logs, score, epoch])
85+
86+
if best['order'] == 'max':
87+
reverse = True
88+
elif best['order'] == 'min':
89+
reverse = False
90+
table.sort(key=lambda x: x[1], reverse=reverse)
91+
92+
for i, x in enumerate(table):
93+
x.insert(0, f'# {i+1}')
94+
return table
95+
96+
97+
def metric_str_to_dict(metric):
98+
split_ = metric.split('.')
99+
return {
100+
'fname': split_[0],
101+
'group': split_[1],
102+
'column': split_[2],
103+
'order': split_[3]
104+
}
105+
106+
107+
def display_metrics(list_dir, metrics, nb_epochs=None, best=None):
108+
best = metric_str_to_dict(best)
109+
for mstr in metrics:
110+
metric = metric_str_to_dict(mstr)
111+
table = load_table(list_dir, metric, nb_epochs=nb_epochs, best=best)
112+
print(f'\n\n## {mstr}\n')
113+
print(tabulate(table, headers=['Place', 'Method', 'Score', 'Epoch']))
106114

107115

108116
if __name__ == '__main__':
109117
parser = argparse.ArgumentParser(description='')
110118
parser.add_argument('-n', '--nb_epochs', default=-1, type=int)
111119
parser.add_argument('-d', '--dir_logs', default='', type=str, nargs='*')
112-
parser.add_argument('-m', '--metrics', type=str, action='append', nargs=3,
113-
metavar=('json', 'name', 'order'),
114-
default=[['logs', 'eval_epoch.accuracy_top1', 'max'],
115-
['logs', 'eval_epoch.accuracy_top5', 'max'],
116-
['logs', 'eval_epoch.loss', 'min']])
117-
parser.add_argument('-b', '--best', type=str, nargs=3,
118-
metavar=('json', 'name', 'order'),
119-
default=['logs', 'eval_epoch.accuracy_top1', 'max'])
120+
parser.add_argument('-m', '--metrics', type=str, nargs='*',
121+
default=['logs.eval_epoch.accuracy.max',
122+
'logs.train_epoch.loss.min',
123+
'logs.train_epoch.accuracy.max'])
124+
parser.add_argument('-b', '--best', type=str,
125+
default='logs.eval_epoch.accuracy.max')
120126
args = parser.parse_args()
121-
main(args)
127+
nb_epochs = None if args.nb_epochs == -1 else args.nb_epochs
128+
display_metrics(args.dir_logs, args.metrics, nb_epochs, args.best)

0 commit comments

Comments
 (0)