Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
  • Loading branch information
pesser committed Mar 17, 2018
1 parent a4df9d8 commit 152728b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
3 changes: 3 additions & 0 deletions batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import math


n_boxes = 8


class BufferedWrapper(object):
"""Fetch next batch asynchronuously to avoid bottleneck during GPU
training."""
Expand Down
15 changes: 5 additions & 10 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,17 @@
config.gpu_options.allow_growth = False
session = tf.Session(config = config)

import os, logging, shutil, datetime, time, math, pickle
import os, logging, shutil, datetime
import glob
import argparse
import numpy as np
from tqdm import tqdm, trange
import PIL

import nn
import models
from batches import get_batches, plot_batch, postprocess
from batches import get_batches, plot_batch, postprocess, n_boxes
import deeploss

N_BOXES = 8


def init_logging(out_base_dir):
# get unique output directory based on current time
Expand Down Expand Up @@ -45,7 +42,7 @@ def __init__(self, opt, out_dir, logger):
self.batch_size = opt.batch_size
self.img_shape = 2*[opt.spatial_size] + [3]
redux = 2
self.imgn_shape = 2*[opt.spatial_size//(2**redux)] + [N_BOXES*3]
self.imgn_shape = 2*[opt.spatial_size//(2**redux)] + [n_boxes*3]
self.init_batches = opt.init_batches

self.initial_lr = opt.lr
Expand Down Expand Up @@ -156,7 +153,6 @@ def define_graph(self):
self.lr_decay_end // 2, 3 * self.lr_decay_end // 4,
1e-6, 1.0,
1e-6, 1.0)
#kl_weight = tf.to_float(0.1)

# initialization
self.x_init = tf.placeholder(
Expand Down Expand Up @@ -362,7 +358,6 @@ def log_result(self, result, **kwargs):
for i in range(bs):
x_infer = XN_batch[i,...]
c_infer = CN_batch[i,...]
#imgs.append(x_infer)
imgs.append(X_batch[i,...])

x_infer_batch = x_infer[None,...].repeat(bs, axis = 0)
Expand Down Expand Up @@ -428,9 +423,9 @@ def transfer(self, x_encode, c_encode, c_decode):
default_log_dir = os.path.join(os.getcwd(), "log")

parser = argparse.ArgumentParser()
parser.add_argument("--data_index", required = True, help = "path to training or testing data index")
parser.add_argument("--data_index", required = True, help = "path to data index")
parser.add_argument("--mode", default = "train",
choices=["train", "test", "mcmc", "add_reconstructions", "transfer"])
choices=["train", "test", "add_reconstructions", "transfer"])
parser.add_argument("--log_dir", default = default_log_dir, help = "path to log into")
parser.add_argument("--batch_size", default = 8, type = int, help = "batch size")
parser.add_argument("--init_batches", default = 4, type = int, help = "number of batches for initialization")
Expand Down

0 comments on commit 152728b

Please sign in to comment.