From 4e238e9f2dcb36ebccc8019a17be3a8f45fc4fe2 Mon Sep 17 00:00:00 2001 From: Louis Dupont Date: Thu, 11 May 2023 12:28:35 +0300 Subject: [PATCH 1/3] fix --- .../training/sg_trainer/sg_trainer.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index 10f80d2a23..0972634af0 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -133,6 +133,7 @@ def __init__(self, experiment_name: str, device: str = None, multi_gpu: Union[Mu # SET THE EMPTY PROPERTIES self.net, self.architecture, self.arch_params, self.dataset_interface = None, None, None, None + self.train_loader, self.valid_loader = None, None self.ema = None self.ema_model = None self.sg_logger = None @@ -996,8 +997,15 @@ def forward(self, inputs, targets): global logger if training_params is None: training_params = dict() - self.train_loader = train_loader or self.train_loader - self.valid_loader = valid_loader or self.valid_loader + + self.train_loader = train_loader if train_loader is not None else self.train_loader + self.valid_loader = valid_loader if train_loader is not None else self.valid_loader + + if self.train_loader is None: + raise ValueError("No `train_loader` found. Please provide a value `train_loader`") + + if self.valid_loader is None: + raise ValueError("No `valid_loader` found. Please provide a value for `valid_loader`") if hasattr(self.train_loader, "batch_sampler") and self.train_loader.batch_sampler is not None: batch_size = self.train_loader.batch_sampler.batch_size From a5fe1301986ea2b708a4b934fca4892e252e4632 Mon Sep 17 00:00:00 2001 From: Louis Dupont Date: Thu, 11 May 2023 12:32:35 +0300 Subject: [PATCH 2/3] fix typo --- src/super_gradients/training/sg_trainer/sg_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index 0972634af0..9ef441a00d 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -999,7 +999,7 @@ def forward(self, inputs, targets): training_params = dict() self.train_loader = train_loader if train_loader is not None else self.train_loader - self.valid_loader = valid_loader if train_loader is not None else self.valid_loader + self.valid_loader = valid_loader if valid_loader is not None else self.valid_loader if self.train_loader is None: raise ValueError("No `train_loader` found. Please provide a value `train_loader`") From 7991484fbbdb81be705925a21691d1c6860b5fce Mon Sep 17 00:00:00 2001 From: Louis Dupont Date: Thu, 11 May 2023 12:34:54 +0300 Subject: [PATCH 3/3] typo --- src/super_gradients/training/sg_trainer/sg_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index 9ef441a00d..117beb2bdc 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -1002,7 +1002,7 @@ def forward(self, inputs, targets): self.valid_loader = valid_loader if valid_loader is not None else self.valid_loader if self.train_loader is None: - raise ValueError("No `train_loader` found. Please provide a value `train_loader`") + raise ValueError("No `train_loader` found. Please provide a value for `train_loader`") if self.valid_loader is None: raise ValueError("No `valid_loader` found. Please provide a value for `valid_loader`")