forked from TopoXLab/consistency-ranking-loss
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
78 lines (66 loc) · 2.26 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from collections import Iterable
class AverageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class Logger(object):
def __init__(self, path, int_form=':04d', float_form=':.6f'):
self.path = path
self.int_form = int_form
self.float_form = float_form
self.width = 0
def __len__(self):
try: return len(self.read())
except: return 0
def write(self, values):
if not isinstance(values, Iterable):
values = [values]
if self.width == 0:
self.width = len(values)
assert self.width == len(values), 'Inconsistent number of items.'
line = ''
for v in values:
if isinstance(v, int):
line += '{{{}}} '.format(self.int_form).format(v)
elif isinstance(v, float):
line += '{{{}}} '.format(self.float_form).format(v)
elif isinstance(v, str):
line += '{} '.format(v)
else:
raise Exception('Not supported type.')
with open(self.path, 'a') as f:
f.write(line[:-1] + '\n')
def read(self):
with open(self.path, 'r') as f:
log = []
for line in f:
values = []
for v in line.split(' '):
try:
v = float(v)
except:
pass
values.append(v)
log.append(values)
return log
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res[0], correct.squeeze()