Skip to content

Commit

Permalink
Merge branch 'master' into fast-im-read
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffreyangus committed Aug 4, 2022
2 parents b572f1c + e9774d1 commit e8b0160
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
10 changes: 10 additions & 0 deletions ludwig/schema/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,16 @@ class GBMTrainerConfig(BaseTrainerConfig):
allow_none=False,
)

# NOTE: Overwritten here to provide a default value. In many places, we fall back to eval_batch_size if batch_size
# is not specified. GBM does not have a value for batch_size, so we need to specify eval_batch_size here.
eval_batch_size: Union[None, int, str] = schema_utils.IntegerOrAutoField(
default=128,
allow_none=True,
min_exclusive=0,
description=("Size of batch to pass to the model for evaluation."),
parameter_metadata=TRAINER_METADATA["eval_batch_size"],
)

# LightGBM core parameters (https://lightgbm.readthedocs.io/en/latest/Parameters.html)
boosting_type: str = schema_utils.StringOptions(
["gbdt", "rf", "dart", "goss"],
Expand Down
2 changes: 1 addition & 1 deletion ludwig/trainers/trainer_lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
self.skip_save_progress = skip_save_progress
self.skip_save_model = skip_save_model

self.eval_batch_size = config.eval_batch_size or 128
self.eval_batch_size = config.eval_batch_size
self._validation_field = config.validation_field
self._validation_metric = config.validation_metric
self.evaluate_training_set = config.evaluate_training_set
Expand Down

0 comments on commit e8b0160

Please sign in to comment.