Skip to content

Commit

Permalink
Merge branch 'master' of github.com:ludwig-ai/ludwig into increment_e…
Browse files Browse the repository at this point in the history
…pochs
  • Loading branch information
justinxzhao committed Sep 27, 2023
2 parents 19d5cdd + 4af5331 commit e7178ab
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 6 deletions.
9 changes: 7 additions & 2 deletions ludwig/data/batcher/random_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ def next_batch(self):
return sub_batch

def last_batch(self):
"""Returns whether we've exhausted all batches for this epoch.
If False, then there is at least 1 more batch available with next_batch().
"""
# If our current index in the dataset exceeds the size of the dataset,
# we've finished the epoch and can indicate that this is the last batch
if self.index >= self.total_size:
Expand All @@ -71,8 +75,9 @@ def last_batch(self):
# For e.g., batch size = 128 but the dataset only has 100 rows.
elif self.ignore_last and self.step:
# index += batch_size after each epoch. So, if our current index in total dataset is 1 less than the total
# dataset size, then the last batch will only have 1 row. Drop it if this happens.
if self.index - self.total_size == -1:
# dataset size, then the last batch will only have 1 row.
# If this happens, we drop the last batch, unless batch_size is 1.
if self.batch_size > 1 and self.index - self.total_size == -1:
logger.info("Last batch in epoch only has 1 sample and will be dropped.")
return True
return False
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ sacremoses
sentencepiece

# requirements for daft
getdaft ; platform_system != "Windows"
getdaft

# requirement for various paged and 8-bit optimizers
bitsandbytes<0.41.0
Expand Down
32 changes: 29 additions & 3 deletions tests/integration_tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,9 +563,7 @@ def test_api_callbacks_default_train_steps(tmpdir, csv_filename):


def test_api_callbacks_fixed_train_steps(tmpdir, csv_filename):
# If train_steps is set manually, epochs is ignored.
train_steps = 100
epochs = 2
batch_size = 8
num_examples = 80
mock_callback = mock.Mock(wraps=Callback())
Expand All @@ -576,7 +574,7 @@ def test_api_callbacks_fixed_train_steps(tmpdir, csv_filename):
"input_features": input_features,
"output_features": output_features,
"combiner": {"type": "concat", "output_size": 14},
TRAINER: {"epochs": epochs, "train_steps": train_steps, "batch_size": batch_size},
TRAINER: {"train_steps": train_steps, "batch_size": batch_size},
}
model = LudwigModel(config, callbacks=[mock_callback])
model.train(
Expand Down Expand Up @@ -616,6 +614,34 @@ def test_api_callbacks_fixed_train_steps_partial_epochs(tmpdir, csv_filename):
assert mock_callback.on_epoch_end.call_count == 9


def test_api_callbacks_batch_size_1(tmpdir, csv_filename):
epochs = 2
batch_size = 1
num_examples = 80
mock_callback = mock.Mock(wraps=Callback())

input_features = [sequence_feature(encoder={"reduce_output": "sum"})]
output_features = [category_feature(decoder={"vocab_size": 5}, reduce_input="sum")]
config = {
"input_features": input_features,
"output_features": output_features,
"combiner": {"type": "concat", "output_size": 14},
TRAINER: {"epochs": epochs, "batch_size": batch_size},
}
model = LudwigModel(config, callbacks=[mock_callback])
model.train(
training_set=generate_data(
input_features, output_features, os.path.join(tmpdir, csv_filename), num_examples=num_examples
)
)

# There are exactly 2 epoch starts, even with batch_size = 1.
assert mock_callback.on_epoch_start.call_count == 2
assert mock_callback.on_epoch_end.call_count == 2
assert mock_callback.on_batch_start.call_count == 160
assert mock_callback.on_batch_end.call_count == 160


def test_api_callbacks_fixed_train_steps_less_than_one_epoch(tmpdir, csv_filename):
# If train_steps is set manually, epochs is ignored.
train_steps = total_batches = 6
Expand Down

0 comments on commit e7178ab

Please sign in to comment.