diff --git a/batches.py b/batches.py index 5cde8f16..c4639373 100644 --- a/batches.py +++ b/batches.py @@ -7,6 +7,9 @@ import math +n_boxes = 8 + + class BufferedWrapper(object): """Fetch next batch asynchronuously to avoid bottleneck during GPU training.""" diff --git a/main.py b/main.py index 483e5cf9..56f9f9b8 100644 --- a/main.py +++ b/main.py @@ -3,20 +3,17 @@ config.gpu_options.allow_growth = False session = tf.Session(config = config) -import os, logging, shutil, datetime, time, math, pickle +import os, logging, shutil, datetime import glob import argparse import numpy as np from tqdm import tqdm, trange -import PIL import nn import models -from batches import get_batches, plot_batch, postprocess +from batches import get_batches, plot_batch, postprocess, n_boxes import deeploss -N_BOXES = 8 - def init_logging(out_base_dir): # get unique output directory based on current time @@ -45,7 +42,7 @@ def __init__(self, opt, out_dir, logger): self.batch_size = opt.batch_size self.img_shape = 2*[opt.spatial_size] + [3] redux = 2 - self.imgn_shape = 2*[opt.spatial_size//(2**redux)] + [N_BOXES*3] + self.imgn_shape = 2*[opt.spatial_size//(2**redux)] + [n_boxes*3] self.init_batches = opt.init_batches self.initial_lr = opt.lr @@ -156,7 +153,6 @@ def define_graph(self): self.lr_decay_end // 2, 3 * self.lr_decay_end // 4, 1e-6, 1.0, 1e-6, 1.0) - #kl_weight = tf.to_float(0.1) # initialization self.x_init = tf.placeholder( @@ -362,7 +358,6 @@ def log_result(self, result, **kwargs): for i in range(bs): x_infer = XN_batch[i,...] c_infer = CN_batch[i,...] - #imgs.append(x_infer) imgs.append(X_batch[i,...]) x_infer_batch = x_infer[None,...].repeat(bs, axis = 0) @@ -428,9 +423,9 @@ def transfer(self, x_encode, c_encode, c_decode): default_log_dir = os.path.join(os.getcwd(), "log") parser = argparse.ArgumentParser() - parser.add_argument("--data_index", required = True, help = "path to training or testing data index") + parser.add_argument("--data_index", required = True, help = "path to data index") parser.add_argument("--mode", default = "train", - choices=["train", "test", "mcmc", "add_reconstructions", "transfer"]) + choices=["train", "test", "add_reconstructions", "transfer"]) parser.add_argument("--log_dir", default = default_log_dir, help = "path to log into") parser.add_argument("--batch_size", default = 8, type = int, help = "batch size") parser.add_argument("--init_batches", default = 4, type = int, help = "number of batches for initialization")