diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index bbb1f2ea00..c1ca383f93 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -458,8 +458,12 @@ def _train_epoch(self, context: PhaseContext, silent_mode: bool = False) -> tupl # SET THE MODEL IN training STATE self.net.train() + expected_iterations = len(self.train_loader) if self.max_train_batches is None else self.max_train_batches + # THE DISABLE FLAG CONTROLS WHETHER THE PROGRESS BAR IS SILENT OR PRINTS THE LOGS - with tqdm(self.train_loader, bar_format="{l_bar}{bar:10}{r_bar}", dynamic_ncols=True, disable=silent_mode) as progress_bar_train_loader: + with tqdm( + self.train_loader, total=expected_iterations, bar_format="{l_bar}{bar:10}{r_bar}", dynamic_ncols=True, disable=silent_mode + ) as progress_bar_train_loader: progress_bar_train_loader.set_description(f"Train epoch {context.epoch}") # RESET/INIT THE METRIC LOGGERS @@ -471,6 +475,9 @@ def _train_epoch(self, context: PhaseContext, silent_mode: bool = False) -> tupl context.update_context(loss_avg_meter=loss_avg_meter, metrics_compute_fn=self.train_metrics) for batch_idx, batch_items in enumerate(progress_bar_train_loader): + if expected_iterations <= batch_idx: + break + batch_items = core_utils.tensor_container_to_device(batch_items, device_config.device, non_blocking=True) inputs, targets, additional_batch_items = sg_trainer_utils.unpack_batch_items(batch_items) @@ -510,9 +517,6 @@ def _train_epoch(self, context: PhaseContext, silent_mode: bool = False) -> tupl progress_bar_train_loader.set_postfix(**pbar_message_dict) self.phase_callback_handler.on_train_batch_end(context) - if self.max_train_batches is not None and self.max_train_batches - 1 <= batch_idx: - break - self.train_monitored_values = sg_trainer_utils.update_monitored_values_dict( monitored_values_dict=self.train_monitored_values, new_values_dict=pbar_message_dict ) @@ -1331,21 +1335,23 @@ def forward(self, inputs, targets): self.ckpt_best_name = self.training_params.ckpt_best_name + self.max_train_batches = self.training_params.max_train_batches + self.max_valid_batches = self.training_params.max_valid_batches + if self.training_params.max_train_batches is not None: if self.training_params.max_train_batches > len(self.train_loader): logger.warning("max_train_batches is greater than len(self.train_loader) and will have no effect.") + self.max_train_batches = len(self.train_loader) elif self.training_params.max_train_batches <= 0: raise ValueError("max_train_batches must be positive.") if self.training_params.max_valid_batches is not None: if self.training_params.max_valid_batches > len(self.valid_loader): logger.warning("max_valid_batches is greater than len(self.valid_loader) and will have no effect.") + self.max_valid_batches = len(self.valid_loader) elif self.training_params.max_valid_batches <= 0: raise ValueError("max_valid_batches must be positive.") - self.max_train_batches = self.training_params.max_train_batches - self.max_valid_batches = self.training_params.max_valid_batches - # STATE ATTRIBUTE SET HERE FOR SUBSEQUENT TRAIN() CALLS self._first_backward = True @@ -1394,6 +1400,7 @@ def forward(self, inputs, targets): batch_accumulate=self.batch_accumulate, train_dataset_length=len(self.train_loader.dataset), train_dataloader_len=len(self.train_loader), + max_train_batches=self.max_train_batches, ) processing_params = self._get_preprocessing_from_valid_loader() @@ -2014,7 +2021,12 @@ def _validate_epoch(self, context: PhaseContext, silent_mode: bool = False) -> D self._reset_metrics() self.valid_metrics.to(device_config.device) return self.evaluate( - data_loader=self.valid_loader, metrics=self.valid_metrics, evaluation_type=EvaluationType.VALIDATION, epoch=context.epoch, silent_mode=silent_mode + data_loader=self.valid_loader, + metrics=self.valid_metrics, + evaluation_type=EvaluationType.VALIDATION, + epoch=context.epoch, + silent_mode=silent_mode, + max_batches=self.max_valid_batches, ) def _test_epoch(self, data_loader: DataLoader, context: PhaseContext, silent_mode: bool = False, dataset_name: str = "") -> Dict[str, float]: @@ -2047,6 +2059,7 @@ def evaluate( silent_mode: bool = False, metrics_progress_verbose: bool = False, dataset_name: str = "", + max_batches: Optional[int] = None, ) -> Dict[str, float]: """ Evaluates the model on given dataloader and metrics. @@ -2081,7 +2094,11 @@ def evaluate( loss_logging_items_names=self.loss_logging_items_names, ) - with tqdm(data_loader, bar_format="{l_bar}{bar:10}{r_bar}", dynamic_ncols=True, disable=silent_mode) as progress_bar_data_loader: + expected_iterations = len(data_loader) if max_batches is None else max_batches + + with tqdm( + data_loader, total=expected_iterations, bar_format="{l_bar}{bar:10}{r_bar}", dynamic_ncols=True, disable=silent_mode + ) as progress_bar_data_loader: if not silent_mode: # PRINT TITLES @@ -2091,9 +2108,11 @@ def evaluate( if epoch: pbar_start_msg += f" epoch {epoch}" progress_bar_data_loader.set_description(pbar_start_msg) - with torch.no_grad(): for batch_idx, batch_items in enumerate(progress_bar_data_loader): + if evaluation_type == EvaluationType.VALIDATION and expected_iterations <= batch_idx: + break + batch_items = core_utils.tensor_container_to_device(batch_items, device_config.device, non_blocking=True) inputs, targets, additional_batch_items = sg_trainer_utils.unpack_batch_items(batch_items) @@ -2128,9 +2147,6 @@ def evaluate( progress_bar_data_loader.set_postfix(**pbar_message_dict) - if evaluation_type == EvaluationType.VALIDATION and self.max_valid_batches is not None and self.max_valid_batches - 1 <= batch_idx: - break - logging_values = get_logging_values(loss_avg_meter, metrics, self.criterion) # NEED TO COMPUTE METRICS FOR THE FIRST TIME IF PROGRESS VERBOSITY IS NOT SET if not metrics_progress_verbose: diff --git a/src/super_gradients/training/utils/sg_trainer_utils.py b/src/super_gradients/training/utils/sg_trainer_utils.py index 03f8c957b0..1d269f60d9 100644 --- a/src/super_gradients/training/utils/sg_trainer_utils.py +++ b/src/super_gradients/training/utils/sg_trainer_utils.py @@ -447,19 +447,34 @@ def get_callable_param_names(obj: callable) -> Tuple[str]: def log_main_training_params( - multi_gpu: MultiGPUMode, num_gpus: int, batch_size: int, batch_accumulate: int, train_dataset_length: int, train_dataloader_len: int + multi_gpu: MultiGPUMode, + num_gpus: int, + batch_size: int, + batch_accumulate: int, + train_dataset_length: int, + train_dataloader_len: int, + max_train_batches: Optional[int] = None, ): """Log training parameters""" + + iterations_per_epoch = int(train_dataloader_len) if max_train_batches is None else max_train_batches + gradients_updates_per_epoch = int(iterations_per_epoch / batch_accumulate) + what_used_str = "len(train_loader)" if max_train_batches is None else "max_train_batches" + msg = ( "TRAINING PARAMETERS:\n" f" - Mode: {multi_gpu.name if multi_gpu else 'Single GPU'}\n" f" - Number of GPUs: {num_gpus if 'cuda' in device_config.device else 0:<10} ({torch.cuda.device_count()} available on the machine)\n" - f" - Dataset size: {train_dataset_length:<10} (len(train_set))\n" + f" - Full dataset size: {train_dataset_length:<10} (len(train_set))\n" f" - Batch size per GPU: {batch_size:<10} (batch_size)\n" f" - Batch Accumulate: {batch_accumulate:<10} (batch_accumulate)\n" f" - Total batch size: {num_gpus * batch_size:<10} (num_gpus * batch_size)\n" f" - Effective Batch size: {num_gpus * batch_size * batch_accumulate:<10} (num_gpus * batch_size * batch_accumulate)\n" - f" - Iterations per epoch: {int(train_dataloader_len):<10} (len(train_loader))\n" - f" - Gradient updates per epoch: {int(train_dataloader_len / batch_accumulate):<10} (len(train_loader) / batch_accumulate)\n" + f" - Iterations per epoch: {iterations_per_epoch:<10} ({what_used_str})\n" + f" - Gradient updates per epoch: {gradients_updates_per_epoch:<10} ({what_used_str} / batch_accumulate)\n" ) + logger.info(msg) + + if max_train_batches: + logger.warning(f"max_train_batch is set to {max_train_batches}. This limits the number of iterations per epoch and gradient updates per epoch.")