Skip to content

Commit

Permalink
Increment epochs based on last_batch() instead of at the end of the t…
Browse files Browse the repository at this point in the history
…rain loop.
  • Loading branch information
justinxzhao committed Sep 26, 2023
1 parent bd6d34f commit b75170d
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 5 deletions.
14 changes: 9 additions & 5 deletions ludwig/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,7 +912,7 @@ def train(

self.callback(lambda c: c.on_epoch_start(self, progress_tracker, save_path))

# Trains over a full epoch of data.
# Trains over a full epoch of data or up to the last training step, whichever is sooner.
should_break = self._train_loop(
batcher,
progress_tracker,
Expand All @@ -934,10 +934,6 @@ def train(
profiler,
)

# ================ Post Training Epoch ================
progress_tracker.epoch += 1
self.callback(lambda c: c.on_epoch_end(self, progress_tracker, save_path))

if self.is_coordinator():
# ========== Save training progress ==========
logger.debug(
Expand Down Expand Up @@ -1114,8 +1110,16 @@ def _train_loop(
# batch duration measurements when using timer callbacks.
self.callback(lambda c: c.on_batch_end(self, progress_tracker, save_path, sync_step=should_step))

if batcher.last_batch():
# We have completed an epoch, so we need to increment the epoch counter. It's important to do this here
# instead of outside of the train loop since it's possible the train loop will exit early due to
# early stopping, or step-based training.
progress_tracker.epoch += 1
self.callback(lambda c: c.on_epoch_end(self, progress_tracker, save_path))

if progress_tracker.steps % final_steps_per_checkpoint == 0:
if not self.skip_all_evaluation:
# Publishes metrics to MLFLow if there are any MLFlow callbacks.
should_break = self.run_evaluation(
training_set,
validation_set,
Expand Down
27 changes: 27 additions & 0 deletions tests/integration_tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,33 @@ def test_api_callbacks_fixed_train_steps(tmpdir, csv_filename):
assert mock_callback.on_epoch_start.call_count == 10


def test_api_callbacks_fixed_train_steps_partial_epochs(tmpdir, csv_filename):
# If train_steps is set manually, epochs is ignored.
train_steps = 95
epochs = 2
batch_size = 8
num_examples = 80
mock_callback = mock.Mock(wraps=Callback())

input_features = [sequence_feature(encoder={"reduce_output": "sum"})]
output_features = [category_feature(decoder={"vocab_size": 5}, reduce_input="sum")]
config = {
"input_features": input_features,
"output_features": output_features,
"combiner": {"type": "concat", "output_size": 14},
TRAINER: {"epochs": epochs, "train_steps": train_steps, "batch_size": batch_size},
}
model = LudwigModel(config, callbacks=[mock_callback])
model.train(
training_set=generate_data(
input_features, output_features, os.path.join(tmpdir, csv_filename), num_examples=num_examples
)
)

# There are 10 steps per epoch, so 95 train steps => 9 full epochs.
assert mock_callback.on_epoch_end.call_count == 9


def test_api_callbacks_fixed_train_steps_less_than_one_epoch(tmpdir, csv_filename):
# If train_steps is set manually, epochs is ignored.
train_steps = total_batches = 6
Expand Down

0 comments on commit b75170d

Please sign in to comment.