Skip to content

Commit

Permalink
transfer example
Browse files Browse the repository at this point in the history
  • Loading branch information
pesser committed Apr 1, 2019
1 parent db88509 commit 0a7db40
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,5 +474,41 @@ def transfer(self, x_encode, c_encode, c_decode):
if opt.retrain:
model.reset_global_step()
model.fit(batches, valid_batches)
elif opt.mode == "transfer":
batch_size = config["batch_size"]
img_shape = 2*[config["spatial_size"]] + [3]
data_shape = [batch_size] + img_shape
box_factor = config["box_factor"]
data_index = config["data_index"]

valid_batches = get_batches(data_shape, data_index,
box_factor = box_factor, train = False)

model = Model(config, out_dir, logger)
assert opt.checkpoint is not None
model.restore_graph(opt.checkpoint)

for step in trange(10):
X_batch, C_batch, XN_batch, CN_batch = next(valid_batches)
bs = X_batch.shape[0]
imgs = list()
imgs.append(np.zeros_like(X_batch[0,...]))
for r in range(bs):
imgs.append(C_batch[r,...])
for i in range(bs):
x_infer = XN_batch[i,...]
c_infer = CN_batch[i,...]
imgs.append(X_batch[i,...])

x_infer_batch = x_infer[None,...].repeat(bs, axis = 0)
c_infer_batch = c_infer[None,...].repeat(bs, axis = 0)
c_generate_batch = C_batch
results = model.transfer(x_infer_batch, c_infer_batch, c_generate_batch)
for j in range(bs):
imgs.append(results[j,...])
imgs = np.stack(imgs, axis = 0)
plot_batch(imgs, os.path.join(
out_dir,
"transfer_{}.png".format(step)))
else:
raise NotImplemented()

0 comments on commit 0a7db40

Please sign in to comment.