diff --git a/utils.py b/utils.py index ee3dd928..727a6a08 100644 --- a/utils.py +++ b/utils.py @@ -1,21 +1,22 @@ import numpy as np import re import functools +from collections import deque class AverageMeter(object): """Computes and stores the average and current value""" - def __init__(self): + def __init__(self, window_size=20): + self.deque = deque(maxlen=window_size) self.initialized = False self.val = None - self.avg = None self.sum = None self.count = None def initialize(self, val, weight): self.val = val - self.avg = val self.sum = val * weight self.count = weight + self.deque.append(val) self.initialized = True def update(self, val, weight=1): @@ -28,14 +29,16 @@ def add(self, val, weight): self.val = val self.sum += val * weight self.count += weight - self.avg = self.sum / self.count + self.deque.append(val) def value(self): return self.val def average(self): - return self.avg + return np.mean(self.deque) + def median(self): + return np.median(self.deque) def unique(ar, return_index=False, return_inverse=False, return_counts=False): ar = np.asanyarray(ar).flatten()