Skip to content

Commit

Permalink
Clean up logic for Ray datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
arnavgarg1 committed Dec 12, 2022
1 parent 9243a54 commit bd759ff
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions ludwig/data/dataset/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def __init__(

self.features = features
self.columns = list(features.keys())
self._sample_feature_name = self.columns[0]
self.reshape_map = {
proc_column: training_set_metadata[feature[NAME]].get("reshape")
for proc_column, feature in features.items()
Expand Down Expand Up @@ -342,8 +343,9 @@ def _fetch_next_batch(self):
self._last_batch = False
try:
self._next_batch = next(self.dataset_batch_iter)
print(self._next_batch)
if self.ignore_last and len(self._next_batch) == 1:
# If the batch has only one row and self.ignore_last, skip the batch
# to prevent batchnorm / dropout related Torch errors
if self.ignore_last and len(self._next_batch[self._sample_feature_name]) == 1:
raise StopIteration
except StopIteration:
self._last_batch = True
Expand Down

0 comments on commit bd759ff

Please sign in to comment.