diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index 10f80d2a23..117beb2bdc 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -133,6 +133,7 @@ def __init__(self, experiment_name: str, device: str = None, multi_gpu: Union[Mu # SET THE EMPTY PROPERTIES self.net, self.architecture, self.arch_params, self.dataset_interface = None, None, None, None + self.train_loader, self.valid_loader = None, None self.ema = None self.ema_model = None self.sg_logger = None @@ -996,8 +997,15 @@ def forward(self, inputs, targets): global logger if training_params is None: training_params = dict() - self.train_loader = train_loader or self.train_loader - self.valid_loader = valid_loader or self.valid_loader + + self.train_loader = train_loader if train_loader is not None else self.train_loader + self.valid_loader = valid_loader if valid_loader is not None else self.valid_loader + + if self.train_loader is None: + raise ValueError("No `train_loader` found. Please provide a value for `train_loader`") + + if self.valid_loader is None: + raise ValueError("No `valid_loader` found. Please provide a value for `valid_loader`") if hasattr(self.train_loader, "batch_sampler") and self.train_loader.batch_sampler is not None: batch_size = self.train_loader.batch_sampler.batch_size