-
Notifications
You must be signed in to change notification settings - Fork 17
/
checkpoint.py
138 lines (123 loc) · 6.07 KB
/
checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from __future__ import print_function
import os
import time
import shutil
import torch
class Checkpoint(object):
"""
The Checkpoint class manages the saving and loading of a model during training. It allows training to be suspended
and resumed at a later time (e.g. when running on a cluster using sequential jobs).
To make a checkpoint, initialize a Checkpoint object with the following args; then call that object's save() method
to write parameters to disk.
Args:
model (seq2seq): seq2seq model being trained
optimizer (Optimizer): stores the state of the optimizer
epoch (int): current epoch (an epoch is a loop through the full training vocab_data)
step (int): number of examples seen within the current epoch
input_vocab (Vocabulary): vocabulary for the input language
output_vocab (Vocabulary): vocabulary for the output language
Attributes:
CHECKPOINT_DIR_NAME (str): name of the checkpoint directory
TRAINER_STATE_NAME (str): name of the file storing trainer states
MODEL_NAME (str): name of the file storing model
INPUT_VOCAB_FILE (str): name of the input vocab file
OUTPUT_VOCAB_FILE (str): name of the output vocab file
"""
CHECKPOINT_DIR_NAME = 'checkpoints'
TRAINER_STATE_NAME = 'trainer_states.pt'
MODEL_NAME = 'model.pt'
def __init__(self, model, opt, epoch, step, path=None):
self.model = model
self.opt = opt
self.epoch = epoch
self.step = step
self._path = path
@property
def path(self):
if self._path is None:
raise LookupError("The checkpoint has not been saved.")
return self._path
def save(self, experiment_dir):
"""
Saves the current model and related training parameters into a subdirectory of the checkpoint directory.
The name of the subdirectory is the current local time in Y_M_D_H_M_S format.
Args:
experiment_dir (str): path to the experiment root directory
Returns:
str: path to the saved checkpoint subdirectory
"""
date_time = time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime())
self._path = os.path.join(experiment_dir, self.CHECKPOINT_DIR_NAME, date_time)
path = self._path
if os.path.exists(path):
shutil.rmtree(path)
os.makedirs(path)
torch.save({'epoch': self.epoch,
'step': self.step,
'opt': self.opt
},
os.path.join(path, self.TRAINER_STATE_NAME))
torch.save(self.model, os.path.join(path, self.MODEL_NAME))
#with open(os.path.join(path, self.INPUT_VOCAB_FILE), 'wb') as fout:
# dill.dump(self.input_vocab, fout)
#with open(os.path.join(path, self.OUTPUT_VOCAB_FILE), 'wb') as fout:
# dill.dump(self.output_vocab, fout)
return path
@classmethod
def load(cls, path):
"""
Loads a Checkpoint object that was previously saved to disk.
Args:
path (str): path to the checkpoint subdirectory
Returns:
checkpoint (Checkpoint): checkpoint object with fields copied from those stored on disk
"""
if torch.cuda.is_available():
resume_checkpoint = torch.load(os.path.join(path, cls.TRAINER_STATE_NAME))
model = torch.load(os.path.join(path, cls.MODEL_NAME), map_location=lambda storage, loc: storage)
model.cuda()
# # Load all tensors onto the CPU
# torch.load('tensors.pt', map_location=lambda storage, loc: storage)
# # Map tensors from GPU 1 to GPU 0
# torch.load('tensors.pt', map_location={'cuda:1': 'cuda:%d'%gpu})
else:
resume_checkpoint = torch.load(os.path.join(path, cls.TRAINER_STATE_NAME), map_location=lambda storage, loc: storage)
model = torch.load(os.path.join(path, cls.MODEL_NAME), map_location=lambda storage, loc: storage)
#model.flatten_parameters() # make RNN parameters contiguous
#with open(os.path.join(path, cls.INPUT_VOCAB_FILE), 'rb') as fin:
# input_vocab = dill.load(fin)
#with open(os.path.join(path, cls.OUTPUT_VOCAB_FILE), 'rb') as fin:
# output_vocab = dill.load(fin)
opt = resume_checkpoint['opt']
print('the fking model is,', type(model))
return Checkpoint(model=model,
opt=opt,
epoch=resume_checkpoint['epoch'],
step=resume_checkpoint['step'],
path=path)
@classmethod
def get_latest_checkpoint(cls, experiment_path):
"""
Given the path to an experiment directory, returns the path to the last saved checkpoint's subdirectory.
Precondition: at least one checkpoint has been made (i.e., latest checkpoint subdirectory exists).
Args:
experiment_path (str): path to the experiment directory
Returns:
str: path to the last saved checkpoint's subdirectory
"""
checkpoints_path = os.path.join(experiment_path, cls.CHECKPOINT_DIR_NAME)
all_times = sorted(os.listdir(checkpoints_path), reverse=True)
return os.path.join(checkpoints_path, all_times[0])
@classmethod
def get_all_checkpoints(cls, experiment_path):
"""
Given the path to an experiment directory, returns the path to the last saved checkpoint's subdirectory.
Precondition: at least one checkpoint has been made (i.e., latest checkpoint subdirectory exists).
Args:
experiment_path (str): path to the experiment directory
Returns:
str: path to the last saved checkpoint's subdirectory
"""
checkpoints_path = os.path.join(experiment_path, cls.CHECKPOINT_DIR_NAME)
all_times = sorted(os.listdir(checkpoints_path))
return [os.path.join(checkpoints_path, ckpt) for ckpt in all_times]