Skip to content

Commit

Permalink
Add skip_all_evaluation as a mechanic to skip all evaluation. (#3543)
Browse files Browse the repository at this point in the history
  • Loading branch information
justinxzhao authored Aug 28, 2023
1 parent 5ef0878 commit f34c272
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 49 deletions.
6 changes: 5 additions & 1 deletion ludwig/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,11 @@ def on_epoch_end(self, trainer, progress_tracker, save_path):
save_json(training_stats_fn, train_stats)

# results of the model with highest validation test performance
if self.backend.is_coordinator() and validation_set is not None:
if (
self.backend.is_coordinator()
and validation_set is not None
and not self.config_obj.trainer.skip_all_evaluation
):
print_boxed("TRAINING REPORT")
training_report = get_training_report(
trainer.validation_field,
Expand Down
9 changes: 9 additions & 0 deletions ludwig/schema/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,15 @@ class BaseTrainerConfig(schema_utils.BaseMarshmallowConfig, ABC):
),
)

skip_all_evaluation: bool = schema_utils.Boolean(
default=False,
description=(
"Whether to skip evaluation entirely. If you are training a model with a well-known configuration on a "
"well-known dataset and are confident about the expected results, you might skip all evaluation. Moreover, "
"evaluating a model, especially on large validation or test sets, can be time-consuming."
),
)

def can_tune_batch_size(self) -> bool:
return True

Expand Down
47 changes: 30 additions & 17 deletions ludwig/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def __init__(
self.steps_per_checkpoint = config.steps_per_checkpoint
self.checkpoints_per_epoch = config.checkpoints_per_epoch
self.evaluate_training_set = config.evaluate_training_set
self.skip_all_evaluation = config.skip_all_evaluation
self.increase_batch_size_on_plateau = config.increase_batch_size_on_plateau
self.increase_batch_size_on_plateau_patience = config.increase_batch_size_on_plateau_patience
self.increase_batch_size_on_plateau_rate = config.increase_batch_size_on_plateau_rate
Expand Down Expand Up @@ -836,6 +837,10 @@ def train(
if self.is_coordinator():
progress_tracker.save(os.path.join(save_path, TRAINING_PROGRESS_TRACKER_FILE_NAME))

if not self.skip_save_model and self.skip_all_evaluation:
# All evaluation was skipped, so save the current step as the best so far.
checkpoint_manager.save_best(progress_tracker.steps)

# Early stop if needed.
if should_break:
break
Expand All @@ -853,6 +858,10 @@ def train(
if test_summary_writer is not None:
test_summary_writer.close()

if not self.skip_save_model and self.skip_all_evaluation:
# All evaluation was skipped, so save the current step as the best so far.
checkpoint_manager.save_best(progress_tracker.steps)

if not self.skip_save_progress:
checkpoint_manager.close()

Expand Down Expand Up @@ -933,6 +942,7 @@ def _train_loop(
}

loss, all_losses = self.train_step(inputs, targets, should_step=should_step)
logger.info(f"Train loss for step {progress_tracker.steps}: {loss:.3f}")

if should_step:
# Update LR schduler here instead of train loop to avoid updating during batch size tuning, etc.
Expand Down Expand Up @@ -961,23 +971,26 @@ def _train_loop(
self.callback(lambda c: c.on_batch_end(self, progress_tracker, save_path, sync_step=should_step))

if progress_tracker.steps % final_steps_per_checkpoint == 0:
should_break = self.run_evaluation(
training_set,
validation_set,
test_set,
progress_tracker,
train_summary_writer,
validation_summary_writer,
test_summary_writer,
model_hyperparameters_path,
output_features,
metrics_names,
save_path,
loss,
all_losses,
early_stopping_steps,
checkpoint_manager,
)
if not self.skip_all_evaluation:
should_break = self.run_evaluation(
training_set,
validation_set,
test_set,
progress_tracker,
train_summary_writer,
validation_summary_writer,
test_summary_writer,
model_hyperparameters_path,
output_features,
metrics_names,
save_path,
loss,
all_losses,
early_stopping_steps,
checkpoint_manager,
)
else:
should_break = False

# Checkpoint the model.
# NOTE: Ideally we would do this before evaluation, but for some reason DeepSpeed will complain
Expand Down
34 changes: 19 additions & 15 deletions ludwig/trainers/trainer_lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(
self._validation_field = config.validation_field
self._validation_metric = config.validation_metric
self.evaluate_training_set = config.evaluate_training_set
self.skip_all_evaluation = config.skip_all_evaluation
try:
base_learning_rate = float(config.learning_rate)
except ValueError:
Expand Down Expand Up @@ -380,21 +381,24 @@ def _train_loop(
loss = evals_result["train"][loss_name][-1]
loss = torch.tensor(loss, dtype=torch.float32)

should_break = self.run_evaluation(
training_set,
validation_set,
test_set,
progress_tracker,
train_summary_writer,
validation_summary_writer,
test_summary_writer,
output_features,
metrics_names,
save_path,
loss,
{output_feature_name: loss},
early_stopping_steps,
)
if not self.skip_all_evaluation:
should_break = self.run_evaluation(
training_set,
validation_set,
test_set,
progress_tracker,
train_summary_writer,
validation_summary_writer,
test_summary_writer,
output_features,
metrics_names,
save_path,
loss,
{output_feature_name: loss},
early_stopping_steps,
)
else:
should_break = False

self.callback(lambda c: c.on_batch_end(self, progress_tracker, save_path))

Expand Down
51 changes: 35 additions & 16 deletions ludwig/trainers/trainer_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,23 @@ def __init__(
self.checkpoints_per_epoch = self.config.checkpoints_per_epoch
self.early_stop = self.config.early_stop
self.evaluate_training_set = self.config.evaluate_training_set
self.skip_all_evaluation = self.config.skip_all_evaluation

def close_writers(
self, progress_tracker, save_path, train_summary_writer, validation_summary_writer, test_summary_writer
):
# ================ Finished Training ================
self.callback(
lambda c: c.on_trainer_train_teardown(self, progress_tracker, save_path, self.is_coordinator()),
coordinator_only=False,
)

if train_summary_writer is not None:
train_summary_writer.close()
if validation_summary_writer is not None:
validation_summary_writer.close()
if test_summary_writer is not None:
test_summary_writer.close()

def train(
self,
Expand Down Expand Up @@ -151,6 +168,22 @@ def train(
output_features=output_features,
)

# When running with Ray, we only need to return the state dict, as it's faster and cheaper to send the
# state dict over the network than to load the model state here, serialize it back to a state dict, then
# load it back on the head node.
return_value = self.model if not return_state_dict else self.model.cpu().state_dict()

if self.skip_all_evaluation:
self.close_writers(
progress_tracker, save_path, train_summary_writer, validation_summary_writer, test_summary_writer
)
return (
return_value,
progress_tracker.train_metrics,
progress_tracker.validation_metrics,
progress_tracker.test_metrics,
)

try:
self.run_evaluation(
training_set,
Expand All @@ -164,24 +197,10 @@ def train(
save_path,
)
finally:
# ================ Finished Training ================
self.callback(
lambda c: c.on_trainer_train_teardown(self, progress_tracker, save_path, self.is_coordinator()),
coordinator_only=False,
self.close_writers(
progress_tracker, save_path, train_summary_writer, validation_summary_writer, test_summary_writer
)

if train_summary_writer is not None:
train_summary_writer.close()
if validation_summary_writer is not None:
validation_summary_writer.close()
if test_summary_writer is not None:
test_summary_writer.close()

# When running with Ray, we only need to return the state dict, as it's faster and cheaper to send the
# state dict over the network than to load the model state here, serialize it back to a state dict, then
# load it back on the head node.
return_value = self.model if not return_state_dict else self.model.cpu().state_dict()

return (
return_value,
progress_tracker.train_metrics,
Expand Down

0 comments on commit f34c272

Please sign in to comment.