diff --git a/ludwig/backend/base.py b/ludwig/backend/base.py index 112a8afe35e..01ba196bef6 100644 --- a/ludwig/backend/base.py +++ b/ludwig/backend/base.py @@ -318,7 +318,7 @@ def create_trainer( ) -> BaseTrainer: # type: ignore[override] from ludwig.trainers.trainer import Trainer - return Trainer(distributed=self._distributed, **kwargs) + return Trainer(config, model, distributed=self._distributed, **kwargs) def create_predictor(self, model: BaseModel, **kwargs): from ludwig.models.predictor import get_predictor_cls