diff --git a/batches_pg2.py b/batches_pg2.py index ff157fd8..5d297638 100644 --- a/batches_pg2.py +++ b/batches_pg2.py @@ -93,6 +93,9 @@ def tile(X, rows, cols): def plot_batch(X, out_path): """Save batch of images tiled.""" + n_channels = X.shape[3] + if n_channels > 3: + X = X[:,:,:,np.random.choice(n_channels, size = 3)] X = postprocess(X) rc = math.sqrt(X.shape[0]) rows = cols = math.ceil(rc) @@ -166,6 +169,193 @@ def make_joint_img(img_shape, jo, joints): return img +def valid_joints(*joints): + j = np.stack(joints) + return (j >= 0).all() + + +def zoom(img, factor, center = None): + shape = img.shape[:2] + if center is None or not valid_joints(center): + center = np.array(shape) / 2 + e1 = np.array([1,0]) + e2 = np.array([0,1]) + + dst_center = np.array(center) + dst_e1 = e1 * factor + dst_e2 = e2 * factor + + src = np.float32([center, center+e1, center+e2]) + dst = np.float32([dst_center, dst_center+dst_e1, dst_center+dst_e2]) + M = cv2.getAffineTransform(src, dst) + + return cv2.warpAffine(img, M, shape, flags = cv2.INTER_AREA, borderMode = cv2.BORDER_REPLICATE) + + +def get_crop(bpart, joints, jo, wh, o_w, o_h, ar = 1.0): + bpart_indices = [jo.index(b) for b in bpart] + part_src = np.float32(joints[bpart_indices]) + + # fall backs + if not valid_joints(part_src): + if bpart[0] == "lhip" and bpart[1] == "lknee": + bpart = ["lhip"] + bpart_indices = [jo.index(b) for b in bpart] + part_src = np.float32(joints[bpart_indices]) + elif bpart[0] == "rhip" and bpart[1] == "rknee": + bpart = ["rhip"] + bpart_indices = [jo.index(b) for b in bpart] + part_src = np.float32(joints[bpart_indices]) + + if not valid_joints(part_src): + return None + + if part_src.shape[0] == 1: + # leg fallback + a = part_src[0] + b = np.float32([a[0],o_h - 1]) + part_src = np.float32([a,b]) + + if part_src.shape[0] == 4: + pass + elif part_src.shape[0] == 3: + # lshoulder, rshoulder, cnose + segment = part_src[1] - part_src[0] + normal = np.array([-segment[1],segment[0]]) + if normal[1] > 0.0: + normal = -normal + + a = part_src[0] + normal + b = part_src[0] + c = part_src[1] + d = part_src[1] + normal + part_src = np.float32([a,b,c,d]) + else: + assert part_src.shape[0] == 2 + + segment = part_src[1] - part_src[0] + normal = np.array([-segment[1],segment[0]]) + alpha = ar / 2.0 + a = part_src[0] + alpha*normal + b = part_src[0] - alpha*normal + c = part_src[1] - alpha*normal + d = part_src[1] + alpha*normal + part_src = np.float32([a,b,c,d]) + + dst = np.float32([[0.0,0.0],[0.0,1.0],[1.0,1.0],[1.0,0.0]]) + part_dst = np.float32(wh * dst) + + M = cv2.getPerspectiveTransform(part_src, part_dst) + return M + + +def normalize(imgs, coords, stickmen, jo): + + out_imgs = list() + out_stickmen = list() + + bs = len(imgs) + for i in range(bs): + img = imgs[i] + joints = coords[i] + stickman = stickmen[i] + + h,w = img.shape[:2] + o_h = h + o_w = w + h = h // 4 + w = w // 4 + wh = np.array([w,h]) + wh = np.expand_dims(wh, 0) + + bparts = [ + ["lshoulder","lhip","rhip","rshoulder"], + ["lshoulder", "rshoulder", "rshoulder"], + ["lshoulder","lelbow"], + ["lelbow", "lwrist"], + ["rshoulder","relbow"], + ["relbow", "rwrist"], + ["lhip", "lknee"], + ["lknee", "lankle"], + ["rhip", "rknee"], + ["rknee", "rankle"]] + ar = 0.5 + + part_imgs = list() + part_stickmen = list() + for bpart in bparts: + part_img = np.zeros((h,w,3)) + part_stickman = np.zeros((h,w,3)) + M = get_crop(bpart, joints, jo, wh, o_w, o_h, ar) + + if M is not None: + part_img = cv2.warpPerspective(img, M, (h,w), borderMode = cv2.BORDER_REPLICATE) + part_stickman = cv2.warpPerspective(stickman, M, (h,w), borderMode = cv2.BORDER_REPLICATE) + + part_imgs.append(part_img) + part_stickmen.append(part_stickman) + img = np.concatenate(part_imgs, axis = 2) + stickman = np.concatenate(part_stickmen, axis = 2) + + """ + bpart = ["lshoulder","lhip","rhip","rshoulder"] + dst = np.float32([[0.0,0.0],[0.0,1.0],[1.0,1.0],[1.0,0.0]]) + bpart_indices = [jo.index(b) for b in bpart] + part_src = np.float32(joints[bpart_indices]) + part_dst = np.float32(wh * dst) + + M = cv2.getPerspectiveTransform(part_src, part_dst) + img = cv2.warpPerspective(img, M, (h,w), borderMode = cv2.BORDER_REPLICATE) + stickman = cv2.warpPerspective(stickman, M, (h,w), borderMode = cv2.BORDER_REPLICATE) + """ + + """ + # center of possible rescaling + c = joints[jo.index("cneck")] + + # find valid body part for scale estimation + a = joints[jo.index("lshoulder")] + b = joints[jo.index("lhip")] + target_length = 33.0 + if not valid_joints(a,b): + a = joints[jo.index("rshoulder")] + b = joints[jo.index("rhip")] + target_length = 33.0 + if not valid_joints(a,b): + a = joints[jo.index("rshoulder")] + b = joints[jo.index("relbow")] + target_length = 33.0 / 2 + if not valid_joints(a,b): + a = joints[jo.index("lshoulder")] + b = joints[jo.index("lelbow")] + target_length = 33.0 / 2 + if not valid_joints(a,b): + a = joints[jo.index("lwrist")] + b = joints[jo.index("lelbow")] + target_length = 33.0 / 2 + if not valid_joints(a,b): + a = joints[jo.index("rwrist")] + b = joints[jo.index("relbow")] + target_length = 33.0 / 2 + + if valid_joints(a,b): + body_length = np.linalg.norm(b - a) + factor = target_length / body_length + img = zoom(img, factor, center = c) + stickman = zoom(stickman, factor, center = c) + else: + factor = 0.25 + img = zoom(img, factor, center = c) + stickman = zoom(stickman, factor, center = c) + """ + + out_imgs.append(img) + out_stickmen.append(stickman) + out_imgs = np.stack(out_imgs) + out_stickmen = np.stack(out_stickmen) + return out_imgs, out_stickmen + + def make_mask_img(img_shape, jo, joints): scale_factor = img_shape[1] / 128 masks = 3*[None] @@ -316,6 +506,10 @@ def __next__(self): batch["imgs"] = batch["imgs"] * batch["masks"] + imgs, joints = normalize(batch["imgs"], batch["joints_coordinates"], batch["joints"], self.jo) + batch["norm_imgs"] = imgs + batch["norm_joints"] = joints + batch_list = [batch[k] for k in self.return_keys] return batch_list @@ -333,7 +527,7 @@ def get_batches( mask, fill_batches = True, shuffle = True, - return_keys = ["imgs", "joints"]): + return_keys = ["imgs", "joints", "norm_imgs", "norm_joints"]): """Buffered IndexFlow.""" flow = IndexFlow(shape, index_path, train, mask, fill_batches, shuffle, return_keys) return BufferedWrapper(flow) diff --git a/batches_pg2_vis.py b/batches_pg2_vis.py index a196d1c6..5ea4092f 100644 --- a/batches_pg2_vis.py +++ b/batches_pg2_vis.py @@ -5,85 +5,6 @@ import os import cv2 import math -from numpy.random import RandomState - - -def get_orientation(joints, jo): - return (min(joints[jo.index("lhip"),0],joints[jo.index("lshoulder"),0]) < - max(joints[jo.index("rhip"),0],joints[jo.index("rshoulder"),0])) - - -def flip(j,x,c): - x = cv2.flip(x, 1) - c = cv2.flip(c, 1) - width = x.shape[1] - j[:,0] = width - 1 - j[:,0] - return j,x,c - - -def register(xs,cs,srcs,targets,ys,jo): - #print("Registering") - bs = xs.shape[0] - - xx = list() - cc = list() - for i in range(bs): - x = xs[i] - c = cs[i] - src = srcs[i] - target = targets[i] - - valid_mask = (src >= 0.0) & (target >= 0.0) - valid_mask = np.all(valid_mask, axis = 1) - - valid_src = src[valid_mask] - valid_target = target[valid_mask] - - fall_back = False - - if np.sum(valid_mask) >= 4: - # figure out orientation and flip if necessary to find restricted - # affine transforms - src_orient = get_orientation(src, jo) - dst_orient = get_orientation(target, jo) - if src_orient != dst_orient: - valid_src, x, c = flip(valid_src, x, c) - - affine = True - if affine: - M = cv2.estimateRigidTransform(valid_src, valid_target, fullAffine = False) - if M is None: - fall_back = True - else: - warped_x = cv2.warpAffine(x, M, x.shape[:2], borderMode = cv2.BORDER_REPLICATE) - xx.append(warped_x) - - warped_c = cv2.warpAffine(c, M, x.shape[:2], borderMode = cv2.BORDER_REPLICATE) - cc.append(warped_c) - else: - M, mask = cv2.findHomography(valid_src, valid_target, cv2.RANSAC,5.0) - #M, mask = cv2.findHomography(valid_src, valid_target) - - warped_x = cv2.warpPerspective(x, M, x.shape[:2], borderMode = cv2.BORDER_REPLICATE) - xx.append(warped_x) - - warped_c = cv2.warpPerspective(c, M, x.shape[:2], borderMode = cv2.BORDER_REPLICATE) - cc.append(warped_c) - else: - fall_back = True - - if fall_back: - xx.append(x) - cc.append(c) - - xx = np.stack(xx) - cc = np.stack(cc) - - #plot_batch(xs,"xs.png") - #plot_batch(xx,"xx.png") - #plot_batch(ys,"ys.png") - - return xx,cc class BufferedWrapper(object): @@ -92,7 +13,6 @@ class BufferedWrapper(object): def __init__(self, gen): self.gen = gen self.n = gen.n - self.jo = gen.jo self.pool = ThreadPool(1) self._async_next() @@ -173,6 +93,9 @@ def tile(X, rows, cols): def plot_batch(X, out_path): """Save batch of images tiled.""" + n_channels = X.shape[3] + if n_channels > 3: + X = X[:,:,:,np.random.choice(n_channels, size = 3)] X = postprocess(X) rc = math.sqrt(X.shape[0]) rows = cols = math.ceil(rc) @@ -246,6 +169,193 @@ def make_joint_img(img_shape, jo, joints): return img +def valid_joints(*joints): + j = np.stack(joints) + return (j >= 0).all() + + +def zoom(img, factor, center = None): + shape = img.shape[:2] + if center is None or not valid_joints(center): + center = np.array(shape) / 2 + e1 = np.array([1,0]) + e2 = np.array([0,1]) + + dst_center = np.array(center) + dst_e1 = e1 * factor + dst_e2 = e2 * factor + + src = np.float32([center, center+e1, center+e2]) + dst = np.float32([dst_center, dst_center+dst_e1, dst_center+dst_e2]) + M = cv2.getAffineTransform(src, dst) + + return cv2.warpAffine(img, M, shape, flags = cv2.INTER_AREA, borderMode = cv2.BORDER_REPLICATE) + + +def get_crop(bpart, joints, jo, wh, o_w, o_h, ar = 1.0): + bpart_indices = [jo.index(b) for b in bpart] + part_src = np.float32(joints[bpart_indices]) + + # fall backs + if not valid_joints(part_src): + if bpart[0] == "lhip" and bpart[1] == "lknee": + bpart = ["lhip"] + bpart_indices = [jo.index(b) for b in bpart] + part_src = np.float32(joints[bpart_indices]) + elif bpart[0] == "rhip" and bpart[1] == "rknee": + bpart = ["rhip"] + bpart_indices = [jo.index(b) for b in bpart] + part_src = np.float32(joints[bpart_indices]) + + if not valid_joints(part_src): + return None + + if part_src.shape[0] == 1: + # leg fallback + a = part_src[0] + b = np.float32([a[0],o_h - 1]) + part_src = np.float32([a,b]) + + if part_src.shape[0] == 4: + pass + elif part_src.shape[0] == 3: + # lshoulder, rshoulder, cnose + segment = part_src[1] - part_src[0] + normal = np.array([-segment[1],segment[0]]) + if normal[1] > 0.0: + normal = -normal + + a = part_src[0] + normal + b = part_src[0] + c = part_src[1] + d = part_src[1] + normal + part_src = np.float32([a,b,c,d]) + else: + assert part_src.shape[0] == 2 + + segment = part_src[1] - part_src[0] + normal = np.array([-segment[1],segment[0]]) + alpha = ar / 2.0 + a = part_src[0] + alpha*normal + b = part_src[0] - alpha*normal + c = part_src[1] - alpha*normal + d = part_src[1] + alpha*normal + part_src = np.float32([a,b,c,d]) + + dst = np.float32([[0.0,0.0],[0.0,1.0],[1.0,1.0],[1.0,0.0]]) + part_dst = np.float32(wh * dst) + + M = cv2.getPerspectiveTransform(part_src, part_dst) + return M + + +def normalize(imgs, coords, stickmen, jo): + + out_imgs = list() + out_stickmen = list() + + bs = len(imgs) + for i in range(bs): + img = imgs[i] + joints = coords[i] + stickman = stickmen[i] + + h,w = img.shape[:2] + o_h = h + o_w = w + h = h // 4 + w = w // 4 + wh = np.array([w,h]) + wh = np.expand_dims(wh, 0) + + bparts = [ + ["lshoulder","lhip","rhip","rshoulder"], + ["lshoulder", "rshoulder", "rshoulder"], + ["lshoulder","lelbow"], + ["lelbow", "lwrist"], + ["rshoulder","relbow"], + ["relbow", "rwrist"], + ["lhip", "lknee"], + ["lknee", "lankle"], + ["rhip", "rknee"], + ["rknee", "rankle"]] + ar = 0.5 + + part_imgs = list() + part_stickmen = list() + for bpart in bparts: + part_img = np.zeros((h,w,3)) + part_stickman = np.zeros((h,w,3)) + M = get_crop(bpart, joints, jo, wh, o_w, o_h, ar) + + if M is not None: + part_img = cv2.warpPerspective(img, M, (h,w), borderMode = cv2.BORDER_REPLICATE) + part_stickman = cv2.warpPerspective(stickman, M, (h,w), borderMode = cv2.BORDER_REPLICATE) + + part_imgs.append(part_img) + part_stickmen.append(part_stickman) + img = np.concatenate(part_imgs, axis = 2) + stickman = np.concatenate(part_stickmen, axis = 2) + + """ + bpart = ["lshoulder","lhip","rhip","rshoulder"] + dst = np.float32([[0.0,0.0],[0.0,1.0],[1.0,1.0],[1.0,0.0]]) + bpart_indices = [jo.index(b) for b in bpart] + part_src = np.float32(joints[bpart_indices]) + part_dst = np.float32(wh * dst) + + M = cv2.getPerspectiveTransform(part_src, part_dst) + img = cv2.warpPerspective(img, M, (h,w), borderMode = cv2.BORDER_REPLICATE) + stickman = cv2.warpPerspective(stickman, M, (h,w), borderMode = cv2.BORDER_REPLICATE) + """ + + """ + # center of possible rescaling + c = joints[jo.index("cneck")] + + # find valid body part for scale estimation + a = joints[jo.index("lshoulder")] + b = joints[jo.index("lhip")] + target_length = 33.0 + if not valid_joints(a,b): + a = joints[jo.index("rshoulder")] + b = joints[jo.index("rhip")] + target_length = 33.0 + if not valid_joints(a,b): + a = joints[jo.index("rshoulder")] + b = joints[jo.index("relbow")] + target_length = 33.0 / 2 + if not valid_joints(a,b): + a = joints[jo.index("lshoulder")] + b = joints[jo.index("lelbow")] + target_length = 33.0 / 2 + if not valid_joints(a,b): + a = joints[jo.index("lwrist")] + b = joints[jo.index("lelbow")] + target_length = 33.0 / 2 + if not valid_joints(a,b): + a = joints[jo.index("rwrist")] + b = joints[jo.index("relbow")] + target_length = 33.0 / 2 + + if valid_joints(a,b): + body_length = np.linalg.norm(b - a) + factor = target_length / body_length + img = zoom(img, factor, center = c) + stickman = zoom(stickman, factor, center = c) + else: + factor = 0.25 + img = zoom(img, factor, center = c) + stickman = zoom(stickman, factor, center = c) + """ + + out_imgs.append(img) + out_stickmen.append(stickman) + out_imgs = np.stack(out_imgs) + out_stickmen = np.stack(out_stickmen) + return out_imgs, out_stickmen + + def make_mask_img(img_shape, jo, joints): scale_factor = img_shape[1] / 128 masks = 3*[None] @@ -310,10 +420,7 @@ def __init__( mask = True, fill_batches = True, shuffle = True, - return_keys = ["imgs", "joints"], - prefix = None, - seed = 1): - self.prng = RandomState(seed) + return_keys = ["imgs", "joints"]): self.shape = shape self.batch_size = self.shape[0] self.img_shape = self.shape[1:] @@ -327,8 +434,6 @@ def __init__( self.return_keys = return_keys self.jo = self.index["joint_order"] - if prefix is None: - prefix = "" self.indices = np.array( [i for i in range(len(self.index["train"])) if self._filter(i)]) @@ -420,11 +525,10 @@ def __next__(self): # apply mask to images batch["imgs"] = batch["imgs"] * batch["masks"] - valid_joints = ["lhip","rhip","lshoulder","rshoulder"] - valid_joint_indices = [self.jo.index(j) for j in valid_joints] - invalid_joint_indices = [i for i in range(len(self.jo)) if i not in valid_joint_indices] - for i in range(len(batch["joints_coordinates"])): - batch["joints_coordinates"][i][invalid_joint_indices,:] = -100.0 + + imgs, joints = normalize(batch["imgs"], batch["joints_coordinates"], batch["joints"], self.jo) + batch["norm_imgs"] = imgs + batch["norm_joints"] = joints batch_list = [batch[k] for k in self.return_keys] return batch_list @@ -433,7 +537,7 @@ def __next__(self): def shuffle(self): self.batch_start = 0 if self.shuffle_: - self.prng.shuffle(self.indices) + np.random.shuffle(self.indices) def get_batches( @@ -443,10 +547,9 @@ def get_batches( mask, fill_batches = True, shuffle = True, - return_keys = ["imgs", "joints"], - prefix = None): + return_keys = ["imgs", "joints", "norm_imgs", "norm_joints"]): """Buffered IndexFlow.""" - flow = IndexFlow(shape, index_path, train, mask, fill_batches, shuffle, return_keys,prefix) + flow = IndexFlow(shape, index_path, train, mask, fill_batches, shuffle, return_keys) return BufferedWrapper(flow) diff --git a/main.py b/main.py index 19f49882..39e20b01 100644 --- a/main.py +++ b/main.py @@ -58,6 +58,8 @@ class Model(object): def __init__(self, opt, out_dir, logger): self.batch_size = opt.batch_size self.img_shape = 2*[opt.spatial_size] + [3] + redux = 2 + self.imgn_shape = 2*[opt.spatial_size//(2**redux)] + [30] self.init_batches = opt.init_batches self.initial_lr = opt.lr @@ -83,17 +85,21 @@ def __init__(self, opt, out_dir, logger): def define_models(self): n_latent_scales = 2 - n_scales = 1 + int(np.round(np.log2(self.img_shape[0]))) + n_scales = 1 + int(np.round(np.log2(self.img_shape[0]))) - 2 + n_filters = 32 + redux = 2 self.enc_up_pass = models.make_model( "enc_up", models.enc_up, - n_scales = n_scales) + n_scales = n_scales - redux, + n_filters = n_filters*2**redux) self.enc_down_pass = models.make_model( "enc_down", models.enc_down, - n_scales = n_scales, + n_scales = n_scales - redux, n_latent_scales = n_latent_scales) self.dec_up_pass = models.make_model( "dec_up", models.dec_up, - n_scales = n_scales) + n_scales = n_scales, + n_filters = n_filters) self.dec_down_pass = models.make_model( "dec_down", models.dec_down, n_scales = n_scales, @@ -102,10 +108,10 @@ def define_models(self): "dec_params", models.dec_parameters) - def train_forward_pass(self, x, c, dropout_p, init = False): + def train_forward_pass(self, x, c, xn, cn, dropout_p, init = False): kwargs = {"init": init, "dropout_p": dropout_p} # encoder - hs = self.enc_up_pass(x, c, **kwargs) + hs = self.enc_up_pass(xn, cn, **kwargs) es, qs, zs_posterior = self.enc_down_pass(hs, **kwargs) # decoder gs = self.dec_up_pass(c, **kwargs) @@ -146,7 +152,7 @@ def sample(self, params, **kwargs): def likelihood_loss(self, x, params): - return self.vgg19.make_loss_op(x, params) + return 5.0*self.vgg19.make_loss_op(x, params) def define_graph(self): @@ -162,9 +168,9 @@ def define_graph(self): 0.0, self.initial_lr) kl_weight = nn.make_linear_var( global_step, - self.lr_decay_begin, self.lr_decay_end // 2, - 1e-3, 1.0, - 1e-3, 1.0) + self.lr_decay_end // 2, 3 * self.lr_decay_end // 4, + 1e-6, 1.0, + 1e-6, 1.0) #kl_weight = tf.to_float(0.1) # initialization @@ -174,7 +180,16 @@ def define_graph(self): self.c_init = tf.placeholder( tf.float32, shape = [self.init_batches * self.batch_size] + self.img_shape) - _ = self.train_forward_pass(self.x_init, self.c_init, dropout_p = self.dropout_p, init = True) + self.xn_init = tf.placeholder( + tf.float32, + shape = [self.init_batches * self.batch_size] + self.imgn_shape) + self.cn_init = tf.placeholder( + tf.float32, + shape = [self.init_batches * self.batch_size] + self.imgn_shape) + _ = self.train_forward_pass( + self.x_init, self.c_init, + self.xn_init, self.cn_init, + dropout_p = self.dropout_p, init = True) # training self.x = tf.placeholder( @@ -183,8 +198,17 @@ def define_graph(self): self.c = tf.placeholder( tf.float32, shape = [self.batch_size] + self.img_shape) + self.xn = tf.placeholder( + tf.float32, + shape = [self.batch_size] + self.imgn_shape) + self.cn = tf.placeholder( + tf.float32, + shape = [self.batch_size] + self.imgn_shape) # compute parameters of model distribution - params, qs, ps, activations = self.train_forward_pass(self.x, self.c, dropout_p = self.dropout_p) + params, qs, ps, activations = self.train_forward_pass( + self.x, self.c, + self.xn, self.cn, + dropout_p = self.dropout_p) # sample from model distribution sample = self.sample(params) # maximize likelihood @@ -200,7 +224,10 @@ def define_graph(self): test_sample = self.sample(test_forward) # reconstruction - reconstruction_params, _, _, _ = self.train_forward_pass(self.x, self.c, dropout_p = 0.0) + reconstruction_params, _, _, _ = self.train_forward_pass( + self.x, self.c, + self.xn, self.cn, + dropout_p = 0.0) self.reconstruction = self.sample(reconstruction_params) # optimization @@ -224,6 +251,8 @@ def define_graph(self): self.img_ops["test_sample"] = test_sample self.img_ops["x"] = self.x self.img_ops["c"] = self.c + for i in range(10): + self.img_ops["xn{}".format(i)] = self.xn[:,:,:,i*3:(i+1)*3] for i, l in enumerate(self.vgg19.losses): self.log_ops["vgg_loss_{}".format(i)] = l @@ -253,6 +282,8 @@ def init_graph(self, init_batch): self.saver = tf.train.Saver(self.variables) initializer_op = tf.variables_initializer(self.variables) session.run(initializer_op, { + self.xn_init: init_batch[2], + self.cn_init: init_batch[3], self.x_init: init_batch[0], self.c_init: init_batch[1]}) self.logger.info("Initialized model from scratch") @@ -271,8 +302,10 @@ def fit(self, batches, valid_batches = None): start_step = self.log_ops["global_step"].eval(session) self.valid_batches = valid_batches for batch in trange(start_step, self.lr_decay_end): - X_batch, C_batch = next(batches) + X_batch, C_batch, XN_batch, CN_batch = next(batches) feed_dict = { + self.xn: XN_batch, + self.cn: CN_batch, self.x: X_batch, self.c: C_batch} fetch_dict = {"train": self.train_op} @@ -301,8 +334,10 @@ def log_result(self, result, **kwargs): if self.valid_batches is not None: # validation run - X_batch, C_batch = next(self.valid_batches) + X_batch, C_batch, XN_batch, CN_batch = next(self.valid_batches) feed_dict = { + self.xn: XN_batch, + self.cn: CN_batch, self.x: X_batch, self.c: C_batch} fetch_dict = dict() @@ -329,7 +364,7 @@ def log_result(self, result, **kwargs): if global_step % self.test_frequency == 0: if self.valid_batches is not None: # testing - X_batch, C_batch = next(self.valid_batches) + X_batch, C_batch, XN_batch, CN_batch = next(self.valid_batches) x_gen = self.test(C_batch) for k in x_gen: plot_batch(x_gen[k], os.path.join( @@ -342,9 +377,10 @@ def log_result(self, result, **kwargs): for r in range(bs): imgs.append(C_batch[r,...]) for i in range(bs): - x_infer = X_batch[i,...] - c_infer = C_batch[i,...] - imgs.append(x_infer) + x_infer = XN_batch[i,...] + c_infer = CN_batch[i,...] + #imgs.append(x_infer) + imgs.append(X_batch[i,...]) x_infer_batch = x_infer[None,...].repeat(bs, axis = 0) c_infer_batch = c_infer[None,...].repeat(bs, axis = 0) @@ -406,8 +442,8 @@ def transfer(self, x_encode, c_encode, c_decode): self.c_generator = tf.placeholder( tf.float32, shape = [self.batch_size] + self.img_shape) - infer_x = self.x - infer_c = self.c + infer_x = self.xn + infer_c = self.cn generate_c = self.c_generator transfer_params = self.transfer_pass(infer_x, infer_c, generate_c) self.transfer_mean_sample = self.sample(transfer_params) @@ -415,8 +451,8 @@ def transfer(self, x_encode, c_encode, c_decode): return session.run( self.transfer_mean_sample, { - self.x: x_encode, - self.c: c_encode, + self.xn: x_encode, + self.cn: c_encode, self.c_generator: c_decode}) @@ -428,10 +464,10 @@ def transfer(self, x_encode, c_encode, c_decode): parser.add_argument("--mode", default = "train", choices=["train", "test", "mcmc", "add_reconstructions", "transfer"]) parser.add_argument("--log_dir", default = default_log_dir, help = "path to log into") - parser.add_argument("--batch_size", default = 16, type = int, help = "batch size") + parser.add_argument("--batch_size", default = 8, type = int, help = "batch size") parser.add_argument("--init_batches", default = 4, type = int, help = "number of batches for initialization") parser.add_argument("--checkpoint", help = "path to checkpoint to restore") - parser.add_argument("--spatial_size", default = 128, type = int, help = "spatial size to resize images to") + parser.add_argument("--spatial_size", default = 256, type = int, help = "spatial size to resize images to") parser.add_argument("--lr", default = 1e-3, type = float, help = "initial learning rate") parser.add_argument("--lr_decay_begin", default = 1000, type = int, help = "steps after which to begin linear lr decay") parser.add_argument("--lr_decay_end", default = 100000, type = int, help = "step at which lr is zero, i.e. number of training steps") @@ -537,7 +573,7 @@ def process_batches(batches): elif opt.mode == "transfer": if not opt.checkpoint: - opt.checkpoint = "log/2017-10-19T23:41:03/checkpoints/model.ckpt-100000" + opt.checkpoint = "log/2017-10-24T16:34:09/checkpoints/model.ckpt-100000" batch_size = opt.batch_size img_shape = 2*[opt.spatial_size] + [3] data_shape = [batch_size] + img_shape @@ -548,16 +584,16 @@ def process_batches(batches): ids = ["00038", "00281", "01166", "x", "06909", "y", "07586", "07607", "z", "09874"] for step in trange(10): - X_batch, C_batch = next(valid_batches) + X_batch, C_batch, XN_batch, CN_batch = next(valid_batches) bs = X_batch.shape[0] imgs = list() imgs.append(np.zeros_like(X_batch[0,...])) for r in range(bs): imgs.append(C_batch[r,...]) for i in range(bs): - x_infer = X_batch[i,...] - c_infer = C_batch[i,...] - imgs.append(x_infer) + x_infer = XN_batch[i,...] + c_infer = CN_batch[i,...] + imgs.append(X_batch[i,...]) x_infer_batch = x_infer[None,...].repeat(bs, axis = 0) c_infer_batch = c_infer[None,...].repeat(bs, axis = 0) diff --git a/maindiff.py b/maindiff.py index 79647a28..f572d168 100644 --- a/maindiff.py +++ b/maindiff.py @@ -34,3 +34,41 @@ plot_batch(imgs, os.path.join( out_dir, "transfer_{}.png".format(ids[step]))) + + + +# boxnormalized + elif opt.mode == "transfer": + if not opt.checkpoint: + opt.checkpoint = "log/2017-10-24T16:34:09/checkpoints/model.ckpt-100000" + batch_size = opt.batch_size + img_shape = 2*[opt.spatial_size] + [3] + data_shape = [batch_size] + img_shape + valid_batches = get_batches(data_shape, opt.data_index, + mask = opt.mask, train = False) + model = Model(opt, out_dir, logger) + model.restore_graph(opt.checkpoint) + + ids = ["00038", "00281", "01166", "x", "06909", "y", "07586", "07607", "z", "09874"] + for step in trange(10): + X_batch, C_batch, XN_batch, CN_batch = next(valid_batches) + bs = X_batch.shape[0] + imgs = list() + imgs.append(np.zeros_like(X_batch[0,...])) + for r in range(bs): + imgs.append(C_batch[r,...]) + for i in range(bs): + x_infer = XN_batch[i,...] + c_infer = CN_batch[i,...] + imgs.append(X_batch[i,...]) + + x_infer_batch = x_infer[None,...].repeat(bs, axis = 0) + c_infer_batch = c_infer[None,...].repeat(bs, axis = 0) + c_generate_batch = C_batch + results = model.transfer(x_infer_batch, c_infer_batch, c_generate_batch) + for j in range(bs): + imgs.append(results[j,...]) + imgs = np.stack(imgs, axis = 0) + plot_batch(imgs, os.path.join( + out_dir, + "transfer_{}.png".format(ids[step]))) diff --git a/models.py b/models.py index 0fac0b3f..b0023b89 100644 --- a/models.py +++ b/models.py @@ -250,7 +250,8 @@ def enc_up( # outputs hs = [] # prepare input - xc = tf.concat([x,c], axis = -1) + #xc = tf.concat([x,c], axis = -1) + xc = x h = nn.nin(xc, n_filters) for l in range(n_scales): # level module