Skip to content

Commit

Permalink
Resume as default
Browse files Browse the repository at this point in the history
Cool trick to load always latest pkl file and kimg to continue trainning. Specially useful in colab.
  • Loading branch information
andersonfaaria committed Aug 8, 2020
1 parent ba27ab0 commit 4285a84
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 11 deletions.
4 changes: 2 additions & 2 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def run(dataset, train_dir, config, d_aug, diffaug_policy, cond, ops, jpg_data,
sc = dnnlib.SubmitConfig() # Options for dnnlib.submit_run().
tf_config = {'rnd.np_random_seed': 1000} # Options for tflib.init_tf().
G.impl = D.impl = ops

# resolutions
data_res = basename(tfr_file).split('-')[-1].split('x') # get resolution from dataset filename
data_res = list(reversed([int(x) for x in data_res])) # convert to int list
Expand Down Expand Up @@ -129,7 +129,7 @@ def main():
# main
parser.add_argument('--dataset', required=True, help='Training dataset path', metavar='DIR')
parser.add_argument('--train_dir', default='train', help='Root directory for training results (default: %(default)s)', metavar='DIR')
parser.add_argument('--resume', default=None, help='Resume checkpoint path. None = from scratch', metavar='DIR')
parser.add_argument('--resume', default='latest', help='Resume checkpoint path. None = from scratch', metavar='DIR')
parser.add_argument('--resume_kimg', default=0, type=int, help='Resume training from (in thousands of images)', metavar='N')
parser.add_argument('--lod_step_kimg', default=20, type=int, help='Thousands of images per LOD/layer step (default: %(default)s)', metavar='N')
parser.add_argument('--finetune', action='store_true', help='finetune trained model (start from 1e4 kimg, stop when enough)')
Expand Down
9 changes: 3 additions & 6 deletions src/training/dataset_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import dnnlib

from util.progress_bar import ProgressBar
from tqdm import tqdm

class TFRecordExporter:
def __init__(self, data_dir, expected_images, print_progress=False, progress_interval=10):
Expand Down Expand Up @@ -124,11 +124,9 @@ def create_from_images(dataset, jpg=False, shuffle=True, size=None):

with TFRecordExporter(dataset, len(image_filenames)) as tfr:
order = tfr.choose_shuffled_order() if shuffle else np.arange(len(image_filenames))
pbar = ProgressBar(order.size)
for idx in range(order.size):
for idx in tqdm(range(order.size)):
img_path = image_filenames[order[idx]]
tfr.add_image(img_path, jpg=jpg, size=size)
pbar.upd()
return tfr.tfr_file, len(image_filenames)


Expand All @@ -142,5 +140,4 @@ def main():
create_from_images(args.dataset, args.jpg, args.shuffle)

if __name__ == "__main__":
main()

main()
6 changes: 4 additions & 2 deletions src/training/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ def load_pkl(file_or_url):
return pickle.load(file, encoding='latin1')

def locate_latest_pkl(train_dir):
allpickles = sorted(glob.glob(os.path.join(train_dir, 'snapshot-*.pkl')))
allpickles = sorted(glob.glob(os.path.join(train_dir, '0*', 'snapshot-*.pkl')))
if len(allpickles) == 0:
return None, 0.
latest_pickle = allpickles[-1]
kimg = int(os.path.splitext(latest_pickle)[0].split('-')[-1])
kimg = os.path.splitext(latest_pickle)[0].split('-')[-1]
if kimg == 'final':
kimg = os.path.splitext(allpickles[-2])[0].split('-')[-1]
return latest_pickle, float(kimg)

def save_pkl(obj, filename):
Expand Down
4 changes: 3 additions & 1 deletion src/training/training_loop_diffaug.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def training_loop(
network_snapshot_ticks = 50, # How often to save network snapshots? None = only save 'networks-final.pkl'.
save_tf_graph = False, # Include full TensorFlow computation graph in the tfevents file?
save_weight_histograms = False, # Include weight histograms in the tfevents file?
resume_pkl = None, # Network pickle to resume training from, None = train from scratch.
resume_pkl = 'latest', # Network pickle to resume training from, None = train from scratch.
resume_kimg = 0.0, # Assumed training progress at the beginning. Affects reporting and training schedule.
resume_time = 0.0, # Assumed wallclock time at the beginning. Affects reporting.
restore_partial_fn = None, # Filename of network for partial restore
Expand Down Expand Up @@ -75,6 +75,8 @@ def training_loop(

# Construct or load networks.
with tf.device('/gpu:0'):
if resume_pkl == 'latest':
resume_pkl, resume_kimg = misc.locate_latest_pkl(dnnlib.submit_config.run_dir_root)
if resume_pkl is None or resume_with_new_nets:
print(' Constructing networks...')
G = tflib.Network('G', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **G_args)
Expand Down

0 comments on commit 4285a84

Please sign in to comment.