-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
115 lines (80 loc) · 3.01 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# Birth: 2022-06-01 13:37:43.576184507 +0530
# Modify: 2022-06-19 09:35:40.222532267 +0530
"""Utilities for BertMultiLabel"""
import json
import logging
import os
import torch
import torch.nn as nn
import numpy as np
__author__ = "Upal Bhattacharya"
__copyright__ = ""
__license__ = ""
__version__ = "1.0"
__email__ = "upal.bhattacharya@gmail.com"
class Params:
"""Class that loads hyperparameters from a json file."""
def __init__(self, json_path):
with open(json_path, 'r') as f:
params = json.load(f)
self.__dict__.update(params)
def save(self, json_path):
with open(json_path, 'w') as f:
json.dump(self.__dict__, f, indent=4)
@property
def dict(self):
"""Provide dictionary-like access to hyperparameters."""
return self.__dict__
class Accumulate:
"""Maintain all data used in an epoch for metrics calculation."""
def __init__(self):
self.output_batch = []
self.targets_batch = []
def update(self, output_batch, targets_batch):
self.output_batch.extend(output_batch.tolist())
self.targets_batch.extend(targets_batch.tolist())
def __call__(self):
return (np.stack(self.output_batch, axis=0),
np.stack(self.targets_batch, axis=0))
def load_checkpoint(restore_path, model, optimizer=None, device_id=None):
if not(os.path.exists(restore_path)):
raise (f"No restore file found at {restore_path}.")
if device_id is None:
ckpt = torch.load(restore_path)
else:
ckpt = torch.load(restore_path, map_location=f"cuda:{device_id}")
model.load_state_dict(ckpt['state_dict'])
if optimizer:
optimizer.load_state_dict(ckpt['optim_dict'])
return ckpt['epoch']
def save_dict_to_json(dict_obj, save_path):
if not(os.path.exists(os.path.split(save_path)[0])):
os.makedirs(os.path.split(save_path)[0])
with open(save_path, 'w') as f:
json.dump(dict_obj, f, indent=4)
def save_checkpoint(state, is_best, save_path, to_save=False):
if not(os.path.exists(save_path)):
os.makedirs(save_path)
if is_best:
torch.save(state, os.path.join(save_path, "best.pth.tar"))
if to_save:
torch.save(state, os.path.join(save_path,
f"epoch_{state['epoch']}.pth.tar"))
def set_logger(log_path: str):
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
if not logger.handlers:
# File handler
file_handler = logging.FileHandler(log_path)
file_handler.setFormatter(logging.Formatter(
"%(asctime)s: [%(levelname)s] %(message)s",
"%Y-%m-%d %H:%M:%S"))
logger.addHandler(file_handler)
# Stream handler
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter(
"%(asctime)s : [%(levelname)s] %(message)s",
"%Y-%m-%d %H:%M:%S"))
logger.addHandler(stream_handler)