Skip to content

Commit

Permalink
put in an extra check to make sure does not exceed calculated total b…
Browse files Browse the repository at this point in the history
…atches
  • Loading branch information
lucidrains committed Feb 5, 2021
1 parent 0cf76c7 commit 91a9339
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
6 changes: 5 additions & 1 deletion deep_daze/deep_daze.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def generate_size_schedule(self):
counter = 0
self.scheduled_sizes = []

while batches < self.total_batches:
while batches <= self.total_batches:
counter += 1
sizes = self.sample_sizes(counter)
batches += len(sizes)
Expand Down Expand Up @@ -366,6 +366,10 @@ def forward(self):
loss = self.train_step(epoch, i)
pbar.set_description(f'loss: {loss.item():.2f}')

if self.model.num_batches_processed > self.model.total_batches:
print('number of batches processed exceeds calculated total batches')
return

if terminate:
print('interrupted by keyboard, gracefully exiting')
return
2 changes: 1 addition & 1 deletion deep_daze/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.4.5'
__version__ = '0.4.7'

0 comments on commit 91a9339

Please sign in to comment.