Skip to content

Commit

Permalink
fix: Use len() to get size of dataset
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
  • Loading branch information
peri044 committed Jun 18, 2021
1 parent eb39f9c commit ccc60d5
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion py/trtorch/ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def get_batch_size(self):


def get_batch(self, names):
if self.current_batch_idx + self.batch_size > self.data_loader.dataset.data.shape[0]:
print("Current batch idx: ", self.current_batch_idx, " Dataset size: ", len(self.data_loader.dataset))
if self.current_batch_idx + self.batch_size > len(self.data_loader.dataset):
return None

batch = self.dataset_iterator.next()
Expand Down

0 comments on commit ccc60d5

Please sign in to comment.