From 4b78ba0007c72d6e8a25675c0141c7a9db0e2803 Mon Sep 17 00:00:00 2001 From: hakuryuu96 Date: Thu, 19 Oct 2023 13:53:14 +0000 Subject: [PATCH 1/3] Added max_batches support to training log and tqdm progress bar. --- .../training/sg_trainer/sg_trainer.py | 42 +++++++++++++------ .../training/utils/sg_trainer_utils.py | 22 ++++++++-- 2 files changed, 47 insertions(+), 17 deletions(-) diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index bbb1f2ea00..c8f30d465d 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -458,8 +458,12 @@ def _train_epoch(self, context: PhaseContext, silent_mode: bool = False) -> tupl # SET THE MODEL IN training STATE self.net.train() + expected_iterations = len(self.train_loader) if self.max_train_batches is None else self.max_train_batches + # THE DISABLE FLAG CONTROLS WHETHER THE PROGRESS BAR IS SILENT OR PRINTS THE LOGS - with tqdm(self.train_loader, bar_format="{l_bar}{bar:10}{r_bar}", dynamic_ncols=True, disable=silent_mode) as progress_bar_train_loader: + with tqdm( + self.train_loader, total=expected_iterations, bar_format="{l_bar}{bar:10}{r_bar}", dynamic_ncols=True, disable=silent_mode + ) as progress_bar_train_loader: progress_bar_train_loader.set_description(f"Train epoch {context.epoch}") # RESET/INIT THE METRIC LOGGERS @@ -471,6 +475,9 @@ def _train_epoch(self, context: PhaseContext, silent_mode: bool = False) -> tupl context.update_context(loss_avg_meter=loss_avg_meter, metrics_compute_fn=self.train_metrics) for batch_idx, batch_items in enumerate(progress_bar_train_loader): + if self.max_train_batches is not None and self.max_train_batches <= batch_idx: + break + batch_items = core_utils.tensor_container_to_device(batch_items, device_config.device, non_blocking=True) inputs, targets, additional_batch_items = sg_trainer_utils.unpack_batch_items(batch_items) @@ -510,9 +517,6 @@ def _train_epoch(self, context: PhaseContext, silent_mode: bool = False) -> tupl progress_bar_train_loader.set_postfix(**pbar_message_dict) self.phase_callback_handler.on_train_batch_end(context) - if self.max_train_batches is not None and self.max_train_batches - 1 <= batch_idx: - break - self.train_monitored_values = sg_trainer_utils.update_monitored_values_dict( monitored_values_dict=self.train_monitored_values, new_values_dict=pbar_message_dict ) @@ -1331,21 +1335,23 @@ def forward(self, inputs, targets): self.ckpt_best_name = self.training_params.ckpt_best_name + self.max_train_batches = self.training_params.max_train_batches + self.max_valid_batches = self.training_params.max_valid_batches + if self.training_params.max_train_batches is not None: if self.training_params.max_train_batches > len(self.train_loader): logger.warning("max_train_batches is greater than len(self.train_loader) and will have no effect.") + self.max_train_batches = len(self.train_loader) elif self.training_params.max_train_batches <= 0: raise ValueError("max_train_batches must be positive.") if self.training_params.max_valid_batches is not None: if self.training_params.max_valid_batches > len(self.valid_loader): logger.warning("max_valid_batches is greater than len(self.valid_loader) and will have no effect.") + self.max_valid_batches = len(self.valid_loader) elif self.training_params.max_valid_batches <= 0: raise ValueError("max_valid_batches must be positive.") - self.max_train_batches = self.training_params.max_train_batches - self.max_valid_batches = self.training_params.max_valid_batches - # STATE ATTRIBUTE SET HERE FOR SUBSEQUENT TRAIN() CALLS self._first_backward = True @@ -1394,6 +1400,7 @@ def forward(self, inputs, targets): batch_accumulate=self.batch_accumulate, train_dataset_length=len(self.train_loader.dataset), train_dataloader_len=len(self.train_loader), + max_train_batches=self.max_train_batches, ) processing_params = self._get_preprocessing_from_valid_loader() @@ -2014,7 +2021,12 @@ def _validate_epoch(self, context: PhaseContext, silent_mode: bool = False) -> D self._reset_metrics() self.valid_metrics.to(device_config.device) return self.evaluate( - data_loader=self.valid_loader, metrics=self.valid_metrics, evaluation_type=EvaluationType.VALIDATION, epoch=context.epoch, silent_mode=silent_mode + data_loader=self.valid_loader, + metrics=self.valid_metrics, + evaluation_type=EvaluationType.VALIDATION, + epoch=context.epoch, + silent_mode=silent_mode, + max_batches=self.max_valid_batches, ) def _test_epoch(self, data_loader: DataLoader, context: PhaseContext, silent_mode: bool = False, dataset_name: str = "") -> Dict[str, float]: @@ -2047,6 +2059,7 @@ def evaluate( silent_mode: bool = False, metrics_progress_verbose: bool = False, dataset_name: str = "", + max_batches: Optional[int] = None, ) -> Dict[str, float]: """ Evaluates the model on given dataloader and metrics. @@ -2081,7 +2094,11 @@ def evaluate( loss_logging_items_names=self.loss_logging_items_names, ) - with tqdm(data_loader, bar_format="{l_bar}{bar:10}{r_bar}", dynamic_ncols=True, disable=silent_mode) as progress_bar_data_loader: + expected_iterations = len(data_loader) if max_batches is None else max_batches + + with tqdm( + data_loader, total=expected_iterations, bar_format="{l_bar}{bar:10}{r_bar}", dynamic_ncols=True, disable=silent_mode + ) as progress_bar_data_loader: if not silent_mode: # PRINT TITLES @@ -2091,9 +2108,11 @@ def evaluate( if epoch: pbar_start_msg += f" epoch {epoch}" progress_bar_data_loader.set_description(pbar_start_msg) - with torch.no_grad(): for batch_idx, batch_items in enumerate(progress_bar_data_loader): + if evaluation_type == EvaluationType.VALIDATION and self.max_valid_batches is not None and self.max_valid_batches <= batch_idx: + break + batch_items = core_utils.tensor_container_to_device(batch_items, device_config.device, non_blocking=True) inputs, targets, additional_batch_items = sg_trainer_utils.unpack_batch_items(batch_items) @@ -2128,9 +2147,6 @@ def evaluate( progress_bar_data_loader.set_postfix(**pbar_message_dict) - if evaluation_type == EvaluationType.VALIDATION and self.max_valid_batches is not None and self.max_valid_batches - 1 <= batch_idx: - break - logging_values = get_logging_values(loss_avg_meter, metrics, self.criterion) # NEED TO COMPUTE METRICS FOR THE FIRST TIME IF PROGRESS VERBOSITY IS NOT SET if not metrics_progress_verbose: diff --git a/src/super_gradients/training/utils/sg_trainer_utils.py b/src/super_gradients/training/utils/sg_trainer_utils.py index 03f8c957b0..072dbcf8d6 100644 --- a/src/super_gradients/training/utils/sg_trainer_utils.py +++ b/src/super_gradients/training/utils/sg_trainer_utils.py @@ -447,19 +447,33 @@ def get_callable_param_names(obj: callable) -> Tuple[str]: def log_main_training_params( - multi_gpu: MultiGPUMode, num_gpus: int, batch_size: int, batch_accumulate: int, train_dataset_length: int, train_dataloader_len: int + multi_gpu: MultiGPUMode, + num_gpus: int, + batch_size: int, + batch_accumulate: int, + train_dataset_length: int, + train_dataloader_len: int, + max_train_batches: Optional[int] = None, ): """Log training parameters""" + + iterations_per_epoch = int(train_dataloader_len) if max_train_batches is None else max_train_batches + gradients_updates_per_epoch = int(iterations_per_epoch / batch_accumulate) + msg = ( "TRAINING PARAMETERS:\n" f" - Mode: {multi_gpu.name if multi_gpu else 'Single GPU'}\n" f" - Number of GPUs: {num_gpus if 'cuda' in device_config.device else 0:<10} ({torch.cuda.device_count()} available on the machine)\n" - f" - Dataset size: {train_dataset_length:<10} (len(train_set))\n" + f" - Full dataset size: {train_dataset_length:<10} (len(train_set))\n" f" - Batch size per GPU: {batch_size:<10} (batch_size)\n" f" - Batch Accumulate: {batch_accumulate:<10} (batch_accumulate)\n" f" - Total batch size: {num_gpus * batch_size:<10} (num_gpus * batch_size)\n" f" - Effective Batch size: {num_gpus * batch_size * batch_accumulate:<10} (num_gpus * batch_size * batch_accumulate)\n" - f" - Iterations per epoch: {int(train_dataloader_len):<10} (len(train_loader))\n" - f" - Gradient updates per epoch: {int(train_dataloader_len / batch_accumulate):<10} (len(train_loader) / batch_accumulate)\n" + f" - Iterations per epoch: {iterations_per_epoch:<10} (len(train_loader) OR max_train_batches)\n" + f" - Gradient updates per epoch: {gradients_updates_per_epoch:<10} (len(train_loader) OR max_train_batches / batch_accumulate)\n" ) + logger.info(msg) + + if max_train_batches: + logger.warning(f"max_train_batch is set to {max_train_batches}. This limits the number of iterations per epoch and gradient updates per epoch.") From 572776a8345f63ecb0f881079f4583752909c764 Mon Sep 17 00:00:00 2001 From: hakuryuu96 Date: Thu, 19 Oct 2023 15:40:55 +0000 Subject: [PATCH 2/3] Added changing string in accordance which parameter is used (len(loader) of max_batches) --- src/super_gradients/training/utils/sg_trainer_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/super_gradients/training/utils/sg_trainer_utils.py b/src/super_gradients/training/utils/sg_trainer_utils.py index 072dbcf8d6..1d269f60d9 100644 --- a/src/super_gradients/training/utils/sg_trainer_utils.py +++ b/src/super_gradients/training/utils/sg_trainer_utils.py @@ -459,6 +459,7 @@ def log_main_training_params( iterations_per_epoch = int(train_dataloader_len) if max_train_batches is None else max_train_batches gradients_updates_per_epoch = int(iterations_per_epoch / batch_accumulate) + what_used_str = "len(train_loader)" if max_train_batches is None else "max_train_batches" msg = ( "TRAINING PARAMETERS:\n" @@ -469,8 +470,8 @@ def log_main_training_params( f" - Batch Accumulate: {batch_accumulate:<10} (batch_accumulate)\n" f" - Total batch size: {num_gpus * batch_size:<10} (num_gpus * batch_size)\n" f" - Effective Batch size: {num_gpus * batch_size * batch_accumulate:<10} (num_gpus * batch_size * batch_accumulate)\n" - f" - Iterations per epoch: {iterations_per_epoch:<10} (len(train_loader) OR max_train_batches)\n" - f" - Gradient updates per epoch: {gradients_updates_per_epoch:<10} (len(train_loader) OR max_train_batches / batch_accumulate)\n" + f" - Iterations per epoch: {iterations_per_epoch:<10} ({what_used_str})\n" + f" - Gradient updates per epoch: {gradients_updates_per_epoch:<10} ({what_used_str} / batch_accumulate)\n" ) logger.info(msg) From 723d0470fa0c27b50276d392037a1d058e104d71 Mon Sep 17 00:00:00 2001 From: hakuryuu96 Date: Fri, 20 Oct 2023 13:07:18 +0000 Subject: [PATCH 3/3] Replaced stopping condition for the epoch with a smaller one --- src/super_gradients/training/sg_trainer/sg_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index c8f30d465d..c1ca383f93 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -475,7 +475,7 @@ def _train_epoch(self, context: PhaseContext, silent_mode: bool = False) -> tupl context.update_context(loss_avg_meter=loss_avg_meter, metrics_compute_fn=self.train_metrics) for batch_idx, batch_items in enumerate(progress_bar_train_loader): - if self.max_train_batches is not None and self.max_train_batches <= batch_idx: + if expected_iterations <= batch_idx: break batch_items = core_utils.tensor_container_to_device(batch_items, device_config.device, non_blocking=True) @@ -2110,7 +2110,7 @@ def evaluate( progress_bar_data_loader.set_description(pbar_start_msg) with torch.no_grad(): for batch_idx, batch_items in enumerate(progress_bar_data_loader): - if evaluation_type == EvaluationType.VALIDATION and self.max_valid_batches is not None and self.max_valid_batches <= batch_idx: + if evaluation_type == EvaluationType.VALIDATION and expected_iterations <= batch_idx: break batch_items = core_utils.tensor_container_to_device(batch_items, device_config.device, non_blocking=True)