From f16c8e70c90185eed471653c80330744fe6e6e70 Mon Sep 17 00:00:00 2001 From: jloveric Date: Sun, 19 May 2024 21:10:30 -0700 Subject: [PATCH] Sampler almost there --- .../rendering.py | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/high_order_implicit_representation/rendering.py b/high_order_implicit_representation/rendering.py index 2160c39..3c45346 100644 --- a/high_order_implicit_representation/rendering.py +++ b/high_order_implicit_representation/rendering.py @@ -247,7 +247,7 @@ class Text2ImageSampler(Callback): def __init__(self, filenames, batch_size): self._dataset = Text2ImageRenderDataset(filenames) self._dataloader = DataLoader( - self._dataset, batch_size=batch_size, shuffle=False + self._dataset, batch_size=1, shuffle=False ) self._batch_size = batch_size @@ -257,23 +257,23 @@ def on_train_epoch_end( ) -> None: pl_module.eval() with torch.no_grad(): - - y_hat_list = [] - + print("We are calling this") image_count=0 for caption_embedding, flattened_position, image in self._dataloader: + flattened_position=flattened_position[0] size = len(flattened_position) + y_hat_list = [] for i in range(0, size, self._batch_size): - # embed = caption_embedding[i:(i+self._batch_size)] - # rgb = flattened_image[i:(i+self._batch_size)] - pos = flattened_position[i : (i + self._batch_size)] + - for i in range(size//self._batch_size): + embed_single = caption_embedding.to(pl_module.device) + + # rgb = flattened_image[i:(i+self._batch_size)] + pos = flattened_position[i : (i + self._batch_size)].to(pl_module.device) + embed = embed_single.repeat(pos.shape[0],1) res = pl_module( - caption_embedding, - flattened_position[ - i * self._batch_size : (i + 1) * self._batch_size - ], + embed, + pos, ) y_hat_list.append(res.detach().cpu()) @@ -281,14 +281,17 @@ def on_train_epoch_end( y_hat = torch.vstack(y_hat_list) ans = y_hat.reshape( - self._image.shape[0], self._image.shape[1], self._image.shape[2] + image.shape[1], image.shape[2], 3 ) + + print('ans.shape', ans.shape) + ans = 0.5 * (ans + 1.0) f, axarr = plt.subplots(1, 2) axarr[0].imshow(ans.detach().cpu().numpy()) axarr[0].set_title("fit") - axarr[1].imshow(self._image.cpu()) + axarr[1].imshow(self.image.cpu()) axarr[1].set_title("original") for i in range(2):