Skip to content

Commit

Permalink
Merge pull request #396 from yoshitomo-matsubara/dev
Browse files Browse the repository at this point in the history
Add docstrings to misc subpackage
  • Loading branch information
yoshitomo-matsubara authored Sep 4, 2023
2 parents ee4b00b + 70bc985 commit 863697e
Showing 1 changed file with 64 additions and 4 deletions.
68 changes: 64 additions & 4 deletions torchdistill/misc/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,27 @@


def setup_log_file(log_file_path):
"""
Sets a file handler with ``log_file_path`` to write a log file.
:param log_file_path: log file path.
:type log_file_path: str
"""
make_parent_dirs(log_file_path)
fh = FileHandler(filename=log_file_path, mode='w')
fh.setFormatter(Formatter(LOGGING_FORMAT))
def_logger.addHandler(fh)


class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
A deque-based value object tracks a series of values and provides access to smoothed values
over a window or the global series average. The original implementation is https://github.com/pytorch/vision/blob/main/references/classification/utils.py
:param window_size: window size.
:type window_size: int
:param fmt: text format.
:type fmt: str or None
"""

def __init__(self, window_size=20, fmt=None):
Expand All @@ -34,13 +46,24 @@ def __init__(self, window_size=20, fmt=None):
self.fmt = fmt

def update(self, value, n=1):
"""
Appends ``value``.
:param value: value to be added.
:type value: float or int
:param n: sample count.
:type n: int
"""
self.deque.append(value)
self.count += n
self.total += value * n

def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
Synchronizes between processes.
.. warning::
It does not synchronize the deque.
"""
if not is_dist_avail_and_initialized():
return
Expand Down Expand Up @@ -80,15 +103,29 @@ def __str__(self):
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
value=self.value
)


class MetricLogger(object):
"""
A metric logger with :class:`SmoothedValue`.
The original implementation is https://github.com/pytorch/vision/blob/main/references/classification/utils.py
:param delimiter: delimiter in a log message.
:type delimiter: str
"""
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter

def update(self, **kwargs):
"""
Updates a metric dict whose values are :class:`SmoothedValue`.
:param kwargs: keys and values.
:type kwargs: dict
"""
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
Expand All @@ -113,13 +150,36 @@ def __str__(self):
return self.delimiter.join(loss_str)

def synchronize_between_processes(self):
"""
Synchronizes between processes.
"""
for meter in self.meters.values():
meter.synchronize_between_processes()

def add_meter(self, name, meter):
"""
Add a new metric name and value.
:param name: metric name.
:type name: str
:param meter: smoothed value.
:type meter: SmoothedValue
"""
self.meters[name] = meter

def log_every(self, iterable, log_freq, header=None):
"""
Add a new metric name and value.
:param iterable: iterable object (e.g., data loader).
:type iterable: typing.Iterable
:param log_freq: log frequency.
:type log_freq: int
:param header: log message header.
:type header: str
:return: item in ``iterative``.
:rtype: Any
"""
i = 0
if not header:
header = ''
Expand Down

0 comments on commit 863697e

Please sign in to comment.