-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
47 lines (38 loc) · 1.31 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import json
import logging
import numpy as np
def save_dict_to_json(d, json_path):
"""Saves dict of floats in json file
Args:
d: (dict) of float-castable values (np.float, int, float, etc.)
json_path: (string) path to json file
"""
with open(json_path, 'w') as f:
for k, v in d.items():
if v == str:
d[k] = v
else: #v == num
d[k] = float(v)
json.dump(d, f, indent=4)
def set_logger(log_path):
"""
logging.info("Starting training...")
Args:
log_path: (string) where to log
"""
logger = logging.getLogger()
logger.setLevel(logging.INFO)
if not logger.handlers:
# Logging to a file
file_handler = logging.FileHandler(log_path)
file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
logger.addHandler(file_handler)
# Logging to console
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter('%(message)s'))
logger.addHandler(stream_handler)
def get_params(model):
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(f"# of trainable parameters: {params}")
return params