Skip to content

Commit

Permalink
Sampler almost there
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed May 20, 2024
1 parent 196e5b7 commit f16c8e7
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions high_order_implicit_representation/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ class Text2ImageSampler(Callback):
def __init__(self, filenames, batch_size):
self._dataset = Text2ImageRenderDataset(filenames)
self._dataloader = DataLoader(
self._dataset, batch_size=batch_size, shuffle=False
self._dataset, batch_size=1, shuffle=False
)
self._batch_size = batch_size

Expand All @@ -257,38 +257,41 @@ def on_train_epoch_end(
) -> None:
pl_module.eval()
with torch.no_grad():

y_hat_list = []

print("We are calling this")
image_count=0
for caption_embedding, flattened_position, image in self._dataloader:
flattened_position=flattened_position[0]
size = len(flattened_position)
y_hat_list = []
for i in range(0, size, self._batch_size):
# embed = caption_embedding[i:(i+self._batch_size)]
# rgb = flattened_image[i:(i+self._batch_size)]
pos = flattened_position[i : (i + self._batch_size)]


for i in range(size//self._batch_size):
embed_single = caption_embedding.to(pl_module.device)

# rgb = flattened_image[i:(i+self._batch_size)]
pos = flattened_position[i : (i + self._batch_size)].to(pl_module.device)
embed = embed_single.repeat(pos.shape[0],1)
res = pl_module(
caption_embedding,
flattened_position[
i * self._batch_size : (i + 1) * self._batch_size
],
embed,
pos,
)

y_hat_list.append(res.detach().cpu())

y_hat = torch.vstack(y_hat_list)

ans = y_hat.reshape(
self._image.shape[0], self._image.shape[1], self._image.shape[2]
image.shape[1], image.shape[2], 3
)

print('ans.shape', ans.shape)

ans = 0.5 * (ans + 1.0)

f, axarr = plt.subplots(1, 2)
axarr[0].imshow(ans.detach().cpu().numpy())
axarr[0].set_title("fit")
axarr[1].imshow(self._image.cpu())
axarr[1].imshow(self.image.cpu())
axarr[1].set_title("original")

for i in range(2):
Expand Down

0 comments on commit f16c8e7

Please sign in to comment.