Skip to content

Commit

Permalink
fix final bug
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 6, 2021
1 parent 91a9339 commit 97401ad
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
7 changes: 4 additions & 3 deletions deep_daze/deep_daze.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __init__(

self.generate_size_schedule()

def forward(self, text, return_loss=True):
def forward(self, text, return_loss=True, dry_run=False):
out = self.model()
out = norm_siren_output(out)

Expand All @@ -153,7 +153,8 @@ def forward(self, text, return_loss=True):
image_embed = perceptor.encode_image(image)
text_embed = perceptor.encode_text(text)

self.num_batches_processed += self.batch_size
if not dry_run:
self.num_batches_processed += self.batch_size

loss = -self.loss_coef * torch.cosine_similarity(text_embed, image_embed, dim=-1).mean()
return loss
Expand Down Expand Up @@ -354,7 +355,7 @@ def forward(self):

tqdm.write(f'Imagining "{self.text}" from the depths of my weights...')

self.model(self.encoded_text) # do one warmup step due to potential issue with CLIP and CUDA
self.model(self.encoded_text, dry_run = True) # do one warmup step due to potential issue with CLIP and CUDA

if self.open_folder:
open_folder('./')
Expand Down
2 changes: 1 addition & 1 deletion deep_daze/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.4.7'
__version__ = '0.4.8'

0 comments on commit 97401ad

Please sign in to comment.