Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#14 from qingqing01/save_checkpoint
Browse files Browse the repository at this point in the history
Do not save checkpoint if not set save_dir
  • Loading branch information
qingqing01 authored Mar 24, 2020
2 parents 839db7b + 9d7760c commit b1862cf
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 87 deletions.
24 changes: 15 additions & 9 deletions callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from progressbar import ProgressBar
from distributed import get_local_rank


def config_callbacks(callbacks=None,
model=None,
batch_size=None,
Expand All @@ -26,6 +27,7 @@ def config_callbacks(callbacks=None,
log_freq=2,
verbose=2,
save_freq=1,
save_dir=None,
metrics=None,
mode='train'):
cbks = callbacks or []
Expand All @@ -34,7 +36,7 @@ def config_callbacks(callbacks=None,
cbks = cbks + [ProgBarLogger(log_freq, verbose=verbose)]

if not any(isinstance(k, ModelCheckpoint) for k in cbks):
cbks = cbks + [ModelCheckpoint(save_freq)]
cbks = cbks + [ModelCheckpoint(save_freq, save_dir)]

cbk_list = CallbackList(cbks)
cbk_list.set_model(model)
Expand Down Expand Up @@ -209,9 +211,10 @@ def _updates(self, logs, mode):

def on_train_batch_end(self, step, logs=None):
logs = logs or {}
self.train_step = step
self.train_step += 1

if self.train_step % self.log_freq == 0 and self.verbose and get_local_rank() == 0:
if self.train_step % self.log_freq == 0 and self.verbose and get_local_rank(
) == 0:
# if steps is not None, last step will update in on_epoch_end
if self.steps and self.train_step < self.steps:
self._updates(logs, 'train')
Expand Down Expand Up @@ -247,21 +250,24 @@ def on_eval_end(self, logs=None):


class ModelCheckpoint(Callback):
def __init__(self, save_freq=1, save_file='output'):
def __init__(self, save_freq=1, save_dir=None):
self.save_freq = save_freq
self.save_file = save_file
self.save_dir = save_dir

def on_epoch_begin(self, epoch=None, logs=None):
self.epoch = epoch

def _is_save(self):
return self.model and self.save_dir and get_local_rank() == 0

def on_epoch_end(self, epoch, logs=None):
if self.model and self.epoch % self.save_freq == 0 and get_local_rank() == 0:
path = '{}/{}'.format(self.save_file, epoch)
if self._is_save() and self.epoch % self.save_freq == 0:
path = '{}/{}'.format(self.save_dir, epoch)
print('save checkpoint at {}'.format(path))
self.model.save(path)

def on_train_end(self, logs=None):
if self.model and get_local_rank() == 0:
path = '{}/final'.format(self.save_file)
if self._is_save():
path = '{}/final'.format(self.save_dir)
print('save checkpoint at {}'.format(path))
self.model.save(path)
20 changes: 11 additions & 9 deletions mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,24 +107,26 @@ def forward(self, inputs):

def main():
init_context('dynamic' if FLAGS.dynamic else 'static')

train_dataset = MnistDataset(mode='train')
val_dataset = MnistDataset(mode='test')

inputs = [Input([None, 784], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')]

model = MNIST()
optim = Momentum(
learning_rate=FLAGS.lr,
momentum=.9,
parameter_list=model.parameters())

learning_rate=FLAGS.lr, momentum=.9, parameter_list=model.parameters())

model.prepare(optim, CrossEntropy(), Accuracy(topk=(1, 2)), inputs, labels)
if FLAGS.resume is not None:
model.load(FLAGS.resume)

model.fit(train_dataset, val_dataset, epochs=FLAGS.epoch, batch_size=FLAGS.batch_size)

model.fit(train_dataset,
val_dataset,
epochs=FLAGS.epoch,
batch_size=FLAGS.batch_size,
save_dir='mnist_checkpoint')


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit b1862cf

Please sign in to comment.